AISimplyExplained commited on
Commit
f228a1c
1 Parent(s): 9d0c88a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -0
app.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
4
+ import torch
5
+
6
+
7
+ class Guardrail:
8
+ def __init__(self):
9
+ tokenizer = AutoTokenizer.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection")
10
+ model = AutoModelForSequenceClassification.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection")
11
+
12
+ self.classifier = pipeline(
13
+ "text-classification",
14
+ model=model,
15
+ tokenizer=tokenizer,
16
+ truncation=True,
17
+ max_length=512,
18
+ device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ )
20
+
21
+ def guard(self, prompt):
22
+ return self.classifier(prompt)
23
+
24
+
25
+ class TextPrompt(BaseModel):
26
+ prompt: str
27
+
28
+
29
+ app = FastAPI()
30
+ guardrail = Guardrail()
31
+
32
+
33
+ @app.post("/classify/")
34
+ def classify_text(text_prompt: TextPrompt):
35
+ try:
36
+ result = guardrail.guard(text_prompt.prompt)
37
+ return result
38
+ except Exception as e:
39
+ raise HTTPException(status_code=500, detail=str(e))
40
+
41
+
42
+ if __name__ == "__main__":
43
+ import uvicorn
44
+
45
+ uvicorn.run(app, host="0.0.0.0", port=8000)