Samet Yilmaz
HF Token
14e2513
raw
history blame
2.95 kB
import os
from vllm import LLM, SamplingParams
import gradio as gr
from PIL import Image
from io import BytesIO
import base64
import requests
from huggingface_hub import login
import os
login(os.environ["HF_TOKEN"])
repo_id = "mistral-community/pixtral-12b-240910" #Replace to the model you would like to use
sampling_params = SamplingParams(max_tokens=8192, temperature=0.7)
max_tokens_per_img = 4096
max_img_per_msg = 5
def encode_image(image: Image.Image, image_format="PNG") -> str:
im_file = BytesIO()
image.save(im_file, format=image_format)
im_bytes = im_file.getvalue()
im_64 = base64.b64encode(im_bytes).decode("utf-8")
return im_64
# @spaces.GPU #[uncomment to use ZeroGPU]
def infer(image_url, prompt, progress=gr.Progress(track_tqdm=True)):
# tokenize image urls and text
llm = LLM(model="mistralai/Pixtral-12B-2409",
tokenizer_mode="mistral",
max_model_len=65536,
max_num_batched_tokens=max_img_per_msg * max_tokens_per_img,
limit_mm_per_prompt={"image": max_img_per_msg}) # Name or path of your model
image = Image.open(BytesIO(requests.get(image_url).content))
image = image.resize((3844, 2408))
new_image_url = f"data:image/png;base64,{encode_image(image, image_format='PNG')}"
messages = [
{
"role": "user",
"content": [{"type": "text", "text": prompt}, {"type": "image_url", "image_url": {"url": new_image_url}}]
},
]
outputs = llm.chat(messages, sampling_params=sampling_params)
print(outputs[0].outputs[0].text)
return outputs
example_images = ["https://picsum.photos/id/237/200/300"]
example_prompts = ["What do you see in this image?"]
css = """
#col-container {
margin: 0 auto;
max-width: 640px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(f"""
# Text-to-Image Gradio Template
""")
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=2,
placeholder="Enter your prompt",
container=False,
)
image_url = gr.Text(
label="Image URL",
show_label=False,
max_lines=1,
placeholder="Enter your image URL",
container=False,
)
run_button = gr.Button("Run", scale=0)
result = gr.Textbox(
show_label=False
)
gr.Examples(
examples=example_images,
inputs=[image_url]
)
gr.Examples(
examples=example_prompts,
inputs=[prompt]
)
gr.on(
triggers=[run_button.click, image_url.submit, prompt.submit],
fn=infer,
inputs=[image_url, prompt],
outputs=[result]
)
demo.queue().launch()