unknown
'greatest'
8429f39
raw
history blame
No virus
5.32 kB
try:
from data import *
except:
from foleycrafter.models.specvqgan.onset_baseline.data import *
import pdb
import csv
import glob
import h5py
import io
import json
import librosa
import numpy as np
import os
import pickle
from PIL import Image
from PIL import ImageFilter
import random
import scipy
import soundfile as sf
import time
from tqdm import tqdm
import glob
import cv2
import torch
import torch.nn as nn
import torchaudio
import torchvision.transforms as transforms
# import kornia as K
import sys
sys.path.append('..')
class GreatestHitDataset(object):
def __init__(self, args, split='train'):
self.split = split
if split == 'train':
list_sample = './data/greatesthit_train_2.00.json'
elif split == 'val':
list_sample = './data/greatesthit_valid_2.00.json'
elif split == 'test':
list_sample = './data/greatesthit_test_2.00.json'
# save args parameter
self.repeat = args.repeat if split == 'train' else 1
self.max_sample = args.max_sample
self.video_transform = transforms.Compose(
self.generate_video_transform(args))
if isinstance(list_sample, str):
with open(list_sample, "r") as f:
self.list_sample = json.load(f)
if self.max_sample > 0:
self.list_sample = self.list_sample[0:self.max_sample]
self.list_sample = self.list_sample * self.repeat
random.seed(1234)
np.random.seed(1234)
num_sample = len(self.list_sample)
if self.split == 'train':
random.shuffle(self.list_sample)
# self.class_dist = self.unbalanced_dist()
print('Greatesthit Dataloader: # sample of {}: {}'.format(self.split, num_sample))
def __getitem__(self, index):
# import pdb; pdb.set_trace()
info = self.list_sample[index].split('_')[0]
video_path = os.path.join('data', 'greatesthit', 'greatesthit_processed', info)
frame_path = os.path.join(video_path, 'frames')
audio_path = os.path.join(video_path, 'audio')
audio_path = glob.glob(f"{audio_path}/*.wav")[0]
# Unused, consider remove
meta_path = os.path.join(video_path, 'hit_record.json')
if os.path.exists(meta_path):
with open(meta_path, "r") as f:
meta_dict = json.load(f)
audio, audio_sample_rate = sf.read(audio_path, start=0, stop=1000, dtype='float64', always_2d=True)
frame_rate = 15
duration = 2.0
frame_list = glob.glob(f'{frame_path}/*.jpg')
frame_list.sort()
hit_time = float(self.list_sample[index].split('_')[-1]) / 22050
if self.split == 'train':
frame_start = hit_time * frame_rate + np.random.randint(10) - 5
frame_start = max(frame_start, 0)
frame_start = min(frame_start, len(frame_list) - duration * frame_rate)
else:
frame_start = hit_time * frame_rate
frame_start = max(frame_start, 0)
frame_start = min(frame_start, len(frame_list) - duration * frame_rate)
frame_start = int(frame_start)
frame_list = frame_list[frame_start: int(
frame_start + np.ceil(duration * frame_rate))]
audio_start = int(frame_start / frame_rate * audio_sample_rate)
audio_end = int(audio_start + duration * audio_sample_rate)
imgs = self.read_image(frame_list)
audio, audio_rate = sf.read(audio_path, start=audio_start, stop=audio_end, dtype='float64', always_2d=True)
audio = audio.mean(-1)
onsets = librosa.onset.onset_detect(y=audio, sr=audio_rate, units='time', delta=0.3)
onsets = np.rint(onsets * frame_rate).astype(int)
onsets[onsets>29] = 29
label = torch.zeros(len(frame_list))
label[onsets] = 1
batch = {
'frames': imgs,
'label': label
}
return batch
def getitem_test(self, index):
self.__getitem__(index)
def __len__(self):
return len(self.list_sample)
def read_image(self, frame_list):
imgs = []
convert_tensor = transforms.ToTensor()
for img_path in frame_list:
image = Image.open(img_path).convert('RGB')
image = convert_tensor(image)
imgs.append(image.unsqueeze(0))
# (T, C, H ,W)
imgs = torch.cat(imgs, dim=0).squeeze()
imgs = self.video_transform(imgs)
imgs = imgs.permute(1, 0, 2, 3)
# (C, T, H ,W)
return imgs
def generate_video_transform(self, args):
resize_funct = transforms.Resize((128, 128))
if self.split == 'train':
crop_funct = transforms.RandomCrop(
(112, 112))
color_funct = transforms.ColorJitter(
brightness=0.1, contrast=0.1, saturation=0, hue=0)
else:
crop_funct = transforms.CenterCrop(
(112, 112))
color_funct = transforms.Lambda(lambda img: img)
vision_transform_list = [
resize_funct,
crop_funct,
color_funct,
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]
return vision_transform_list