Krypton / app.py
sandz7's picture
added the model to cuda
d3f5533
raw
history blame
No virus
2.63 kB
import torch
import gradio as gr
from transformers import TextIteratorStreamer, AutoProcessor, LlavaForConditionalGeneration
from PIL import Image
import threading
import spaces
import accelerate
import time
DESCRIPTION = '''
<div>
<h1 style="text-align: center;">Krypton πŸ•‹</h1>
<p>This uses an Open Source model from <a href="https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers"><b>xtuner/llava-llama-3-8b-v1_1-transformers</b></a></p>
</div>
'''
model_id = "xtuner/llava-llama-3-8b-v1_1-transformers"
model = LlavaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True
)
model.to('cuda')
processor = AutoProcessor.from_pretrained(model_id)
# Confirming and setting the eos_token_id (if necessary)
model.generation_config.eos_token_id = processor.tokenizer.eos_token_id
@spaces.GPU(duration=120)
def krypton(input, history):
if input["files"]:
image = input["files"][-1]["path"] if isinstance(input["files"][-1], dict) else input["files"][-1]
else:
image = None
for hist in history:
if isinstance(hist[0], tuple):
image = hist[0][0]
if not image:
gr.Error("You need to upload an image for Krypton to work.")
return
prompt = f"user\n\n<image>\n{input['text']}\nassistant\n\n"
image = Image.open(image)
inputs = processor(prompt, images=image, return_tensors='pt').to(0, torch.float16)
# Streamer
streamer = TextIteratorStreamer(processor.tokenizer, skip_special_tokens=False, skip_prompt=True)
# Generation kwargs
generation_kwargs = dict(
inputs=inputs['input_ids'],
attention_mask=inputs['attention_mask'],
streamer=streamer,
max_new_tokens=1024,
do_sample=False
)
thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
buffer = ""
time.sleep(0.5)
for new_text in streamer:
buffer += new_text
generated_text_without_prompt = buffer
time.sleep(0.06)
yield generated_text_without_prompt
chatbot = gr.Chatbot(height=600, label="Krypt AI")
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter your question or upload an image.", show_label=False)
with gr.Blocks(fill_height=True) as demo:
gr.Markdown(DESCRIPTION)
gr.ChatInterface(
fn=krypton,
chatbot=chatbot,
fill_height=True,
multimodal=True,
textbox=chat_input,
)
demo.queue(api_open=False)
demo.launch(show_api=False, share=False)