admincybers2 commited on
Commit
daa0d0e
1 Parent(s): b201b50

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -0
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import torch
4
+ from transformers import TextStreamer, AutoModelForCausalLM, AutoTokenizer
5
+ import spaces
6
+
7
+ # Define the model configuration
8
+ model_config = {
9
+ "model_name": "admincybers2/sentinal",
10
+ "max_seq_length": 1024,
11
+ "dtype": torch.float16,
12
+ "load_in_4bit": True
13
+ }
14
+
15
+ # Hugging Face token
16
+ hf_token = os.getenv("HF_TOKEN")
17
+
18
+ # Load the model when the application starts
19
+ loaded_model = None
20
+ loaded_tokenizer = None
21
+
22
+ def load_model():
23
+ global loaded_model, loaded_tokenizer
24
+ if loaded_model is None:
25
+ model = AutoModelForCausalLM.from_pretrained(
26
+ model_config["model_name"],
27
+ torch_dtype=model_config["dtype"],
28
+ device_map="auto",
29
+ use_auth_token=hf_token
30
+ )
31
+ tokenizer = AutoTokenizer.from_pretrained(
32
+ model_config["model_name"],
33
+ use_auth_token=hf_token
34
+ )
35
+ loaded_model = model
36
+ loaded_tokenizer = tokenizer
37
+ return loaded_model, loaded_tokenizer
38
+
39
+ # Vulnerability prompt template
40
+ vulnerability_prompt = """Identify the specific line of code that is vulnerable and describe the type of software vulnerability.
41
+ ### Vulnerable Line:
42
+ {}
43
+ ### Vulnerability Description:
44
+ """
45
+
46
+ @spaces.GPU(duration=120)
47
+ def predict(prompt):
48
+ model, tokenizer = load_model()
49
+ formatted_prompt = vulnerability_prompt.format(prompt) # Ensure this matches the correct number of placeholders
50
+ inputs = tokenizer([formatted_prompt], return_tensors="pt").to("cuda")
51
+ text_streamer = TextStreamer(tokenizer)
52
+
53
+ output = model.generate(
54
+ **inputs,
55
+ streamer=text_streamer,
56
+ use_cache=True,
57
+ temperature=0.4,
58
+ top_k=50, # Default value, considers the top 50 most likely next tokens
59
+ top_p=0.9, # Nucleus sampling, focuses on the most likely token set
60
+ min_p=0.01, # Ensures that tokens below this probability are less likely to be selected
61
+ typical_p=0.95, # Focuses on tokens that are most typical given the context
62
+ repetition_penalty=1.2, # Penalizes repetitive sequences to improve text diversity
63
+ no_repeat_ngram_size=3, # Prevents the same 3-gram sequence from repeating
64
+ renormalize_logits=True, # Ensures logits are normalized after processing
65
+ max_new_tokens=640
66
+ )
67
+ return tokenizer.decode(output[0], skip_special_tokens=True)
68
+
69
+ theme = gr.themes.Default(
70
+ primary_hue=gr.themes.colors.rose,
71
+ secondary_hue=gr.themes.colors.blue,
72
+ font=gr.themes.GoogleFont("Source Sans Pro")
73
+ )
74
+
75
+ # Pre-load the model
76
+ load_model()
77
+
78
+ with gr.Blocks(theme=theme) as demo:
79
+ prompt = gr.Textbox(lines=5, placeholder="Enter your code snippet or topic here...", label="Prompt")
80
+ generated_text = gr.Textbox(label="Generated Text")
81
+ generate_button = gr.Button("Generate")
82
+ generate_button.click(predict, inputs=[prompt], outputs=generated_text)
83
+
84
+ gr.Examples(
85
+ examples=[
86
+ ["$buff = 'A' x 10000;\nopen(myfile, '>>PASS.PK2');\nprint myfile $buff;\nclose(myfile);"]
87
+ ],
88
+ inputs=[prompt]
89
+ )
90
+
91
+ demo.queue(default_concurrency_limit=10).launch(
92
+ server_name="0.0.0.0",
93
+ allowed_paths=["/"]
94
+ )