deberta_api / app.py
AISimplyExplained's picture
Create app.py
f228a1c verified
raw
history blame
No virus
1.2 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
import torch
class Guardrail:
def __init__(self):
tokenizer = AutoTokenizer.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection")
model = AutoModelForSequenceClassification.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection")
self.classifier = pipeline(
"text-classification",
model=model,
tokenizer=tokenizer,
truncation=True,
max_length=512,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
def guard(self, prompt):
return self.classifier(prompt)
class TextPrompt(BaseModel):
prompt: str
app = FastAPI()
guardrail = Guardrail()
@app.post("/classify/")
def classify_text(text_prompt: TextPrompt):
try:
result = guardrail.guard(text_prompt.prompt)
return result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)