File size: 860 Bytes
6a9fc93
 
178081b
904f07a
 
 
 
6a9fc93
904f07a
 
 
 
 
6a9fc93
904f07a
 
6a9fc93
904f07a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
import torch

class MultilingualChatbot:
    def __init__(self):
        self.model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
        self.tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")

    def generate_response(self, prompt, src_lang):
        self.tokenizer.src_lang = src_lang
        encoded_input = self.tokenizer(prompt, return_tensors="pt")
        generated_tokens = self.model.generate(**encoded_input, max_length=100)
        return self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]

def initialize_chatbot():
    return MultilingualChatbot()

def get_chatbot_response(chatbot, prompt, src_lang):
    return chatbot.generate_response(prompt, src_lang)