Nekochu commited on
Commit
88bb7df
1 Parent(s): ddbc638

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -17
app.py CHANGED
@@ -13,7 +13,7 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
13
 
14
  DESCRIPTION = """\
15
  # Nekochu/Luminia-13B-v3
16
- This Space demonstrates model [Nekochu/Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3) by Nekochu, a Llama 2 model with 13B parameters fine-tuned for SD gen prompt
17
  """
18
 
19
  LICENSE = """
@@ -21,22 +21,29 @@ LICENSE = """
21
  ---.
22
  """
23
 
 
 
 
 
 
 
24
  if not torch.cuda.is_available():
25
  DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
26
 
27
-
28
  if torch.cuda.is_available():
29
  model_id = "Nekochu/Luminia-13B-v3"
30
- model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_4bit=True)
31
- tokenizer = AutoTokenizer.from_pretrained(model_id)
32
- tokenizer.use_default_system_prompt = False
33
-
34
 
35
- models_cache = {}
 
 
 
 
36
 
37
  @spaces.GPU(duration=120)
38
  def generate(
39
- model_id: str,
 
40
  message: str,
41
  chat_history: list[tuple[str, str]],
42
  system_prompt: str,
@@ -46,13 +53,11 @@ def generate(
46
  top_k: int = 50,
47
  repetition_penalty: float = 1.2,
48
  ) -> Iterator[str]:
49
- if model_id not in models_cache:
50
- model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_4bit=True)
51
- tokenizer = AutoTokenizer.from_pretrained(model_id)
52
- tokenizer.use_default_system_prompt = False
53
- models_cache[model_id] = (model, tokenizer)
54
- else:
55
- model, tokenizer = models_cache[model_id]
56
  if system_prompt:
57
  conversation.append({"role": "system", "content": system_prompt})
58
  for user, assistant in chat_history:
@@ -85,11 +90,18 @@ def generate(
85
  outputs.append(text)
86
  yield "".join(outputs)
87
 
 
 
 
 
 
 
88
 
89
  chat_interface = gr.ChatInterface(
90
  fn=generate,
91
  additional_inputs=[
92
- gr.Textbox(label="Model ID", placeholder="Nekochu/Luminia-13B-v3"),
 
93
  gr.Textbox(label="System prompt", lines=6),
94
  gr.Slider(
95
  label="Max new tokens",
@@ -141,4 +153,4 @@ with gr.Blocks(css="style.css") as demo:
141
  gr.Markdown(LICENSE)
142
 
143
  if __name__ == "__main__":
144
- demo.queue(max_size=20).launch()
 
13
 
14
  DESCRIPTION = """\
15
  # Nekochu/Luminia-13B-v3
16
+ This Space demonstrates model Nekochu/Luminia-13B-v3 by Nekochu, a Llama 2 model with 13B parameters fine-tuned for SD gen prompt
17
  """
18
 
19
  LICENSE = """
 
21
  ---.
22
  """
23
 
24
+ def load_model(model_id):
25
+ model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_4bit=True)
26
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
27
+ tokenizer.use_default_system_prompt = False
28
+ return model, tokenizer
29
+
30
  if not torch.cuda.is_available():
31
  DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
32
 
 
33
  if torch.cuda.is_available():
34
  model_id = "Nekochu/Luminia-13B-v3"
35
+ model, tokenizer = load_model(model_id)
 
 
 
36
 
37
+ MODELS = [
38
+ {"name": "Nekochu/Luminia-13B-v3", "id": "Nekochu/Luminia-13B-v3"},
39
+ {"name": "Nekochu/Llama-2-13B-German-ORPO", "id": "Nekochu/Llama-2-13B-German-ORPO"},
40
+ # Add more models here in the future
41
+ ]
42
 
43
  @spaces.GPU(duration=120)
44
  def generate(
45
+ model_dropdown: str,
46
+ custom_model_id: str,
47
  message: str,
48
  chat_history: list[tuple[str, str]],
49
  system_prompt: str,
 
53
  top_k: int = 50,
54
  repetition_penalty: float = 1.2,
55
  ) -> Iterator[str]:
56
+ # Prioritize custom model ID if provided, otherwise use the dropdown selection
57
+ selected_model_id = custom_model_id if custom_model_id else model_dropdown
58
+ model, tokenizer = load_model(selected_model_id)
59
+
60
+ conversation = []
 
 
61
  if system_prompt:
62
  conversation.append({"role": "system", "content": system_prompt})
63
  for user, assistant in chat_history:
 
90
  outputs.append(text)
91
  yield "".join(outputs)
92
 
93
+ model_dropdown = gr.Dropdown(
94
+ label="Select Predefined Model",
95
+ choices=[model["name"] for model in MODELS],
96
+ value=MODELS[0]["name"], # Default to the first model
97
+ )
98
+ custom_model_id_input = gr.Textbox(label="Or Enter Custom Model ID", placeholder="Enter model ID here")
99
 
100
  chat_interface = gr.ChatInterface(
101
  fn=generate,
102
  additional_inputs=[
103
+ model_dropdown,
104
+ custom_model_id_input,
105
  gr.Textbox(label="System prompt", lines=6),
106
  gr.Slider(
107
  label="Max new tokens",
 
153
  gr.Markdown(LICENSE)
154
 
155
  if __name__ == "__main__":
156
+ demo.queue(max_size=20).launch()