File size: 2,945 Bytes
14e2513
 
216b96d
 
6cae924
 
 
 
 
 
14e2513
 
 
 
 
216b96d
f9fa47c
 
 
216b96d
6cae924
 
 
 
 
 
 
 
216b96d
 
 
f9fa47c
 
6cae924
 
f9fa47c
6cae924
 
 
 
 
216b96d
 
 
6cae924
216b96d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os

from vllm import LLM, SamplingParams
import gradio as gr

from PIL import Image
from io import BytesIO
import base64
import requests

from huggingface_hub import login
import os

login(os.environ["HF_TOKEN"])

repo_id = "mistral-community/pixtral-12b-240910" #Replace to the model you would like to use
sampling_params = SamplingParams(max_tokens=8192, temperature=0.7)
max_tokens_per_img = 4096
max_img_per_msg = 5

def encode_image(image: Image.Image, image_format="PNG") -> str:
    im_file = BytesIO()
    image.save(im_file, format=image_format)
    im_bytes = im_file.getvalue()
    im_64 = base64.b64encode(im_bytes).decode("utf-8")
    return im_64


# @spaces.GPU #[uncomment to use ZeroGPU]
def infer(image_url, prompt, progress=gr.Progress(track_tqdm=True)):
    # tokenize image urls and text
    llm = LLM(model="mistralai/Pixtral-12B-2409",
              tokenizer_mode="mistral",
              max_model_len=65536,
              max_num_batched_tokens=max_img_per_msg * max_tokens_per_img,
              limit_mm_per_prompt={"image": max_img_per_msg})  # Name or path of your model

    image = Image.open(BytesIO(requests.get(image_url).content))
    image = image.resize((3844, 2408))
    new_image_url = f"data:image/png;base64,{encode_image(image, image_format='PNG')}"

    messages = [
        {
            "role": "user",
            "content": [{"type": "text", "text": prompt}, {"type": "image_url", "image_url": {"url": new_image_url}}]
        },
    ]

    outputs = llm.chat(messages, sampling_params=sampling_params)

    print(outputs[0].outputs[0].text)
    return outputs


example_images = ["https://picsum.photos/id/237/200/300"]
example_prompts = ["What do you see in this image?"]

css = """
#col-container {
    margin: 0 auto;
    max-width: 640px;
}
"""

with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown(f"""
        # Text-to-Image Gradio Template
        """)

        with gr.Row():
            prompt = gr.Text(
                label="Prompt",
                show_label=False,
                max_lines=2,
                placeholder="Enter your prompt",
                container=False,
            )

            image_url = gr.Text(
                label="Image URL",
                show_label=False,
                max_lines=1,
                placeholder="Enter your image URL",
                container=False,
            )

            run_button = gr.Button("Run", scale=0)

        result = gr.Textbox(
            show_label=False
        )

        gr.Examples(
            examples=example_images,
            inputs=[image_url]
        )

        gr.Examples(
            examples=example_prompts,
            inputs=[prompt]
        )
    gr.on(
        triggers=[run_button.click, image_url.submit, prompt.submit],
        fn=infer,
        inputs=[image_url, prompt],
        outputs=[result]
    )

demo.queue().launch()