merve's picture
merve HF staff
fix task tag
e23a84d verified
|
raw
history blame
6.82 kB
---
license: mit
pipeline_tag: video-text-to-text
---
<p align="center" width="100%">
<img src="https://i.postimg.cc/MKmyP9wH/new-banner.png" width="80%" height="80%">
</p>
<div>
<div align="center">
<a href='https://brianboli.com/' target='_blank'>Bo Li*<sup>1</sup></a>&emsp;
<a href='https://zhangyuanhan-ai.github.io/' target='_blank'>Yuanhan Zhang*<sup>,1</sup></a>&emsp;
<a href='https://cliangyu.com/' target='_blank'>Liangyu Chen*<sup>,1</sup></a>&emsp;
<a href='https://king159.github.io/' target='_blank'>Jinghao Wang*<sup>,1</sup></a>&emsp;
<a href='https://pufanyi.github.io/' target='_blank'>Fanyi Pu*<sup>,1</sup></a>&emsp;
</br>
<a href='https://jingkang50.github.io/' target='_blank'>Jingkang Yang<sup>1</sup></a>&emsp;
<a href='https://chunyuan.li/' target='_blank'>Chunyuan Li<sup>2</sup></a>&emsp;
<a href='https://liuziwei7.github.io/' target='_blank'>Ziwei Liu<sup>1</sup></a>
</div>
<div>
<div align="center">
<sup>1</sup>S-Lab, Nanyang Technological University&emsp;
<sup>2</sup>Microsoft Research, Redmond
</div>
-----------------
![](https://img.shields.io/badge/otter-v0.2-darkcyan)
![](https://img.shields.io/github/stars/luodian/otter?style=social)
[![Hits](https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fgithub.com%2FLuodian%2Fotter&count_bg=%23FFA500&title_bg=%23555555&icon=&icon_color=%23E7E7E7&title=visitors&edge_flat=false)](https://hits.seeyoufarm.com)
![](https://black.readthedocs.io/en/stable/_static/license.svg)
![](https://img.shields.io/badge/code%20style-black-000000.svg)
An example of using this model to run on your video.
Please first clone [Otter](https://github.com/Luodian/Otter) to your local disk.
Place following script inside the `Otter` folder to make sure it has the access to `otter/modeling_otter.py`.
```python
import mimetypes
import os
from typing import Union
import cv2
import requests
import torch
import transformers
from PIL import Image
import sys
# make sure you can properly access the otter folder
from otter.modeling_otter import OtterForConditionalGeneration
# Disable warnings
requests.packages.urllib3.disable_warnings()
# ------------------- Utility Functions -------------------
def get_content_type(file_path):
content_type, _ = mimetypes.guess_type(file_path)
return content_type
# ------------------- Image and Video Handling Functions -------------------
def extract_frames(video_path, num_frames=16):
video = cv2.VideoCapture(video_path)
total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
frame_step = total_frames // num_frames
frames = []
for i in range(num_frames):
video.set(cv2.CAP_PROP_POS_FRAMES, i * frame_step)
ret, frame = video.read()
if ret:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = Image.fromarray(frame).convert("RGB")
frames.append(frame)
video.release()
return frames
def get_image(url: str) -> Union[Image.Image, list]:
if "://" not in url: # Local file
content_type = get_content_type(url)
else: # Remote URL
content_type = requests.head(url, stream=True, verify=False).headers.get("Content-Type")
if "image" in content_type:
if "://" not in url: # Local file
return Image.open(url)
else: # Remote URL
return Image.open(requests.get(url, stream=True, verify=False).raw)
elif "video" in content_type:
video_path = "temp_video.mp4"
if "://" not in url: # Local file
video_path = url
else: # Remote URL
with open(video_path, "wb") as f:
f.write(requests.get(url, stream=True, verify=False).content)
frames = extract_frames(video_path)
if "://" in url: # Only remove the temporary video file if it was downloaded
os.remove(video_path)
return frames
else:
raise ValueError("Invalid content type. Expected image or video.")
# ------------------- OTTER Prompt and Response Functions -------------------
def get_formatted_prompt(prompt: str) -> str:
return f"<image>User: {prompt} GPT:<answer>"
def get_response(input_data, prompt: str, model=None, image_processor=None, tensor_dtype=None) -> str:
if isinstance(input_data, Image.Image):
vision_x = image_processor.preprocess([input_data], return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0)
elif isinstance(input_data, list): # list of video frames
vision_x = image_processor.preprocess(input_data, return_tensors="pt")["pixel_values"].unsqueeze(0).unsqueeze(0)
else:
raise ValueError("Invalid input data. Expected PIL Image or list of video frames.")
lang_x = model.text_tokenizer(
[
get_formatted_prompt(prompt),
],
return_tensors="pt",
)
bad_words_id = model.text_tokenizer(["User:", "GPT1:", "GFT:", "GPT:"], add_special_tokens=False).input_ids
generated_text = model.generate(
vision_x=vision_x.to(model.device, dtype=tensor_dtype),
lang_x=lang_x["input_ids"].to(model.device),
attention_mask=lang_x["attention_mask"].to(model.device),
max_new_tokens=512,
num_beams=3,
no_repeat_ngram_size=3,
bad_words_ids=bad_words_id,
)
parsed_output = (
model.text_tokenizer.decode(generated_text[0])
.split("<answer>")[-1]
.lstrip()
.rstrip()
.split("<|endofchunk|>")[0]
.lstrip()
.rstrip()
.lstrip('"')
.rstrip('"')
)
return parsed_output
# ------------------- Main Function -------------------
load_bit = "fp32"
if load_bit == "fp16":
precision = {"torch_dtype": torch.float16}
elif load_bit == "bf16":
precision = {"torch_dtype": torch.bfloat16}
elif load_bit == "fp32":
precision = {"torch_dtype": torch.float32}
# This model version is trained on MIMIC-IT DC dataset.
model = OtterForConditionalGeneration.from_pretrained("luodian/OTTER-9B-DenseCaption", device_map="auto", **precision)
tensor_dtype = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32}[load_bit]
model.text_tokenizer.padding_side = "left"
tokenizer = model.text_tokenizer
image_processor = transformers.CLIPImageProcessor()
model.eval()
while True:
video_url = input("Enter video path: ") # Replace with the path to your video file, could be any common format.
frames_list = get_image(video_url)
while True:
prompts_input = input("Enter prompts: ")
if prompts_input.lower() == "quit":
break
print(f"\nPrompt: {prompts_input}")
response = get_response(frames_list, prompts_input, model, image_processor, tensor_dtype)
print(f"Response: {response}")
```