File size: 4,137 Bytes
168da77
 
 
 
 
06de88f
bfd4b05
d364219
 
 
 
 
 
 
 
 
168da77
 
 
 
 
 
 
d364219
5659ce7
 
69eca47
5659ce7
daa8caf
 
 
 
 
5659ce7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d364219
cb6f09d
d364219
5659ce7
daa8caf
cb6f09d
daa8caf
 
 
 
 
 
 
 
 
 
153de5a
daa8caf
 
 
 
 
 
 
 
 
 
 
 
 
d364219
debb687
482857a
d364219
 
5659ce7
d364219
5659ce7
debb687
daa8caf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5659ce7
 
d364219
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch
import gradio as gr
from transformers import TextIteratorStreamer, AutoProcessor, LlavaForConditionalGeneration
from PIL import Image
import threading
import spaces
import accelerate

DESCRIPTION = '''
<div>
<h1 style="text-align: center;">Krypton πŸ•‹</h1>
<p>This uses an Open Source model from <a href="https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers"><b>xtuner/llava-llama-3-8b-v1_1-transformers</b></a></p>
</div>
'''

model_id = "xtuner/llava-llama-3-8b-v1_1-transformers"
model = LlavaForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True
).to('cuda')

processor = AutoProcessor.from_pretrained(model_id)

model.generation_config.eos_token_id = 128009

@spaces.GPU(duration=120)
def krypton(input,
            history,
            max_new_tokens,
            temperature,
            num_beams,
            do_sample: bool=True):
    """
    Recieves inputs (prompts with images if they were added),
    the image is formated for pil and prompt is formated for the model,
    to place it's output to the user, these prompts and images are passed in
    the processor and generation of the model, than the output is decoded from the processor,
    onto the UI.
    """
    if input["files"]:
        if type(input["files"][-1]) == dict:
            image = input["files"][-1]["path"]
        else:
            image = input["files"][-1]
    else:
        # If no images were passed now, look at the past images to keep up as reference still to the prompts
        # kept inside in tuples, the last one
        for hist in history:
            if type(hist[0]) == tuple:
                image = hist[0][0]
    try:
        if image is None:
            gr.Error("You need to upload an image please for krypton to work.")
    except NameError:
        # Image is not defined at all
        gr.Error("Uplaod an image for Krypton to work")

    image = Image.open(image)
    # image = Image.open(requests.get(url, stream=True).raw)
    prompt = ("<|start_header_id|>user<|end_header_id|>\n\n<image>\n{input['text']}<|eot_id|>"
              "<|start_header_id|>assistant<|end_header_id|>\n\n")
    inputs = processor(prompt, image, return_tensors='pt').to('cuda', torch.float16)

    # Streamer
    streamer = TextIteratorStreamer(processor, timeout=10.0, skip_prompt=True, skip_special_tokens=True)

    if temperature == 0.0:
        do_sample = False

    # Generation kwargs
    generation_kwargs = dict(
        inputs=inputs,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        num_beams=num_beams,
        do_sample=do_sample
    )

    thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)

chatbot=gr.Chatbot(height=600, label="Krypt AI")
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter your question or upload an image.", show_label=False)
with gr.Blocks(fill_height=True) as demo:
    gr.Markdown(DESCRIPTION)
    gr.ChatInterface(
        fn=krypton,
        chatbot=chatbot,
        fill_height=True,
        additional_inputs_accordion=gr.Accordion(label="βš™οΈ Parameters", open=False, render=False),
        additional_inputs=[
            gr.Slider(minimum=20,
                      maximum=80,
                      step=1,
                      value=50,
                      label="Max New Tokens",
                      render=False),
            gr.Slider(minimum=0.0,
                      maximum=1.0,
                      step=0.1,
                      value=0.7,
                      label="Temperature",
                      render=False),
            gr.Slider(minimum=1,
                      maximum=12,
                      step=1,
                      value=5,
                      label="Number of Beams",
                      render=False),
        ],
        multimodal=True,
        textbox=chat_input,
    )

if __name__ == "__main__":
    demo.launch()