try: from data import * except: from foleycrafter.models.specvqgan.onset_baseline.data import * import pdb import csv import glob import h5py import io import json import librosa import numpy as np import os import pickle from PIL import Image from PIL import ImageFilter import random import scipy import soundfile as sf import time from tqdm import tqdm import glob import cv2 import torch import torch.nn as nn import torchaudio import torchvision.transforms as transforms # import kornia as K import sys sys.path.append('..') class GreatestHitDataset(object): def __init__(self, args, split='train'): self.split = split if split == 'train': list_sample = './data/greatesthit_train_2.00.json' elif split == 'val': list_sample = './data/greatesthit_valid_2.00.json' elif split == 'test': list_sample = './data/greatesthit_test_2.00.json' # save args parameter self.repeat = args.repeat if split == 'train' else 1 self.max_sample = args.max_sample self.video_transform = transforms.Compose( self.generate_video_transform(args)) if isinstance(list_sample, str): with open(list_sample, "r") as f: self.list_sample = json.load(f) if self.max_sample > 0: self.list_sample = self.list_sample[0:self.max_sample] self.list_sample = self.list_sample * self.repeat random.seed(1234) np.random.seed(1234) num_sample = len(self.list_sample) if self.split == 'train': random.shuffle(self.list_sample) # self.class_dist = self.unbalanced_dist() print('Greatesthit Dataloader: # sample of {}: {}'.format(self.split, num_sample)) def __getitem__(self, index): # import pdb; pdb.set_trace() info = self.list_sample[index].split('_')[0] video_path = os.path.join('data', 'greatesthit', 'greatesthit_processed', info) frame_path = os.path.join(video_path, 'frames') audio_path = os.path.join(video_path, 'audio') audio_path = glob.glob(f"{audio_path}/*.wav")[0] # Unused, consider remove meta_path = os.path.join(video_path, 'hit_record.json') if os.path.exists(meta_path): with open(meta_path, "r") as f: meta_dict = json.load(f) audio, audio_sample_rate = sf.read(audio_path, start=0, stop=1000, dtype='float64', always_2d=True) frame_rate = 15 duration = 2.0 frame_list = glob.glob(f'{frame_path}/*.jpg') frame_list.sort() hit_time = float(self.list_sample[index].split('_')[-1]) / 22050 if self.split == 'train': frame_start = hit_time * frame_rate + np.random.randint(10) - 5 frame_start = max(frame_start, 0) frame_start = min(frame_start, len(frame_list) - duration * frame_rate) else: frame_start = hit_time * frame_rate frame_start = max(frame_start, 0) frame_start = min(frame_start, len(frame_list) - duration * frame_rate) frame_start = int(frame_start) frame_list = frame_list[frame_start: int( frame_start + np.ceil(duration * frame_rate))] audio_start = int(frame_start / frame_rate * audio_sample_rate) audio_end = int(audio_start + duration * audio_sample_rate) imgs = self.read_image(frame_list) audio, audio_rate = sf.read(audio_path, start=audio_start, stop=audio_end, dtype='float64', always_2d=True) audio = audio.mean(-1) onsets = librosa.onset.onset_detect(y=audio, sr=audio_rate, units='time', delta=0.3) onsets = np.rint(onsets * frame_rate).astype(int) onsets[onsets>29] = 29 label = torch.zeros(len(frame_list)) label[onsets] = 1 batch = { 'frames': imgs, 'label': label } return batch def getitem_test(self, index): self.__getitem__(index) def __len__(self): return len(self.list_sample) def read_image(self, frame_list): imgs = [] convert_tensor = transforms.ToTensor() for img_path in frame_list: image = Image.open(img_path).convert('RGB') image = convert_tensor(image) imgs.append(image.unsqueeze(0)) # (T, C, H ,W) imgs = torch.cat(imgs, dim=0).squeeze() imgs = self.video_transform(imgs) imgs = imgs.permute(1, 0, 2, 3) # (C, T, H ,W) return imgs def generate_video_transform(self, args): resize_funct = transforms.Resize((128, 128)) if self.split == 'train': crop_funct = transforms.RandomCrop( (112, 112)) color_funct = transforms.ColorJitter( brightness=0.1, contrast=0.1, saturation=0, hue=0) else: crop_funct = transforms.CenterCrop( (112, 112)) color_funct = transforms.Lambda(lambda img: img) vision_transform_list = [ resize_funct, crop_funct, color_funct, transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ] return vision_transform_list