Karzan's picture
Update app.py
e4ccddf verified
raw
history blame
No virus
870 Bytes
from transformers import pipeline,GemmaForCausalLM,AutoTokenizer,BitsAndBytesConfig
import gradio as gr
import spaces
import torch
# ignore_mismatched_sizes=True
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
tokenizer = AutoTokenizer.from_pretrained('google/gemma-2-9b')
model = GemmaForCausalLM.from_pretrained('google/gemma-2-9b',
quantization_config=quantization_config
)
# pipe = pipeline('text-generation', model=model,tokenizer = tokenizer)
@spaces.GPU(duration=120)
def generate(prompt):
input_ids = tokenizer(prompt, return_tensors="pt").to("cuda")
outputs = model.generate(**input_ids)
return tokenizer.decode(outputs[0]);
# return pipe(prompt)[0]['generated_text']
gr.Interface(
fn=generate,
inputs=gr.Text(),
outputs="text",
).launch()