AISimplyExplained commited on
Commit
4fa87d4
1 Parent(s): 5c19b8d

added labels as in input

Browse files
Files changed (1) hide show
  1. main.py +9 -5
main.py CHANGED
@@ -5,6 +5,7 @@ import torch
5
  from detoxify import Detoxify
6
  import asyncio
7
  from fastapi.concurrency import run_in_threadpool
 
8
 
9
  class Guardrail:
10
  def __init__(self):
@@ -60,17 +61,20 @@ class TopicBannerClassifier:
60
  device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
61
  )
62
  self.hypothesis_template = "This text is about {}"
63
- self.classes_verbalized = ["politics", "economy", "entertainment", "environment"]
64
 
65
- async def classify(self, text):
66
  return await run_in_threadpool(
67
  self.classifier,
68
  text,
69
- self.classes_verbalized,
70
  hypothesis_template=self.hypothesis_template,
71
  multi_label=False
72
  )
73
 
 
 
 
 
74
  class TopicBannerResult(BaseModel):
75
  sequence: str
76
  labels: list
@@ -108,9 +112,9 @@ async def classify_text(text_prompt: TextPrompt):
108
  raise HTTPException(status_code=500, detail=str(e))
109
 
110
  @app.post("/api/models/TopicBanner/classify", response_model=TopicBannerResult)
111
- async def classify_topic_banner(text_prompt: TextPrompt):
112
  try:
113
- result = await topic_banner_classifier.classify(text_prompt.prompt)
114
  return {
115
  "sequence": result["sequence"],
116
  "labels": result["labels"],
 
5
  from detoxify import Detoxify
6
  import asyncio
7
  from fastapi.concurrency import run_in_threadpool
8
+ from typing import List
9
 
10
  class Guardrail:
11
  def __init__(self):
 
61
  device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
62
  )
63
  self.hypothesis_template = "This text is about {}"
 
64
 
65
+ async def classify(self, text, labels):
66
  return await run_in_threadpool(
67
  self.classifier,
68
  text,
69
+ labels,
70
  hypothesis_template=self.hypothesis_template,
71
  multi_label=False
72
  )
73
 
74
+ class TopicBannerRequest(BaseModel):
75
+ prompt: str
76
+ labels: List[str]
77
+
78
  class TopicBannerResult(BaseModel):
79
  sequence: str
80
  labels: list
 
112
  raise HTTPException(status_code=500, detail=str(e))
113
 
114
  @app.post("/api/models/TopicBanner/classify", response_model=TopicBannerResult)
115
+ async def classify_topic_banner(request: TopicBannerRequest):
116
  try:
117
+ result = await topic_banner_classifier.classify(request.prompt, request.labels)
118
  return {
119
  "sequence": result["sequence"],
120
  "labels": result["labels"],