import os import streamlit as st import asyncio from typing_extensions import TypedDict, List from IPython.display import Image, display from langchain_core.pydantic_v1 import BaseModel, Field from langchain.schema import Document from langgraph.graph import START, END, StateGraph from langchain.prompts import PromptTemplate import uuid from langchain_groq import ChatGroq from langchain_community.utilities import GoogleSerperAPIWrapper from langchain_chroma import Chroma from langchain_community.document_loaders import NewsURLLoader from langchain_community.retrievers.wikipedia import WikipediaRetriever from sentence_transformers import SentenceTransformer from langchain.vectorstores import Chroma from langchain_community.document_loaders import UnstructuredURLLoader, NewsURLLoader from langchain_community.embeddings import HuggingFaceEmbeddings from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.document_loaders import WebBaseLoader from langchain_core.output_parsers import StrOutputParser from langchain_core.output_parsers import JsonOutputParser from langchain_community.vectorstores.utils import filter_complex_metadata from langchain.schema import Document from langgraph.graph import START, END, StateGraph from langchain_community.document_loaders.directory import DirectoryLoader from langchain.document_loaders import TextLoader from functions import * lang_api_key = os.getenv("lang_api_key") SERPER_API_KEY = os.getenv("SERPER_API_KEY") groq_api_key = os.getenv("groq_api_key") os.environ["LANGCHAIN_TRACING_V2"] = "true" os.environ["LANGCHAIN_ENDPOINT"] = "https://api.langchain.plus" os.environ["LANGCHAIN_API_KEY"] = lang_api_key os.environ["LANGCHAIN_PROJECT"] = "Info Assistant" os.environ["GROQ_API_KEY"] = groq_api_key os.environ["SERPER_API_KEY"] = SERPER_API_KEY def main(): st.set_page_config(page_title="Info Assistant: ", page_icon=":books:") st.header("Info Assistant :" ":books:") logo_path = "digital-a-high-resolution-logo-transparent (2).png" link = "https://digitala.lt/" st.logo(logo_path , link=link) st.markdown(""" ###### Get support of **"Info Assistant"**, who has in memory a lot of Data Science related articles. If it can't answer based on its knowledge base, information will be found on the internet :books: """) if "messages" not in st.session_state: st.session_state["messages"] = [ {"role": "assistant", "content": "Hi, I'm a chatbot who is based on respublic of Lithuania law documents. How can I help you?"} ] class GraphState(TypedDict): """ Represents the state of our graph. Attributes: question: question generation: LLM generation search: whether to add search documents: list of documents generations_count : generations count """ question: str generation: str search: str documents: List[str] steps: List[str] generation_count: int search_type = st.selectbox( "Choose search type. Options are [Max marginal relevance search (similarity) , Similarity search (similarity). Default value (similarity)]", options=["mmr", "similarity"], index=1 ) k = st.select_slider( "Select amount of documents to be retrieved. Default value (4): ", options=list(range(2, 6)), value=4 ) llm = ChatGroq( model="gemma2-9b-it", # Specify the Gemma2 9B model temperature=0.0, max_tokens=400, max_retries=3 ) retriever = create_retriever_from_chroma(vectorstore_path="docs/chroma/", search_type=search_type, k=k, chunk_size=550, chunk_overlap=40) # Graph workflow = StateGraph(GraphState) # Define the nodes workflow.add_node("ask_question", lambda state: ask_question(state, retriever)) workflow.add_node("retrieve", lambda state: retrieve(state, retriever)) workflow.add_node("grade_documents", lambda state: grade_documents(state, retrieval_grader_grader(llm) )) # grade documents workflow.add_node("generate", lambda state: generate(state,QA_chain(llm) )) # generatae workflow.add_node("web_search", web_search) # web search workflow.add_node("transform_query", lambda state: transform_query(state,create_question_rewriter(llm) )) # Build graph workflow.set_entry_point("ask_question") workflow.add_conditional_edges( "ask_question", lambda state: grade_question_toxicity(state, create_toxicity_checker(llm)), { "good": "retrieve", 'bad': END, }, ) workflow.add_edge("retrieve", "grade_documents") workflow.add_conditional_edges( "grade_documents", decide_to_generate, { "search": "web_search", "generate": "generate", }, ) workflow.add_edge("web_search", "generate") workflow.add_conditional_edges( "generate", lambda state: grade_generation_v_documents_and_question(state, create_hallucination_checker(llm), create_helpfulness_checker(llm)), { "not supported": "generate", "useful": END, "not useful": "transform_query", }, ) workflow.add_edge("transform_query", "retrieve") custom_graph = workflow.compile() if user_question := st.text_input("Ask a question about your documents:"): asyncio.run(handle_userinput(user_question, custom_graph)) if __name__ == "__main__": main()