shripadbhat commited on
Commit
115644a
1 Parent(s): 7054cc7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -0
app.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pysbd
3
+ from transformers import pipeline
4
+ from sentence_transformers import CrossEncoder
5
+ from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline
6
+
7
+ model_name = "MaRiOrOsSi/t5-base-finetuned-question-answering"
8
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+ model = AutoModelWithLMHead.from_pretrained(model_name)
10
+
11
+ #from transformers import pipeline
12
+
13
+ #text2text_generator = pipeline("text2text-generation", model = "gpt2")
14
+
15
+ sentence_segmenter = pysbd.Segmenter(language='en',clean=False)
16
+ passage_retreival_model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
17
+ qa_model = pipeline("question-answering",'a-ware/bart-squadv2')
18
+
19
+ def fetch_answers(question, document ):
20
+ document_paragraphs = document.splitlines()
21
+ query_paragraph_list = [(question, para) for para in document_paragraphs if len(para.strip()) > 0 ]
22
+
23
+ scores = passage_retreival_model.predict(query_paragraph_list)
24
+ top_5_indices = scores.argsort()[-5:]
25
+ top_5_query_paragraph_list = [query_paragraph_list[i] for i in top_5_indices ]
26
+ top_5_query_paragraph_list.reverse()
27
+
28
+ top_5_query_paragraph_answer_list = ""
29
+ count = 1
30
+ for query, passage in top_5_query_paragraph_list:
31
+ passage_sentences = sentence_segmenter.segment(passage)
32
+ answer = qa_model(question = query, context = passage)['answer']
33
+ evidence_sentence = ""
34
+ for i in range(len(passage_sentences)):
35
+ if answer.startswith('.') or answer.startswith(':'):
36
+ answer = answer[1:].strip()
37
+ if answer in passage_sentences[i]:
38
+ evidence_sentence = evidence_sentence + " " + passage_sentences[i]
39
+
40
+
41
+ model_input = f"question: {query} context: {evidence_sentence}"
42
+ #output_answer = text2text_generator(model_input)[0]['generated_text']
43
+ encoded_input = tokenizer([model_input],
44
+ return_tensors='pt',
45
+ max_length=512,
46
+ truncation=True)
47
+
48
+ output = model.generate(input_ids = encoded_input.input_ids,
49
+ attention_mask = encoded_input.attention_mask)
50
+ output_answer = tokenizer.decode(output[0], skip_special_tokens=True)
51
+
52
+ result_str = "# ANSWER "+str(count)+": "+ output_answer +"\n"
53
+ result_str = result_str + "REFERENCE: "+ evidence_sentence + "\n\n"
54
+ top_5_query_paragraph_answer_list += result_str
55
+ count+=1
56
+
57
+ return top_5_query_paragraph_answer_list
58
+
59
+
60
+ query = st.text_area("Query", "", height=25)
61
+ document = st.text_area("Document Text", "", height=100)
62
+
63
+ if st.button("Get Answers"):
64
+ st.markdown(fetch_answers(query, document))