Suku0 commited on
Commit
78b4546
1 Parent(s): 8818ed5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -60
app.py CHANGED
@@ -1,60 +1,60 @@
1
- import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
- from sentence_transformers import SentenceTransformer
4
- from qdrant_client import QdrantClient
5
- import torch
6
- from llama_cpp import Llama
7
-
8
- llm = Llama.from_pretrained(
9
- repo_id="Suku0/mistral-7b-instruct-v0.3-bnb-4bit-GGUF",
10
- filename="mistral-7b-instruct-v0.3-bnb-4bit.Q4_K_M.gguf",
11
- n_ctx=16384
12
- )
13
- embedding_model = SentenceTransformer('nomic-ai/nomic-embed-text-v1.5', trust_remote_code=True)
14
- qdrant_client = QdrantClient(
15
- url="https://9a5cbf91-7dac-4dd0-80f6-13e512da1060.europe-west3-0.gcp.cloud.qdrant.io:6333",
16
- api_key="1M-sCCVolJOOJeRXMBUh4wHfj8bkY4nZyHiau0LBllFr1vsXb1oDPg",
17
- )
18
-
19
- def retrieve_context(query):
20
- query_vector = embedding_model.encode(query).tolist()
21
-
22
- search_result = qdrant_client.search(
23
- collection_name="ctx_collection",
24
- query_vector=query_vector,
25
- limit=10,
26
- with_payload=True
27
- )
28
-
29
- context = " ".join([hit.payload["text"] for hit in search_result])
30
- return context
31
-
32
- def respond(message, history, system_message, max_tokens, temperature, top_p):
33
- context = retrieve_context(message)
34
- prompt = f"""You are a helpful assistant. Please answer the user's question based on the given context. If the context doesn't provide any answer, say the context doesn't provide the answer.
35
-
36
- ### Context:
37
- {context}
38
-
39
- ### Question:
40
- {message}
41
-
42
- ### Answer:
43
- """
44
-
45
- response = llm(prompt.format(ctx=context, question=message), max_tokens=243)
46
-
47
- return response["choices"][0]["text"]
48
-
49
- demo = gr.ChatInterface(
50
- respond,
51
- additional_inputs=[
52
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
53
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
54
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
55
- gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
56
- ]
57
- )
58
-
59
- if __name__ == "__main__":
60
- demo.launch()
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ from sentence_transformers import SentenceTransformer
4
+ from qdrant_client import QdrantClient
5
+ import torch
6
+ from llama_cpp import Llama
7
+
8
+ llm = Llama.from_pretrained(
9
+ repo_id="Suku0/mistral-7b-instruct-v0.3-bnb-4bit-GGUF",
10
+ filename="mistral-7b-instruct-v0.3-bnb-4bit.Q4_K_M.gguf",
11
+ n_ctx=16384
12
+ )
13
+ embedding_model = SentenceTransformer('nomic-ai/nomic-embed-text-v1.5', trust_remote_code=True)
14
+ qdrant_client = QdrantClient(
15
+ url="https://9a5cbf91-7dac-4dd0-80f6-13e512da1060.europe-west3-0.gcp.cloud.qdrant.io:6333",
16
+ api_key="1M-sCCVolJOOJeRXMBUh4wHfj8bkY4nZyHiau0LBllFr1vsXb1oDPg",
17
+ )
18
+
19
+ def retrieve_context(query):
20
+ query_vector = embedding_model.encode(query).tolist()
21
+
22
+ search_result = qdrant_client.search(
23
+ collection_name="ctx_collection",
24
+ query_vector=query_vector,
25
+ limit=10,
26
+ with_payload=True
27
+ )
28
+
29
+ context = " ".join([hit.payload["text"] for hit in search_result])
30
+ return context
31
+
32
+ def respond(message, history, system_message, max_tokens, temperature, top_p):
33
+ context = retrieve_context(message)
34
+ prompt = f"""You are a helpful assistant. Please answer the user's question based on the given context. If the context doesn't provide any answer, say the context doesn't provide the answer.
35
+
36
+ ### Context:
37
+ {context}
38
+
39
+ ### Question:
40
+ {message}
41
+
42
+ ### Answer:
43
+ """
44
+
45
+ response = llm(prompt.format(ctx=context, question=message), max_tokens=243)
46
+
47
+ return response["choices"][0]["text"]
48
+
49
+ app = gr.ChatInterface(
50
+ respond,
51
+ additional_inputs=[
52
+ gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
53
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
54
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
55
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
56
+ ]
57
+ )
58
+
59
+ if __name__ == "__main__":
60
+ app.launch()