import gradio as gr import torch import re from decord import VideoReader, cpu from PIL import Image import numpy as np import transformers from typing import Dict, Optional, Sequence, List import subprocess subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) import sys # sys.path.append('/mnt/lzy/oryx-demo') from oryx.conversation import conv_templates, SeparatorStyle from oryx.model.builder import load_pretrained_model from oryx.utils import disable_torch_init from oryx.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, process_anyres_video_genli from oryx.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX model_path = "THUdyh/Oryx-7B" model_name = get_model_name_from_path(model_path) overwrite_config = {} overwrite_config["mm_resampler_type"] = "dynamic_compressor" overwrite_config["patchify_video_feature"] = False overwrite_config["attn_implementation"] = "sdpa" if torch.__version__ >= "2.1.2" else "eager" tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, device_map="cuda:0", overwrite_config=overwrite_config) model.to('cuda').eval() def preprocess_qwen(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False, max_len=2048, system_message: str = "You are a helpful assistant.") -> Dict: roles = {"human": "<|im_start|>user", "gpt": "<|im_start|>assistant"} im_start, im_end = tokenizer.additional_special_tokens_ids nl_tokens = tokenizer("\n").input_ids _system = tokenizer("system").input_ids + nl_tokens _user = tokenizer("user").input_ids + nl_tokens _assistant = tokenizer("assistant").input_ids + nl_tokens # Apply prompt templates input_ids, targets = [], [] source = sources if roles[source[0]["from"]] != roles["human"]: source = source[1:] input_id, target = [], [] system = [im_start] + _system + tokenizer(system_message).input_ids + [im_end] + nl_tokens input_id += system target += [im_start] + [IGNORE_INDEX] * (len(system) - 3) + [im_end] + nl_tokens assert len(input_id) == len(target) for j, sentence in enumerate(source): role = roles[sentence["from"]] if has_image and sentence["value"] is not None and "" in sentence["value"]: num_image = len(re.findall(DEFAULT_IMAGE_TOKEN, sentence["value"])) texts = sentence["value"].split('') _input_id = tokenizer(role).input_ids + nl_tokens for i,text in enumerate(texts): _input_id += tokenizer(text).input_ids if iuser": _target = [im_start] + [IGNORE_INDEX] * (len(_input_id) - 3) + [im_end] + nl_tokens elif role == "<|im_start|>assistant": _target = [im_start] + [IGNORE_INDEX] * len(tokenizer(role).input_ids) + _input_id[len(tokenizer(role).input_ids) + 1 : -2] + [im_end] + nl_tokens else: raise NotImplementedError target += _target input_ids.append(input_id) targets.append(target) input_ids = torch.tensor(input_ids, dtype=torch.long) targets = torch.tensor(targets, dtype=torch.long) return input_ids def oryx_inference(video, text): vr = VideoReader(video, ctx=cpu(0)) total_frame_num = len(vr) fps = round(vr.get_avg_fps()) uniform_sampled_frames = np.linspace(0, total_frame_num - 1, 64, dtype=int) frame_idx = uniform_sampled_frames.tolist() spare_frames = vr.get_batch(frame_idx).asnumpy() video = [Image.fromarray(frame) for frame in spare_frames] conv_mode = "qwen_1_5" question = text question = "\n" + question conv = conv_templates[conv_mode].copy() conv.append_message(conv.roles[0], question) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() input_ids = preprocess_qwen([{'from': 'human','value': question},{'from': 'gpt','value': None}], tokenizer, has_image=True).cuda() video_processed = [] for idx, frame in enumerate(video): image_processor.do_resize = False image_processor.do_center_crop = False frame = process_anyres_video_genli(frame, image_processor) if frame_idx is not None and idx in frame_idx: video_processed.append(frame.unsqueeze(0)) elif frame_idx is None: video_processed.append(frame.unsqueeze(0)) if frame_idx is None: frame_idx = np.arange(0, len(video_processed), dtype=int).tolist() video_processed = torch.cat(video_processed, dim=0).bfloat16().cuda() video_processed = (video_processed, video_processed) video_data = (video_processed, (384, 384), "video") stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 keywords = [stop_str] with torch.inference_mode(): output_ids = model.generate( inputs=input_ids, images=video_data[0][0], images_highres=video_data[0][1], modalities=video_data[2], do_sample=False, temperature=0, max_new_tokens=1024, use_cache=True, ) outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] outputs = outputs.strip() if outputs.endswith(stop_str): outputs = outputs[:-len(stop_str)] outputs = outputs.strip() return outputs # Define input and output for the Gradio interface demo = gr.Interface( fn=oryx_inference, inputs=[gr.Video(label="Input Video"), gr.Textbox(label="Input Text")], outputs="text", title="Oryx Inference", description="This is a demo for Oryx inference." ) # Launch the Gradio app demo.launch(server_name="0.0.0.0",server_port=80)