CMLL commited on
Commit
220ce3a
1 Parent(s): 8724a4d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -15
app.py CHANGED
@@ -36,16 +36,14 @@ if torch.cuda.is_available():
36
  def generate(
37
  message: str,
38
  chat_history: list[tuple[str, str]],
39
- system_prompt: str,
40
  max_new_tokens: int = 1024,
41
  temperature: float = 0.6,
42
  top_p: float = 0.9,
43
  top_k: int = 50,
44
  repetition_penalty: float = 1.2,
45
  ) -> Iterator[str]:
46
- conversation = []
47
- if system_prompt:
48
- conversation.append({"role": "system", "content": system_prompt})
49
  for user, assistant in chat_history:
50
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
51
  conversation.append({"role": "user", "content": message})
@@ -61,26 +59,27 @@ def generate(
61
  "repetition_penalty": repetition_penalty,
62
  }
63
 
 
64
  def run_generation():
65
  try:
66
- return pipe(input_text, **generate_kwargs)
 
67
  except Exception as e:
68
- gr.Error(f"Error in generation: {e}")
69
- return []
70
-
71
- t = Thread(target=run_generation)
72
- t.start()
73
- t.join() # Ensure the thread completes before proceeding
74
 
 
75
  outputs = []
76
- for text in run_generation():
77
- outputs.append(text['generated_text'])
78
- yield "".join(outputs)
 
 
 
79
 
80
  chat_interface = gr.ChatInterface(
81
  fn=generate,
82
  additional_inputs=[
83
- gr.Textbox(label="System prompt", lines=6),
84
  gr.Slider(
85
  label="Max new tokens",
86
  minimum=1,
 
36
  def generate(
37
  message: str,
38
  chat_history: list[tuple[str, str]],
39
+ system_prompt: str = "You are a helpful TCM medical assistant named 仲景中医大语言模型, created by 医哲未来.",
40
  max_new_tokens: int = 1024,
41
  temperature: float = 0.6,
42
  top_p: float = 0.9,
43
  top_k: int = 50,
44
  repetition_penalty: float = 1.2,
45
  ) -> Iterator[str]:
46
+ conversation = [{"role": "system", "content": system_prompt}]
 
 
47
  for user, assistant in chat_history:
48
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
49
  conversation.append({"role": "user", "content": message})
 
59
  "repetition_penalty": repetition_penalty,
60
  }
61
 
62
+ # Function to run the generation
63
  def run_generation():
64
  try:
65
+ results = pipe(input_text, **generate_kwargs)
66
+ return results
67
  except Exception as e:
68
+ return [f"Error in generation: {e}"]
 
 
 
 
 
69
 
70
+ # Run generation in a separate thread and wait for it to finish
71
  outputs = []
72
+ generation_thread = Thread(target=lambda: outputs.extend(run_generation()))
73
+ generation_thread.start()
74
+ generation_thread.join()
75
+
76
+ for output in outputs:
77
+ yield output['generated_text'] if isinstance(output, dict) else output
78
 
79
  chat_interface = gr.ChatInterface(
80
  fn=generate,
81
  additional_inputs=[
82
+ gr.Textbox(label="System prompt", lines=6, value="You are a helpful TCM medical assistant named 仲景中医大语言模型, created by 医哲未来."),
83
  gr.Slider(
84
  label="Max new tokens",
85
  minimum=1,