kxx-kkk's picture
Update app.py
4ef7de6 verified
raw
history blame contribute delete
No virus
8.57 kB
import streamlit as st
import re
from transformers import pipeline
from transformers import AutoModelForQuestionAnswering, AutoTokenizer
import tempfile
import pytesseract
import PyPDF2
from pdf2image import convert_from_path
from PIL import Image
st.set_page_config(page_title="Automated Question Answering System") # set page title
# heading
st.markdown("<h2 style='text-align: center;'>Question Answering on Academic Essays</h2>", unsafe_allow_html=True)
# description
st.markdown("<h3 style='text-align: left; color:#F63366; font-size:18px;'><b>What is this project about?<b></h3>", unsafe_allow_html=True)
st.write("This project is to develop a web-based automated question-and-answer system for academic essays using natural language processing (NLP). Users can enter the essay and ask questions about it, and the system will automatically create answers.")
st.write("πŸ‘ Click 'Input Text' or 'Upload File' to start experience the system. ")
# store the model in cache resources to enhance efficiency (ref: https://docs.streamlit.io/library/advanced-features/caching)
@st.cache_resource(show_spinner=False)
def question_model():
# call my model for question answering
with st.spinner(text="Loading question model..."):
model_name = "kxx-kkk/FYP_qa_final"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForQuestionAnswering.from_pretrained(model_name)
question_answerer = pipeline("question-answering", model=model, tokenizer=tokenizer, handle_impossible_answer=True)
print("QA model is dowloaded and ready to use")
return question_answerer
qamodel = question_model()
@st.cache_data(show_spinner=False)
def extract_text(file_path):
text = ""
image_text = ""
with st.spinner(text="Extracting text from file..."):
with open(file_path, "rb") as pdf_file:
pdf_reader = PyPDF2.PdfReader(pdf_file)
num_pages = len(pdf_reader.pages)
for page_number in range(num_pages):
# st.write(f"Page {page_number + 1}")
page = pdf_reader.pages[page_number]
text += page.extract_text()
images = convert_from_path(file_path) # Convert PDF pages to images
for i, image in enumerate(images):
image_text += pytesseract.image_to_string(image)
# text = text + image_text
text = image_text
# remove more than one new line
text = re.sub(r"(?<!\n)\n(?!\n)", " ", text)
return text
# get the answer by passing the context & question to the model
def question_answering(context, question):
with st.spinner(text="Loading question model..."):
question_answerer = qamodel
with st.spinner(text="Getting answer..."):
answer = question_answerer(context=context, question=question)
print(answer)
answer_score = str(answer["score"])
answer = answer["answer"]
if (answer==""):
answer = "CANNOT ANSWER"
# display the result in container
container = st.container(border=True)
container.write("<h5><b>Answer:</b></h5>"+answer+"<p><small>(F1 score: "+answer_score+")</small></p><br>", unsafe_allow_html=True)
# def question_answering(context, question):
# with st.spinner(text="Loading question model..."):
# question_answerer = qamodel
# print("loading QA model...")
# with st.spinner(text="Getting answer..."):
# segment_size = 512
# overlap_size = 32
# text_length = len(context)
# segments = []
# # Split context into segments
# for i in range(0, text_length, segment_size - overlap_size):
# segment_start = i
# segment_end = i + segment_size
# segment = context[segment_start:segment_end]
# segments.append(segment)
# answers = {} # Dictionary to store answers for each segment
# # Get answers for each segment
# for i, segment in enumerate(segments):
# answer = question_answerer(context=segment, question=question)
# answers[i] = answer
# # Find the answer with the highest score
# highest_score = -1
# highest_answer = None
# for segment_index, answer in answers.items():
# print(answer)
# score = answer["score"]
# if score > highest_score:
# highest_score = score
# highest_answer = answer
# if highest_answer is not None:
# answer = highest_answer["answer"]
# if answer == "":
# answer = "CANNOT ANSWER"
# answer_score = str(highest_answer["score"])
# # Display the result in container
# container = st.container(border=True)
# container.write("<h5><b>Answer:</b></h5>" + answer + "<p><small>(F1 score: " + answer_score + ")</small></p><br>",
# unsafe_allow_html=True)
#-------------------- Main Webpage --------------------
# choose the source with different tabs
tab1, tab2 = st.tabs(["Input Text", "Upload File"])
#---------- input text ----------
# if type the text as input
with tab1:
# set the example
sample_question = "What is NLP?"
with open("sample.txt", "r") as text_file:
sample_text = text_file.read()
# Get the initial values of context and question
context = st.session_state.get("contextInput", "")
question = st.session_state.get("questionInput", "")
# Button to try the example
example = st.button("Try with example")
# Update the values if the "Try with example" button is clicked
if example:
context = sample_text
question = sample_question
# Display the text area and text input with the updated or default values
context = st.text_area("Enter the essay below:", value=context, key="contextInput", height=330)
question = st.text_input(label="Enter the question: ", value=question, key="questionInput")
# perform question answering when "get answer" button clicked
button = st.button("Get answer", key="textInput", type="primary")
if button:
if context=="" or question=="":
st.error ("Please enter BOTH the context and the question", icon="🚨")
else:
question_answering(context, question)
# ---------- upload file ----------
# if upload file as input
with tab2:
# provide upload place
uploaded_file = st.file_uploader("Upload essay in PDF format:", type=["pdf"])
# Create a session-level variable to track the uploaded file
if 'file' not in st.session_state:
st.session_state.file = None
# Create a session-level variable to track if text extraction has been done
if 'text_extracted' not in st.session_state:
st.session_state.text_extracted = False
# Get the initial values of context and question
context2 = st.session_state.get("contextInput2", "")
question2 = st.session_state.get("questionInput2", "")
# transfer file to context and allow ask question, then perform question answering
if uploaded_file is not None:
if st.session_state.file != uploaded_file:
# Update the session state with the new file
st.session_state.file = uploaded_file
st.session_state.text_extracted = False
if not st.session_state.text_extracted:
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
temp_file.write(uploaded_file.read()) # Save uploaded file to a temporary path
raw_text = extract_text(temp_file.name)
context2 = raw_text
st.session_state.text_extracted = True
question2 = st.text_input(label="Enter your question",value=question2, key="questionInput2")
context2 = st.text_area("Your essay context: ", value=context2, height=330, key="contextInput2")
# perform question answering when "get answer" button clicked
button2 = st.button("Get answer", key="fileInput", type="primary")
if button2:
if context2=="" or question2=="":
st.error ("Please enter BOTH the context and the question", icon="🚨")
else:
question_answering(context2, question2)
st.markdown("<p style='text-align:center;'>Β© 20069913D HUI Man Ki - Final Year Project</p>", unsafe_allow_html=True)
st.markdown("<br><br><br><br><br>", unsafe_allow_html=True)