File size: 8,267 Bytes
7f2690b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import json
from random import shuffle, choice, sample

from moviepy.editor import VideoFileClip
import librosa
from scipy import signal
from scipy.io import wavfile
import torchaudio
torchaudio.set_audio_backend("sox_io")

INTERVAL = 1000

# discard
stft = torchaudio.transforms.MelSpectrogram(
    sample_rate=16000, hop_length=161, n_mels=64).cuda()


def log10(x): return torch.log(x)/torch.log(torch.tensor(10.))


def norm_range(x, min_val, max_val):
    return 2.*(x - min_val)/float(max_val - min_val) - 1.


def normalize_spec(spec, spec_min, spec_max):
    return norm_range(spec, spec_min, spec_max)


def db_from_amp(x, cuda=False):
    # rescale the audio
    if cuda:
        return 20. * log10(torch.max(torch.tensor(1e-5).to('cuda'), x.float()))
    else:
        return 20. * log10(torch.max(torch.tensor(1e-5), x.float()))


def audio_stft(audio, stft=stft):
    # We'll apply stft to the audio samples to convert it to a HxW matrix
    N, C, A = audio.size()
    audio = audio.view(N * C, A)
    spec = stft(audio)
    spec = spec.transpose(-1, -2)
    spec = db_from_amp(spec, cuda=True)
    spec = normalize_spec(spec, -100., 100.)
    _, T, F = spec.size()
    spec = spec.view(N, C, T, F)
    return spec


# discard
# def get_spec(
#     wavs,
#     sample_rate=16000,
#     use_volume_jittering=False,
#     center=False,
# ):
#     # Volume  jittering - scale volume by factor in range (0.9, 1.1)
#     if use_volume_jittering:
#         wavs = [wav * np.random.uniform(0.9, 1.1) for wav in wavs]
#     if center:
#         wavs = [center_only(wav) for wav in wavs]

#     # Convert to log filterbank
#     specs = [logfbank(
#         wav,
#         sample_rate,
#         winlen=0.009,
#         winstep=0.005,  # if num_sec==1 else 0.01,
#         nfilt=256,
#         nfft=1024
#     ).astype('float32').T for wav in wavs]

#     # Convert to 32-bit float and expand dim
#     specs = np.stack(specs, axis=0)
#     specs = np.expand_dims(specs, 1)
#     specs = torch.as_tensor(specs)  # Nx1xFxT

#     return specs


def center_only(audio, sr=16000, L=1.0):
    # center_wav = np.arange(0, L, L/(0.5*sr)) ** 2
    # center_wav = np.concatenate([center_wav, center_wav[::-1]])
    # center_wav[L*sr//2:3*L*sr//4] = 1
    # only take 0.3 sec audio
    center_wav = np.zeros(int(L * sr))
    center_wav[int(0.4*L*sr):int(0.7*L*sr)] = 1

    return audio * center_wav

def get_spec_librosa(
    wavs,
    sample_rate=16000,
    use_volume_jittering=False,
    center=False,
):
    # Volume  jittering - scale volume by factor in range (0.9, 1.1)
    if use_volume_jittering:
        wavs = [wav * np.random.uniform(0.9, 1.1) for wav in wavs]
    if center:
        wavs = [center_only(wav) for wav in wavs]

    # Convert to log filterbank
    specs = [librosa.feature.melspectrogram(
        y=wav,
        sr=sample_rate,
        n_fft=400,
        hop_length=126,
        n_mels=128,
    ).astype('float32') for wav in wavs]

    # Convert to 32-bit float and expand dim
    specs = [librosa.power_to_db(spec) for spec in specs]
    specs = np.stack(specs, axis=0)
    specs = np.expand_dims(specs, 1)
    specs = torch.as_tensor(specs)  # Nx1xFxT

    return specs


def calcEuclideanDistance_Mat(X, Y):
    """
    Inputs:
    - X: A numpy array of shape (N, F)
    - Y: A numpy array of shape (M, F)

    Returns:
    A numpy array D of shape (N, M) where D[i, j] is the Euclidean distance
    between X[i] and Y[j].
    """
    return ((torch.sum(X ** 2, axis=1, keepdims=True)) + (torch.sum(Y ** 2, axis=1, keepdims=True)).T - 2 * X @ Y.T) ** 0.5


def calcEuclideanDistance(x1, x2):
    return torch.sum((x1 - x2)**2, dim=1)**0.5


