Nekochu commited on
Commit
4a32d8a
1 Parent(s): 53d1a2e

attempt10 fix

Browse files
Files changed (1) hide show
  1. app.py +35 -15
app.py CHANGED
@@ -11,37 +11,50 @@ MAX_MAX_NEW_TOKENS = 2048
11
  DEFAULT_MAX_NEW_TOKENS = 1024
12
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
13
 
14
- MODELS = {
15
- "Nekochu/Luminia-13B-v3": "Default - Nekochu/Luminia-13B-v3",
16
- "Nekochu/Llama-2-13B-German-ORPO": "German ORPO - Nekochu/Llama-2-13B-German-ORPO",
17
- }
18
-
19
  DESCRIPTION = """\
20
- # Text Generation with Selectable Models
 
 
21
 
22
- This Space demonstrates text generation using different models. Choose a model from the dropdown and experience its creative capabilities!
 
 
23
  """
24
 
25
- LICENSE = """<p/> ---."""
 
 
 
 
26
 
27
  if not torch.cuda.is_available():
28
- DESCRIPTION += "\n<p>Running on CPU This demo does not work on CPU.</p>"
 
 
 
 
 
 
 
 
 
29
 
30
  @spaces.GPU(duration=120)
31
  def generate(
 
 
32
  message: str,
33
  chat_history: list[tuple[str, str]],
34
  system_prompt: str,
35
- model_id: str,
36
  max_new_tokens: int = 1024,
37
  temperature: float = 0.6,
38
  top_p: float = 0.9,
39
  top_k: int = 50,
40
  repetition_penalty: float = 1.2,
41
  ) -> Iterator[str]:
42
- model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_4bit=True)
43
- tokenizer = AutoTokenizer.from_pretrained(model_id)
44
- tokenizer.use_default_system_prompt = False
45
  conversation = []
46
  if system_prompt:
47
  conversation.append({"role": "system", "content": system_prompt})
@@ -75,11 +88,18 @@ def generate(
75
  outputs.append(text)
76
  yield "".join(outputs)
77
 
78
- model_dropdown = gr.Dropdown(label="Select Model", choices=list(MODELS.values()))
 
 
 
 
 
 
79
  chat_interface = gr.ChatInterface(
80
  fn=generate,
81
  additional_inputs=[
82
  model_dropdown,
 
83
  gr.Textbox(label="System prompt", lines=6),
84
  gr.Slider(
85
  label="Max new tokens",
@@ -131,4 +151,4 @@ with gr.Blocks(css="style.css") as demo:
131
  gr.Markdown(LICENSE)
132
 
133
  if __name__ == "__main__":
134
- demo.queue(max_size=20).launch()
 
11
  DEFAULT_MAX_NEW_TOKENS = 1024
12
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
13
 
 
 
 
 
 
14
  DESCRIPTION = """\
15
+ # Nekochu/Luminia-13B-v3
16
+ This Space demonstrates model Nekochu/Luminia-13B-v3 by Nekochu, a Llama 2 model with 13B parameters fine-tuned for SD gen prompt
17
+ """
18
 
19
+ LICENSE = """
20
+ <p/>
21
+ ---.
22
  """
23
 
24
+ def load_model(model_id):
25
+ model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_4bit=True)
26
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
27
+ tokenizer.use_default_system_prompt = False
28
+ return model, tokenizer
29
 
30
  if not torch.cuda.is_available():
31
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
32
+
33
+ if torch.cuda.is_available():
34
+ model_id = "Nekochu/Luminia-13B-v3"
35
+ model, tokenizer = load_model(model_id)
36
+
37
+ MODELS = [
38
+ {"name": "Nekochu/Luminia-13B-v3", "id": "Nekochu/Luminia-13B-v3"},
39
+ {"name": "Nekochu/Llama-2-13B-German-ORPO", "id": "Nekochu/Llama-2-13B-German-ORPO"},
40
+ ]
41
 
42
  @spaces.GPU(duration=120)
43
  def generate(
44
+ model_dropdown: str,
45
+ custom_model_id: str,
46
  message: str,
47
  chat_history: list[tuple[str, str]],
48
  system_prompt: str,
 
49
  max_new_tokens: int = 1024,
50
  temperature: float = 0.6,
51
  top_p: float = 0.9,
52
  top_k: int = 50,
53
  repetition_penalty: float = 1.2,
54
  ) -> Iterator[str]:
55
+ selected_model_id = custom_model_id if custom_model_id else model_dropdown
56
+ model, tokenizer = load_model(selected_model_id)
57
+
58
  conversation = []
59
  if system_prompt:
60
  conversation.append({"role": "system", "content": system_prompt})
 
88
  outputs.append(text)
89
  yield "".join(outputs)
90
 
91
+ model_dropdown = gr.Dropdown(
92
+ label="Select Predefined Model",
93
+ choices=[model["name"] for model in MODELS],
94
+ value=MODELS[0]["name"], # Default to the first model
95
+ )
96
+ custom_model_id_input = gr.Textbox(label="Or Enter Custom Model ID", placeholder="Enter model ID here")
97
+
98
  chat_interface = gr.ChatInterface(
99
  fn=generate,
100
  additional_inputs=[
101
  model_dropdown,
102
+ custom_model_id_input,
103
  gr.Textbox(label="System prompt", lines=6),
104
  gr.Slider(
105
  label="Max new tokens",
 
151
  gr.Markdown(LICENSE)
152
 
153
  if __name__ == "__main__":
154
+ demo.queue(max_size=20).launch()