import math import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import json from random import shuffle, choice, sample from moviepy.editor import VideoFileClip import librosa from scipy import signal from scipy.io import wavfile import torchaudio torchaudio.set_audio_backend("sox_io") INTERVAL = 1000 # discard stft = torchaudio.transforms.MelSpectrogram( sample_rate=16000, hop_length=161, n_mels=64).cuda() def log10(x): return torch.log(x)/torch.log(torch.tensor(10.)) def norm_range(x, min_val, max_val): return 2.*(x - min_val)/float(max_val - min_val) - 1. def normalize_spec(spec, spec_min, spec_max): return norm_range(spec, spec_min, spec_max) def db_from_amp(x, cuda=False): # rescale the audio if cuda: return 20. * log10(torch.max(torch.tensor(1e-5).to('cuda'), x.float())) else: return 20. * log10(torch.max(torch.tensor(1e-5), x.float())) def audio_stft(audio, stft=stft): # We'll apply stft to the audio samples to convert it to a HxW matrix N, C, A = audio.size() audio = audio.view(N * C, A) spec = stft(audio) spec = spec.transpose(-1, -2) spec = db_from_amp(spec, cuda=True) spec = normalize_spec(spec, -100., 100.) _, T, F = spec.size() spec = spec.view(N, C, T, F) return spec # discard # def get_spec( # wavs, # sample_rate=16000, # use_volume_jittering=False, # center=False, # ): # # Volume jittering - scale volume by factor in range (0.9, 1.1) # if use_volume_jittering: # wavs = [wav * np.random.uniform(0.9, 1.1) for wav in wavs] # if center: # wavs = [center_only(wav) for wav in wavs] # # Convert to log filterbank # specs = [logfbank( # wav, # sample_rate, # winlen=0.009, # winstep=0.005, # if num_sec==1 else 0.01, # nfilt=256, # nfft=1024 # ).astype('float32').T for wav in wavs] # # Convert to 32-bit float and expand dim # specs = np.stack(specs, axis=0) # specs = np.expand_dims(specs, 1) # specs = torch.as_tensor(specs) # Nx1xFxT # return specs def center_only(audio, sr=16000, L=1.0): # center_wav = np.arange(0, L, L/(0.5*sr)) ** 2 # center_wav = np.concatenate([center_wav, center_wav[::-1]]) # center_wav[L*sr//2:3*L*sr//4] = 1 # only take 0.3 sec audio center_wav = np.zeros(int(L * sr)) center_wav[int(0.4*L*sr):int(0.7*L*sr)] = 1 return audio * center_wav def get_spec_librosa( wavs, sample_rate=16000, use_volume_jittering=False, center=False, ): # Volume jittering - scale volume by factor in range (0.9, 1.1) if use_volume_jittering: wavs = [wav * np.random.uniform(0.9, 1.1) for wav in wavs] if center: wavs = [center_only(wav) for wav in wavs] # Convert to log filterbank specs = [librosa.feature.melspectrogram( y=wav, sr=sample_rate, n_fft=400, hop_length=126, n_mels=128, ).astype('float32') for wav in wavs] # Convert to 32-bit float and expand dim specs = [librosa.power_to_db(spec) for spec in specs] specs = np.stack(specs, axis=0) specs = np.expand_dims(specs, 1) specs = torch.as_tensor(specs) # Nx1xFxT return specs def calcEuclideanDistance_Mat(X, Y): """ Inputs: - X: A numpy array of shape (N, F) - Y: A numpy array of shape (M, F) Returns: A numpy array D of shape (N, M) where D[i, j] is the Euclidean distance between X[i] and Y[j]. """ return ((torch.sum(X ** 2, axis=1, keepdims=True)) + (torch.sum(Y ** 2, axis=1, keepdims=True)).T - 2 * X @ Y.T) ** 0.5 def calcEuclideanDistance(x1, x2): return torch.sum((x1 - x2)**2, dim=1)**0.5 def split_data(in_list, portion=(0.9, 0.95), is_shuffle=True): if is_shuffle: shuffle(in_list) if type(in_list) == str: with open(in_list) as l: fw_list = json.load(l) elif type(in_list) == list: fw_list = in_list else: print(type(in_list)) raise TypeError('Invalid input list type') c1, c2 = int(len(fw_list) * portion[0]), int(len(fw_list) * portion[1]) tr_list, va_list, te_list = fw_list[:c1], fw_list[c1:c2], fw_list[c2:] print( f'==> train set: {len(tr_list)}, validation set: {len(va_list)}, test set: {len(te_list)}') return tr_list, va_list, te_list def load_one_clip(video_path): v = VideoFileClip(video_path) fps = int(v.fps) frames = [f for f in v.iter_frames()][:-1] frame_cnt = len(frames) frame_length = 1000./fps total_length = int(1000 * (frame_cnt / fps)) a = v.audio sr = a.fps a = np.array([fa for fa in a.iter_frames()]) a = librosa.resample(a, sr, 48000) if len(a.shape) > 1: a = np.mean(a, axis=1) while True: idx = np.random.choice(np.arange(frame_cnt - 1), 1)[0] frame_clip = frames[idx] start_time = int(idx * frame_length + 0.5 * frame_length - 500) end_time = start_time + INTERVAL if start_time < 0 or end_time > total_length: continue wave_clip = a[48 * start_time: 48 * end_time] if wave_clip.shape[0] != 48000: continue break return frame_clip, wave_clip def resize_frame(frame): H, W = frame.size short_edge = min(H, W) scale = 256 / short_edge H_tar, W_tar = int(np.round(H * scale)), int(np.round(W * scale)) return frame.resize((H_tar, W_tar)) def get_spectrogram(wave, amp_jitter, amp_jitter_range, log_scale=True, sr=48000): # random clip-level amplitude jittering if amp_jitter: amplified = wave * np.random.uniform(*amp_jitter_range) if wave.dtype == np.int16: amplified[amplified >= 32767] = 32767 amplified[amplified <= -32768] = -32768 wave = amplified.astype('int16') elif wave.dtype == np.float32 or wave.dtype == np.float64: amplified[amplified >= 1] = 1 amplified[amplified <= -1] = -1 # fr, ts, spectrogram = signal.spectrogram(wave[:48000], fs=sr, nperseg=480, noverlap=240, nfft=512) # spectrogram = librosa.feature.melspectrogram(S=spectrogram, n_mels=257) # Try log-mel spectrogram? spectrogram = librosa.feature.melspectrogram( y=wave[:48000], sr=sr, hop_length=240, win_length=480, n_mels=257) if log_scale: spectrogram = librosa.power_to_db(spectrogram, ref=np.max) assert spectrogram.shape[0] == 257 return spectrogram def cropAudio(audio, sr, f_idx, fps=10, length=1., left_shift=0): time_per_frame = 1./fps assert audio.shape[0] > sr * length start_time = f_idx * time_per_frame - left_shift start_time = 0 if start_time < 0 else start_time start_idx = int(np.round(sr * start_time)) end_idx = int(np.round(start_idx + (sr * length))) if end_idx > audio.shape[0]: end_idx = audio.shape[0] start_idx = int(end_idx - (sr * length)) try: assert audio[start_idx:end_idx].shape[0] == sr * length except: print(audio.shape, start_idx, end_idx, end_idx - start_idx) exit(1) return audio[start_idx:end_idx] def pick_async_frame_idx(idx, total_frames, fps=10, gap=2.0, length=1.0, cnt=1): assert idx < total_frames - fps * length lower_bound = idx - int((length + gap) * fps) upper_bound = idx + int((length + gap) * fps) proposal = list(range(0, lower_bound)) + \ list(range(upper_bound, int(total_frames - fps * length))) # assert len(proposal) >= cnt avail_cnt = len(proposal) try: for i in range(cnt - avail_cnt): proposal.append(proposal[i % avail_cnt]) except Exception as e: print(idx, total_frames, proposal) raise e return sample(proposal, k=cnt) def adjust_learning_rate(optimizer, epoch, args): """Decay the learning rate based on schedule""" lr = args.lr if args.cos: # cosine lr schedule lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epoch)) else: # stepwise lr schedule for milestone in args.schedule: lr *= 0.1 if epoch >= milestone else 1. for param_group in optimizer.param_groups: param_group['lr'] = lr