CMLL commited on
Commit
f3b7005
1 Parent(s): 6e999ef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -29
app.py CHANGED
@@ -1,38 +1,43 @@
1
  import os
 
 
 
2
  import gradio as gr
3
  import spaces
4
  import torch
5
- from transformers import AutoModelForCausalLM, AutoTokenizer
6
- from threading import Thread
7
- from typing import Iterator
8
 
9
- # Constants
10
  MAX_MAX_NEW_TOKENS = 2048
11
  DEFAULT_MAX_NEW_TOKENS = 1024
12
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
13
 
14
  DESCRIPTION = """\
15
- # Llama-2 7B Chat
16
- This Space demonstrates model [Llama-2-7b-chat](https://huggingface.co/meta-llama/Llama-2-7b-chat) by Meta, a Llama 2 model with 7B parameters fine-tuned for chat instructions. Feel free to play with it, or duplicate to run generations without a queue! If you want to run your own service, you can also [deploy the model on Inference Endpoints](https://huggingface.co/inference-endpoints).
17
- 🔎 For more details about the Llama 2 family of models and how to use them with `transformers`, take a look [at our blog post](https://huggingface.co/blog/llama2).
18
- 🔨 Looking for an even more powerful model? Check out the [13B version](https://huggingface.co/spaces/huggingface-projects/llama-2-13b-chat) or the large [70B model demo](https://huggingface.co/spaces/ysharma/Explore_llamav2_with_TGI).
19
  """
20
 
21
  LICENSE = """
 
22
  <p/>
23
  ---
24
- As a derivate work of [Llama-2-7b-chat](https://huggingface.co/meta-llama/Llama-2-7b-chat) by Meta,
25
- this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/USE_POLICY.md).
26
  """
27
 
28
- # Set the device
29
- device = "cuda" if torch.cuda.is_available() else "cpu"
30
 
31
- # Model loading with the replacement setup
32
- base_model_id = "Qwen/Qwen1.5-1.8B-Chat"
33
- model = AutoModelForCausalLM.from_pretrained(base_model_id, device_map="auto")
34
- model.load_adapter("CMLM/ZhongJing-2-1_8b")
35
- tokenizer = AutoTokenizer.from_pretrained("CMLM/ZhongJing-2-1_8b", padding_side="right", trust_remote_code=True, pad_token='')
 
 
 
 
 
 
36
 
37
  @spaces.GPU
38
  def generate(
@@ -56,26 +61,27 @@ def generate(
56
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
57
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
58
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
59
-
60
- input_ids = input_ids.to(device) # Ensure the input tensor is on the correct device
61
 
62
- outputs = []
63
- generated_ids = model.generate(
64
- input_ids,
 
65
  max_new_tokens=max_new_tokens,
66
  do_sample=True,
67
  top_p=top_p,
68
  top_k=top_k,
69
  temperature=temperature,
70
  num_beams=1,
71
- repetition_penalty=repetition_penalty
72
  )
 
 
73
 
74
- generated_ids = generated_ids.to(device) # Ensure the generated ids are moved to the device
75
-
76
- outputs.append(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
77
- return "".join(outputs)
78
-
79
 
80
  chat_interface = gr.ChatInterface(
81
  fn=generate,
@@ -135,4 +141,3 @@ with gr.Blocks(css="style.css") as demo:
135
 
136
  if __name__ == "__main__":
137
  demo.queue(max_size=20).launch()
138
-
 
1
  import os
2
+ from threading import Thread
3
+ from typing import Iterator
4
+
5
  import gradio as gr
6
  import spaces
7
  import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
 
 
9
 
 
10
  MAX_MAX_NEW_TOKENS = 2048
11
  DEFAULT_MAX_NEW_TOKENS = 1024
12
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
13
 
14
  DESCRIPTION = """\
15
+
16
+ ZhongJing-2-1_8b Chat
17
+ This Space demonstrates the ZhongJing-2-1_8b model, a fine-tuned model for chat instructions. Feel free to play with it, or duplicate to run generations without a queue! If you want to run your own service, you can also deploy the model on Inference Endpoints.
 
18
  """
19
 
20
  LICENSE = """
21
+
22
  <p/>
23
  ---
24
+ As a derivate work of [ZhongJing-2-1_8b](https://huggingface.co/CMLM/ZhongJing-2-1_8b) by 医哲未来 of Fudan University, this demo is governed by the original license and acceptable use policy.
 
25
  """
26
 
27
+ if not torch.cuda.is_available():
28
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
29
 
30
+ if torch.cuda.is_available():
31
+ base_model_id = "Qwen/Qwen1.5-1.8B-Chat"
32
+ peft_model_id = "CMLM/ZhongJing-2-1_8b"
33
+ model = AutoModelForCausalLM.from_pretrained(base_model_id, torch_dtype=torch.float16, device_map="auto")
34
+ model.load_adapter(peft_model_id)
35
+ tokenizer = AutoTokenizer.from_pretrained(
36
+ peft_model_id,
37
+ padding_side="right",
38
+ trust_remote_code=True,
39
+ pad_token=''
40
+ )
41
 
42
  @spaces.GPU
43
  def generate(
 
61
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
62
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
63
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
64
+ input_ids = input_ids.to(model.device)
 
65
 
66
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
67
+ generate_kwargs = dict(
68
+ input_ids=input_ids,
69
+ streamer=streamer,
70
  max_new_tokens=max_new_tokens,
71
  do_sample=True,
72
  top_p=top_p,
73
  top_k=top_k,
74
  temperature=temperature,
75
  num_beams=1,
76
+ repetition_penalty=repetition_penalty,
77
  )
78
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
79
+ t.start()
80
 
81
+ outputs = []
82
+ for text in streamer:
83
+ outputs.append(text)
84
+ yield "".join(outputs)
 
85
 
86
  chat_interface = gr.ChatInterface(
87
  fn=generate,
 
141
 
142
  if __name__ == "__main__":
143
  demo.queue(max_size=20).launch()