File size: 3,114 Bytes
08a6c8d
5f6343c
08a6c8d
f7151f4
08a6c8d
 
 
 
f7151f4
08a6c8d
e67e492
08a6c8d
e67e492
08a6c8d
 
 
 
 
 
 
 
f7151f4
 
 
e67e492
08a6c8d
 
 
 
f7151f4
08a6c8d
 
e325f49
659f477
e325f49
 
 
 
 
 
a2f5d42
e325f49
 
 
a2f5d42
e325f49
a2f5d42
e325f49
 
2ca3c06
e325f49
 
659f477
 
 
08a6c8d
 
 
f7151f4
659f477
 
 
 
 
 
f7151f4
08a6c8d
f7151f4
08a6c8d
 
bca5e76
08a6c8d
659f477
1ae1376
08a6c8d
f7151f4
08a6c8d
 
f7151f4
08a6c8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7151f4
 
 
08a6c8d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from threading import Thread
from typing import Dict

import gradio as gr
import spaces
import torch
from PIL import Image
from transformers import AutoModelForVision2Seq, AutoProcessor, AutoTokenizer, TextIteratorStreamer


TITLE = "<h1><center>Chat with PaliGemma-3B-Chat-v0.2</center></h1>"

DESCRIPTION = "<h3><center>Visit <a href='https://huggingface.co/BUAADreamer/PaliGemma-3B-Chat-v0.2' target='_blank'>our model page</a> for details.</center></h3>"

CSS = """
.duplicate-button {
  margin: auto !important;
  color: white !important;
  background: black !important;
  border-radius: 100vh !important;
}
"""


model_id = "BUAADreamer/PaliGemma-3B-Chat-v0.2"
tokenizer = AutoTokenizer.from_pretrained(model_id)
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForVision2Seq.from_pretrained(model_id, torch_dtype="auto", device_map="auto")


@spaces.GPU
def stream_chat(message: Dict[str, str], history: list):
    # Turn 1:
    # {'text': 'what is this', 'files': ['image-xxx.jpg']}
    # []

    # Turn 2:
    # {'text': 'continue?', 'files': []}
    # [[('image-xxx.jpg',), None], ['what is this', 'a image.']]

    image_path = None
    if len(message["files"]) != 0:
        image_path = message["files"][0]

    if len(history) != 0 and isinstance(history[0][0], tuple):
        image_path = history[0][0][0]
        history = history[1:]

    if image_path is not None:
        image = Image.open(image_path).convert("RGB")
    else:
        image = Image.new("RGB", (100, 100), (255, 255, 255))

    pixel_values = processor(images=[image], return_tensors="pt").to(model.device)["pixel_values"]

    conversation = []
    for prompt, answer in history:
        conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])

    conversation.append({"role": "user", "content": message["text"]})

    input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
    image_token_id = tokenizer.convert_tokens_to_ids("<image>")
    image_prefix = torch.empty((1, getattr(processor, "image_seq_length")), dtype=input_ids.dtype).fill_(image_token_id)
    input_ids = torch.cat((image_prefix, input_ids), dim=-1).to(model.device)

    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)

    generate_kwargs = dict(
        input_ids=input_ids,
        pixel_values=pixel_values,
        streamer=streamer,
        max_new_tokens=256,
        do_sample=True,
    )

    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    output = ""
    for new_token in streamer:
        output += new_token
        yield output


chatbot = gr.Chatbot(height=450)

with gr.Blocks(css=CSS) as demo:
    gr.HTML(TITLE)
    gr.HTML(DESCRIPTION)
    gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
    gr.ChatInterface(
        fn=stream_chat,
        multimodal=True,
        chatbot=chatbot,
        fill_height=True,
        cache_examples=False,
    )


if __name__ == "__main__":
    demo.launch()