import time from threading import Thread import subprocess subprocess.run(["pip", "install", "."]) import gradio as gr from io import BytesIO import requests from PIL import Image from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN from llava.model.builder import load_pretrained_model from llava.mm_utils import tokenizer_image_token from transformers.generation.streamers import TextIteratorStreamer import spaces device = "cuda:0" tokenizer, model, image_processor, context_len = load_pretrained_model(model_path="./checkpoints/llava-phi3-3.8b-lora", model_name="llava-phi3-3.8b-lora", model_base="microsoft/Phi-3-mini-128k-instruct", load_8bit=False, load_4bit=False, device=device) model.to(device) def load_image(image_file): if image_file.startswith('http') or image_file.startswith('https'): response = requests.get(image_file) image = Image.open(BytesIO(response.content)).convert('RGB') else: image = Image.open(image_file).convert('RGB') return image @spaces.GPU def bot_streaming(message, history): if message["files"]: # message["files"][-1] is a Dict or just a string if type(message["files"][-1]) == dict: image = message["files"][-1]["path"] else: image = message["files"][-1] else: gr.Error("You need to upload an image for LLaVA to work.") # print(f"prompt: {prompt}") image_data = load_image(str(image)) image_tensor = image_processor.preprocess(image_data, return_tensors='pt')['pixel_values'].half().to(device) # just one turn, always prepend image token prompt = f"<|user|>{chr(10)}{DEFAULT_IMAGE_TOKEN + chr(10) + message['text']}<|end|>{chr(10)}<|assistant|>" input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(device) streamer = TextIteratorStreamer(tokenizer, **{"skip_special_tokens": True, "skip_prompt": True, "timeout": 20.0}) thread = Thread(target=model.generate, kwargs=dict( inputs=input_ids, images=image_tensor, do_sample=True, temperature=0.2, top_p=1.0, max_new_tokens=1024, streamer=streamer, use_cache=True)) thread.start() buffer = "" time.sleep(0.5) for new_text in streamer: # find <|end|> and remove it from the new_text if "<|end|>" in new_text: new_text = new_text.split("<|end|>")[0] buffer += new_text # generated_text_without_prompt = buffer[len(text_prompt):] generated_text_without_prompt = buffer # print(generated_text_without_prompt) time.sleep(0.06) # print(f"new_text: {generated_text_without_prompt}") yield generated_text_without_prompt chatbot=gr.Chatbot(scale=1) chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False) with gr.Blocks(fill_height=True, ) as demo: gr.ChatInterface( fn=bot_streaming, title="LLaVA Phi 3 3.8B", examples=[{"text": "What is the color of the cat?", "files": ["./FELV-cat.jpg"]}, {"text": "What is the type of the fish?", "files": ["./fish.jpg"]}], description="Try [LLaVA Phi3-3.8B LoRA](https://huggingface.co/praysimanjuntak/llava-phi3-3.8b-lora). Upload an image and start chatting about it, or simply try one of the examples below. If you don't upload an image, you will receive an error.", stop_btn="Stop Generation", multimodal=True, textbox=chat_input, chatbot=chatbot, ) demo.queue(api_open=False) demo.launch(show_api=False, share=False)