luodian commited on
Commit
42a8ed8
1 Parent(s): 72019db

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +181 -0
README.md CHANGED
@@ -1,3 +1,184 @@
1
  ---
2
  license: mit
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
3
  ---
4
+
5
+ <p align="center" width="100%">
6
+ <img src="https://i.postimg.cc/MKmyP9wH/new-banner.png" width="80%" height="80%">
7
+ </p>
8
+
9
+
10
+ <div>
11
+ <div align="center">
12
+ <a href='https://brianboli.com/' target='_blank'>Bo Li*<sup>1</sup></a>&emsp;
13
+ <a href='https://zhangyuanhan-ai.github.io/' target='_blank'>Yuanhan Zhang*<sup>,1</sup></a>&emsp;
14
+ <a href='https://cliangyu.com/' target='_blank'>Liangyu Chen*<sup>,1</sup></a>&emsp;
15
+ <a href='https://king159.github.io/' target='_blank'>Jinghao Wang*<sup>,1</sup></a>&emsp;
16
+ <a href='https://pufanyi.github.io/' target='_blank'>Fanyi Pu*<sup>,1</sup></a>&emsp;
17
+ </br>
18
+ <a href='https://jingkang50.github.io/' target='_blank'>Jingkang Yang<sup>1</sup></a>&emsp;
19
+ <a href='https://chunyuan.li/' target='_blank'>Chunyuan Li<sup>2</sup></a>&emsp;
20
+ <a href='https://liuziwei7.github.io/' target='_blank'>Ziwei Liu<sup>1</sup></a>
21
+ </div>
22
+ <div>
23
+ <div align="center">
24
+ <sup>1</sup>S-Lab, Nanyang Technological University&emsp;
25
+ <sup>2</sup>Microsoft Research, Redmond
26
+ </div>
27
+
28
+ -----------------
29
+
30
+ ![](https://img.shields.io/badge/otter-v0.2-darkcyan)
31
+ ![](https://img.shields.io/github/stars/luodian/otter?style=social)
32
+ [![Hits](https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fgithub.com%2FLuodian%2Fotter&count_bg=%23FFA500&title_bg=%23555555&icon=&icon_color=%23E7E7E7&title=visitors&edge_flat=false)](https://hits.seeyoufarm.com)
33
+ ![](https://black.readthedocs.io/en/stable/_static/license.svg)
34
+ ![](https://img.shields.io/badge/code%20style-black-000000.svg)
35
+
36
+ An example of using this model to run on your video. Please first clone [Otter](https://github.com/Luodian/Otter) to your local disk. Place following script inside the `Otter` folder to make sure it has the access to `otter/modeling_otter.py`.
37
+
38
+ ```python
39
+ import mimetypes
40
+ import os
41
+ from io import BytesIO
42
+ from typing import Union
43
+ import cv2
44
+ import requests
45
+ import torch
46
+ import transformers
47
+ from PIL import Image
48
+ from torchvision.transforms import Compose, Resize, ToTensor
49
+ from tqdm import tqdm
50
+ import sys
51
+
52
+ from otter.modeling_otter import OtterForConditionalGeneration
53
+
54
+ # Disable warnings
55
+ requests.packages.urllib3.disable_warnings()
56
+
57
+ # ------------------- Utility Functions -------------------
58
+
59
+
60
+ def get_content_type(file_path):
61
+ content_type, _ = mimetypes.guess_type(file_path)
62
+ return content_type
63
+
64
+
65
+ # ------------------- Image and Video Handling Functions -------------------
66
+
67
+
68
+ def extract_frames(video_path, num_frames=128):
69
+ video = cv2.VideoCapture(video_path)
70
+ total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
71
+ frame_step = total_frames // num_frames
72
+ frames = []
73
+
74
+ for i in range(num_frames):
75
+ video.set(cv2.CAP_PROP_POS_FRAMES, i * frame_step)
76
+ ret, frame = video.read()
77
+ if ret:
78
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
79
+ frame = Image.fromarray(frame).convert("RGB")
80
+ frames.append(frame)
81
+
82
+ video.release()
83
+ return frames
84
+
85
+
86
+ def get_image(url: str) -> Union[Image.Image, list]:
87
+ if "://" not in url: # Local file
88
+ content_type = get_content_type(url)
89
+ else: # Remote URL
90
+ content_type = requests.head(url, stream=True, verify=False).headers.get("Content-Type")
91
+
92
+ if "image" in content_type:
93
+ if "://" not in url: # Local file
94
+ return Image.open(url)
95
+ else: # Remote URL
96
+ return Image.open(requests.get(url, stream=True, verify=False).raw)
97
+ elif "video" in content_type:
98
+ video_path = "temp_video.mp4"
99
+ if "://" not in url: # Local file
100
+ video_path = url
101
+ else: # Remote URL
102
+ with open(video_path, "wb") as f:
103
+ f.write(requests.get(url, stream=True, verify=False).content)
104
+ frames = extract_frames(video_path)
105
+ if "://" in url: # Only remove the temporary video file if it was downloaded
106
+ os.remove(video_path)
107
+ return frames
108
+ else:
109
+ raise ValueError("Invalid content type. Expected image or video.")
110
+
111
+
112
+ # ------------------- OTTER Prompt and Response Functions -------------------
113
+
114
+
115
+ def get_formatted_prompt(prompt: str) -> str:
116
+ return f"<image>User: {prompt} GPT:<answer>"
117
+
118
+
119
+ def get_response(input_data, prompt: str, model=None, image_processor=None) -> str:
120
+ if isinstance(input_data, Image.Image):
121
+ vision_x = (
122
+ image_processor.preprocess([input_data], return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0)
123
+ )
124
+ elif isinstance(input_data, list): # list of video frames
125
+ vision_x = image_processor.preprocess(input_data, return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0)
126
+ else:
127
+ raise ValueError("Invalid input data. Expected PIL Image or list of video frames.")
128
+
129
+ lang_x = model.text_tokenizer(
130
+ [
131
+ get_formatted_prompt(prompt),
132
+ ],
133
+ return_tensors="pt",
134
+ )
135
+
136
+ generated_text = model.generate(
137
+ vision_x=vision_x.to(model.device),
138
+ lang_x=lang_x["input_ids"].to(model.device),
139
+ attention_mask=lang_x["attention_mask"].to(model.device),
140
+ max_new_tokens=512,
141
+ num_beams=3,
142
+ no_repeat_ngram_size=3,
143
+ )
144
+ parsed_output = (
145
+ model.text_tokenizer.decode(generated_text[0])
146
+ .split("<answer>")[-1]
147
+ .lstrip()
148
+ .rstrip()
149
+ .split("<|endofchunk|>")[0]
150
+ .lstrip()
151
+ .rstrip()
152
+ .lstrip('"')
153
+ .rstrip('"')
154
+ )
155
+ return parsed_output
156
+
157
+
158
+ # ------------------- Main Function -------------------
159
+
160
+ if __name__ == "__main__":
161
+ model = OtterForConditionalGeneration.from_pretrained(
162
+ "luodian/otter-9b-dc-hf",
163
+ )
164
+ model.text_tokenizer.padding_side = "left"
165
+ tokenizer = model.text_tokenizer
166
+ image_processor = transformers.CLIPImageProcessor()
167
+ model.eval()
168
+
169
+ while True:
170
+ video_url = "dc_demo.mp4" # Replace with the path to your video file
171
+
172
+ frames_list = get_image(video_url)
173
+
174
+ prompts_input = input("Enter prompts (comma-separated): ")
175
+ prompts = [prompt.strip() for prompt in prompts_input.split(",")]
176
+
177
+ for prompt in prompts:
178
+ print(f"\nPrompt: {prompt}")
179
+ response = get_response(frames_list, prompt, model, image_processor)
180
+ print(f"Response: {response}")
181
+
182
+ if prompts_input.lower() == "quit":
183
+ break
184
+ ```