Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
# Load the tokenizer and model | |
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x22B-v0.1") | |
model = AutoModelForCausalLM.from_pretrained("mistralai/Mixtral-8x22B-v0.1", device_map="auto") | |
# Function to generate text using the model | |
def generate_text(prompt, max_length=500, temperature=0.7, top_k=50, top_p=0.95, num_return_sequences=1): | |
input_ids = tokenizer.encode(prompt, return_tensors="pt") | |
output = model.generate( | |
input_ids, | |
max_length=max_length, | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
num_return_sequences=num_return_sequences, | |
) | |
generated_text = tokenizer.batch_decode(output, skip_special_tokens=True)[0] | |
return generated_text | |
# Create the Gradio interface | |
iface = gr.Interface( | |
fn=generate_text, | |
inputs=[ | |
gr.inputs.Textbox(lines=5, label="Input Prompt"), | |
gr.inputs.Slider(minimum=100, maximum=1000, default=500, step=50, label="Max Length"), | |
gr.inputs.Slider(minimum=0.1, maximum=1.0, default=0.7, step=0.1, label="Temperature"), | |
gr.inputs.Slider(minimum=1, maximum=100, default=50, step=1, label="Top K"), | |
gr.inputs.Slider(minimum=0.1, maximum=1.0, default=0.95, step=0.05, label="Top P"), | |
gr.inputs.Slider(minimum=1, maximum=10, default=1, step=1, label="Num Return Sequences"), | |
], | |
outputs=gr.outputs.Textbox(label="Generated Text"), | |
title="MixTRAL 8x22B Text Generation", | |
description="Use this interface to generate text using the MixTRAL 8x22B language model.", | |
) | |
# Launch the Gradio interface | |
iface.launch() |