fschwartzer commited on
Commit
d8f6691
1 Parent(s): b13ea5d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -27
app.py CHANGED
@@ -2,8 +2,7 @@ import streamlit as st
2
  import pandas as pd
3
  import torch
4
  from transformers import pipeline
5
- #from transformers import TapasTokenizer, TapexTokenizer, BartForConditionalGeneration
6
- from transformers import AutoTokenizer, AutoModelForTableQuestionAnswering
7
  import datetime
8
 
9
  #df = pd.read_excel('discrepantes.xlsx', index_col='Unnamed: 0')
@@ -15,30 +14,29 @@ print(table_data.head())
15
  def response(user_question, table_data):
16
  a = datetime.datetime.now()
17
 
18
- model_name = "google/tapas-base-finetuned-wtq"
19
- model = AutoModelForTableQuestionAnswering.from_pretrained(model_name)
20
- tokenizer = AutoTokenizer.from_pretrained(model_name)
21
 
22
- # The query should be passed as a list
23
- encoding = tokenizer(table=table_data, queries=[user_question], padding=True, return_tensors="pt", truncation=True)
24
 
25
- # Instead of using generate, we pass the encoding through the model to get the logits
26
- outputs = model(**encoding)
27
 
28
- # Extract the answer coordinates
29
- predicted_answer_coordinates = outputs.logits.argmax(-1)
30
-
31
- # Decode the answer from the table using the coordinates
32
- answer = tokenizer.convert_logits_to_predictions(
33
- encoding.data,
34
- predicted_answer_coordinates
 
 
35
  )
36
 
37
- # Process the answer into a readable format
38
- answer_text = answer[0][0][0] if len(answer[0]) > 0 else "Não foi possível encontrar uma resposta"
39
 
40
  query_result = {
41
- "Resposta": answer_text
42
  }
43
 
44
  b = datetime.datetime.now()
@@ -46,6 +44,7 @@ def response(user_question, table_data):
46
 
47
  return query_result
48
 
 
49
  # Streamlit interface
50
  st.markdown("""
51
  <div style='display: flex; align-items: center;'>
@@ -65,15 +64,15 @@ user_question = st.text_input("Escreva sua questão aqui:", "")
65
 
66
  if user_question:
67
  # Add person emoji when typing question
68
- st.session_state['history'].append(('👤', user_question))
69
- st.markdown(f"**👤 {user_question}**")
70
 
71
  # Generate the response
72
  bot_response = response(user_question, table_data)
73
 
74
  # Add robot emoji when generating response and align to the right
75
- st.session_state['history'].append(('🤖', bot_response))
76
- st.markdown(f"<div style='text-align: right'>**🤖 {bot_response}**</div>", unsafe_allow_html=True)
77
 
78
  # Clear history button
79
  if st.button("Limpar"):
@@ -81,7 +80,7 @@ if st.button("Limpar"):
81
 
82
  # Display chat history
83
  for sender, message in st.session_state['history']:
84
- if sender == '👤':
85
- st.markdown(f"**👤 {message}**")
86
- elif sender == '🤖':
87
- st.markdown(f"<div style='text-align: right'>**🤖 {message}**</div>", unsafe_allow_html=True)
 
2
  import pandas as pd
3
  import torch
4
  from transformers import pipeline
5
+ from transformers import TapasTokenizer, TapexTokenizer, BartForConditionalGeneration
 
6
  import datetime
7
 
8
  #df = pd.read_excel('discrepantes.xlsx', index_col='Unnamed: 0')
 
14
  def response(user_question, table_data):
15
  a = datetime.datetime.now()
16
 
17
+ model_name = "microsoft/tapex-large-finetuned-wtq"
18
+ model = BartForConditionalGeneration.from_pretrained(model_name)
19
+ tokenizer = TapexTokenizer.from_pretrained(model_name)
20
 
21
+ queries = [user_question]
 
22
 
23
+ encoding = tokenizer(table=table_data, query=queries, padding=True, return_tensors="pt", truncation=True)
 
24
 
25
+ # Experiment with generation parameters
26
+ outputs = model.generate(
27
+ **encoding,
28
+ num_beams=5, # Beam search to generate more diverse responses
29
+ top_k=50, # Top-k sampling for diversity
30
+ top_p=0.95, # Nucleus sampling
31
+ temperature=0.7, # Temperature scaling (if supported by the model)
32
+ max_length=50, # Limit the length of the generated response
33
+ early_stopping=True # Stop generation when an end token is generated
34
  )
35
 
36
+ ans = tokenizer.batch_decode(outputs, skip_special_tokens=True)
 
37
 
38
  query_result = {
39
+ "Resposta": ans[0]
40
  }
41
 
42
  b = datetime.datetime.now()
 
44
 
45
  return query_result
46
 
47
+
48
  # Streamlit interface
49
  st.markdown("""
50
  <div style='display: flex; align-items: center;'>
 
64
 
65
  if user_question:
66
  # Add person emoji when typing question
67
+ st.session_state['history'].append(('??', user_question))
68
+ st.markdown(f"**?? {user_question}**")
69
 
70
  # Generate the response
71
  bot_response = response(user_question, table_data)
72
 
73
  # Add robot emoji when generating response and align to the right
74
+ st.session_state['history'].append(('??', bot_response))
75
+ st.markdown(f"<div style='text-align: right'>**?? {bot_response}**</div>", unsafe_allow_html=True)
76
 
77
  # Clear history button
78
  if st.button("Limpar"):
 
80
 
81
  # Display chat history
82
  for sender, message in st.session_state['history']:
83
+ if sender == '??':
84
+ st.markdown(f"**?? {message}**")
85
+ elif sender == '??':
86
+ st.markdown(f"<div style='text-align: right'>**?? {message}**</div>", unsafe_allow_html=True)