def split_data(in_list, portion=(0.9, 0.95), is_shuffle=True):
    if is_shuffle:
        shuffle(in_list)
    if type(in_list) == str:
        with open(in_list) as l:
            fw_list = json.load(l)
    elif type(in_list) == list:
        fw_list = in_list
    else:
        print(type(in_list))
        raise TypeError('Invalid input list type')
    c1, c2 = int(len(fw_list) * portion[0]), int(len(fw_list) * portion[1])
    tr_list, va_list, te_list = fw_list[:c1], fw_list[c1:c2], fw_list[c2:]
    print(
        f'==> train set: {len(tr_list)}, validation set: {len(va_list)}, test set: {len(te_list)}')
    return tr_list, va_list, te_list


def load_one_clip(video_path):
    v = VideoFileClip(video_path)
    fps = int(v.fps)
    frames = [f for f in v.iter_frames()][:-1]
    frame_cnt = len(frames)
    frame_length = 1000./fps
    total_length = int(1000 * (frame_cnt / fps))

    a = v.audio
    sr = a.fps
    a = np.array([fa for fa in a.iter_frames()])
    a = librosa.resample(a, sr, 48000)
    if len(a.shape) > 1:
        a = np.mean(a, axis=1)

    while True:
        idx = np.random.choice(np.arange(frame_cnt - 1), 1)[0]
        frame_clip = frames[idx]
        start_time = int(idx * frame_length + 0.5 * frame_length - 500)
        end_time = start_time + INTERVAL
        if start_time < 0 or end_time > total_length:
            continue
        wave_clip = a[48 * start_time: 48 * end_time]
        if wave_clip.shape[0] != 48000:
            continue
        break
    return frame_clip, wave_clip


def resize_frame(frame):
    H, W = frame.size
    short_edge = min(H, W)
    scale = 256 / short_edge
    H_tar, W_tar = int(np.round(H * scale)), int(np.round(W * scale))
    return frame.resize((H_tar, W_tar))


def get_spectrogram(wave, amp_jitter, amp_jitter_range, log_scale=True, sr=48000):
    # random clip-level amplitude jittering
    if amp_jitter:
        amplified = wave * np.random.uniform(*amp_jitter_range)
        if wave.dtype == np.int16:
            amplified[amplified >= 32767] = 32767
            amplified[amplified <= -32768] = -32768
            wave = amplified.astype('int16')
        elif wave.dtype == np.float32 or wave.dtype == np.float64:
            amplified[amplified >= 1] = 1
            amplified[amplified <= -1] = -1

    # fr, ts, spectrogram = signal.spectrogram(wave[:48000], fs=sr, nperseg=480, noverlap=240, nfft=512)
    # spectrogram = librosa.feature.melspectrogram(S=spectrogram, n_mels=257) # Try log-mel spectrogram?
    spectrogram = librosa.feature.melspectrogram(
        y=wave[:48000], sr=sr, hop_length=240, win_length=480, n_mels=257)
    if log_scale:
        spectrogram = librosa.power_to_db(spectrogram, ref=np.max)
    assert spectrogram.shape[0] == 257

    return spectrogram


def cropAudio(audio, sr, f_idx, fps=10, length=1., left_shift=0):
    time_per_frame = 1./fps
    assert audio.shape[0] > sr * length
    start_time = f_idx * time_per_frame - left_shift
    start_time = 0 if start_time < 0 else start_time
    start_idx = int(np.round(sr * start_time))
    end_idx = int(np.round(start_idx + (sr * length)))
    if end_idx > audio.shape[0]:
        end_idx = audio.shape[0]
        start_idx = int(end_idx - (sr * length))
    try:
        assert audio[start_idx:end_idx].shape[0] == sr * length
    except:
        print(audio.shape, start_idx, end_idx, end_idx - start_idx)
        exit(1)
    return audio[start_idx:end_idx]


def pick_async_frame_idx(idx, total_frames, fps=10, gap=2.0, length=1.0, cnt=1):
    assert idx < total_frames - fps * length
    lower_bound = idx - int((length + gap) * fps)
    upper_bound = idx + int((length + gap) * fps)
    proposal = list(range(0, lower_bound)) + \
        list(range(upper_bound, int(total_frames - fps * length)))
    # assert len(proposal) >= cnt
    avail_cnt = len(proposal)
    try:
        for i in range(cnt - avail_cnt):
            proposal.append(proposal[i % avail_cnt])
    except Exception as e:
        print(idx, total_frames, proposal)
        raise e
    return sample(proposal, k=cnt)


def adjust_learning_rate(optimizer, epoch, args):
    """Decay the learning rate based on schedule"""
    lr = args.lr
    if args.cos:  # cosine lr schedule
        lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epoch))
    else:  # stepwise lr schedule
        for milestone in args.schedule:
            lr *= 0.1 if epoch >= milestone else 1.
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr