import json import random from tqdm import tqdm import torch import decord decord.bridge.set_bridge("torch") import torchaudio from math import ceil from torch.utils.data import Dataset, DataLoader import pandas as pd import numpy as np class AudioVisualDataset(Dataset): """Can sample data from audio-visual databases Params: min_video_frames: used to drop short video clips video_resize: resize for CLIP processing sampling_rate: audio sampling rate max_clip_len: max length (seconds) of audiovisual clip to be sampled num_sample_frames: number of image frames to be uniformly sampled from video """ def __init__( self, datafiles=[ "/mnt/bn/data-xubo/dataset/audioset_videos/datafiles/audioset_balanced_train.json", ], min_video_frames=30, video_resize=[224, 224], sampling_rate=16000, sample_av_clip=True, max_clip_len=10, num_sample_frames=10, # hyparameters used for SpecAug freqm=48, timem=192, return_label=False, ): all_data_json = [] for datafile in datafiles: with open(datafile, "r") as fp: data_json = json.load(fp)["data"] all_data_json.extend(data_json) # drop short video clips self.all_data_json = [ data for data in all_data_json if int(data["video_shape"][0]) >= min_video_frames ] self.max_clip_len = max_clip_len self.video_resize = video_resize self.sampling_rate = sampling_rate self.sample_av_clip = sample_av_clip self.num_sample_frames = num_sample_frames self.corresponding_audio_len = self.sampling_rate * self.max_clip_len # hyparameters used for AudioMAE self.freqm = freqm self.timem = timem self.norm_mean = -4.2677393 self.norm_std = 4.5689974 self.melbins = 128 self.TARGET_LEN = 1024 self.return_label = return_label if self.return_label: self.audioset_label2idx = self._prepare_audioset() def __len__(self): return len(self.all_data_json) def _read_audio_video(self, index): try: video_path = self.all_data_json[index]["mp4"] # read audio ar = decord.AudioReader( video_path, sample_rate=self.sampling_rate, mono=True ) # read video frames vr = decord.VideoReader( video_path, height=self.video_resize[0], width=self.video_resize[1], ) labels = self.all_data_json[index]["labels"] return vr, ar, labels except Exception as e: print(f"error: {e} occurs, when loading {video_path}") random_index = random.randint(0, len(self.all_data_json) - 1) return self._read_audio_video(index=random_index) def _prepare_audioset(self): df1 = pd.read_csv( "/mnt/bn/lqhaoheliu/datasets/audioset/metadata/class_labels_indices.csv", delimiter=",", skiprows=0, ) label_set = df1.to_numpy() code2id = {} for i in range(len(label_set)): code2id[label_set[i][1]] = label_set[i][0] return code2id def __getitem__(self, index): # read audio and video vr, ar, labels = self._read_audio_video(index) # create a audio tensor audio_data = ar[:] # [1, samples] audio_len = audio_data.shape[1] / self.sampling_rate audio_data = audio_data.squeeze(0) # [samples] # create a video tensor full_vid_length = len(vr) video_rate = ceil(vr.get_avg_fps()) samples_per_frame = float(self.sampling_rate) / video_rate start_frame = 0 # sample video clip if audio_len > self.max_clip_len and self.sample_av_clip: start_frame = random.randint( 0, max(full_vid_length - video_rate * self.max_clip_len, 0) ) end_frame = min(start_frame + video_rate * self.max_clip_len, full_vid_length) video_data = vr.get_batch(range(start_frame, end_frame)) # sample audio clip if audio_len > self.max_clip_len and self.sample_av_clip: # corresponding_audio_len = int(video_data.size()[0] * samples_per_frame) corresponding_audio_start = int(start_frame * samples_per_frame) audio_data = audio_data[corresponding_audio_start:] # cut or pad audio clip with respect to the sampled video clip if audio_data.shape[0] < self.corresponding_audio_len: zero_data = torch.zeros(self.corresponding_audio_len) zero_data[: audio_data.shape[0]] = audio_data audio_data = zero_data elif audio_data.shape[0] > self.corresponding_audio_len: audio_data = audio_data[: self.corresponding_audio_len] # uniformly sample image frames from video [tentative solution] interval = video_data.shape[0] // self.num_sample_frames video_data = video_data[::interval][: self.num_sample_frames] assert ( video_data.shape[0] == self.num_sample_frames ), f"number of sampled image frames is {video_data.shape[0]}" assert ( audio_data.shape[0] == self.corresponding_audio_len ), f"number of audio samples is {audio_data.shape[0]}" # video transformation video_data = video_data / 255.0 video_data = video_data.permute(0, 3, 1, 2) # [N, H, W, C] -> [N, C, H, W] # calculate mel fbank of waveform for audio encoder audio_data = audio_data.unsqueeze(0) # [1, samples] audio_data = audio_data - audio_data.mean() fbank = torchaudio.compliance.kaldi.fbank( audio_data, htk_compat=True, sample_frequency=self.sampling_rate, use_energy=False, window_type="hanning", num_mel_bins=self.melbins, dither=0.0, frame_shift=10, ) # cut and pad n_frames = fbank.shape[0] p = self.TARGET_LEN - n_frames if p > 0: m = torch.nn.ZeroPad2d((0, 0, 0, p)) fbank = m(fbank) elif p < 0: fbank = fbank[0 : self.TARGET_LEN, :] # SpecAug for training (not for eval) freqm = torchaudio.transforms.FrequencyMasking(self.freqm) timem = torchaudio.transforms.TimeMasking(self.timem) fbank = fbank.transpose(0, 1).unsqueeze(0) # 1, 128, 1024 (...,freq,time) if self.freqm != 0: fbank = freqm(fbank) if self.timem != 0: fbank = timem(fbank) # (..., freq, time) fbank = torch.transpose(fbank.squeeze(), 0, 1) # time, freq fbank = (fbank - self.norm_mean) / (self.norm_std * 2) fbank = fbank.unsqueeze(0) if self.return_label: # get audioset lebel indexes label_indices = np.zeros(527) for label_str in labels.split(","): label_indices[int(self.audioset_label2idx[label_str])] = 1.0 label_indices = torch.FloatTensor(label_indices) data_dict = { "labels": label_indices, "images": video_data, "fbank": fbank, # 'modality': 'audio_visual' } else: data_dict = { "images": video_data, "fbank": fbank, # 'modality': 'audio_visual' } return data_dict def collate_fn(list_data_dict): r"""Collate mini-batch data to inputs and targets for training. Args: list_data_dict: e.g., [ {'vocals': (channels_num, segment_samples), 'accompaniment': (channels_num, segment_samples), 'mixture': (channels_num, segment_samples) }, {'vocals': (channels_num, segment_samples), 'accompaniment': (channels_num, segment_samples), 'mixture': (channels_num, segment_samples) }, ...] Returns: data_dict: e.g. { 'vocals': (batch_size, channels_num, segment_samples), 'accompaniment': (batch_size, channels_num, segment_samples), 'mixture': (batch_size, channels_num, segment_samples) } """ data_dict = {} for key in list_data_dict[0].keys(): # for key in ['waveform']: # try: data_dict[key] = [data_dict[key] for data_dict in list_data_dict] # except: # from IPython import embed; embed(using=False); os._exit(0) data_dict[key] = torch.stack(data_dict[key]) return data_dict