Hack337 commited on
Commit
a6d6149
1 Parent(s): 634e7a7

Update space

Browse files
Files changed (2) hide show
  1. app.py +29 -18
  2. requirements.txt +4 -1
app.py CHANGED
@@ -1,10 +1,16 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("Hack337/WavGPT-1.0")
 
 
 
 
 
8
 
9
 
10
  def respond(
@@ -25,23 +31,28 @@ def respond(
25
 
26
  messages.append({"role": "user", "content": message})
27
 
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
  messages,
32
- max_tokens=max_tokens,
33
- stream=True,
 
 
 
 
 
 
 
34
  temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
 
 
 
 
 
38
 
39
- response += token
40
- yield response
41
 
42
- """
43
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
- """
45
  demo = gr.ChatInterface(
46
  respond,
47
  additional_inputs=[
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
 
5
+ device = "cuda" if torch.cuda.is_available() else "cpu"
6
+
7
+ model_path = "Hack337/WavGPT-1.0" # Replace with the actual model path
8
+ model = AutoModelForCausalLM.from_pretrained(
9
+ model_path,
10
+ torch_dtype="auto",
11
+ device_map="auto"
12
+ )
13
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
14
 
15
 
16
  def respond(
 
31
 
32
  messages.append({"role": "user", "content": message})
33
 
34
+ text = tokenizer.apply_chat_template(
 
 
35
  messages,
36
+ tokenize=False,
37
+ add_generation_prompt=True
38
+ )
39
+ model_inputs = tokenizer([text], return_tensors="pt").to(device)
40
+
41
+ generated_ids = model.generate(
42
+ model_inputs.input_ids,
43
+ max_new_tokens=max_tokens,
44
+ pad_token_id=tokenizer.eos_token_id,
45
  temperature=temperature,
46
+ top_p=top_p
47
+ )
48
+ generated_ids = [
49
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
50
+ ]
51
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
52
+
53
+ return response
54
 
 
 
55
 
 
 
 
56
  demo = gr.ChatInterface(
57
  respond,
58
  additional_inputs=[
requirements.txt CHANGED
@@ -1,2 +1,5 @@
1
  huggingface_hub==0.22.2
2
- minijinja
 
 
 
 
1
  huggingface_hub==0.22.2
2
+ minijinja
3
+ torch
4
+ transformers
5
+ gradio