Spaces:
Running
Running
import os | |
import gradio as gr | |
import platform | |
from langchain_community.document_loaders import ObsidianLoader | |
from langchain_text_splitters import RecursiveCharacterTextSplitter, Language | |
from langchain.embeddings import CacheBackedEmbeddings | |
from langchain.storage import LocalFileStore | |
from langchain_community.embeddings import HuggingFaceBgeEmbeddings | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.retrievers import BM25Retriever | |
from langchain.retrievers import EnsembleRetriever | |
from langchain_cohere import CohereRerank | |
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever | |
from langchain_core.runnables import ConfigurableField, RunnablePassthrough | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_groq import ChatGroq | |
from langchain_google_genai import GoogleGenerativeAI | |
from prompt_template import PROMPT_TEMPLATE | |
DIRECTORIES = ["./docs/obsidian-help", "./docs/obsidian-developer"] | |
FAISS_DB_INDEX = "db_index" | |
def load_and_process_documents(directories): | |
md_docs = [] | |
for directory in directories: | |
try: | |
loader = ObsidianLoader(directory, encoding="utf-8") | |
md_docs.extend(loader.load()) | |
except Exception: | |
pass | |
md_splitter = RecursiveCharacterTextSplitter.from_language( | |
language=Language.MARKDOWN, | |
chunk_size=2000, | |
chunk_overlap=200, | |
) | |
return md_splitter.split_documents(md_docs) | |
def setup_retrieval_system(splitted_docs): | |
if platform.system() == "Darwin": | |
model_kwargs = {"device": "mps"} | |
else: | |
model_kwargs = {"device": "cpu"} | |
embeddings = HuggingFaceBgeEmbeddings( | |
model_name="BAAI/bge-m3", | |
model_kwargs=model_kwargs, | |
encode_kwargs={"normalize_embeddings": True}, | |
) | |
store = LocalFileStore("./.cache/") | |
cached_embeddings = CacheBackedEmbeddings.from_bytes_store( | |
embeddings, | |
store, | |
namespace=embeddings.model_name, | |
) | |
if os.path.exists(FAISS_DB_INDEX): | |
db = FAISS.load_local( | |
FAISS_DB_INDEX, | |
cached_embeddings, | |
allow_dangerous_deserialization=True, | |
) | |
else: | |
db = FAISS.from_documents(splitted_docs, cached_embeddings) | |
db.save_local(folder_path=FAISS_DB_INDEX) | |
faiss_retriever = db.as_retriever(search_type="mmr", search_kwargs={"k": 10}) | |
bm25_retriever = BM25Retriever.from_documents(splitted_docs) | |
bm25_retriever.k = 10 | |
ensemble_retriever = EnsembleRetriever( | |
retrievers=[bm25_retriever, faiss_retriever], | |
weights=[0.5, 0.5], | |
search_type="mmr", | |
) | |
compressor = CohereRerank(model="rerank-multilingual-v3.0", top_n=5) | |
return ContextualCompressionRetriever( | |
base_compressor=compressor, | |
base_retriever=ensemble_retriever, | |
) | |
def setup_language_model(): | |
return ChatGroq( | |
model_name="llama3-70b-8192", | |
temperature=0, | |
).configurable_alternatives( | |
ConfigurableField(id="llm"), | |
default_key="llama3", | |
gemini=GoogleGenerativeAI( | |
model="gemini-pro", | |
temperature=0, | |
), | |
) | |
def format_docs(docs): | |
formatted_docs = [] | |
for doc in docs: | |
formatted_doc = f"Page Content:\n{doc.page_content}\n" | |
if doc.metadata.get("source"): | |
formatted_doc += f"Source: {doc.metadata['source']}\n" | |
formatted_docs.append(formatted_doc) | |
return "\n---\n".join(formatted_docs) | |
def main(): | |
splitted_docs = load_and_process_documents(DIRECTORIES) | |
compression_retriever = setup_retrieval_system(splitted_docs) | |
llm = setup_language_model() | |
rag_chain = ( | |
{"context": compression_retriever | format_docs, "question": RunnablePassthrough()} | |
| PROMPT_TEMPLATE | |
| llm | |
| StrOutputParser() | |
) | |
def predict(message, history=None): | |
return rag_chain.invoke(message) | |
gr.ChatInterface( | |
predict, | |
title="옵시디언 노트앱 및 플러그인 개발에 대해서 물어보세요!", | |
description="안녕하세요!\n저는 옵시디언 노트앱과 플러그인 개발에 대한 인공지능 QA봇입니다. 옵시디언 노트앱의 사용법, 고급 기능, 플러그인 및 테마 개발에 대해 깊은 지식을 가지고 있어요. 문서 작업, 정보 정리 또는 개발에 관한 도움이 필요하시면 언제든지 질문해주세요!", | |
).launch() | |
if __name__ == "__main__": | |
main() | |