AIdeaText commited on
Commit
904f07a
1 Parent(s): 9633a42

Update modules/chatbot.py

Browse files
Files changed (1) hide show
  1. modules/chatbot.py +13 -53
modules/chatbot.py CHANGED
@@ -1,59 +1,19 @@
1
  from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
2
  import torch
3
 
4
- def initialize_chatbot():
5
- model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
6
- tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
7
- return model, tokenizer
8
-
9
- def get_chatbot_response(model, tokenizer, prompt, src_lang):
10
- tokenizer.src_lang = src_lang
11
- encoded_input = tokenizer(prompt, return_tensors="pt")
12
- generated_tokens = model.generate(**encoded_input, max_length=100)
13
- return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
14
-
15
- def display_chatbot_interface(lang_code):
16
- translations = {
17
- 'es': {
18
- 'title': "AIdeaText - Chatbot Multilingüe",
19
- 'input_placeholder': "Escribe tu mensaje aquí...",
20
- 'send_button': "Enviar",
21
- },
22
- 'en': {
23
- 'title': "AIdeaText - Multilingual Chatbot",
24
- 'input_placeholder': "Type your message here...",
25
- 'send_button': "Send",
26
- },
27
- 'fr': {
28
- 'title': "AIdeaText - Chatbot Multilingue",
29
- 'input_placeholder': "Écrivez votre message ici...",
30
- 'send_button': "Envoyer",
31
- }
32
- }
33
-
34
- t = translations[lang_code]
35
-
36
- st.header(t['title'])
37
 
38
- if 'chatbot' not in st.session_state:
39
- st.session_state.chatbot, st.session_state.tokenizer = initialize_chatbot()
 
 
 
40
 
41
- if 'messages' not in st.session_state:
42
- st.session_state.messages = []
43
-
44
- for message in st.session_state.messages:
45
- with st.chat_message(message["role"]):
46
- st.markdown(message["content"])
47
-
48
- if prompt := st.chat_input(t['input_placeholder']):
49
- st.session_state.messages.append({"role": "user", "content": prompt})
50
- with st.chat_message("user"):
51
- st.markdown(prompt)
52
-
53
- with st.chat_message("assistant"):
54
- response = get_chatbot_response(st.session_state.chatbot, st.session_state.tokenizer, prompt, lang_code)
55
- st.markdown(response)
56
- st.session_state.messages.append({"role": "assistant", "content": response})
57
 
58
- # Guardar la conversación en la base de datos
59
- store_chat_history(st.session_state.username, st.session_state.messages)
 
1
  from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
2
  import torch
3
 
4
+ class MultilingualChatbot:
5
+ def __init__(self):
6
+ self.model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
7
+ self.tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ def generate_response(self, prompt, src_lang):
10
+ self.tokenizer.src_lang = src_lang
11
+ encoded_input = self.tokenizer(prompt, return_tensors="pt")
12
+ generated_tokens = self.model.generate(**encoded_input, max_length=100)
13
+ return self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
14
 
15
+ def initialize_chatbot():
16
+ return MultilingualChatbot()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ def get_chatbot_response(chatbot, prompt, src_lang):
19
+ return chatbot.generate_response(prompt, src_lang)