File size: 1,677 Bytes
3643b73
f11657e
95b2a11
dc09589
0b7787a
d591ad9
3643b73
 
0b7787a
 
 
42db216
0b7787a
 
42db216
a1b3c51
0b7787a
d591ad9
0b7787a
 
 
8b8d45c
f11657e
8b8d45c
f11657e
 
8b8d45c
d591ad9
 
95b2a11
a3606dd
 
 
 
 
d591ad9
 
 
 
 
 
dc09589
d591ad9
da2c202
 
dc09589
 
d591ad9
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import os
import bitsandbytes as bnb
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

access_token = os.environ["GATED_ACCESS_TOKEN"]

quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype="float16",
)

model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", quantization_config=quantization_config, device_map="auto", token=access_token)
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")

# Load the tokenizer and model
#model_id = "mistralai/Mixtral-8x7B-v0.1"
#tokenizer = AutoTokenizer.from_pretrained(model_id, token=access_token)
#model = AutoModelForCausalLM.from_pretrained(model_id, token=access_token, load_in_4bit=True)
#model = AutoModelForCausalLM.from_pretrained(model_id, token=access_token)
# Initialize the quantizer
#quantizer = bnb.GemmQuantizer(act_bits=8, weight_bits=8)

# Quantize the model
#model = quantizer(model)

# Function to generate text using the model
def generate_text(prompt):
    text = prompt
    inputs = tokenizer(text, return_tensors="pt")
    
    outputs = model.generate(**inputs, max_new_tokens=20)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Create the Gradio interface
iface = gr.Interface(
    fn=generate_text,
    inputs=[
        gr.inputs.Textbox(lines=5, label="Input Prompt"),
    ],
    outputs=gr.outputs.Textbox(label="Generated Text"),
    title="MisTRAL Text Generation",
    description="Use this interface to generate text using the MisTRAL language model.",
)

# Launch the Gradio interface
iface.launch()