Nekochu commited on
Commit
2bc6f48
1 Parent(s): b0302a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -15
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  from threading import Thread
3
  from typing import Iterator
 
4
  import gradio as gr
5
  import spaces
6
  import torch
@@ -20,35 +21,33 @@ LICENSE = """
20
  ---.
21
  """
22
 
23
- if not torch.cuda.is_available():
24
- DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
25
-
26
- models_cache = {}
27
-
28
- def load_model(model_id: str):
29
  model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_4bit=True)
30
  tokenizer = AutoTokenizer.from_pretrained(model_id)
31
  tokenizer.use_default_system_prompt = False
32
  return model, tokenizer
33
 
 
 
 
 
 
 
 
 
34
  @spaces.GPU(duration=120)
35
  def generate(
36
- model_id: str,
37
  message: str,
38
  chat_history: list[tuple[str, str]],
39
  system_prompt: str,
 
40
  max_new_tokens: int = 1024,
41
  temperature: float = 0.6,
42
  top_p: float = 0.9,
43
  top_k: int = 50,
44
  repetition_penalty: float = 1.2,
45
  ) -> Iterator[str]:
46
- # Load the model if it's not already loaded
47
- if model_id not in models_cache:
48
- model, tokenizer = load_model(model_id)
49
- models_cache[model_id] = (model, tokenizer)
50
- else:
51
- model, tokenizer = models_cache[model_id]
52
  conversation = []
53
  if system_prompt:
54
  conversation.append({"role": "system", "content": system_prompt})
@@ -86,8 +85,8 @@ def generate(
86
  chat_interface = gr.ChatInterface(
87
  fn=generate,
88
  additional_inputs=[
89
- gr.Textbox(label="Model ID", placeholder="Nekochu/Luminia-13B-v3"),
90
  gr.Textbox(label="System prompt", lines=6),
 
91
  gr.Slider(
92
  label="Max new tokens",
93
  minimum=1,
@@ -138,4 +137,4 @@ with gr.Blocks(css="style.css") as demo:
138
  gr.Markdown(LICENSE)
139
 
140
  if __name__ == "__main__":
141
- demo.queue(max_size=20).launch()
 
1
  import os
2
  from threading import Thread
3
  from typing import Iterator
4
+
5
  import gradio as gr
6
  import spaces
7
  import torch
 
21
  ---.
22
  """
23
 
24
+ def load_model(model_id):
 
 
 
 
 
25
  model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_4bit=True)
26
  tokenizer = AutoTokenizer.from_pretrained(model_id)
27
  tokenizer.use_default_system_prompt = False
28
  return model, tokenizer
29
 
30
+ if not torch.cuda.is_available():
31
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
32
+
33
+ if torch.cuda.is_available():
34
+ model_id = "Nekochu/Luminia-13B-v3"
35
+ model, tokenizer = load_model(model_id)
36
+
37
+
38
  @spaces.GPU(duration=120)
39
  def generate(
 
40
  message: str,
41
  chat_history: list[tuple[str, str]],
42
  system_prompt: str,
43
+ model_id: str = "Nekochu/Luminia-13B-v3",
44
  max_new_tokens: int = 1024,
45
  temperature: float = 0.6,
46
  top_p: float = 0.9,
47
  top_k: int = 50,
48
  repetition_penalty: float = 1.2,
49
  ) -> Iterator[str]:
50
+ model, tokenizer = load_model(model_id)
 
 
 
 
 
51
  conversation = []
52
  if system_prompt:
53
  conversation.append({"role": "system", "content": system_prompt})
 
85
  chat_interface = gr.ChatInterface(
86
  fn=generate,
87
  additional_inputs=[
 
88
  gr.Textbox(label="System prompt", lines=6),
89
+ gr.Textbox(label="Model ID", default="Nekochu/Luminia-13B-v3"),
90
  gr.Slider(
91
  label="Max new tokens",
92
  minimum=1,
 
137
  gr.Markdown(LICENSE)
138
 
139
  if __name__ == "__main__":
140
+ demo.queue(max_size=20).launch()