Deepak7376 commited on
Commit
8515a17
1 Parent(s): e5c1f95

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -102
app.py CHANGED
@@ -1,32 +1,35 @@
1
- import streamlit as st
2
- import os
3
  import base64
4
- import time
5
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
- from transformers import pipeline
7
- import torch
8
- import textwrap
9
- from langchain.document_loaders import PyPDFLoader, DirectoryLoader, PDFMinerLoader
10
- from langchain.text_splitter import RecursiveCharacterTextSplitter
11
- from langchain.embeddings import SentenceTransformerEmbeddings
12
- from langchain.vectorstores import FAISS
13
  from langchain.llms import HuggingFacePipeline
14
- from langchain.chains import RetrievalQA
 
15
  from streamlit_chat import message
 
 
16
 
17
- st.set_page_config(layout="wide")
 
 
 
18
 
19
- device = torch.device('cpu')
20
 
21
- checkpoint = "MBZUAI/LaMini-T5-738M"
22
- tokenizer = AutoTokenizer.from_pretrained(checkpoint)
23
- base_model = AutoModelForSeq2SeqLM.from_pretrained(
24
- checkpoint,
25
- device_map=device,
26
- torch_dtype=torch.float32
27
- )
 
 
 
 
28
 
29
- persist_directory = "db"
30
 
31
  @st.cache_resource
32
  def data_ingestion():
@@ -35,36 +38,35 @@ def data_ingestion():
35
  if file.endswith(".pdf"):
36
  print(file)
37
  loader = PDFMinerLoader(os.path.join(root, file))
38
-
39
  documents = loader.load()
40
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=500)
41
- splits = text_splitter.split_documents(document)
42
-
43
- #create embeddings here
44
  embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
45
  vectordb = FAISS.from_documents(splits, embeddings)
46
  vectordb.save_local("faiss_index")
47
 
48
-
49
  @st.cache_resource
50
- def qa_llm():
51
  pipe = pipeline(
52
  'text2text-generation',
53
- model = base_model,
54
- tokenizer = tokenizer,
55
- max_length = 256,
56
- do_sample = True,
57
- temperature = 0.3,
58
- top_p= 0.95,
59
- device=device
60
  )
61
-
62
  llm = HuggingFacePipeline(pipeline=pipe)
63
  embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
64
-
65
  vectordb = FAISS.load_local("faiss_index", embeddings)
66
- retriever = db.as_retriever()
67
-
68
  # Build a QA chain
69
  qa_chain = RetrievalQA.from_chain_type(
70
  llm=llm,
@@ -73,44 +75,35 @@ def qa_llm():
73
  )
74
  return qa_chain
75
 
76
- def process_answer(instruction):
77
- response = ''
78
- instruction = instruction
79
- qa_chain = qa_llm()
80
-
81
- generated_text = qa_chain.run(instruction)
82
- return generated_text
83
-
84
- def get_file_size(file):
85
- file.seek(0, os.SEEK_END)
86
- file_size = file.tell()
87
- file.seek(0)
88
- return file_size
89
 
90
  @st.cache_data
91
- #function to display the PDF of a given file
92
- def displayPDF(file):
93
- # Opening file from file path
94
- with open(file, "rb") as f:
95
- base64_pdf = base64.b64encode(f.read()).decode('utf-8')
 
96
 
97
- # Embedding PDF in HTML
98
- pdf_display = F'<iframe src="data:application/pdf;base64,{base64_pdf}" width="100%" height="600" type="application/pdf"></iframe>'
 
 
 
 
 
99
 
100
- # Displaying File
101
- st.markdown(pdf_display, unsafe_allow_html=True)
102
 
103
  # Display conversation history using Streamlit messages
104
  def display_conversation(history):
105
  for i in range(len(history["generated"])):
106
- message(history["past"][i], is_user=True, key=str(i) + "_user")
107
- message(history["generated"][i],key=str(i))
108
 
