Oryx / app.py
THUdyh's picture
Upload app.py
40fcdeb verified
raw
history blame
No virus
6.35 kB
import gradio as gr
import torch
import re
from decord import VideoReader, cpu
from PIL import Image
import numpy as np
import transformers
from typing import Dict, Optional, Sequence, List
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
import sys
# sys.path.append('/mnt/lzy/oryx-demo')
from oryx.conversation import conv_templates, SeparatorStyle
from oryx.model.builder import load_pretrained_model
from oryx.utils import disable_torch_init
from oryx.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, process_anyres_video_genli
from oryx.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
model_path = "THUdyh/Oryx-7B"
model_name = get_model_name_from_path(model_path)
overwrite_config = {}
overwrite_config["mm_resampler_type"] = "dynamic_compressor"
overwrite_config["patchify_video_feature"] = False
overwrite_config["attn_implementation"] = "sdpa" if torch.__version__ >= "2.1.2" else "eager"
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, device_map="cuda:0", overwrite_config=overwrite_config)
model.to('cuda').eval()
def preprocess_qwen(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False, max_len=2048, system_message: str = "You are a helpful assistant.") -> Dict:
roles = {"human": "<|im_start|>user", "gpt": "<|im_start|>assistant"}
im_start, im_end = tokenizer.additional_special_tokens_ids
nl_tokens = tokenizer("\n").input_ids
_system = tokenizer("system").input_ids + nl_tokens
_user = tokenizer("user").input_ids + nl_tokens
_assistant = tokenizer("assistant").input_ids + nl_tokens
# Apply prompt templates
input_ids, targets = [], []
source = sources
if roles[source[0]["from"]] != roles["human"]:
source = source[1:]
input_id, target = [], []
system = [im_start] + _system + tokenizer(system_message).input_ids + [im_end] + nl_tokens
input_id += system
target += [im_start] + [IGNORE_INDEX] * (len(system) - 3) + [im_end] + nl_tokens
assert len(input_id) == len(target)
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
if has_image and sentence["value"] is not None and "<image>" in sentence["value"]:
num_image = len(re.findall(DEFAULT_IMAGE_TOKEN, sentence["value"]))
texts = sentence["value"].split('<image>')
_input_id = tokenizer(role).input_ids + nl_tokens
for i,text in enumerate(texts):
_input_id += tokenizer(text).input_ids
if i<len(texts)-1:
_input_id += [IMAGE_TOKEN_INDEX] + nl_tokens
_input_id += [im_end] + nl_tokens
assert sum([i==IMAGE_TOKEN_INDEX for i in _input_id])==num_image
else:
if sentence["value"] is None:
_input_id = tokenizer(role).input_ids + nl_tokens
else:
_input_id = tokenizer(role).input_ids + nl_tokens + tokenizer(sentence["value"]).input_ids + [im_end] + nl_tokens
input_id += _input_id
if role == "<|im_start|>user":
_target = [im_start] + [IGNORE_INDEX] * (len(_input_id) - 3) + [im_end] + nl_tokens
elif role == "<|im_start|>assistant":
_target = [im_start] + [IGNORE_INDEX] * len(tokenizer(role).input_ids) + _input_id[len(tokenizer(role).input_ids) + 1 : -2] + [im_end] + nl_tokens
else:
raise NotImplementedError
target += _target
input_ids.append(input_id)
targets.append(target)
input_ids = torch.tensor(input_ids, dtype=torch.long)
targets = torch.tensor(targets, dtype=torch.long)
return input_ids
def oryx_inference(video, text):
vr = VideoReader(video, ctx=cpu(0))
total_frame_num = len(vr)
fps = round(vr.get_avg_fps())
uniform_sampled_frames = np.linspace(0, total_frame_num - 1, 64, dtype=int)
frame_idx = uniform_sampled_frames.tolist()
spare_frames = vr.get_batch(frame_idx).asnumpy()
video = [Image.fromarray(frame) for frame in spare_frames]
conv_mode = "qwen_1_5"
question = text
question = "<image>\n" + question
conv = conv_templates[conv_mode].copy()
conv.append_message(conv.roles[0], question)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = preprocess_qwen([{'from': 'human','value': question},{'from': 'gpt','value': None}], tokenizer, has_image=True).cuda()
video_processed = []
for idx, frame in enumerate(video):
image_processor.do_resize = False
image_processor.do_center_crop = False
frame = process_anyres_video_genli(frame, image_processor)
if frame_idx is not None and idx in frame_idx:
video_processed.append(frame.unsqueeze(0))
elif frame_idx is None:
video_processed.append(frame.unsqueeze(0))
if frame_idx is None:
frame_idx = np.arange(0, len(video_processed), dtype=int).tolist()
video_processed = torch.cat(video_processed, dim=0).bfloat16().cuda()
video_processed = (video_processed, video_processed)
video_data = (video_processed, (384, 384), "video")
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
with torch.inference_mode():
output_ids = model.generate(
inputs=input_ids,
images=video_data[0][0],
images_highres=video_data[0][1],
modalities=video_data[2],
do_sample=False,
temperature=0,
max_new_tokens=1024,
use_cache=True,
)
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
outputs = outputs.strip()
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)]
outputs = outputs.strip()
return outputs
# Define input and output for the Gradio interface
demo = gr.Interface(
fn=oryx_inference,
inputs=[gr.Video(label="Input Video"), gr.Textbox(label="Input Text")],
outputs="text",
title="Oryx Inference",
description="This is a demo for Oryx inference."
)
# Launch the Gradio app
demo.launch(server_name="0.0.0.0",server_port=80)