tuyendragon commited on
Commit
6bb8996
1 Parent(s): de8099e

Upload 4 files

Browse files
Files changed (4) hide show
  1. ingest.py +159 -0
  2. localGPT_UI.py +119 -0
  3. run_localGPT.py +246 -0
  4. run_localGPT_API.py +173 -0
ingest.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
4
+
5
+ import click
6
+ import torch
7
+ from langchain.docstore.document import Document
8
+ from langchain.embeddings import HuggingFaceInstructEmbeddings
9
+ from langchain.text_splitter import Language, RecursiveCharacterTextSplitter
10
+ from langchain.vectorstores import Chroma
11
+
12
+ from constants import (
13
+ CHROMA_SETTINGS,
14
+ DOCUMENT_MAP,
15
+ EMBEDDING_MODEL_NAME,
16
+ INGEST_THREADS,
17
+ PERSIST_DIRECTORY,
18
+ SOURCE_DIRECTORY,
19
+ )
20
+
21
+
22
+ def load_single_document(file_path: str) -> Document:
23
+ # Loads a single document from a file path
24
+ file_extension = os.path.splitext(file_path)[1]
25
+ loader_class = DOCUMENT_MAP.get(file_extension)
26
+ if loader_class:
27
+ loader = loader_class(file_path)
28
+ else:
29
+ raise ValueError("Document type is undefined")
30
+ return loader.load()[0]
31
+
32
+
33
+ def load_document_batch(filepaths):
34
+ logging.info("Loading document batch")
35
+ # create a thread pool
36
+ with ThreadPoolExecutor(len(filepaths)) as exe:
37
+ # load files
38
+ futures = [exe.submit(load_single_document, name) for name in filepaths]
39
+ # collect data
40
+ data_list = [future.result() for future in futures]
41
+ # return data and file paths
42
+ return (data_list, filepaths)
43
+
44
+
45
+ def load_documents(source_dir: str) -> list[Document]:
46
+ # Loads all documents from the source documents directory
47
+ all_files = os.listdir(source_dir)
48
+ paths = []
49
+ for file_path in all_files:
50
+ file_extension = os.path.splitext(file_path)[1]
51
+ source_file_path = os.path.join(source_dir, file_path)
52
+ if file_extension in DOCUMENT_MAP.keys():
53
+ paths.append(source_file_path)
54
+
55
+ # Have at least one worker and at most INGEST_THREADS workers
56
+ n_workers = min(INGEST_THREADS, max(len(paths), 1))
57
+ chunksize = round(len(paths) / n_workers)
58
+ docs = []
59
+ with ProcessPoolExecutor(n_workers) as executor:
60
+ futures = []
61
+ # split the load operations into chunks
62
+ for i in range(0, len(paths), chunksize):
63
+ # select a chunk of filenames
64
+ filepaths = paths[i : (i + chunksize)]
65
+ # submit the task
66
+ future = executor.submit(load_document_batch, filepaths)
67
+ futures.append(future)
68
+ # process all results
69
+ for future in as_completed(futures):
70
+ # open the file and load the data
71
+ contents, _ = future.result()
72
+ docs.extend(contents)
73
+
74
+ return docs
75
+
76
+
77
+ def split_documents(documents: list[Document]) -> tuple[list[Document], list[Document]]:
78
+ # Splits documents for correct Text Splitter
79
+ text_docs, python_docs = [], []
80
+ for doc in documents:
81
+ file_extension = os.path.splitext(doc.metadata["source"])[1]
82
+ if file_extension == ".py":
83
+ python_docs.append(doc)
84
+ else:
85
+ text_docs.append(doc)
86
+
87
+ return text_docs, python_docs
88
+
89
+
90
+ @click.command()
91
+ @click.option(
92
+ "--device_type",
93
+ default="cuda" if torch.cuda.is_available() else "cpu",
94
+ type=click.Choice(
95
+ [
96
+ "cpu",
97
+ "cuda",
98
+ "ipu",
99
+ "xpu",
100
+ "mkldnn",
101
+ "opengl",
102
+ "opencl",
103
+ "ideep",
104
+ "hip",
105
+ "ve",
106
+ "fpga",
107
+ "ort",
108
+ "xla",
109
+ "lazy",
110
+ "vulkan",
111
+ "mps",
112
+ "meta",
113
+ "hpu",
114
+ "mtia",
115
+ ],
116
+ ),
117
+ help="Device to run on. (Default is cuda)",
118
+ )
119
+ def main(device_type):
120
+ # Load documents and split in chunks
121
+ logging.info(f"Loading documents from {SOURCE_DIRECTORY}")
122
+ documents = load_documents(SOURCE_DIRECTORY)
123
+ text_documents, python_documents = split_documents(documents)
124
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
125
+ python_splitter = RecursiveCharacterTextSplitter.from_language(
126
+ language=Language.PYTHON, chunk_size=880, chunk_overlap=200
127
+ )
128
+ texts = text_splitter.split_documents(text_documents)
129
+ texts.extend(python_splitter.split_documents(python_documents))
130
+ logging.info(f"Loaded {len(documents)} documents from {SOURCE_DIRECTORY}")
131
+ logging.info(f"Split into {len(texts)} chunks of text")
132
+
133
+ # Create embeddings
134
+ embeddings = HuggingFaceInstructEmbeddings(
135
+ model_name=EMBEDDING_MODEL_NAME,
136
+ model_kwargs={"device": device_type},
137
+ )
138
+ # change the embedding type here if you are running into issues.
139
+ # These are much smaller embeddings and will work for most appications
140
+ # If you use HuggingFaceEmbeddings, make sure to also use the same in the
141
+ # run_localGPT.py file.
142
+
143
+ # embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
144
+
145
+ db = Chroma.from_documents(
146
+ texts,
147
+ embeddings,
148
+ persist_directory=PERSIST_DIRECTORY,
149
+ client_settings=CHROMA_SETTINGS,
150
+ )
151
+ db.persist()
152
+ db = None
153
+
154
+
155
+ if __name__ == "__main__":
156
+ logging.basicConfig(
157
+ format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)s - %(message)s", level=logging.INFO
158
+ )
159
+ main()
localGPT_UI.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import subprocess
3
+ import streamlit as st
4
+ from run_localGPT import load_model
5
+ from langchain.vectorstores import Chroma
6
+ from constants import CHROMA_SETTINGS, EMBEDDING_MODEL_NAME, PERSIST_DIRECTORY, MODEL_ID, MODEL_BASENAME
7
+ from langchain.embeddings import HuggingFaceInstructEmbeddings
8
+ from langchain.chains import RetrievalQA
9
+ from streamlit_extras.add_vertical_space import add_vertical_space
10
+ from langchain.prompts import PromptTemplate
11
+ from langchain.memory import ConversationBufferMemory
12
+
13
+
14
+
15
+ def model_memory():
16
+ # Adding history to the model.
17
+ template = """Use the following pieces of context to answer the question at the end. If you don't know the answer,\
18
+ just say that you don't know, don't try to make up an answer.
19
+
20
+ {context}
21
+
22
+ {history}
23
+ Question: {question}
24
+ Helpful Answer:"""
25
+
26
+ prompt = PromptTemplate(input_variables=["history", "context", "question"], template=template)
27
+ memory = ConversationBufferMemory(input_key="question", memory_key="history")
28
+
29
+ return prompt, memory
30
+
31
+ # Sidebar contents
32
+ with st.sidebar:
33
+ st.title('🤗💬 Converse with your Data')
34
+ st.markdown('''
35
+ ## About
36
+ This app is an LLM-powered chatbot built using:
37
+ - [Streamlit](https://streamlit.io/)
38
+ - [LangChain](https://python.langchain.com/)
39
+ - [LocalGPT](https://github.com/PromtEngineer/localGPT)
40
+
41
+ ''')
42
+ add_vertical_space(5)
43
+ st.write('Made with ❤️ by [Prompt Engineer](https://youtube.com/@engineerprompt)')
44
+
45
+
46
+ DEVICE_TYPE = "cuda" if torch.cuda.is_available() else "cpu"
47
+
48
+
49
+
50
+ if "result" not in st.session_state:
51
+ # Run the document ingestion process.
52
+ run_langest_commands = ["python", "ingest.py"]
53
+ run_langest_commands.append("--device_type")
54
+ run_langest_commands.append(DEVICE_TYPE)
55
+
56
+ result = subprocess.run(run_langest_commands, capture_output=True)
57
+ st.session_state.result = result
58
+
59
+ # Define the retreiver
60
+ # load the vectorstore
61
+ if "EMBEDDINGS" not in st.session_state:
62
+ EMBEDDINGS = HuggingFaceInstructEmbeddings(model_name=EMBEDDING_MODEL_NAME, model_kwargs={"device": DEVICE_TYPE})
63
+ st.session_state.EMBEDDINGS = EMBEDDINGS
64
+
65
+ if "DB" not in st.session_state:
66
+ DB = Chroma(
67
+ persist_directory=PERSIST_DIRECTORY,
68
+ embedding_function=st.session_state.EMBEDDINGS,
69
+ client_settings=CHROMA_SETTINGS,
70
+ )
71
+ st.session_state.DB = DB
72
+
73
+ if "RETRIEVER" not in st.session_state:
74
+ RETRIEVER = DB.as_retriever()
75
+ st.session_state.RETRIEVER = RETRIEVER
76
+
77
+ if "LLM" not in st.session_state:
78
+ LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME)
79
+ st.session_state["LLM"] = LLM
80
+
81
+
82
+
83
+
84
+ if "QA" not in st.session_state:
85
+
86
+ prompt, memory = model_memory()
87
+
88
+ QA = RetrievalQA.from_chain_type(
89
+ llm=LLM,
90
+ chain_type="stuff",
91
+ retriever=RETRIEVER,
92
+ return_source_documents=True,
93
+ chain_type_kwargs={"prompt": prompt, "memory": memory},
94
+ )
95
+ st.session_state["QA"] = QA
96
+
97
+ st.title('LocalGPT App 💬')
98
+ # Create a text input box for the user
99
+ prompt = st.text_input('Input your prompt here')
100
+ # while True:
101
+
102
+ # If the user hits enter
103
+ if prompt:
104
+ # Then pass the prompt to the LLM
105
+ response = st.session_state["QA"](prompt)
106
+ answer, docs = response["result"], response["source_documents"]
107
+ # ...and write it out to the screen
108
+ st.write(answer)
109
+
110
+ # With a streamlit expander
111
+ with st.expander('Document Similarity Search'):
112
+ # Find the relevant pages
113
+ search = st.session_state.DB.similarity_search_with_score(prompt)
114
+ # Write out the first
115
+ for i, doc in enumerate(search):
116
+ # print(doc)
117
+ st.write(f"Source Document # {i+1} : {doc[0].metadata['source'].split('/')[-1]}")
118
+ st.write(doc[0].page_content)
119
+ st.write("--------------------------------")
run_localGPT.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import click
4
+ import torch
5
+ from auto_gptq import AutoGPTQForCausalLM
6
+ from huggingface_hub import hf_hub_download
7
+ from langchain.chains import RetrievalQA
8
+ from langchain.embeddings import HuggingFaceInstructEmbeddings
9
+ from langchain.llms import HuggingFacePipeline, LlamaCpp
10
+ from langchain.memory import ConversationBufferMemory
11
+ from langchain.prompts import PromptTemplate
12
+
13
+ # from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
14
+ from langchain.vectorstores import Chroma
15
+ from transformers import (
16
+ AutoModelForCausalLM,
17
+ AutoTokenizer,
18
+ GenerationConfig,
19
+ LlamaForCausalLM,
20
+ LlamaTokenizer,
21
+ pipeline,
22
+ )
23
+
24
+ from constants import CHROMA_SETTINGS, EMBEDDING_MODEL_NAME, PERSIST_DIRECTORY, MODEL_ID, MODEL_BASENAME
25
+
26
+
27
+ def load_model(device_type, model_id, model_basename=None):
28
+ """
29
+ Select a model for text generation using the HuggingFace library.
30
+ If you are running this for the first time, it will download a model for you.
31
+ subsequent runs will use the model from the disk.
32
+
33
+ Args:
34
+ device_type (str): Type of device to use, e.g., "cuda" for GPU or "cpu" for CPU.
35
+ model_id (str): Identifier of the model to load from HuggingFace's model hub.
36
+ model_basename (str, optional): Basename of the model if using quantized models.
37
+ Defaults to None.
38
+
39
+ Returns:
40
+ HuggingFacePipeline: A pipeline object for text generation using the loaded model.
41
+
42
+ Raises:
43
+ ValueError: If an unsupported model or device type is provided.
44
+ """
45
+ logging.info(f"Loading Model: {model_id}, on: {device_type}")
46
+ logging.info("This action can take a few minutes!")
47
+
48
+ if model_basename is not None:
49
+ if ".ggml" in model_basename:
50
+ logging.info("Using Llamacpp for GGML quantized models")
51
+ model_path = hf_hub_download(repo_id=model_id, filename=model_basename)
52
+ max_ctx_size = 2048
53
+ kwargs = {
54
+ "model_path": model_path,
55
+ "n_ctx": max_ctx_size,
56
+ "max_tokens": max_ctx_size,
57
+ }
58
+ if device_type.lower() == "mps":
59
+ kwargs["n_gpu_layers"] = 1000
60
+ if device_type.lower() == "cuda":
61
+ kwargs["n_gpu_layers"] = 1000
62
+ kwargs["n_batch"] = max_ctx_size
63
+ return LlamaCpp(**kwargs)
64
+
65
+ else:
66
+ # The code supports all huggingface models that ends with GPTQ and have some variation
67
+ # of .no-act.order or .safetensors in their HF repo.
68
+ logging.info("Using AutoGPTQForCausalLM for quantized models")
69
+
70
+ if ".safetensors" in model_basename:
71
+ # Remove the ".safetensors" ending if present
72
+ model_basename = model_basename.replace(".safetensors", "")
73
+
74
+ tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
75
+ logging.info("Tokenizer loaded")
76
+
77
+ model = AutoGPTQForCausalLM.from_quantized(
78
+ model_id,
79
+ model_basename=model_basename,
80
+ use_safetensors=True,
81
+ trust_remote_code=True,
82
+ device="cuda:0",
83
+ use_triton=False,
84
+ quantize_config=None,
85
+ )
86
+ elif (
87
+ device_type.lower() == "cuda"
88
+ ): # The code supports all huggingface models that ends with -HF or which have a .bin
89
+ # file in their HF repo.
90
+ logging.info("Using AutoModelForCausalLM for full models")
91
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
92
+ logging.info("Tokenizer loaded")
93
+
94
+ model = AutoModelForCausalLM.from_pretrained(
95
+ model_id,
96
+ device_map="auto",
97
+ torch_dtype=torch.float16,
98
+ low_cpu_mem_usage=True,
99
+ trust_remote_code=True,
100
+ # max_memory={0: "15GB"} # Uncomment this line with you encounter CUDA out of memory errors
101
+ )
102
+ model.tie_weights()
103
+ else:
104
+ logging.info("Using LlamaTokenizer")
105
+ tokenizer = LlamaTokenizer.from_pretrained(model_id)
106
+ model = LlamaForCausalLM.from_pretrained(model_id)
107
+
108
+ # Load configuration from the model to avoid warnings
109
+ generation_config = GenerationConfig.from_pretrained(model_id)
110
+ # see here for details:
111
+ # https://huggingface.co/docs/transformers/
112
+ # main_classes/text_generation#transformers.GenerationConfig.from_pretrained.returns
113
+
114
+ # Create a pipeline for text generation
115
+ pipe = pipeline(
116
+ "text-generation",
117
+ model=model,
118
+ tokenizer=tokenizer,
119
+ max_length=2048,
120
+ temperature=0,
121
+ top_p=0.95,
122
+ repetition_penalty=1.15,
123
+ generation_config=generation_config,
124
+ )
125
+
126
+ local_llm = HuggingFacePipeline(pipeline=pipe)
127
+ logging.info("Local LLM Loaded")
128
+
129
+ return local_llm
130
+
131
+
132
+ # chose device typ to run on as well as to show source documents.
133
+ @click.command()
134
+ @click.option(
135
+ "--device_type",
136
+ default="cuda" if torch.cuda.is_available() else "cpu",
137
+ type=click.Choice(
138
+ [
139
+ "cpu",
140
+ "cuda",
141
+ "ipu",
142
+ "xpu",
143
+ "mkldnn",
144
+ "opengl",
145
+ "opencl",
146
+ "ideep",
147
+ "hip",
148
+ "ve",
149
+ "fpga",
150
+ "ort",
151
+ "xla",
152
+ "lazy",
153
+ "vulkan",
154
+ "mps",
155
+ "meta",
156
+ "hpu",
157
+ "mtia",
158
+ ],
159
+ ),
160
+ help="Device to run on. (Default is cuda)",
161
+ )
162
+ @click.option(
163
+ "--show_sources",
164
+ "-s",
165
+ is_flag=True,
166
+ help="Show sources along with answers (Default is False)",
167
+ )
168
+ def main(device_type, show_sources):
169
+ """
170
+ This function implements the information retrieval task.
171
+
172
+
173
+ 1. Loads an embedding model, can be HuggingFaceInstructEmbeddings or HuggingFaceEmbeddings
174
+ 2. Loads the existing vectorestore that was created by inget.py
175
+ 3. Loads the local LLM using load_model function - You can now set different LLMs.
176
+ 4. Setup the Question Answer retreival chain.
177
+ 5. Question answers.
178
+ """
179
+
180
+ logging.info(f"Running on: {device_type}")
181
+ logging.info(f"Display Source Documents set to: {show_sources}")
182
+
183
+ embeddings = HuggingFaceInstructEmbeddings(model_name=EMBEDDING_MODEL_NAME, model_kwargs={"device": device_type})
184
+
185
+ # uncomment the following line if you used HuggingFaceEmbeddings in the ingest.py
186
+ # embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
187
+
188
+ # load the vectorstore
189
+ db = Chroma(
190
+ persist_directory=PERSIST_DIRECTORY,
191
+ embedding_function=embeddings,
192
+ client_settings=CHROMA_SETTINGS,
193
+ )
194
+ retriever = db.as_retriever()
195
+
196
+
197
+ template = """Use the following pieces of context to answer the question at the end. If you don't know the answer,\
198
+ just say that you don't know, don't try to make up an answer.
199
+
200
+ {context}
201
+
202
+ {history}
203
+ Question: {question}
204
+ Helpful Answer:"""
205
+
206
+ prompt = PromptTemplate(input_variables=["history", "context", "question"], template=template)
207
+ memory = ConversationBufferMemory(input_key="question", memory_key="history")
208
+
209
+ llm = load_model(device_type, model_id=MODEL_ID, model_basename=MODEL_BASENAME)
210
+
211
+ qa = RetrievalQA.from_chain_type(
212
+ llm=llm,
213
+ chain_type="stuff",
214
+ retriever=retriever,
215
+ return_source_documents=True,
216
+ chain_type_kwargs={"prompt": prompt, "memory": memory},
217
+ )
218
+ # Interactive questions and answers
219
+ while True:
220
+ query = input("\nEnter a query: ")
221
+ if query == "exit":
222
+ break
223
+ # Get the answer from the chain
224
+ res = qa(query)
225
+ answer, docs = res["result"], res["source_documents"]
226
+
227
+ # Print the result
228
+ print("\n\n> Question:")
229
+ print(query)
230
+ print("\n> Answer:")
231
+ print(answer)
232
+
233
+ if show_sources: # this is a flag that you can set to disable showing answers.
234
+ # # Print the relevant sources used for the answer
235
+ print("----------------------------------SOURCE DOCUMENTS---------------------------")
236
+ for document in docs:
237
+ print("\n> " + document.metadata["source"] + ":")
238
+ print(document.page_content)
239
+ print("----------------------------------SOURCE DOCUMENTS---------------------------")
240
+
241
+
242
+ if __name__ == "__main__":
243
+ logging.basicConfig(
244
+ format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)s - %(message)s", level=logging.INFO
245
+ )
246
+ main()
run_localGPT_API.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import shutil
4
+ import subprocess
5
+
6
+ import torch
7
+ from auto_gptq import AutoGPTQForCausalLM
8
+ from flask import Flask, jsonify, request
9
+ from langchain.chains import RetrievalQA
10
+ from langchain.embeddings import HuggingFaceInstructEmbeddings
11
+
12
+ # from langchain.embeddings import HuggingFaceEmbeddings
13
+ from langchain.llms import HuggingFacePipeline
14
+ from run_localGPT import load_model
15
+
16
+ # from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
17
+ from langchain.vectorstores import Chroma
18
+ from transformers import (
19
+ AutoModelForCausalLM,
20
+ AutoTokenizer,
21
+ GenerationConfig,
22
+ LlamaForCausalLM,
23
+ LlamaTokenizer,
24
+ pipeline,
25
+ )
26
+ from werkzeug.utils import secure_filename
27
+
28
+ from constants import CHROMA_SETTINGS, EMBEDDING_MODEL_NAME, PERSIST_DIRECTORY, MODEL_ID, MODEL_BASENAME
29
+
30
+ DEVICE_TYPE = "cuda" if torch.cuda.is_available() else "cpu"
31
+ SHOW_SOURCES = True
32
+ logging.info(f"Running on: {DEVICE_TYPE}")
33
+ logging.info(f"Display Source Documents set to: {SHOW_SOURCES}")
34
+
35
+ EMBEDDINGS = HuggingFaceInstructEmbeddings(model_name=EMBEDDING_MODEL_NAME, model_kwargs={"device": DEVICE_TYPE})
36
+
37
+ # uncomment the following line if you used HuggingFaceEmbeddings in the ingest.py
38
+ # EMBEDDINGS = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
39
+ if os.path.exists(PERSIST_DIRECTORY):
40
+ try:
41
+ shutil.rmtree(PERSIST_DIRECTORY)
42
+ except OSError as e:
43
+ print(f"Error: {e.filename} - {e.strerror}.")
44
+ else:
45
+ print("The directory does not exist")
46
+
47
+ run_langest_commands = ["python", "ingest.py"]
48
+ if DEVICE_TYPE == "cpu":
49
+ run_langest_commands.append("--device_type")
50
+ run_langest_commands.append(DEVICE_TYPE)
51
+
52
+ result = subprocess.run(run_langest_commands, capture_output=True)
53
+ if result.returncode != 0:
54
+ raise FileNotFoundError(
55
+ "No files were found inside SOURCE_DOCUMENTS, please put a starter file inside before starting the API!"
56
+ )
57
+
58
+ # load the vectorstore
59
+ DB = Chroma(
60
+ persist_directory=PERSIST_DIRECTORY,
61
+ embedding_function=EMBEDDINGS,
62
+ client_settings=CHROMA_SETTINGS,
63
+ )
64
+
65
+ RETRIEVER = DB.as_retriever()
66
+
67
+ LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME)
68
+
69
+ QA = RetrievalQA.from_chain_type(
70
+ llm=LLM, chain_type="stuff", retriever=RETRIEVER, return_source_documents=SHOW_SOURCES
71
+ )
72
+
73
+ app = Flask(__name__)
74
+
75
+
76
+ @app.route("/api/delete_source", methods=["GET"])
77
+ def delete_source_route():
78
+ folder_name = "SOURCE_DOCUMENTS"
79
+
80
+ if os.path.exists(folder_name):
81
+ shutil.rmtree(folder_name)
82
+
83
+ os.makedirs(folder_name)
84
+
85
+ return jsonify({"message": f"Folder '{folder_name}' successfully deleted and recreated."})
86
+
87
+
88
+ @app.route("/api/save_document", methods=["GET", "POST"])
89
+ def save_document_route():
90
+ if "document" not in request.files:
91
+ return "No document part", 400
92
+ file = request.files["document"]
93
+ if file.filename == "":
94
+ return "No selected file", 400
95
+ if file:
96
+ filename = secure_filename(file.filename)
97
+ folder_path = "SOURCE_DOCUMENTS"
98
+ if not os.path.exists(folder_path):
99
+ os.makedirs(folder_path)
100
+ file_path = os.path.join(folder_path, filename)
101
+ file.save(file_path)
102
+ return "File saved successfully", 200
103
+
104
+
105
+ @app.route("/api/run_ingest", methods=["GET"])
106
+ def run_ingest_route():
107
+ global DB
108
+ global RETRIEVER
109
+ global QA
110
+ try:
111
+ if os.path.exists(PERSIST_DIRECTORY):
112
+ try:
113
+ shutil.rmtree(PERSIST_DIRECTORY)
114
+ except OSError as e:
115
+ print(f"Error: {e.filename} - {e.strerror}.")
116
+ else:
117
+ print("The directory does not exist")
118
+
119
+ run_langest_commands = ["python", "ingest.py"]
120
+ if DEVICE_TYPE == "cpu":
121
+ run_langest_commands.append("--device_type")
122
+ run_langest_commands.append(DEVICE_TYPE)
123
+
124
+ result = subprocess.run(run_langest_commands, capture_output=True)
125
+ if result.returncode != 0:
126
+ return "Script execution failed: {}".format(result.stderr.decode("utf-8")), 500
127
+ # load the vectorstore
128
+ DB = Chroma(
129
+ persist_directory=PERSIST_DIRECTORY,
130
+ embedding_function=EMBEDDINGS,
131
+ client_settings=CHROMA_SETTINGS,
132
+ )
133
+ RETRIEVER = DB.as_retriever()
134
+
135
+ QA = RetrievalQA.from_chain_type(
136
+ llm=LLM, chain_type="stuff", retriever=RETRIEVER, return_source_documents=SHOW_SOURCES
137
+ )
138
+ return "Script executed successfully: {}".format(result.stdout.decode("utf-8")), 200
139
+ except Exception as e:
140
+ return f"Error occurred: {str(e)}", 500
141
+
142
+
143
+ @app.route("/api/prompt_route", methods=["GET", "POST"])
144
+ def prompt_route():
145
+ global QA
146
+ user_prompt = request.form.get("user_prompt")
147
+ if user_prompt:
148
+ # print(f'User Prompt: {user_prompt}')
149
+ # Get the answer from the chain
150
+ res = QA(user_prompt)
151
+ answer, docs = res["result"], res["source_documents"]
152
+
153
+ prompt_response_dict = {
154
+ "Prompt": user_prompt,
155
+ "Answer": answer,
156
+ }
157
+
158
+ prompt_response_dict["Sources"] = []
159
+ for document in docs:
160
+ prompt_response_dict["Sources"].append(
161
+ (os.path.basename(str(document.metadata["source"])), str(document.page_content))
162
+ )
163
+
164
+ return jsonify(prompt_response_dict), 200
165
+ else:
166
+ return "No user prompt received", 400
167
+
168
+
169
+ if __name__ == "__main__":
170
+ logging.basicConfig(
171
+ format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)s - %(message)s", level=logging.INFO
172
+ )
173
+ app.run(debug=False, port=5110)