File size: 5,496 Bytes
be63200
1dc9fa7
375bd04
 
24bec11
 
1dc9fa7
 
79fbe78
 
a850fbe
444d231
 
 
bbbffce
a850fbe
be63200
 
 
 
 
 
 
 
5df5027
be63200
 
 
 
 
 
 
 
 
a850fbe
be63200
 
 
 
 
 
 
 
 
a850fbe
be63200
 
 
 
a850fbe
be63200
 
 
 
5435ca6
 
 
a850fbe
5435ca6
be63200
 
 
 
 
 
 
 
de20d93
691deb8
a850fbe
24bec11
a850fbe
be63200
24bec11
691deb8
a850fbe
 
 
 
24bec11
 
220b4de
a850fbe
be63200
 
a850fbe
be63200
 
 
 
a850fbe
be63200
 
 
a850fbe
 
 
 
 
 
be63200
 
444d231
 
 
 
24bec11
444d231
 
a850fbe
444d231
 
 
 
 
24bec11
444d231
be63200
 
 
 
 
 
444d231
be63200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375bd04
 
 
 
 
 
 
 
 
 
 
a850fbe
 
 
 
375bd04
 
 
 
 
 
a850fbe
375bd04
 
 
a850fbe
375bd04
 
 
a850fbe
375bd04
a850fbe
375bd04
 
 
a850fbe
 
 
be63200
a850fbe
be63200
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import os
import streamlit as st

from token_stream_handler import StreamHandler
from chat_profile import User, Assistant, ChatProfileRoleEnum

from langchain.chains import ConversationalRetrievalChain
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import Docx2txtLoader, PyPDFLoader, TextLoader
from langchain_community.vectorstores.chroma import Chroma
from langchain_openai import ChatOpenAI, OpenAIEmbeddings

st.set_page_config(page_title="InkChatGPT", page_icon="πŸ“š")


def load_and_process_file(file_data):
    """
    Load and process the uploaded file.
    Returns a vector store containing the embedded chunks of the file.
    """
    file_name = os.path.join("./", file_data.name)
    with open(file_name, "wb") as f:
        f.write(file_data.getvalue())

    _, extension = os.path.splitext(file_name)

    # Load the file using the appropriate loader
    if extension == ".pdf":
        loader = PyPDFLoader(file_name)
    elif extension == ".docx":
        loader = Docx2txtLoader(file_name)
    elif extension == ".txt":
        loader = TextLoader(file_name)
    else:
        st.error("This document format is not supported!")
        return None

    documents = loader.load()

    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=1000,
        chunk_overlap=200,
    )
    chunks = text_splitter.split_documents(documents)
    embeddings = OpenAIEmbeddings(openai_api_key=st.session_state.api_key)
    vector_store = Chroma.from_documents(chunks, embeddings)
    return vector_store


def initialize_chat_model(vector_store):
    """
    Initialize the chat model with the given vector store.
    Returns a ConversationalRetrievalChain instance.
    """
    llm = ChatOpenAI(
        model="gpt-3.5-turbo",
        temperature=0,
        openai_api_key=st.session_state.api_key,
    )
    retriever = vector_store.as_retriever()
    return ConversationalRetrievalChain.from_llm(llm, retriever)


def main():
    """
    The main function that runs the Streamlit app.
    """

    assistant_message = "Hello, you can upload a document and chat with me to ask questions related to its content."
    st.session_state["messages"] = [
        Assistant(message=assistant_message).build_message()
    ]

    st.chat_message(ChatProfileRoleEnum.Assistant).write(assistant_message)

    if prompt := st.chat_input(
        placeholder="Chat with your document",
        disabled=(not st.session_state.api_key),
    ):
        st.session_state.messages.append(User(message=prompt).build_message())
        st.chat_message(ChatProfileRoleEnum.User).write(prompt)

        handle_question(prompt)


def handle_question(question):
    """
    Handles the user's question by generating a response and updating the chat history.
    """
    crc = st.session_state.crc

    if "history" not in st.session_state:
        st.session_state["history"] = []

    response = crc.run(
        {
            "question": question,
            "chat_history": st.session_state["history"],
        }
    )

    st.session_state["history"].append((question, response))

    for msg in st.session_state.messages:
        st.chat_message(msg.role).write(msg.content)

    with st.chat_message(ChatProfileRoleEnum.Assistant):
        stream_handler = StreamHandler(st.empty())
        llm = ChatOpenAI(
            openai_api_key=st.session_state.api_key,
            streaming=True,
            callbacks=[stream_handler],
        )
        response = llm.invoke(st.session_state.messages)
        st.session_state.messages.append(
            Assistant(message=response.content).build_message()
        )


def display_chat_history():
    """
    Displays the chat history in the Streamlit app.
    """

    if "history" in st.session_state:
        st.markdown("## Chat History")
        for q, a in st.session_state["history"]:
            st.markdown(f"**Question:** {q}")
            st.write(a)
            st.write("---")


def clear_history():
    """
    Clear the chat history stored in the session state.
    """
    if "history" in st.session_state:
        del st.session_state["history"]


def process_data(uploaded_file, openai_api_key):
    if uploaded_file and openai_api_key.startswith("sk-"):
        with st.spinner("πŸ’­ Thinking..."):
            vector_store = load_and_process_file(uploaded_file)

            if vector_store:
                crc = initialize_chat_model(vector_store)
                st.session_state.crc = crc
                st.success(f"File: `{uploaded_file.name}`, processed successfully!")


def build_sidebar():
    with st.sidebar:
        st.title("πŸ“š InkChatGPT")

        with st.form(key="input_form"):
            openai_api_key = st.text_input(
                "OpenAI API Key",
                type="password",
                placeholder="Enter your OpenAI API key",
            )

            st.session_state.api_key = openai_api_key
            if not openai_api_key:
                st.info("Please add your OpenAI API key to continue.")

            uploaded_file = st.file_uploader(
                "Select a file", type=["pdf", "docx", "txt"], key="file_uploader"
            )

            st.form_submit_button(
                "Process File",
                on_click=process_data(
                    uploaded_file=uploaded_file, openai_api_key=openai_api_key
                ),
            )


if __name__ == "__main__":
    build_sidebar()
    main()