File size: 3,719 Bytes
15b0eda
 
 
b61da51
 
 
15b0eda
 
 
 
 
 
 
 
d1567f0
15b0eda
d3e4926
 
 
 
15b0eda
 
 
 
 
 
 
 
d1567f0
15b0eda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import time
from threading import Thread

import subprocess
subprocess.run(["pip", "install", "."])

import gradio as gr
from io import BytesIO
import requests
from PIL import Image
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from llava.model.builder import load_pretrained_model
from llava.mm_utils import tokenizer_image_token
from transformers.generation.streamers import TextIteratorStreamer
import spaces

device = "cuda:0"
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path="./checkpoints/llava-phi3-3.8b-lora", model_name="llava-phi3-3.8b-lora", model_base="microsoft/Phi-3-mini-128k-instruct", load_8bit=False, load_4bit=False, device=device)
model.to(device)

def load_image(image_file):
    if image_file.startswith('http') or image_file.startswith('https'):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert('RGB')
    else:
        image = Image.open(image_file).convert('RGB')
    return image

@spaces.GPU
def bot_streaming(message, history):
    if message["files"]:
        # message["files"][-1] is a Dict or just a string
        if type(message["files"][-1]) == dict:
            image = message["files"][-1]["path"]
        else:
            image = message["files"][-1]
    else:
        gr.Error("You need to upload an image for LLaVA to work.")

    # print(f"prompt: {prompt}")
    image_data = load_image(str(image))
    image_tensor = image_processor.preprocess(image_data, return_tensors='pt')['pixel_values'].half().to(device)

    # just one turn, always prepend image token

    prompt = f"<|user|>{chr(10)}{DEFAULT_IMAGE_TOKEN + chr(10) + message['text']}<|end|>{chr(10)}<|assistant|>"
    input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(device)

    streamer = TextIteratorStreamer(tokenizer, **{"skip_special_tokens": True, "skip_prompt": True, "timeout": 20.0})

    thread = Thread(target=model.generate, kwargs=dict(
                inputs=input_ids,
                images=image_tensor,
                do_sample=True,
                temperature=0.2,
                top_p=1.0,
                max_new_tokens=1024,
                streamer=streamer,
                use_cache=True))
    thread.start()

    buffer = ""
    time.sleep(0.5)
    for new_text in streamer:
        # find <|end|> and remove it from the new_text
        if "<|end|>" in new_text:
            new_text = new_text.split("<|end|>")[0]
        buffer += new_text

        # generated_text_without_prompt = buffer[len(text_prompt):]
        generated_text_without_prompt = buffer
        # print(generated_text_without_prompt)
        time.sleep(0.06)
        # print(f"new_text: {generated_text_without_prompt}")
        yield generated_text_without_prompt


chatbot=gr.Chatbot(scale=1)
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
with gr.Blocks(fill_height=True, ) as demo:
    gr.ChatInterface(
    fn=bot_streaming,
    title="LLaVA Phi 3 3.8B",
    examples=[{"text": "What is the color of the cat?", "files": ["./FELV-cat.jpg"]},
              {"text": "What is the type of the fish?", "files": ["./fish.jpg"]}],
    description="Try [LLaVA Phi3-3.8B LoRA](https://huggingface.co/praysimanjuntak/llava-phi3-3.8b-lora). Upload an image and start chatting about it, or simply try one of the examples below. If you don't upload an image, you will receive an error.",
    stop_btn="Stop Generation",
    multimodal=True,
    textbox=chat_input,
    chatbot=chatbot,
    )

demo.queue(api_open=False)
demo.launch(show_api=False, share=False)