praysimanjuntak's picture
Update app.py
7c344d6 verified
raw
history blame
3.86 kB
import time
from threading import Thread
import gradio as gr
from io import BytesIO
import requests
from PIL import Image
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from llava.model.builder import load_pretrained_model
from llava.mm_utils import tokenizer_image_token
from transformers.generation.streamers import TextIteratorStreamer
import spaces
import subprocess
if torch.cuda.is_available():
device = "cuda"
print("Using GPU")
else:
device = "cpu"
print("Using CPU")
subprocess.run(["pip", "install", "."])
device = "cuda:0"
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path="./checkpoints/llava-phi3-3.8b-lora", model_name="llava-phi3-3.8b-lora", model_base="microsoft/Phi-3-mini-128k-instruct", load_8bit=False, load_4bit=False, device=device)
model.to(device)
def load_image(image_file):
if image_file.startswith('http') or image_file.startswith('https'):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert('RGB')
else:
image = Image.open(image_file).convert('RGB')
return image
@spaces.GPU
def bot_streaming(message, history):
print(message)
if message["files"]:
# message["files"][-1] is a Dict or just a string
if type(message["files"][-1]) == dict:
image = message["files"][-1]["path"]
else:
image = message["files"][-1]
else:
gr.Error("You need to upload an image for LLaVA to work.")
# print(f"prompt: {prompt}")
image_data = load_image(str(image))
image_tensor = image_processor.preprocess(image_data, return_tensors='pt')['pixel_values'].half().to(device)
# just one turn, always prepend image token
prompt = f"<|user|>{chr(10)}{DEFAULT_IMAGE_TOKEN + chr(10) + message['text']}<|end|>{chr(10)}<|assistant|>"
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(device)
streamer = TextIteratorStreamer(tokenizer, **{"skip_special_tokens": True, "skip_prompt": True, "timeout": 20.0})
thread = Thread(target=model.generate, kwargs=dict(
inputs=input_ids,
images=image_tensor,
do_sample=True,
temperature=0.2,
top_p=1.0,
max_new_tokens=1024,
streamer=streamer,
use_cache=True))
thread.start()
buffer = ""
time.sleep(0.5)
for new_text in streamer:
# find <|end|> and remove it from the new_text
if "<|end|>" in new_text:
new_text = new_text.split("<|end|>")[0]
buffer += new_text
# generated_text_without_prompt = buffer[len(text_prompt):]
generated_text_without_prompt = buffer
# print(generated_text_without_prompt)
time.sleep(0.06)
# print(f"new_text: {generated_text_without_prompt}")
yield generated_text_without_prompt
chatbot=gr.Chatbot(scale=1)
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
with gr.Blocks(fill_height=True, ) as demo:
gr.ChatInterface(
fn=bot_streaming,
title="LLaVA Phi 3 3.8B",
examples=[{"text": "What is the color of the cat?", "files": ["./FELV-cat.jpg"]},
{"text": "What is the type of the fish?", "files": ["./fish.jpg"]}],
description="Try [LLaVA Phi3-3.8B LoRA](https://huggingface.co/praysimanjuntak/llava-phi3-3.8b-lora). Upload an image and start chatting about it, or simply try one of the examples below. If you don't upload an image, you will receive an error.",
stop_btn="Stop Generation",
multimodal=True,
textbox=chat_input,
chatbot=chatbot,
)
demo.queue(api_open=False)
demo.launch(show_api=False, share=False)