AnishHF commited on
Commit
95b2a11
1 Parent(s): a51f824

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -8
app.py CHANGED
@@ -1,16 +1,17 @@
1
  import os
 
2
  import gradio as gr
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
 
5
  access_token = os.environ["GATED_ACCESS_TOKEN"]
6
 
7
  # Load the tokenizer and model
8
- model_id = "mistralai/Mixtral-8x22B-v0.1"
9
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=access_token)
10
- model = AutoModelForCausalLM.from_pretrained(model_id, token=access_token)
11
 
12
  # Function to generate text using the model
13
- def generate_text(prompt, max_length=500, temperature=0.7, top_k=50, top_p=0.95, num_return_sequences=1):
14
  text = prompt
15
  inputs = tokenizer(text, return_tensors="pt")
16
 
@@ -22,11 +23,6 @@ iface = gr.Interface(
22
  fn=generate_text,
23
  inputs=[
24
  gr.inputs.Textbox(lines=5, label="Input Prompt"),
25
- gr.inputs.Slider(minimum=100, maximum=1000, default=500, step=50, label="Max Length"),
26
- gr.inputs.Slider(minimum=0.1, maximum=1.0, default=0.7, step=0.1, label="Temperature"),
27
- gr.inputs.Slider(minimum=1, maximum=100, default=50, step=1, label="Top K"),
28
- gr.inputs.Slider(minimum=0.1, maximum=1.0, default=0.95, step=0.05, label="Top P"),
29
- gr.inputs.Slider(minimum=1, maximum=10, default=1, step=1, label="Num Return Sequences"),
30
  ],
31
  outputs=gr.outputs.Textbox(label="Generated Text"),
32
  title="MixTRAL 8x22B Text Generation",
 
1
  import os
2
+ import torch
3
  import gradio as gr
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
 
6
  access_token = os.environ["GATED_ACCESS_TOKEN"]
7
 
8
  # Load the tokenizer and model
9
+ model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
10
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=access_token)
11
+ model = AutoModelForCausalLM.from_pretrained(model_id, load_in_4bit=True, device_map="auto", token=access_token)
12
 
13
  # Function to generate text using the model
14
+ def generate_text(prompt):
15
  text = prompt
16
  inputs = tokenizer(text, return_tensors="pt")
17
 
 
23
  fn=generate_text,
24
  inputs=[
25
  gr.inputs.Textbox(lines=5, label="Input Prompt"),
 
 
 
 
 
26
  ],
27
  outputs=gr.outputs.Textbox(label="Generated Text"),
28
  title="MixTRAL 8x22B Text Generation",