ymzhang319's picture
init
7f2690b
raw
history blame
No virus
8.27 kB
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import json
from random import shuffle, choice, sample
from moviepy.editor import VideoFileClip
import librosa
from scipy import signal
from scipy.io import wavfile
import torchaudio
torchaudio.set_audio_backend("sox_io")
INTERVAL = 1000
# discard
stft = torchaudio.transforms.MelSpectrogram(
sample_rate=16000, hop_length=161, n_mels=64).cuda()
def log10(x): return torch.log(x)/torch.log(torch.tensor(10.))
def norm_range(x, min_val, max_val):
return 2.*(x - min_val)/float(max_val - min_val) - 1.
def normalize_spec(spec, spec_min, spec_max):
return norm_range(spec, spec_min, spec_max)
def db_from_amp(x, cuda=False):
# rescale the audio
if cuda:
return 20. * log10(torch.max(torch.tensor(1e-5).to('cuda'), x.float()))
else:
return 20. * log10(torch.max(torch.tensor(1e-5), x.float()))
def audio_stft(audio, stft=stft):
# We'll apply stft to the audio samples to convert it to a HxW matrix
N, C, A = audio.size()
audio = audio.view(N * C, A)
spec = stft(audio)
spec = spec.transpose(-1, -2)
spec = db_from_amp(spec, cuda=True)
spec = normalize_spec(spec, -100., 100.)
_, T, F = spec.size()
spec = spec.view(N, C, T, F)
return spec
# discard
# def get_spec(
# wavs,
# sample_rate=16000,
# use_volume_jittering=False,
# center=False,
# ):
# # Volume jittering - scale volume by factor in range (0.9, 1.1)
# if use_volume_jittering:
# wavs = [wav * np.random.uniform(0.9, 1.1) for wav in wavs]
# if center:
# wavs = [center_only(wav) for wav in wavs]
# # Convert to log filterbank
# specs = [logfbank(
# wav,
# sample_rate,
# winlen=0.009,
# winstep=0.005, # if num_sec==1 else 0.01,
# nfilt=256,
# nfft=1024
# ).astype('float32').T for wav in wavs]
# # Convert to 32-bit float and expand dim
# specs = np.stack(specs, axis=0)
# specs = np.expand_dims(specs, 1)
# specs = torch.as_tensor(specs) # Nx1xFxT
# return specs
def center_only(audio, sr=16000, L=1.0):
# center_wav = np.arange(0, L, L/(0.5*sr)) ** 2
# center_wav = np.concatenate([center_wav, center_wav[::-1]])
# center_wav[L*sr//2:3*L*sr//4] = 1
# only take 0.3 sec audio
center_wav = np.zeros(int(L * sr))
center_wav[int(0.4*L*sr):int(0.7*L*sr)] = 1
return audio * center_wav
def get_spec_librosa(
wavs,
sample_rate=16000,
use_volume_jittering=False,
center=False,
):
# Volume jittering - scale volume by factor in range (0.9, 1.1)
if use_volume_jittering:
wavs = [wav * np.random.uniform(0.9, 1.1) for wav in wavs]
if center:
wavs = [center_only(wav) for wav in wavs]
# Convert to log filterbank
specs = [librosa.feature.melspectrogram(
y=wav,
sr=sample_rate,
n_fft=400,
hop_length=126,
n_mels=128,
).astype('float32') for wav in wavs]
# Convert to 32-bit float and expand dim
specs = [librosa.power_to_db(spec) for spec in specs]
specs = np.stack(specs, axis=0)
specs = np.expand_dims(specs, 1)
specs = torch.as_tensor(specs) # Nx1xFxT
return specs
def calcEuclideanDistance_Mat(X, Y):
"""
Inputs:
- X: A numpy array of shape (N, F)
- Y: A numpy array of shape (M, F)
Returns:
A numpy array D of shape (N, M) where D[i, j] is the Euclidean distance
between X[i] and Y[j].
"""
return ((torch.sum(X ** 2, axis=1, keepdims=True)) + (torch.sum(Y ** 2, axis=1, keepdims=True)).T - 2 * X @ Y.T) ** 0.5
def calcEuclideanDistance(x1, x2):
return torch.sum((x1 - x2)**2, dim=1)**0.5
def split_data(in_list, portion=(0.9, 0.95), is_shuffle=True):
if is_shuffle:
shuffle(in_list)
if type(in_list) == str:
with open(in_list) as l:
fw_list = json.load(l)
elif type(in_list) == list:
fw_list = in_list
else:
print(type(in_list))
raise TypeError('Invalid input list type')
c1, c2 = int(len(fw_list) * portion[0]), int(len(fw_list) * portion[1])
tr_list, va_list, te_list = fw_list[:c1], fw_list[c1:c2], fw_list[c2:]
print(
f'==> train set: {len(tr_list)}, validation set: {len(va_list)}, test set: {len(te_list)}')
return tr_list, va_list, te_list
def load_one_clip(video_path):
v = VideoFileClip(video_path)
fps = int(v.fps)
frames = [f for f in v.iter_frames()][:-1]
frame_cnt = len(frames)
frame_length = 1000./fps
total_length = int(1000 * (frame_cnt / fps))
a = v.audio
sr = a.fps
a = np.array([fa for fa in a.iter_frames()])
a = librosa.resample(a, sr, 48000)
if len(a.shape) > 1:
a = np.mean(a, axis=1)
while True:
idx = np.random.choice(np.arange(frame_cnt - 1), 1)[0]
frame_clip = frames[idx]
start_time = int(idx * frame_length + 0.5 * frame_length - 500)
end_time = start_time + INTERVAL
if start_time < 0 or end_time > total_length:
continue
wave_clip = a[48 * start_time: 48 * end_time]
if wave_clip.shape[0] != 48000:
continue
break
return frame_clip, wave_clip
def resize_frame(frame):
H, W = frame.size
short_edge = min(H, W)
scale = 256 / short_edge
H_tar, W_tar = int(np.round(H * scale)), int(np.round(W * scale))
return frame.resize((H_tar, W_tar))
def get_spectrogram(wave, amp_jitter, amp_jitter_range, log_scale=True, sr=48000):
# random clip-level amplitude jittering
if amp_jitter:
amplified = wave * np.random.uniform(*amp_jitter_range)
if wave.dtype == np.int16:
amplified[amplified >= 32767] = 32767
amplified[amplified <= -32768] = -32768
wave = amplified.astype('int16')
elif wave.dtype == np.float32 or wave.dtype == np.float64:
amplified[amplified >= 1] = 1
amplified[amplified <= -1] = -1
# fr, ts, spectrogram = signal.spectrogram(wave[:48000], fs=sr, nperseg=480, noverlap=240, nfft=512)
# spectrogram = librosa.feature.melspectrogram(S=spectrogram, n_mels=257) # Try log-mel spectrogram?
spectrogram = librosa.feature.melspectrogram(
y=wave[:48000], sr=sr, hop_length=240, win_length=480, n_mels=257)
if log_scale:
spectrogram = librosa.power_to_db(spectrogram, ref=np.max)
assert spectrogram.shape[0] == 257
return spectrogram
def cropAudio(audio, sr, f_idx, fps=10, length=1., left_shift=0):
time_per_frame = 1./fps
assert audio.shape[0] > sr * length
start_time = f_idx * time_per_frame - left_shift
start_time = 0 if start_time < 0 else start_time
start_idx = int(np.round(sr * start_time))
end_idx = int(np.round(start_idx + (sr * length)))
if end_idx > audio.shape[0]:
end_idx = audio.shape[0]
start_idx = int(end_idx - (sr * length))
try:
assert audio[start_idx:end_idx].shape[0] == sr * length
except:
print(audio.shape, start_idx, end_idx, end_idx - start_idx)
exit(1)
return audio[start_idx:end_idx]
def pick_async_frame_idx(idx, total_frames, fps=10, gap=2.0, length=1.0, cnt=1):
assert idx < total_frames - fps * length
lower_bound = idx - int((length + gap) * fps)
upper_bound = idx + int((length + gap) * fps)
proposal = list(range(0, lower_bound)) + \
list(range(upper_bound, int(total_frames - fps * length)))
# assert len(proposal) >= cnt
avail_cnt = len(proposal)
try:
for i in range(cnt - avail_cnt):
proposal.append(proposal[i % avail_cnt])
except Exception as e:
print(idx, total_frames, proposal)
raise e
return sample(proposal, k=cnt)
def adjust_learning_rate(optimizer, epoch, args):
"""Decay the learning rate based on schedule"""
lr = args.lr
if args.cos: # cosine lr schedule
lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epoch))
else: # stepwise lr schedule
for milestone in args.schedule:
lr *= 0.1 if epoch >= milestone else 1.
for param_group in optimizer.param_groups:
param_group['lr'] = lr