import gradio as gr import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline class TwitterEmotionClassifier: def __init__(self, model_name: str, model_type: str): self.is_gpu = False self.model_type = model_type device = torch.device("cuda") if self.is_gpu else torch.device("cpu") model = AutoModelForSequenceClassification.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) model.to(device) model.eval() self.bertweet = pipeline( "text-classification", model=model, tokenizer=tokenizer, device=self.is_gpu - 1, ) self.deberta = None self.emotions = { "LABEL_0": "sadness", "LABEL_1": "joy", "LABEL_2": "love", "LABEL_3": "anger", "LABEL_4": "fear", "LABEL_5": "surprise", } def get_model(self, model_type: str): if self.model_type == "bertweet" and model_type == self.model_type: return self.bertweet elif model_type == "deberta": if self.deberta: return self.deberta model = AutoModelForSequenceClassification.from_pretrained( "Emanuel/twitter-emotion-deberta-v3-base" ) tokenizer = AutoTokenizer.from_pretrained( "Emanuel/twitter-emotion-deberta-v3-base" ) self.deberta = pipeline( "text-classification", model=model, tokenizer=tokenizer, device=self.is_gpu - 1, ) return self.deberta def predict(self, twitter: str, model_type: str): classifier = self.get_model(model_type) preds = classifier(twitter, return_all_scores=True) if preds: pred = preds[0] res = { "Sadness 😢": pred[0]["score"], "Joy 😂": pred[1]["score"], "Love 💛": pred[2]["score"], "Anger 😠": pred[3]["score"], "Fear 😱": pred[4]["score"], "Surprise 😮": pred[5]["score"], } return res return None def main(): model = TwitterEmotionClassifier("Emanuel/bertweet-emotion-base", "bertweet") interFace = gr.Interface( fn=model.predict, inputs=[ gr.inputs.Textbox( placeholder="What's happenning?", label="Tweet content", lines=5 ), gr.inputs.Radio(["bertweet", "deberta"], label="Model"), ], outputs=gr.outputs.Label(num_top_classes=6, label="Emotions of this tweet is "), verbose=True, examples=[ ["This GOT show just remember LOTR times!", "bertweet"], [ "Man, can't believe that my 30 days of training just got a NaN loss", "bertweet", ], ["I couldn't see 3 Tom Hollands coming...", "bertweet"], [ "There is nothing better than a soul-warming coffee in the morning", "bertweet", ], ["I fear the vanishing gradient", "deberta"], ], title="Emotion classification 🤖", description="", theme="huggingface", ) interFace.launch() if __name__ == "__main__": main()