fschwartzer commited on
Commit
b13ea5d
1 Parent(s): cc0f753

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -12
app.py CHANGED
@@ -22,21 +22,23 @@ def response(user_question, table_data):
22
  # The query should be passed as a list
23
  encoding = tokenizer(table=table_data, queries=[user_question], padding=True, return_tensors="pt", truncation=True)
24
 
25
- # Experiment with generation parameters
26
- outputs = model.generate(
27
- **encoding,
28
- num_beams=5, # Beam search to generate more diverse responses
29
- top_k=50, # Top-k sampling for diversity
30
- top_p=0.95, # Nucleus sampling
31
- temperature=0.7, # Temperature scaling (if supported by the model)
32
- max_length=50, # Limit the length of the generated response
33
- early_stopping=True # Stop generation when an end token is generated
 
34
  )
35
 
36
- ans = tokenizer.batch_decode(outputs, skip_special_tokens=True)
 
37
 
38
  query_result = {
39
- "Resposta": ans[0]
40
  }
41
 
42
  b = datetime.datetime.now()
@@ -44,7 +46,6 @@ def response(user_question, table_data):
44
 
45
  return query_result
46
 
47
-
48
  # Streamlit interface
49
  st.markdown("""
50
  <div style='display: flex; align-items: center;'>
 
22
  # The query should be passed as a list
23
  encoding = tokenizer(table=table_data, queries=[user_question], padding=True, return_tensors="pt", truncation=True)
24
 
25
+ # Instead of using generate, we pass the encoding through the model to get the logits
26
+ outputs = model(**encoding)
27
+
28
+ # Extract the answer coordinates
29
+ predicted_answer_coordinates = outputs.logits.argmax(-1)
30
+
31
+ # Decode the answer from the table using the coordinates
32
+ answer = tokenizer.convert_logits_to_predictions(
33
+ encoding.data,
34
+ predicted_answer_coordinates
35
  )
36
 
37
+ # Process the answer into a readable format
38
+ answer_text = answer[0][0][0] if len(answer[0]) > 0 else "Não foi possível encontrar uma resposta"
39
 
40
  query_result = {
41
+ "Resposta": answer_text
42
  }
43
 
44
  b = datetime.datetime.now()
 
46
 
47
  return query_result
48
 
 
49
  # Streamlit interface
50
  st.markdown("""
51
  <div style='display: flex; align-items: center;'>