humanLLaVa / app.py
ponytail's picture
Update app.py
96292e9 verified
raw
history blame contribute delete
No virus
2.79 kB
import gradio as gr
import spaces
from transformers import AutoProcessor, LlavaForConditionalGeneration
# from qwen_vl_utils import process_vision_info
import torch
from PIL import Image
import subprocess
from datetime import datetime
import numpy as np
import os
os.environ["no_proxy"] = "localhost,127.0.0.1,::1"
def array_to_image_path(image_array):
# Convert numpy array to PIL Image
img = Image.fromarray(np.uint8(image_array))
# Generate a unique filename using timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"image_{timestamp}.png"
# Save the image
img.save(filename)
# Get the full path of the saved image
full_path = os.path.abspath(filename)
return full_path
cuda = "cpu"
model_id = "huangfx1020/human_llama3_8b"
models = {
"HumanLlaVA-8B": LlavaForConditionalGeneration.from_pretrained("huangfx1020/human_llama3_8b", torch_dtype=torch.float16, low_cpu_mem_usage=True ).to(cuda).eval()
}
# processors = {
# "Qwen/Qwen2-VL-2B-Instruct": AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True)
# }
processors = {
"HumanLlaVA-8B": AutoProcessor.from_pretrained("huangfx1020/human_llama3_8b")
}
DESCRIPTION = "[HumanLlaVA Demo](https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct)"
kwargs = {}
kwargs['torch_dtype'] = torch.bfloat16
# @spaces.GPU
def run_example(image, text_input=None, model_id="HumanLlaVA-8B"):
image_path = array_to_image_path(image)
print(image_path)
model = models[model_id]
processor = processors[model_id]
raw_image = Image.open(image_path)
prompt = "USER: <image>\n" + text_input + "\nASSISTANT:"
inputs = processor(images=raw_image, text=prompt, return_tensors='pt').to(cuda, torch.float16)
output = model.generate(**inputs, max_new_tokens=400, do_sample=False)
print(output)
predict = processor.decode(output[0][:], skip_special_tokens=False)
print(predict)
return predict
css = """
#output {
height: 500px;
overflow: auto;
border: 1px solid #ccc;
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown(DESCRIPTION)
with gr.Tab(label="HumanLlaVA-8B Input"):
with gr.Row():
with gr.Column():
input_img = gr.Image(label="Input Picture")
model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value="HumanLlaVA-8B")
text_input = gr.Textbox(label="Question")
submit_btn = gr.Button(value="Submit")
with gr.Column():
output_text = gr.Textbox(label="Output Text")
submit_btn.click(run_example, [input_img, text_input, model_selector], [output_text])
demo.queue(api_open=False)
demo.launch(debug=True)