import gradio as gr import torch import re import os from decord import VideoReader, cpu from PIL import Image import numpy as np import transformers import spaces from typing import Dict, Optional, Sequence, List import subprocess subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) import sys from oryx.conversation import conv_templates, SeparatorStyle from oryx.model.builder import load_pretrained_model from oryx.utils import disable_torch_init from oryx.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, process_anyres_video_genli,process_anyres_highres_image_genli from oryx.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX model_path = "THUdyh/Oryx-7B" model_name = get_model_name_from_path(model_path) overwrite_config = {} overwrite_config["mm_resampler_type"] = "dynamic_compressor" overwrite_config["patchify_video_feature"] = False overwrite_config["attn_implementation"] = "sdpa" if torch.__version__ >= "2.1.2" else "eager" tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, device_map="cpu", overwrite_config=overwrite_config) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.to(device).eval() cur_dir = os.path.dirname(os.path.abspath(__file__)) title_markdown = """
Oryx

Oryx MLLM: On-Demand Spatial-Temporal Understanding at Arbitrary Resolution

Project Page | Github | Huggingface | Paper | Twitter
""" bibtext = """ ### Citation ``` @article{liu2024oryx, title={Oryx MLLM: On-Demand Spatial-Temporal Understanding at Arbitrary Resolution}, author={Liu, Zuyan and Dong, Yuhao and Liu, Ziwei and Hu, Winston and Lu, Jiwen and Rao, Yongming}, journal={arXiv preprint arXiv:2409.12961}, year={2024} } ``` """ def preprocess_qwen(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False, max_len=2048, system_message: str = "You are a helpful assistant.") -> Dict: roles = {"human": "<|im_start|>user", "gpt": "<|im_start|>assistant"} im_start, im_end = tokenizer.additional_special_tokens_ids nl_tokens = tokenizer("\n").input_ids _system = tokenizer("system").input_ids + nl_tokens _user = tokenizer("user").input_ids + nl_tokens _assistant = tokenizer("assistant").input_ids + nl_tokens # Apply prompt templates input_ids, targets = [], [] source = sources if roles[source[0]["from"]] != roles["human"]: source = source[1:] input_id, target = [], [] system = [im_start] + _system + tokenizer(system_message).input_ids + [im_end] + nl_tokens input_id += system target += [im_start] + [IGNORE_INDEX] * (len(system) - 3) + [im_end] + nl_tokens assert len(input_id) == len(target) for j, sentence in enumerate(source): role = roles[sentence["from"]] if has_image and sentence["value"] is not None and "" in sentence["value"]: num_image = len(re.findall(DEFAULT_IMAGE_TOKEN, sentence["value"])) texts = sentence["value"].split('') _input_id = tokenizer(role).input_ids + nl_tokens for i,text in enumerate(texts): _input_id += tokenizer(text).input_ids if iuser": _target = [im_start] + [IGNORE_INDEX] * (len(_input_id) - 3) + [im_end] + nl_tokens elif role == "<|im_start|>assistant": _target = [im_start] + [IGNORE_INDEX] * len(tokenizer(role).input_ids) + _input_id[len(tokenizer(role).input_ids) + 1 : -2] + [im_end] + nl_tokens else: raise NotImplementedError target += _target input_ids.append(input_id) targets.append(target) input_ids = torch.tensor(input_ids, dtype=torch.long) targets = torch.tensor(targets, dtype=torch.long) return input_ids @spaces.GPU(duration=120) def oryx_inference(multimodal): visual, text = multimodal["files"][0], multimodal["text"] if visual.endswith("case/image2.png"): modality = "video" visual = f"{cur_dir}/case/case1.mp4" if visual.endswith(".mp4"): modality = "video" else: modality = "image" if modality == "video": vr = VideoReader(visual, ctx=cpu(0)) total_frame_num = len(vr) fps = round(vr.get_avg_fps()) uniform_sampled_frames = np.linspace(0, total_frame_num - 1, 64, dtype=int) frame_idx = uniform_sampled_frames.tolist() spare_frames = vr.get_batch(frame_idx).asnumpy() video = [Image.fromarray(frame) for frame in spare_frames] else: image = [Image.open(visual)] image_sizes = [image[0].size] conv_mode = "qwen_1_5" question = text question = "\n" + question conv = conv_templates[conv_mode].copy() conv.append_message(conv.roles[0], question) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() input_ids = preprocess_qwen([{'from': 'human','value': question},{'from': 'gpt','value': None}], tokenizer, has_image=True).to(device) if modality == "video": video_processed = [] for idx, frame in enumerate(video): image_processor.do_resize = False image_processor.do_center_crop = False frame = process_anyres_video_genli(frame, image_processor) if frame_idx is not None and idx in frame_idx: video_processed.append(frame.unsqueeze(0)) elif frame_idx is None: video_processed.append(frame.unsqueeze(0)) if frame_idx is None: frame_idx = np.arange(0, len(video_processed), dtype=int).tolist() video_processed = torch.cat(video_processed, dim=0).bfloat16().to(device) video_processed = (video_processed, video_processed) video_data = (video_processed, (384, 384), "video") else: image_processor.do_resize = False image_processor.do_center_crop = False image_tensor, image_highres_tensor = [], [] for visual in image: image_tensor_, image_highres_tensor_ = process_anyres_highres_image_genli(visual, image_processor) image_tensor.append(image_tensor_) image_highres_tensor.append(image_highres_tensor_) if all(x.shape == image_tensor[0].shape for x in image_tensor): image_tensor = torch.stack(image_tensor, dim=0) if all(x.shape == image_highres_tensor[0].shape for x in image_highres_tensor): image_highres_tensor = torch.stack(image_highres_tensor, dim=0) if type(image_tensor) is list: image_tensor = [_image.bfloat16().to(device) for _image in image_tensor] else: image_tensor = image_tensor.bfloat16().to(device) if type(image_highres_tensor) is list: image_highres_tensor = [_image.bfloat16().to(device) for _image in image_highres_tensor] else: image_highres_tensor = image_highres_tensor.bfloat16().to(device) stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 keywords = [stop_str] with torch.inference_mode(): if modality == "video": output_ids = model.generate( inputs=input_ids, images=video_data[0][0], images_highres=video_data[0][1], modalities=video_data[2], do_sample=False, temperature=0, max_new_tokens=1024, use_cache=True, ) else: output_ids = model.generate( inputs=input_ids, images=image_tensor, images_highres=image_highres_tensor, image_sizes=image_sizes, modalities=['image'], do_sample=False, temperature=0, max_new_tokens=1024, use_cache=True, ) outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] outputs = outputs.strip() if outputs.endswith(stop_str): outputs = outputs[:-len(stop_str)] outputs = outputs.strip() return outputs # Define input and output for the Gradio interface demo = gr.Interface( fn=oryx_inference, inputs=gr.MultimodalTextbox(file_types=[".mp4", "image"],placeholder="Enter message or upload file..."), outputs="text", examples=[ { "files":[f"{cur_dir}/case/image2.png"], "text":"Describe what is happening in this video in detail.", }, { "files":[f"{cur_dir}/case/image.png"], "text":"Describe this icon.", }, ], title="Oryx-7B Demo", description=title_markdown, article=bibtext, ) # Launch the Gradio app demo.launch()