Chris4K commited on
Commit
17b66cc
1 Parent(s): e7e4d1a

Update sentiment_analysis.py

Browse files
Files changed (1) hide show
  1. sentiment_analysis.py +17 -17
sentiment_analysis.py CHANGED
@@ -8,12 +8,8 @@ class SentimentAnalysisTool(Tool):
8
  description = "This tool analyses the sentiment of a given text input."
9
 
10
  inputs = ["text"] # Adding an empty list for inputs
11
-
12
  outputs = ["json"]
13
 
14
- def __call__(self, inputs: str):
15
- return predicto(str)
16
-
17
  model_id_1 = "nlptown/bert-base-multilingual-uncased-sentiment"
18
  model_id_2 = "microsoft/deberta-xlarge-mnli"
19
  model_id_3 = "distilbert-base-uncased-finetuned-sst-2-english"
@@ -22,25 +18,29 @@ class SentimentAnalysisTool(Tool):
22
  model_id_6 = "sbcBI/sentiment_analysis_model"
23
  model_id_7 = "models/oliverguhr/german-sentiment-bert"
24
 
25
- def parse_output(output_json):
26
- list_pred=[]
 
 
 
27
  for i in range(len(output_json[0])):
28
  label = output_json[0][i]['label']
29
  score = output_json[0][i]['score']
30
  list_pred.append((label, score))
31
  return list_pred
32
 
33
- def get_prediction(model_id):
34
  classifier = pipeline("text-classification", model=model_id, return_all_scores=True)
35
-
36
- def predicto(review):
37
- classifier = get_prediction(model_id_3)
38
- prediction = classifier(review)
39
- print(prediction)
40
- return parse_output(prediction)
41
-
42
 
 
 
43
 
44
-
45
-
46
-
 
8
  description = "This tool analyses the sentiment of a given text input."
9
 
10
  inputs = ["text"] # Adding an empty list for inputs
 
11
  outputs = ["json"]
12
 
 
 
 
13
  model_id_1 = "nlptown/bert-base-multilingual-uncased-sentiment"
14
  model_id_2 = "microsoft/deberta-xlarge-mnli"
15
  model_id_3 = "distilbert-base-uncased-finetuned-sst-2-english"
 
18
  model_id_6 = "sbcBI/sentiment_analysis_model"
19
  model_id_7 = "models/oliverguhr/german-sentiment-bert"
20
 
21
+ def __call__(self, inputs: str):
22
+ return self.predicto(inputs)
23
+
24
+ def parse_output(self, output_json):
25
+ list_pred = []
26
  for i in range(len(output_json[0])):
27
  label = output_json[0][i]['label']
28
  score = output_json[0][i]['score']
29
  list_pred.append((label, score))
30
  return list_pred
31
 
32
+ def get_prediction(self, model_id):
33
  classifier = pipeline("text-classification", model=model_id, return_all_scores=True)
34
+ return classifier
35
+
36
+ def predicto(self, review):
37
+ classifier = self.get_prediction(self.model_id_3)
38
+ prediction = classifier(review)
39
+ print(prediction)
40
+ return self.parse_output(prediction)
41
 
42
+ # Create an instance of the SentimentAnalysisTool class
43
+ sentiment_analysis_tool = SentimentAnalysisTool()
44
 
45
+ # Create the Gradio interface
46
+ gr.Interface(fn=sentiment_analysis_tool, inputs=sentiment_analysis_tool.inputs, outputs=sentiment_analysis_tool.outputs).launch()