freQuensy23 commited on
Commit
6bfe382
1 Parent(s): dfa8941

IMP new mistral

Browse files
Files changed (2) hide show
  1. app.py +7 -7
  2. generators.py +16 -0
app.py CHANGED
@@ -9,32 +9,32 @@ load_dotenv()
9
 
10
  async def handle(system_input: str, user_input: str):
11
  print(system_input, user_input)
12
- buffers = ["", "", "", ""]
13
  async for outputs in async_zip_stream(
14
  generate_gpt2(system_input, user_input),
15
  generate_mistral_7bvo1(system_input, user_input),
16
  generate_llama2(system_input, user_input),
17
  generate_llama3(system_input, user_input),
 
18
  ):
19
  # gpt_output, mistral_output, llama_output, llama2_output, llama3_output, llama4_output = outputs
20
  for i, b in enumerate(buffers):
21
  buffers[i] += str(outputs[i])
22
 
23
  yield list(buffers) + ["", ""]
24
- yield list(buffers) + [(openllama_generation := generate_openllama(system_input, user_input)), '']
25
- yield list(buffers) + [openllama_generation, generate_bloom(system_input, user_input)]
26
 
27
 
28
  with gr.Blocks() as demo:
29
  system_input = gr.Textbox(label='System Input', value='You are AI assistant', lines=2)
30
  with gr.Row():
31
  gpt = gr.Textbox(label='gpt-2', lines=4, interactive=False)
32
- mistral = gr.Textbox(label='mistral', lines=4, interactive=False)
33
- llama = gr.Textbox(label='openllama', lines=4, interactive=False)
34
  with gr.Row():
35
  llama2 = gr.Textbox(label='llama-2', lines=4, interactive=False)
36
  llama3 = gr.Textbox(label='llama-3', lines=4, interactive=False)
37
- bloom = gr.Textbox(label='bloom', lines=4, interactive=False)
38
 
39
  user_input = gr.Textbox(label='User Input', lines=2)
40
  gen_button = gr.Button('Generate')
@@ -42,7 +42,7 @@ with gr.Blocks() as demo:
42
  gen_button.click(
43
  fn=handle,
44
  inputs=[system_input, user_input],
45
- outputs=[gpt, mistral, llama2, llama3, llama, bloom],
46
  )
47
 
48
  demo.launch()
 
9
 
10
  async def handle(system_input: str, user_input: str):
11
  print(system_input, user_input)
12
+ buffers = ["", "", "", "", ""]
13
  async for outputs in async_zip_stream(
14
  generate_gpt2(system_input, user_input),
15
  generate_mistral_7bvo1(system_input, user_input),
16
  generate_llama2(system_input, user_input),
17
  generate_llama3(system_input, user_input),
18
+ generate_mistral_7bvo3(system_input, user_input),
19
  ):
20
  # gpt_output, mistral_output, llama_output, llama2_output, llama3_output, llama4_output = outputs
21
  for i, b in enumerate(buffers):
22
  buffers[i] += str(outputs[i])
23
 
24
  yield list(buffers) + ["", ""]
25
+ yield list(buffers) + [generate_bloom(system_input, user_input)]
 
26
 
27
 
28
  with gr.Blocks() as demo:
29
  system_input = gr.Textbox(label='System Input', value='You are AI assistant', lines=2)
30
  with gr.Row():
31
  gpt = gr.Textbox(label='gpt-2', lines=4, interactive=False)
32
+ mistral = gr.Textbox(label='mistral-v01', lines=4, interactive=False)
33
+ mistral_new = gr.Textbox(label='mistral-v03', lines=4, interactive=False)
34
  with gr.Row():
35
  llama2 = gr.Textbox(label='llama-2', lines=4, interactive=False)
36
  llama3 = gr.Textbox(label='llama-3', lines=4, interactive=False)
37
+ bloom = gr.Textbox(label='bloom [GPU]', lines=4, interactive=False)
38
 
39
  user_input = gr.Textbox(label='User Input', lines=2)
40
  gen_button = gr.Button('Generate')
 
42
  gen_button.click(
43
  fn=handle,
44
  inputs=[system_input, user_input],
45
+ outputs=[gpt, mistral, llama2, llama3, mistral_new, bloom],
46
  )
47
 
48
  demo.launch()
generators.py CHANGED
@@ -38,6 +38,22 @@ async def generate_mistral_7bvo1(system_input, user_input):
38
  yield message.choices[0].delta.content
39
 
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  async def generate_gpt2(system_input, user_input):
42
  output = await query_llm({
43
  "inputs": (inputs:=f"{system_input}\n{user_input}"),
 
38
  yield message.choices[0].delta.content
39
 
40
 
41
+ async def generate_mistral_7bvo3(system_input, user_input):
42
+ client = AsyncInferenceClient(
43
+ "mistralai/Mistral-7B-Instruct-v0.3",
44
+ token=os.getenv('HF_TOKEN'),
45
+ )
46
+
47
+ async for message in await client.chat_completion(
48
+ messages=[
49
+ {"role": "system", "content": system_input},
50
+ {"role": "user", "content": user_input}, ],
51
+ max_tokens=256,
52
+ stream=True,
53
+ ):
54
+ yield message.choices[0].delta.content
55
+
56
+
57
  async def generate_gpt2(system_input, user_input):
58
  output = await query_llm({
59
  "inputs": (inputs:=f"{system_input}\n{user_input}"),