Detsutut commited on
Commit
82aa38e
1 Parent(s): 5ed1df0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -22
app.py CHANGED
@@ -1,45 +1,75 @@
1
  import gradio as gr
2
  import transformers
3
  import torch
4
-
5
- import os
6
-
7
- hf_key = os.getenv("HF_TOKEN")
8
 
9
  # Initialize the model
10
- model_id = "bmi-labmedinfo/Igea-350M-v0.0.1"
11
  pipeline = transformers.pipeline(
12
  "text-generation",
13
  model=model_id,
14
  model_kwargs={"torch_dtype": torch.bfloat16},
15
- token=hf_key
16
  )
17
 
18
  # Define the function to generate text
19
- def generate_text(input_text, max_new_tokens, temperature, top_k, top_p):
20
  output = pipeline(
21
  input_text,
22
  max_new_tokens=max_new_tokens,
23
  temperature=temperature,
24
- top_k=top_k,
25
  top_p=top_p,
26
  )
27
- return output[0]['generated_text']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  # Create the Gradio interface
30
- iface = gr.Interface(
31
- fn=generate_text,
32
- inputs=[
33
- gr.Textbox(lines=2, placeholder="Enter your text here...", label="Input Text"),
34
- gr.Slider(minimum=1, maximum=200, value=128, step=1, label="Max New Tokens"),
35
- gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature"),
36
- gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-k"),
37
- gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.01, label="Top-p")
38
- ],
39
- outputs="text",
40
- title="Text Generation Interface",
41
- description="Enter a prompt to generate text using the Igea-350M model and adjust the hyperparameters."
42
- )
 
 
 
 
 
 
 
 
 
 
43
 
44
  # Launch the interface
45
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  import transformers
3
  import torch
4
+ import re
 
 
 
5
 
6
  # Initialize the model
7
+ model_id = "Detsutut/Igea-350M-v0.0.1"
8
  pipeline = transformers.pipeline(
9
  "text-generation",
10
  model=model_id,
11
  model_kwargs={"torch_dtype": torch.bfloat16},
12
+ device_map="auto",
13
  )
14
 
15
  # Define the function to generate text
16
+ def generate_text(input_text, max_new_tokens, temperature, top_k, top_p, split_output):
17
  output = pipeline(
18
  input_text,
19
  max_new_tokens=max_new_tokens,
20
  temperature=temperature,
 
21
  top_p=top_p,
22
  )
23
+ generated_text = output[0]['generated_text']
24
+ if split_output:
25
+ sentences = re.split('(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s)', generated_text)
26
+ if sentences:
27
+ return sentences[0] + '.'
28
+ return generated_text
29
+
30
+ # JavaScript to dynamically enable/disable sliders based on the checkbox state
31
+ js_code = """
32
+ () => {
33
+ const checkbox = document.querySelector('input[type="checkbox"]');
34
+ const sliders = document.querySelectorAll('input[type="range"]');
35
+ checkbox.addEventListener('change', () => {
36
+ sliders.forEach(slider => {
37
+ slider.disabled = checkbox.checked;
38
+ });
39
+ });
40
+ if (checkbox.checked) {
41
+ sliders.forEach(slider => {
42
+ slider.disabled = true;
43
+ });
44
+ }
45
+ }
46
+ """
47
+
48
 
49
  # Create the Gradio interface
50
+ input_text = gr.Textbox(lines=2, placeholder="Enter your text here...", label="Input Text")
51
+
52
+ max_new_tokens = gr.Slider(minimum=1, maximum=200, value=30, step=1, label="Max New Tokens")
53
+ temperature = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature")
54
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.01, label="Top-p")
55
+ split_output = gr.Checkbox(label="Quick single-sentence output", value=True)
56
+
57
+ with gr.Blocks() as iface:
58
+ gr.Markdown("# Igea Text Generation Interface")
59
+ gr.Markdown("Enter a prompt to generate text using the **Igea-350M** model and adjust the hyperparameters.")
60
+ input_text.render()
61
+ with gr.Accordion("Advanced Options", open=False):
62
+ max_new_tokens.render()
63
+ temperature.render()
64
+ top_p.render()
65
+ split_output.render()
66
+ output = gr.Textbox(label="Generated Text")
67
+
68
+ btn = gr.Button("Generate")
69
+ btn.click(generate_text, [input_text, max_new_tokens, temperature, top_p, split_output], output)
70
+
71
+ # Add custom JavaScript
72
+ iface.load(js_code)
73
 
74
  # Launch the interface
75
  if __name__ == "__main__":