innat commited on
Commit
bbc8456
1 Parent(s): 01acd5f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -30
app.py CHANGED
@@ -10,6 +10,26 @@ from utils import read_video, frame_sampling, denormalize, reconstrunction
10
  from utils import IMAGENET_MEAN, IMAGENET_STD, num_frames, patch_size, input_size
11
  from labels import K400_label_map, SSv2_label_map, UCF_label_map
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  def tube_mask_generator(mask_ratio):
15
  window_size = (
@@ -31,21 +51,27 @@ def tube_mask_generator(mask_ratio):
31
  def get_model(data_type):
32
  ft_model = keras.models.load_model(MODELS[data_type][0])
33
  pt_model = keras.models.load_model(MODELS[data_type][1])
34
- label_map = {v: k for k, v in K400_label_map.items()}
 
 
 
35
  return ft_model, pt_model, label_map
36
 
37
 
38
- def inference(video_file, dataset_type, mask_ratio):
39
  print('---------------------------')
40
  print(video_file)
41
- print(dataset_type)
42
  print(mask_ratio)
43
  print('---------------------------')
44
-
 
45
  container = read_video(video_file)
46
  frames = frame_sampling(container, num_frames=num_frames)
 
 
47
  bool_masked_pos_tf = tube_mask_generator(mask_ratio)
48
- ft_model, pt_model, label_map = get_model(dataset_type)
49
  ft_model.trainable = False
50
  pt_model.trainable = False
51
 
@@ -78,25 +104,11 @@ def inference(video_file, dataset_type, mask_ratio):
78
 
79
 
80
  def main():
81
- MODELS = {
82
- 'K400': [
83
- './TFVideoMAE_S_K400_16x224_FT',
84
- './TFVideoMAE_S_K400_16x224_PT'
85
- ],
86
- 'SSv2': [
87
- './TFVideoMAE_S_K400_16x224_FT',
88
- './TFVideoMAE_S_K400_16x224_PT'
89
- ],
90
- 'UCF' : [
91
- './TFVideoMAE_S_K400_16x224_FT',
92
- './TFVideoMAE_S_K400_16x224_PT'
93
- ]
94
- }
95
- BENCHMARK_DATASETS = ['K400', 'SSv2', 'UCF']
96
- SAMPLE_EXAMPLES = [
97
  ["examples/k400.mp4", 'Kintetics-400'],
98
- ["examples/k400.mp4", 'SSv2'],
99
- ["examples/k400.mp4", 'UCF']
100
  ]
101
 
102
  iface = gr.Interface(
@@ -104,16 +116,16 @@ def main():
104
  inputs=[
105
  gr.Video(type="file", label="Input Video"),
106
  gr.Radio(
107
- BENCHMARK_DATASETS,
108
  type='value',
109
- # default=BENCHMARK_DATASETS[0],
110
  label='Dataset',
111
  ),
112
  gr.Slider(
113
- 0,
114
- 1,
115
- step=0.05,
116
- # default=0.5,
117
  label='Mask Ratio'
118
  )
119
  ],
@@ -121,7 +133,7 @@ def main():
121
  gr.Label(num_top_classes=3, label='scores'),
122
  gr.Image(type="filepath", label='reconstructed')
123
  ],
124
- examples=SAMPLE_EXAMPLES,
125
  title="VideoMAE",
126
  description="Keras reimplementation of <a href='https://github.com/innat/VideoMAE'>VideoMAE</a> is presented here."
127
  )
 
10
  from utils import IMAGENET_MEAN, IMAGENET_STD, num_frames, patch_size, input_size
11
  from labels import K400_label_map, SSv2_label_map, UCF_label_map
12
 
13
+ MODELS = {
14
+ 'K400': [
15
+ './TFVideoMAE_S_K400_16x224_FT',
16
+ './TFVideoMAE_S_K400_16x224_PT'
17
+ ],
18
+ 'SSv2': [
19
+ './TFVideoMAE_S_K400_16x224_FT',
20
+ './TFVideoMAE_S_K400_16x224_PT'
21
+ ],
22
+ 'UCF' : [
23
+ './TFVideoMAE_S_K400_16x224_FT',
24
+ './TFVideoMAE_S_K400_16x224_PT'
25
+ ]
26
+ }
27
+
28
+ LABEL_MAPS = {
29
+ 'K400': K400_label_map,
30
+ 'SSv2': SSv2_label_map,
31
+ 'UCF' : UCF_label_map
32
+ }
33
 
34
  def tube_mask_generator(mask_ratio):
35
  window_size = (
 
51
  def get_model(data_type):
52
  ft_model = keras.models.load_model(MODELS[data_type][0])
53
  pt_model = keras.models.load_model(MODELS[data_type][1])
54
+
55
+ label_map = LABEL_MAPS.get(data_type)
56
+ label_map = {v: k for k, v in label_map.items()}
57
+
58
  return ft_model, pt_model, label_map
59
 
60
 
61
+ def inference(video_file, data_type, mask_ratio):
62
  print('---------------------------')
63
  print(video_file)
64
+ print(data_type)
65
  print(mask_ratio)
66
  print('---------------------------')
67
+
68
+ # get sample data
69
  container = read_video(video_file)
70
  frames = frame_sampling(container, num_frames=num_frames)
71
+
72
+ # get models
73
  bool_masked_pos_tf = tube_mask_generator(mask_ratio)
74
+ ft_model, pt_model, label_map = get_model(data_type)
75
  ft_model.trainable = False
76
  pt_model.trainable = False
77
 
 
104
 
105
 
106
  def main():
107
+ datasets = ['K400', 'SSv2', 'UCF']
108
+ sample_example = [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  ["examples/k400.mp4", 'Kintetics-400'],
110
+ ["examples/k400.mp4", 'Something-Something-V2'],
111
+ ["examples/k400.mp4", 'UCF101']
112
  ]
113
 
114
  iface = gr.Interface(
 
116
  inputs=[
117
  gr.Video(type="file", label="Input Video"),
118
  gr.Radio(
119
+ datasets,
120
  type='value',
121
+ default=datasets[0],
122
  label='Dataset',
123
  ),
124
  gr.Slider(
125
+ 0.5,
126
+ 1.0,
127
+ step=0.1,
128
+ default=0.5,
129
  label='Mask Ratio'
130
  )
131
  ],
 
133
  gr.Label(num_top_classes=3, label='scores'),
134
  gr.Image(type="filepath", label='reconstructed')
135
  ],
136
+ examples=sample_example,
137
  title="VideoMAE",
138
  description="Keras reimplementation of <a href='https://github.com/innat/VideoMAE'>VideoMAE</a> is presented here."
139
  )