jonas commited on
Commit
3da458d
1 Parent(s): 325624a

Upload app.py

Browse files

Tried to add live streaming of an answers

Files changed (1) hide show
  1. app.py +58 -23
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  import pandas as pd
3
  import logging
 
4
  import os
5
  import re
6
  import json
@@ -14,6 +15,7 @@ from langchain.schema import (
14
  HumanMessage,
15
  SystemMessage,
16
  )
 
17
  from langchain_community.llms import HuggingFaceEndpoint
18
  from auditqa.process_chunks import load_chunks, getconfig
19
  from langchain_community.chat_models.huggingface import ChatHuggingFace
@@ -215,36 +217,69 @@ async def chat(query,history,sources,reports,subtype,year):
215
 
216
  ##-----------------------getting inference endpoints------------------------------
217
 
218
- #callbacks = [StreamingStdOutCallbackHandler()]
 
219
  llm_qa = HuggingFaceEndpoint(
220
- endpoint_url= model_config.get('reader','ENDPOINT'),
221
  max_new_tokens=512,
222
  repetition_penalty=1.03,
223
  timeout=70,
224
- huggingfacehub_api_token=HF_token,)
 
 
 
225
 
226
- # create RAG
227
  chat_model = ChatHuggingFace(llm=llm_qa)
228
-
229
- ##-------------------------- get answers ---------------------------------------
230
- answer_lst = []
231
- for question, context in zip(question_lst , context_retrieved_lst):
232
- answer = chat_model.invoke(messages)
233
- answer_lst.append(answer.content)
234
  docs_html = []
235
  for i, d in enumerate(context_retrieved, 1):
236
  docs_html.append(make_html_source(d, i))
237
  docs_html = "".join(docs_html)
238
 
239
- previous_answer = history[-1][1]
240
- previous_answer = previous_answer if previous_answer is not None else ""
241
- answer_yet = previous_answer + answer_lst[0]
242
- answer_yet = parse_output_llm_with_sources(answer_yet)
243
- history[-1] = (query,answer_yet)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
 
245
- history = [tuple(x) for x in history]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
 
247
- yield history,docs_html
248
 
249
  # logging the event
250
  try:
@@ -472,14 +507,14 @@ with gr.Blocks(title="Audit Q&A", css= "style.css", theme=theme,elem_id = "main-
472
  # using event listeners for 1. query box 2. click on example question
473
  # https://www.gradio.app/docs/gradio/textbox#event-listeners-arguments
474
  (textbox
475
- .submit(start_chat, [textbox,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_textbox")
476
- .then(chat, [textbox,chatbot, dropdown_sources,dropdown_reports,dropdown_category,dropdown_year], [chatbot,sources_textbox],concurrency_limit = 8,api_name = "chat_textbox")
477
- .then(finish_chat, None, [textbox],api_name = "finish_chat_textbox"))
478
 
479
  (examples_hidden
480
- .change(start_chat, [examples_hidden,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_examples")
481
- .then(chat, [examples_hidden,chatbot, dropdown_sources,dropdown_reports,dropdown_category,dropdown_year], [chatbot,sources_textbox],concurrency_limit = 8,api_name = "chat_examples")
482
- .then(finish_chat, None, [textbox],api_name = "finish_chat_examples")
483
  )
484
 
485
  demo.queue()
 
1
  import gradio as gr
2
  import pandas as pd
3
  import logging
4
+ import asyncio
5
  import os
6
  import re
7
  import json
 
15
  HumanMessage,
16
  SystemMessage,
17
  )
18
+ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
19
  from langchain_community.llms import HuggingFaceEndpoint
20
  from auditqa.process_chunks import load_chunks, getconfig
21
  from langchain_community.chat_models.huggingface import ChatHuggingFace
 
217
 
218
  ##-----------------------getting inference endpoints------------------------------
219
 
220
+ callback = StreamingStdOutCallbackHandler()
221
+
222
  llm_qa = HuggingFaceEndpoint(
223
+ endpoint_url=model_config.get('reader', 'ENDPOINT'),
224
  max_new_tokens=512,
225
  repetition_penalty=1.03,
226
  timeout=70,
227
+ huggingfacehub_api_token=HF_token,
228
+ streaming=True,
229
+ callbacks=[callback]
230
+ )
231
 
 
232
  chat_model = ChatHuggingFace(llm=llm_qa)
233
+
 
 
 
 
 
234
  docs_html = []
235
  for i, d in enumerate(context_retrieved, 1):
236
  docs_html.append(make_html_source(d, i))
237
  docs_html = "".join(docs_html)
238
 
239
+ answer_yet = ""
240
+
241
+ async def process_stream():
242
+ nonlocal answer_yet
243
+ async for chunk in chat_model.astream(messages):
244
+ token = chunk.content
245
+ answer_yet += token
246
+ parsed_answer = parse_output_llm_with_sources(answer_yet)
247
+ history[-1] = (query, parsed_answer)
248
+ yield [tuple(x) for x in history], docs_html
249
+
250
+ async for update in process_stream():
251
+ yield update
252
+
253
+ # #callbacks = [StreamingStdOutCallbackHandler()]
254
+ # llm_qa = HuggingFaceEndpoint(
255
+ # endpoint_url= model_config.get('reader','ENDPOINT'),
256
+ # max_new_tokens=512,
257
+ # repetition_penalty=1.03,
258
+ # timeout=70,
259
+ # huggingfacehub_api_token=HF_token,)
260
+
261
+ # # create RAG
262
+ # chat_model = ChatHuggingFace(llm=llm_qa)
263
 
264
+ # ##-------------------------- get answers ---------------------------------------
265
+ # answer_lst = []
266
+ # for question, context in zip(question_lst , context_retrieved_lst):
267
+ # answer = chat_model.invoke(messages)
268
+ # answer_lst.append(answer.content)
269
+ # docs_html = []
270
+ # for i, d in enumerate(context_retrieved, 1):
271
+ # docs_html.append(make_html_source(d, i))
272
+ # docs_html = "".join(docs_html)
273
+
274
+ # previous_answer = history[-1][1]
275
+ # previous_answer = previous_answer if previous_answer is not None else ""
276
+ # answer_yet = previous_answer + answer_lst[0]
277
+ # answer_yet = parse_output_llm_with_sources(answer_yet)
278
+ # history[-1] = (query,answer_yet)
279
+
280
+ # history = [tuple(x) for x in history]
281
 
282
+ # yield history,docs_html
283
 
284
  # logging the event
285
  try:
 
507
  # using event listeners for 1. query box 2. click on example question
508
  # https://www.gradio.app/docs/gradio/textbox#event-listeners-arguments
509
  (textbox
510
+ .submit(start_chat, [textbox, chatbot], [textbox, tabs, chatbot], queue=False, api_name="start_chat_textbox")
511
+ .then(chat, [textbox, chatbot, dropdown_sources, dropdown_reports, dropdown_category, dropdown_year], [chatbot, sources_textbox], queue=True, concurrency_limit=8, api_name="chat_textbox")
512
+ .then(finish_chat, None, [textbox], api_name="finish_chat_textbox"))
513
 
514
  (examples_hidden
515
+ .change(start_chat, [examples_hidden, chatbot], [textbox, tabs, chatbot], queue=False, api_name="start_chat_examples")
516
+ .then(chat, [examples_hidden, chatbot, dropdown_sources, dropdown_reports, dropdown_category, dropdown_year], [chatbot, sources_textbox], concurrency_limit=8, api_name="chat_examples")
517
+ .then(finish_chat, None, [textbox], api_name="finish_chat_examples")
518
  )
519
 
520
  demo.queue()