Deepak7376 commited on
Commit
317f434
β€’
1 Parent(s): 8515a17

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -8
app.py CHANGED
@@ -12,11 +12,7 @@ from streamlit_chat import message
12
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
13
  import torch
14
 
15
- # Constants
16
- CHECKPOINT = "MBZUAI/LaMini-T5-738M"
17
- TOKENIZER = AutoTokenizer.from_pretrained(CHECKPOINT)
18
- BASE_MODEL = AutoModelForSeq2SeqLM.from_pretrained(CHECKPOINT, device_map=torch.device('cpu'), torch_dtype=torch.float32)
19
-
20
 
21
  def process_answer(instruction, qa_chain):
22
  response = ''
@@ -50,7 +46,11 @@ def data_ingestion():
50
 
51
 
52
  @st.cache_resource
53
- def initialize_qa_chain():
 
 
 
 
54
  pipe = pipeline(
55
  'text2text-generation',
56
  model=BASE_MODEL,
@@ -101,7 +101,10 @@ def display_conversation(history):
101
 
102
 
103
  def main():
104
- st.set_page_config(layout="wide")
 
 
 
105
  st.markdown("<h1 style='text-align: center; color: blue;'>Custom PDF Chatbot πŸ¦œπŸ“„ </h1>", unsafe_allow_html=True)
106
  st.markdown("<h2 style='text-align: center; color:red;'>Upload your PDF, and Ask Questions πŸ‘‡</h2>", unsafe_allow_html=True)
107
 
@@ -125,6 +128,7 @@ def main():
125
  pdf_view = display_pdf(filepath)
126
 
127
  with col2:
 
128
  with st.spinner('Embeddings are in process...'):
129
  ingested_data = data_ingestion()
130
  st.success('Embeddings are created successfully!')
@@ -140,7 +144,7 @@ def main():
140
 
141
  # Search the database for a response based on user input and update session state
142
  if user_input:
143
- answer = process_answer({'query': user_input}, initialize_qa_chain())
144
  st.session_state["past"].append(user_input)
145
  response = answer
146
  st.session_state["generated"].append(response)
 
12
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
13
  import torch
14
 
15
+ st.set_page_config(layout="wide")
 
 
 
 
16
 
17
  def process_answer(instruction, qa_chain):
18
  response = ''
 
46
 
47
 
48
  @st.cache_resource
49
+ def initialize_qa_chain(selected_model):
50
+ # Constants
51
+ CHECKPOINT = selected_model
52
+ TOKENIZER = AutoTokenizer.from_pretrained(CHECKPOINT)
53
+ BASE_MODEL = AutoModelForSeq2SeqLM.from_pretrained(CHECKPOINT, device_map=torch.device('cpu'), torch_dtype=torch.float32)
54
  pipe = pipeline(
55
  'text2text-generation',
56
  model=BASE_MODEL,
 
101
 
102
 
103
  def main():
104
+ # Add a sidebar for model selection
105
+ model_options = ["MBZUAI/LaMini-T5-738M", "google/flan-t5-base", "google/flan-t5-small"]
106
+ selected_model = st.sidebar.selectbox("Select Model", model_options)
107
+
108
  st.markdown("<h1 style='text-align: center; color: blue;'>Custom PDF Chatbot πŸ¦œπŸ“„ </h1>", unsafe_allow_html=True)
109
  st.markdown("<h2 style='text-align: center; color:red;'>Upload your PDF, and Ask Questions πŸ‘‡</h2>", unsafe_allow_html=True)
110
 
 
128
  pdf_view = display_pdf(filepath)
129
 
130
  with col2:
131
+ st.success(f'model selected successfully: {selected_model}')
132
  with st.spinner('Embeddings are in process...'):
133
  ingested_data = data_ingestion()
134
  st.success('Embeddings are created successfully!')
 
144
 
145
  # Search the database for a response based on user input and update session state
146
  if user_input:
147
+ answer = process_answer({'query': user_input}, initialize_qa_chain(selected_model))
148
  st.session_state["past"].append(user_input)
149
  response = answer
150
  st.session_state["generated"].append(response)