ymzhang319's picture
init
7f2690b
raw
history blame
6.8 kB
import glob
import os
import numpy as np
from moviepy.editor import *
import librosa
import soundfile as sf
import argparse
import numpy as np
import os
import sys
import time
from tqdm import tqdm
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms as transforms
from PIL import Image
import shutil
from config import init_args
import data
import models
from models import *
from utils import utils, torch_utils
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vision_transform_list = [
transforms.Resize((128, 128)),
transforms.CenterCrop((112, 112)),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]
video_transform = transforms.Compose(vision_transform_list)
def read_image(frame_list):
imgs = []
convert_tensor = transforms.ToTensor()
for img_path in frame_list:
image = Image.open(img_path).convert('RGB')
image = convert_tensor(image)
imgs.append(image.unsqueeze(0))
# (T, C, H ,W)
imgs = torch.cat(imgs, dim=0).squeeze()
imgs = video_transform(imgs)
imgs = imgs.permute(1, 0, 2, 3)
# (C, T, H ,W)
return imgs
def get_video_frames(origin_video_path):
save_path = 'results/temp'
if os.path.exists(save_path):
os.system(f'rm -rf {save_path}')
os.makedirs(save_path)
command = f'ffmpeg -v quiet -y -i \"{origin_video_path}\" -f image2 -vf \"scale=-1:360,fps=15\" -qscale:v 3 \"{save_path}\"/frame%06d.jpg'
os.system(command)
frame_list = glob.glob(f'{save_path}/*.jpg')
frame_list.sort()
frame_list = frame_list[:2 * 15]
frames = read_image(frame_list)
return frames
def postprocess_video_onsets(probs, thres=0.5, nearest=5):
# import pdb; pdb.set_trace()
video_onsets = []
pred = np.array(probs, copy=True)
while True:
max_ind = np.argmax(pred)
video_onsets.append(max_ind)
low = max(max_ind - nearest, 0)
high = min(max_ind + nearest, pred.shape[0])
pred[low: high] = 0
if (pred > thres).sum() == 0:
break
video_onsets.sort()
video_onsets = np.array(video_onsets)
return video_onsets
def detect_onset_of_audio(audio, sample_rate):
onsets = librosa.onset.onset_detect(
y=audio, sr=sample_rate, units='samples', delta=0.3)
return onsets
def get_onset_audio_range(audio_len, onsets, i):
if i == 0:
prev_offset = int(onsets[i] // 3)
else:
prev_offset = int((onsets[i] - onsets[i - 1]) // 3)
if i == onsets.shape[0] - 1:
post_offset = int((audio_len - onsets[i]) // 4 * 2)
else:
post_offset = int((onsets[i + 1] - onsets[i]) // 4 * 2)
return prev_offset, post_offset
def generate_audio(con_videoclip, video_onsets):
np.random.seed(2022)
con_audioclip = con_videoclip.audio
con_audio, con_sr = con_audioclip.to_soundarray(), con_audioclip.fps
con_audio = con_audio.mean(-1)
target_sr = 22050
if target_sr != con_sr:
con_audio = librosa.resample(con_audio, orig_sr=con_sr, target_sr=target_sr)
con_sr = target_sr
con_onsets = detect_onset_of_audio(con_audio, con_sr)
gen_audio = np.zeros(int(2 * con_sr))
for i in range(video_onsets.shape[0]):
prev_offset, post_offset = get_onset_audio_range(
int(con_sr * 2), video_onsets, i)
j = np.random.choice(con_onsets.shape[0])
prev_offset_con, post_offset_con = get_onset_audio_range(
con_audio.shape[0], con_onsets, j)
prev_offset = min(prev_offset, prev_offset_con)
post_offset = min(post_offset, post_offset_con)
gen_audio[video_onsets[i] - prev_offset: video_onsets[i] +
post_offset] = con_audio[con_onsets[j] - prev_offset: con_onsets[j] + post_offset]
return gen_audio
def generate_video(net, original_video_list, cond_video_list_0, cond_video_list_1, cond_video_list_2):
save_folder = 'results/onset_baseline/vis'
os.makedirs(save_folder, exist_ok=True)
origin_video_folder = os.path.join(save_folder, '0_original')
os.makedirs(origin_video_folder, exist_ok=True)
for i in range(len(original_video_list)):
# import pdb; pdb.set_trace()
shutil.copy(original_video_list[i], os.path.join(
origin_video_folder, original_video_list[i].split('/')[-1]))
ori_videoclip = VideoFileClip(original_video_list[i])
frames = get_video_frames(original_video_list[i])
inputs = {
'frames': frames.unsqueeze(0).to(device)
}
pred = net(inputs).squeeze()
pred = torch.sigmoid(pred).data.cpu().numpy()
video_onsets = postprocess_video_onsets(pred, thres=0.5, nearest=4)
video_onsets = (video_onsets / 15 * 22050).astype(int)
for ind, cond_video in enumerate([cond_video_list_0[i], cond_video_list_1[i], cond_video_list_2[i]]):
cond_video_folder = os.path.join(save_folder, f'{ind * 2 + 1}_conditional_{ind}')
os.makedirs(cond_video_folder, exist_ok=True)
shutil.copy(cond_video, os.path.join(
cond_video_folder, cond_video.split('/')[-1]))
con_videoclip = VideoFileClip(cond_video)
gen_audio = generate_audio(con_videoclip, video_onsets)
save_audio_path = 'results/gen_audio.wav'
sf.write(save_audio_path, gen_audio, 22050)
gen_audioclip = AudioFileClip(save_audio_path)
gen_videoclip = ori_videoclip.set_audio(gen_audioclip)
save_gen_folder = os.path.join(save_folder, f'{ind * 2 + 2}_generate_{ind}')
os.makedirs(save_gen_folder, exist_ok=True)
gen_videoclip.write_videofile(os.path.join(save_gen_folder, original_video_list[i].split('/')[-1]))
if __name__ == '__main__':
net = models.VideoOnsetNet(pretrained=False).to(device)
resume = 'checkpoints/EXP1/checkpoint_ep100.pth.tar'
net, _ = torch_utils.load_model(resume, net, device=device, strict=True)
read_folder = '' # name to a directory that generated with `audio_generation.py`
original_video_list = glob.glob(f'{read_folder}/2sec_full_orig_video/*.mp4')
original_video_list.sort()
cond_video_list_0 = glob.glob(f'{read_folder}/2sec_full_cond_video_0/*.mp4')
cond_video_list_0.sort()
cond_video_list_1 = glob.glob(f'{read_folder}/2sec_full_cond_video_1/*.mp4')
cond_video_list_1.sort()
cond_video_list_2 = glob.glob(f'{read_folder}/2sec_full_cond_video_2/*.mp4')
cond_video_list_2.sort()
generate_video(net, original_video_list, cond_video_list_0, cond_video_list_1, cond_video_list_2)