Nekochu commited on
Commit
9ec97f1
1 Parent(s): d32b641

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -19
app.py CHANGED
@@ -24,12 +24,15 @@ LICENSE = """
24
  if not torch.cuda.is_available():
25
  DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
26
 
27
- # Define the available models
28
- MODELS = [
29
- {"name": "Nekochu/Luminia-13B-v3", "id": "Nekochu/Luminia-13B-v3"},
30
- {"name": "Nekochu/Llama-2-13B-German-ORPO", "id": "Nekochu/Llama-2-13B-German-ORPO"},
31
- # Add more models here in the future
32
- ]
 
 
 
33
 
34
  @spaces.GPU(duration=120)
35
  def generate(
@@ -43,12 +46,13 @@ def generate(
43
  top_k: int = 50,
44
  repetition_penalty: float = 1.2,
45
  ) -> Iterator[str]:
46
- # Load the model and tokenizer based on the selected model ID
47
- model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_4bit=True)
48
- tokenizer = AutoTokenizer.from_pretrained(model_id)
49
- tokenizer.use_default_system_prompt = False
50
-
51
- conversation = []
 
52
  if system_prompt:
53
  conversation.append({"role": "system", "content": system_prompt})
54
  for user, assistant in chat_history:
@@ -81,17 +85,11 @@ def generate(
81
  outputs.append(text)
82
  yield "".join(outputs)
83
 
84
- # Add a dropdown for model selection
85
- model_dropdown = gr.Dropdown(
86
- label="Select Model",
87
- choices=[model["name"] for model in MODELS],
88
- value=MODELS[0]["name"], # Default to the first model
89
- )
90
 
91
  chat_interface = gr.ChatInterface(
92
  fn=generate,
93
  additional_inputs=[
94
- model_dropdown,
95
  gr.Textbox(label="System prompt", lines=6),
96
  gr.Slider(
97
  label="Max new tokens",
 
24
  if not torch.cuda.is_available():
25
  DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
26
 
27
+
28
+ if torch.cuda.is_available():
29
+ model_id = "Nekochu/Luminia-13B-v3"
30
+ model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_4bit=True)
31
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
32
+ tokenizer.use_default_system_prompt = False
33
+
34
+
35
+ models_cache = {}
36
 
37
  @spaces.GPU(duration=120)
38
  def generate(
 
46
  top_k: int = 50,
47
  repetition_penalty: float = 1.2,
48
  ) -> Iterator[str]:
49
+ if model_id not in models_cache:
50
+ model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_4bit=True)
51
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
52
+ tokenizer.use_default_system_prompt = False
53
+ models_cache[model_id] = (model, tokenizer)
54
+ else:
55
+ model, tokenizer = models_cache[model_id]
56
  if system_prompt:
57
  conversation.append({"role": "system", "content": system_prompt})
58
  for user, assistant in chat_history:
 
85
  outputs.append(text)
86
  yield "".join(outputs)
87
 
 
 
 
 
 
 
88
 
89
  chat_interface = gr.ChatInterface(
90
  fn=generate,
91
  additional_inputs=[
92
+ gr.Textbox(label="Model ID", default="Nekochu/Luminia-13B-v3"),
93
  gr.Textbox(label="System prompt", lines=6),
94
  gr.Slider(
95
  label="Max new tokens",