109
- def main():
110
- st.markdown("<h1 style='text-align: center; color: blue;'>Chat with your PDF 🦜📄 </h1>", unsafe_allow_html=True)
111
- st.markdown("<h3 style='text-align: center; color: grey;'>Built by <a href='https://github.com/AIAnytime'>AI Anytime with ❤️ </a></h3>", unsafe_allow_html=True)
112
 
113
- st.markdown("<h2 style='text-align: center; color:red;'>Upload your PDF 👇</h2>", unsafe_allow_html=True)
 
 
 
114
 
115
  uploaded_file = st.file_uploader("", type=["pdf"])
116
 
@@ -119,43 +112,46 @@ def main():
119
  "Filename": uploaded_file.name,
120
  "File size": get_file_size(uploaded_file)
121
  }
122
- filepath = "docs/"+uploaded_file.name
123
- with open(filepath, "wb") as temp_file:
 
124
  temp_file.write(uploaded_file.read())
125
 
126
- col1, col2= st.columns([1,2])
127
- with col1:
128
- st.markdown("<h4 style color:black;'>File details</h4>", unsafe_allow_html=True)
129
- st.json(file_details)
130
- st.markdown("<h4 style color:black;'>File preview</h4>", unsafe_allow_html=True)
131
- pdf_view = displayPDF(filepath)
132
-
133
- with col2:
134
- with st.spinner('Embeddings are in process...'):
135
- ingested_data = data_ingestion()
136
- st.success('Embeddings are created successfully!')
137
- st.markdown("<h4 style color:black;'>Chat Here</h4>", unsafe_allow_html=True)
138
-
139
-
140
- user_input = st.text_input("", key="input")
141
-
142
- # Initialize session state for generated responses and past messages
143
- if "generated" not in st.session_state:
144
- st.session_state["generated"] = ["I am ready to help you"]
145
- if "past" not in st.session_state:
146
- st.session_state["past"] = ["Hey there!"]
147
-
148
- # Search the database for a response based on user input and update session state
149
- if user_input:
150
- answer = process_answer({'query': user_input})
151
- st.session_state["past"].append(user_input)
152
- response = answer
153
- st.session_state["generated"].append(response)
154
-
155
- # Display conversation history using Streamlit messages
156
- if st.session_state["generated"]:
157
- display_conversation(st.session_state)
158
-
 
 
159
 
160
  if __name__ == "__main__":
161
  main()
 
 
 
1
  import base64
2
+ import os
3
+
4
+ import streamlit as st
5
+ from langchain.chains import RetrievalQA
6
+ from langchain.document_loaders import PDFMinerLoader
7
+ from langchain.embeddings import SentenceTransformerEmbeddings
 
 
 
8
  from langchain.llms import HuggingFacePipeline
9
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
10
+ from langchain.vectorstores import FAISS
11
  from streamlit_chat import message
12
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
13
+ import torch
14
 
15
+ # Constants
16
+ CHECKPOINT = "MBZUAI/LaMini-T5-738M"
17
+ TOKENIZER = AutoTokenizer.from_pretrained(CHECKPOINT)
18
+ BASE_MODEL = AutoModelForSeq2SeqLM.from_pretrained(CHECKPOINT, device_map=torch.device('cpu'), torch_dtype=torch.float32)
19
 
 
20
 
21
+ def process_answer(instruction, qa_chain):
22
+ response = ''
23
+ generated_text = qa_chain.run(instruction)
24
+ return generated_text
25
+
26
+
27
+ def get_file_size(file):
28
+ file.seek(0, os.SEEK_END)
29
+ file_size = file.tell()
30
+ file.seek(0)
31
+ return file_size
32
 
 
33
 
34
  @st.cache_resource
35
  def data_ingestion():
 
38
  if file.endswith(".pdf"):
39
  print(file)
40
  loader = PDFMinerLoader(os.path.join(root, file))
41
+
42
  documents = loader.load()
43
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=500)
44
+ splits = text_splitter.split_documents(documents)
45
+
46
+ # create embeddings here
47
  embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
48
  vectordb = FAISS.from_documents(splits, embeddings)
49
  vectordb.save_local("faiss_index")
50
 
51
+
52
  @st.cache_resource
