cyberapi_1 / app.py
admincybers2's picture
Create app.py
daa0d0e verified
raw
history blame
No virus
3.17 kB
import os
import gradio as gr
import torch
from transformers import TextStreamer, AutoModelForCausalLM, AutoTokenizer
import spaces
# Define the model configuration
model_config = {
"model_name": "admincybers2/sentinal",
"max_seq_length": 1024,
"dtype": torch.float16,
"load_in_4bit": True
}
# Hugging Face token
hf_token = os.getenv("HF_TOKEN")
# Load the model when the application starts
loaded_model = None
loaded_tokenizer = None
def load_model():
global loaded_model, loaded_tokenizer
if loaded_model is None:
model = AutoModelForCausalLM.from_pretrained(
model_config["model_name"],
torch_dtype=model_config["dtype"],
device_map="auto",
use_auth_token=hf_token
)
tokenizer = AutoTokenizer.from_pretrained(
model_config["model_name"],
use_auth_token=hf_token
)
loaded_model = model
loaded_tokenizer = tokenizer
return loaded_model, loaded_tokenizer
# Vulnerability prompt template
vulnerability_prompt = """Identify the specific line of code that is vulnerable and describe the type of software vulnerability.
### Vulnerable Line:
{}
### Vulnerability Description:
"""
@spaces.GPU(duration=120)
def predict(prompt):
model, tokenizer = load_model()
formatted_prompt = vulnerability_prompt.format(prompt) # Ensure this matches the correct number of placeholders
inputs = tokenizer([formatted_prompt], return_tensors="pt").to("cuda")
text_streamer = TextStreamer(tokenizer)
output = model.generate(
**inputs,
streamer=text_streamer,
use_cache=True,
temperature=0.4,
top_k=50, # Default value, considers the top 50 most likely next tokens
top_p=0.9, # Nucleus sampling, focuses on the most likely token set
min_p=0.01, # Ensures that tokens below this probability are less likely to be selected
typical_p=0.95, # Focuses on tokens that are most typical given the context
repetition_penalty=1.2, # Penalizes repetitive sequences to improve text diversity
no_repeat_ngram_size=3, # Prevents the same 3-gram sequence from repeating
renormalize_logits=True, # Ensures logits are normalized after processing
max_new_tokens=640
)
return tokenizer.decode(output[0], skip_special_tokens=True)
theme = gr.themes.Default(
primary_hue=gr.themes.colors.rose,
secondary_hue=gr.themes.colors.blue,
font=gr.themes.GoogleFont("Source Sans Pro")
)
# Pre-load the model
load_model()
with gr.Blocks(theme=theme) as demo:
prompt = gr.Textbox(lines=5, placeholder="Enter your code snippet or topic here...", label="Prompt")
generated_text = gr.Textbox(label="Generated Text")
generate_button = gr.Button("Generate")
generate_button.click(predict, inputs=[prompt], outputs=generated_text)
gr.Examples(
examples=[
["$buff = 'A' x 10000;\nopen(myfile, '>>PASS.PK2');\nprint myfile $buff;\nclose(myfile);"]
],
inputs=[prompt]
)
demo.queue(default_concurrency_limit=10).launch(
server_name="0.0.0.0",
allowed_paths=["/"]
)