File size: 3,174 Bytes
daa0d0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
import gradio as gr
import torch
from transformers import TextStreamer, AutoModelForCausalLM, AutoTokenizer
import spaces

# Define the model configuration
model_config = {
    "model_name": "admincybers2/sentinal",
    "max_seq_length": 1024,
    "dtype": torch.float16,
    "load_in_4bit": True
}

# Hugging Face token
hf_token = os.getenv("HF_TOKEN")

# Load the model when the application starts
loaded_model = None
loaded_tokenizer = None

def load_model():
    global loaded_model, loaded_tokenizer
    if loaded_model is None:
        model = AutoModelForCausalLM.from_pretrained(
            model_config["model_name"],
            torch_dtype=model_config["dtype"],
            device_map="auto",
            use_auth_token=hf_token
        )
        tokenizer = AutoTokenizer.from_pretrained(
            model_config["model_name"],
            use_auth_token=hf_token
        )
        loaded_model = model
        loaded_tokenizer = tokenizer
    return loaded_model, loaded_tokenizer

# Vulnerability prompt template
vulnerability_prompt = """Identify the specific line of code that is vulnerable and describe the type of software vulnerability.
### Vulnerable Line:
{}
### Vulnerability Description:
"""

@spaces.GPU(duration=120)
def predict(prompt):
    model, tokenizer = load_model()
    formatted_prompt = vulnerability_prompt.format(prompt)  # Ensure this matches the correct number of placeholders
    inputs = tokenizer([formatted_prompt], return_tensors="pt").to("cuda")
    text_streamer = TextStreamer(tokenizer)

    output = model.generate(
        **inputs, 
        streamer=text_streamer, 
        use_cache=True, 
        temperature=0.4, 
        top_k=50,  # Default value, considers the top 50 most likely next tokens
        top_p=0.9,  # Nucleus sampling, focuses on the most likely token set
        min_p=0.01,  # Ensures that tokens below this probability are less likely to be selected
        typical_p=0.95,  # Focuses on tokens that are most typical given the context
        repetition_penalty=1.2,  # Penalizes repetitive sequences to improve text diversity
        no_repeat_ngram_size=3,  # Prevents the same 3-gram sequence from repeating
        renormalize_logits=True,  # Ensures logits are normalized after processing
        max_new_tokens=640
    )
    return tokenizer.decode(output[0], skip_special_tokens=True)

theme = gr.themes.Default(
    primary_hue=gr.themes.colors.rose,
    secondary_hue=gr.themes.colors.blue,
    font=gr.themes.GoogleFont("Source Sans Pro")
)

# Pre-load the model
load_model()

with gr.Blocks(theme=theme) as demo:
    prompt = gr.Textbox(lines=5, placeholder="Enter your code snippet or topic here...", label="Prompt")
    generated_text = gr.Textbox(label="Generated Text")
    generate_button = gr.Button("Generate")
    generate_button.click(predict, inputs=[prompt], outputs=generated_text)

    gr.Examples(
        examples=[
            ["$buff = 'A' x 10000;\nopen(myfile, '>>PASS.PK2');\nprint myfile $buff;\nclose(myfile);"]
        ],
        inputs=[prompt]
    )

demo.queue(default_concurrency_limit=10).launch(
    server_name="0.0.0.0",
    allowed_paths=["/"]
)