53
+ def initialize_qa_chain():
54
  pipe = pipeline(
55
  'text2text-generation',
56
+ model=BASE_MODEL,
57
+ tokenizer=TOKENIZER,
58
+ max_length=256,
59
+ do_sample=True,
60
+ temperature=0.3,
61
+ top_p=0.95,
62
+ # device=torch.device('cpu')
63
  )
64
+
65
  llm = HuggingFacePipeline(pipeline=pipe)
66
  embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
67
+
68
  vectordb = FAISS.load_local("faiss_index", embeddings)
69
+
 
70
  # Build a QA chain
71
  qa_chain = RetrievalQA.from_chain_type(
72
  llm=llm,
 
75
  )
76
  return qa_chain
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  @st.cache_data
80
+ # function to display the PDF of a given file
81
+ def display_pdf(file):
82
+ try:
83
+ # Opening file from file path
84
+ with open(file, "rb") as f:
85
+ base64_pdf = base64.b64encode(f.read()).decode('utf-8')
86
 
87
+ # Embedding PDF in HTML
88
+ pdf_display = f'<iframe src="data:application/pdf;base64,{base64_pdf}" width="100%" height="600" type="application/pdf"></iframe>'
89
+
90
+ # Displaying File
91
+ st.markdown(pdf_display, unsafe_allow_html=True)
92
+ except Exception as e:
93
+ st.error(f"An error occurred while displaying the PDF: {e}")
94
 
 
 
95
 
96
  # Display conversation history using Streamlit messages
97
  def display_conversation(history):
98
  for i in range(len(history["generated"])):
99
+ message(history["past"][i], is_user=True, key=f"{i}_user")
100
+ message(history["generated"][i], key=str(i))
101
 
 
 
 
102
 
103
+ def main():
104
+ st.set_page_config(layout="wide")
105
+ st.markdown("<h1 style='text-align: center; color: blue;'>Custom PDF Chatbot 🦜📄 </h1>", unsafe_allow_html=True)
106
+ st.markdown("<h2 style='text-align: center; color:red;'>Upload your PDF, and Ask Questions 👇</h2>", unsafe_allow_html=True)
107
 
108
  uploaded_file = st.file_uploader("", type=["pdf"])
109
 
 
112
  "Filename": uploaded_file.name,
113
  "File size": get_file_size(uploaded_file)
114
  }
115
+ filepath = os.path.join("docs", uploaded_file.name)
116
+ try:
117
+ with open(filepath, "wb") as temp_file:
118
  temp_file.write(uploaded_file.read())
119
 
120
+ col1, col2 = st.columns([1, 2])
121
+ with col1:
122
+ st.markdown("<h4 style color:black;'>File details</h4>", unsafe_allow_html=True)
123
+ st.json(file_details)
124
+ st.markdown("<h4 style color:black;'>File preview</h4>", unsafe_allow_html=True)
125
+ pdf_view = display_pdf(filepath)
126
+
127
+ with col2:
128
+ with st.spinner('Embeddings are in process...'):
129
+ ingested_data = data_ingestion()
130
+ st.success('Embeddings are created successfully!')
131
+ st.markdown("<h4 style color:black;'>Chat Here</h4>", unsafe_allow_html=True)
132
+
133
+ user_input = st.text_input("", key="input")
134
+
135
+ # Initialize session state for generated responses and past messages
136
+ if "generated" not in st.session_state:
137
+ st.session_state["generated"] = ["I am ready to help you"]
138
+ if "past" not in st.session_state:
139
+ st.session_state["past"] = ["Hey there!"]
140
+
141
+ # Search the database for a response based on user input and update session state
142
+ if user_input:
143
+ answer = process_answer({'query': user_input}, initialize_qa_chain())
144
+ st.session_state["past"].append(user_input)
145
+ response = answer
146
+ st.session_state["generated"].append(response)
147
+
148
+ # Display conversation history using Streamlit messages
149
+ if st.session_state["generated"]:
150
+ display_conversation(st.session_state)
151
+
152
+ except Exception as e:
153
+ st.error(f"An error occurred: {e}")
154
+
155
 
156
  if __name__ == "__main__":
157
  main()