from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline import torch from detoxify import Detoxify import asyncio from fastapi.concurrency import run_in_threadpool from typing import List class Guardrail: def __init__(self): tokenizer = AutoTokenizer.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection") model = AutoModelForSequenceClassification.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection") self.classifier = pipeline( "text-classification", model=model, tokenizer=tokenizer, truncation=True, max_length=512, device=torch.device("cuda" if torch.cuda.is_available() else "cpu") ) async def guard(self, prompt): return await run_in_threadpool(self.classifier, prompt) def determine_level(self, label, score): if label == "SAFE": return 0, "safe" else: if score > 0.9: return 4, "high" elif score > 0.75: return 3, "medium" elif score > 0.5: return 2, "low" else: return 1, "very low" class TextPrompt(BaseModel): prompt: str class ClassificationResult(BaseModel): label: str score: float level: int severity_label: str class ToxicityResult(BaseModel): toxicity: float severe_toxicity: float obscene: float threat: float insult: float identity_attack: float class TopicBannerClassifier: def __init__(self): self.classifier = pipeline( "zero-shot-classification", model="MoritzLaurer/deberta-v3-large-zeroshot-v2.0", device=torch.device("cuda" if torch.cuda.is_available() else "cpu") ) self.hypothesis_template = "This text is about {}" async def classify(self, text, labels): return await run_in_threadpool( self.classifier, text, labels, hypothesis_template=self.hypothesis_template, multi_label=False ) class TopicBannerRequest(BaseModel): prompt: str labels: List[str] class TopicBannerResult(BaseModel): sequence: str labels: list scores: list app = FastAPI() guardrail = Guardrail() toxicity_classifier = Detoxify('original') topic_banner_classifier = TopicBannerClassifier() @app.post("/api/models/toxicity/classify", response_model=ToxicityResult) async def classify_toxicity(text_prompt: TextPrompt): try: result = await run_in_threadpool(toxicity_classifier.predict, text_prompt.prompt) return { "toxicity": result['toxicity'], "severe_toxicity": result['severe_toxicity'], "obscene": result['obscene'], "threat": result['threat'], "insult": result['insult'], "identity_attack": result['identity_attack'] } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/api/models/PromptInjection/classify", response_model=ClassificationResult) async def classify_text(text_prompt: TextPrompt): try: result = await guardrail.guard(text_prompt.prompt) label = result[0]['label'] score = result[0]['score'] level, severity_label = guardrail.determine_level(label, score) return {"label": label, "score": score, "level": level, "severity_label": severity_label} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/api/models/TopicBanner/classify", response_model=TopicBannerResult) async def classify_topic_banner(request: TopicBannerRequest): try: result = await topic_banner_classifier.classify(request.prompt, request.labels) return { "sequence": result["sequence"], "labels": result["labels"], "scores": result["scores"] } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)