import os import gradio as gr import torch from transformers import TextStreamer, AutoModelForCausalLM, AutoTokenizer import spaces # Define the model configuration model_config = { "model_name": "admincybers2/sentinal", "max_seq_length": 1024, "dtype": torch.float16, "load_in_4bit": True } # Hugging Face token hf_token = os.getenv("HF_TOKEN") # Load the model when the application starts loaded_model = None loaded_tokenizer = None def load_model(): global loaded_model, loaded_tokenizer if loaded_model is None: model = AutoModelForCausalLM.from_pretrained( model_config["model_name"], torch_dtype=model_config["dtype"], device_map="auto", use_auth_token=hf_token ) tokenizer = AutoTokenizer.from_pretrained( model_config["model_name"], use_auth_token=hf_token ) loaded_model = model loaded_tokenizer = tokenizer return loaded_model, loaded_tokenizer # Vulnerability prompt template vulnerability_prompt = """Identify the specific line of code that is vulnerable and describe the type of software vulnerability. ### Vulnerable Line: {} ### Vulnerability Description: """ @spaces.GPU(duration=120) def predict(prompt): model, tokenizer = load_model() formatted_prompt = vulnerability_prompt.format(prompt) # Ensure this matches the correct number of placeholders inputs = tokenizer([formatted_prompt], return_tensors="pt").to("cuda") text_streamer = TextStreamer(tokenizer) output = model.generate( **inputs, streamer=text_streamer, use_cache=True, temperature=0.4, top_k=50, # Default value, considers the top 50 most likely next tokens top_p=0.9, # Nucleus sampling, focuses on the most likely token set min_p=0.01, # Ensures that tokens below this probability are less likely to be selected typical_p=0.95, # Focuses on tokens that are most typical given the context repetition_penalty=1.2, # Penalizes repetitive sequences to improve text diversity no_repeat_ngram_size=3, # Prevents the same 3-gram sequence from repeating renormalize_logits=True, # Ensures logits are normalized after processing max_new_tokens=640 ) return tokenizer.decode(output[0], skip_special_tokens=True) theme = gr.themes.Default( primary_hue=gr.themes.colors.rose, secondary_hue=gr.themes.colors.blue, font=gr.themes.GoogleFont("Source Sans Pro") ) # Pre-load the model load_model() with gr.Blocks(theme=theme) as demo: prompt = gr.Textbox(lines=5, placeholder="Enter your code snippet or topic here...", label="Prompt") generated_text = gr.Textbox(label="Generated Text") generate_button = gr.Button("Generate") generate_button.click(predict, inputs=[prompt], outputs=generated_text) gr.Examples( examples=[ ["$buff = 'A' x 10000;\nopen(myfile, '>>PASS.PK2');\nprint myfile $buff;\nclose(myfile);"] ], inputs=[prompt] ) demo.queue(default_concurrency_limit=10).launch( server_name="0.0.0.0", allowed_paths=["/"] )