from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline import torch class Guardrail: def __init__(self): tokenizer = AutoTokenizer.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection") model = AutoModelForSequenceClassification.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection") self.classifier = pipeline( "text-classification", model=model, tokenizer=tokenizer, truncation=True, max_length=512, device=torch.device("cuda" if torch.cuda.is_available() else "cpu") ) def guard(self, prompt): return self.classifier(prompt) class TextPrompt(BaseModel): prompt: str app = FastAPI() guardrail = Guardrail() @app.post("/classify/") def classify_text(text_prompt: TextPrompt): try: result = guardrail.guard(text_prompt.prompt) return result except Exception as e: raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)