|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import cv2 |
|
import glob |
|
import os |
|
import torch |
|
import requests |
|
import numpy as np |
|
from os import path as osp |
|
from collections import OrderedDict |
|
from torch.utils.data import DataLoader |
|
|
|
from models.network_vrt import VRT as net |
|
from utils import utils_image as util |
|
from data.dataset_video_test import VideoRecurrentTestDataset, VideoTestVimeo90KDataset, \ |
|
SingleVideoRecurrentTestDataset, VFI_DAVIS, VFI_UCF101, VFI_Vid4 |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--task', type=str, default='001_VRT_videosr_bi_REDS_6frames', help='tasks: 001 to 008') |
|
parser.add_argument('--sigma', type=int, default=0, help='noise level for denoising: 10, 20, 30, 40, 50') |
|
parser.add_argument('--folder_lq', type=str, default='testsets/REDS4/sharp_bicubic', |
|
help='input low-quality test video folder') |
|
parser.add_argument('--folder_gt', type=str, default=None, |
|
help='input ground-truth test video folder') |
|
parser.add_argument('--tile', type=int, nargs='+', default=[40,128,128], |
|
help='Tile size, [0,0,0] for no tile during testing (testing as a whole)') |
|
parser.add_argument('--tile_overlap', type=int, nargs='+', default=[2,20,20], |
|
help='Overlapping of different tiles') |
|
parser.add_argument('--num_workers', type=int, default=16, help='number of workers in data loading') |
|
parser.add_argument('--save_result', action='store_true', help='save resulting image') |
|
args = parser.parse_args() |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
model = prepare_model_dataset(args) |
|
model.eval() |
|
model = model.to(device) |
|
if 'vimeo' in args.folder_lq.lower(): |
|
if 'videofi' in args.task: |
|
test_set = VideoTestVimeo90KDataset({'dataroot_gt':args.folder_gt, 'dataroot_lq':args.folder_gt, |
|
'meta_info_file': "data/meta_info/meta_info_Vimeo90K_test_GT.txt", |
|
'pad_sequence': False, 'num_frame': 7, 'temporal_scale': 2, |
|
'cache_data': False}) |
|
else: |
|
test_set = VideoTestVimeo90KDataset({'dataroot_gt': args.folder_gt, 'dataroot_lq': args.folder_lq, |
|
'meta_info_file': "data/meta_info/meta_info_Vimeo90K_test_GT.txt", |
|
'pad_sequence': True, 'num_frame': 7, 'temporal_scale': 1, |
|
'cache_data': False}) |
|
elif 'davis' in args.folder_lq.lower() and 'videofi' in args.task: |
|
test_set = VFI_DAVIS(data_root=args.folder_gt) |
|
elif 'ucf101' in args.folder_lq.lower() and 'videofi' in args.task: |
|
test_set = VFI_UCF101(data_root=args.folder_gt) |
|
elif 'vid4' in args.folder_lq.lower() and 'videofi' in args.task: |
|
test_set = VFI_Vid4(data_root=args.folder_gt) |
|
elif args.folder_gt is not None: |
|
test_set = VideoRecurrentTestDataset({'dataroot_gt':args.folder_gt, 'dataroot_lq':args.folder_lq, |
|
'sigma':args.sigma, 'num_frame':-1, 'cache_data': False}) |
|
else: |
|
test_set = SingleVideoRecurrentTestDataset({'dataroot_gt':args.folder_gt, 'dataroot_lq':args.folder_lq, |
|
'sigma':args.sigma, 'num_frame':-1, 'cache_data': False}) |
|
|
|
test_loader = DataLoader(dataset=test_set, num_workers=args.num_workers, batch_size=1, shuffle=False) |
|
|
|
save_dir = f'results/{args.task}' |
|
if args.save_result: |
|
os.makedirs(save_dir, exist_ok=True) |
|
test_results = OrderedDict() |
|
test_results['psnr'] = [] |
|
test_results['ssim'] = [] |
|
test_results['psnr_y'] = [] |
|
test_results['ssim_y'] = [] |
|
|
|
assert len(test_loader) != 0, f'No dataset found at {args.folder_lq}' |
|
|
|
for idx, batch in enumerate(test_loader): |
|
lq = batch['L'].to(device) |
|
folder = batch['folder'] |
|
gt = batch['H'] if 'H' in batch else None |
|
|
|
|
|
with torch.no_grad(): |
|
output = test_video(lq, model, args) |
|
|
|
if 'videofi' in args.task: |
|
output = output[:, :1, ...] |
|
batch['lq_path'] = batch['gt_path'] |
|
elif 'videosr' in args.task and 'vimeo' in args.folder_lq.lower(): |
|
output = output[:, 3:4, :, :, :] |
|
batch['lq_path'] = batch['gt_path'] |
|
|
|
test_results_folder = OrderedDict() |
|
test_results_folder['psnr'] = [] |
|
test_results_folder['ssim'] = [] |
|
test_results_folder['psnr_y'] = [] |
|
test_results_folder['ssim_y'] = [] |
|
|
|
for i in range(output.shape[1]): |
|
|
|
img = output[:, i, ...].data.squeeze().float().cpu().clamp_(0, 1).numpy() |
|
if img.ndim == 3: |
|
img = np.transpose(img[[2, 1, 0], :, :], (1, 2, 0)) |
|
img = (img * 255.0).round().astype(np.uint8) |
|
if args.save_result: |
|
seq_ = osp.basename(batch['lq_path'][i][0]).split('.')[0] |
|
os.makedirs(f'{save_dir}/{folder[0]}', exist_ok=True) |
|
cv2.imwrite(f'{save_dir}/{folder[0]}/{seq_}.png', img) |
|
|
|
|
|
if gt is not None: |
|
img_gt = gt[:, i, ...].data.squeeze().float().cpu().clamp_(0, 1).numpy() |
|
if img_gt.ndim == 3: |
|
img_gt = np.transpose(img_gt[[2, 1, 0], :, :], (1, 2, 0)) |
|
img_gt = (img_gt * 255.0).round().astype(np.uint8) |
|
img_gt = np.squeeze(img_gt) |
|
|
|
test_results_folder['psnr'].append(util.calculate_psnr(img, img_gt, border=0)) |
|
test_results_folder['ssim'].append(util.calculate_ssim(img, img_gt, border=0)) |
|
if img_gt.ndim == 3: |
|
img = util.bgr2ycbcr(img.astype(np.float32) / 255.) * 255. |
|
img_gt = util.bgr2ycbcr(img_gt.astype(np.float32) / 255.) * 255. |
|
test_results_folder['psnr_y'].append(util.calculate_psnr(img, img_gt, border=0)) |
|
test_results_folder['ssim_y'].append(util.calculate_ssim(img, img_gt, border=0)) |
|
else: |
|
test_results_folder['psnr_y'] = test_results_folder['psnr'] |
|
test_results_folder['ssim_y'] = test_results_folder['ssim'] |
|
|
|
if gt is not None: |
|
psnr = sum(test_results_folder['psnr']) / len(test_results_folder['psnr']) |
|
ssim = sum(test_results_folder['ssim']) / len(test_results_folder['ssim']) |
|
psnr_y = sum(test_results_folder['psnr_y']) / len(test_results_folder['psnr_y']) |
|
ssim_y = sum(test_results_folder['ssim_y']) / len(test_results_folder['ssim_y']) |
|
test_results['psnr'].append(psnr) |
|
test_results['ssim'].append(ssim) |
|
test_results['psnr_y'].append(psnr_y) |
|
test_results['ssim_y'].append(ssim_y) |
|
print('Testing {:20s} ({:2d}/{}) - PSNR: {:.2f} dB; SSIM: {:.4f}; PSNR_Y: {:.2f} dB; SSIM_Y: {:.4f}'. |
|
format(folder[0], idx, len(test_loader), psnr, ssim, psnr_y, ssim_y)) |
|
else: |
|
print('Testing {:20s} ({:2d}/{})'.format(folder[0], idx, len(test_loader))) |
|
|
|
|
|
if gt is not None: |
|
ave_psnr = sum(test_results['psnr']) / len(test_results['psnr']) |
|
ave_ssim = sum(test_results['ssim']) / len(test_results['ssim']) |
|
ave_psnr_y = sum(test_results['psnr_y']) / len(test_results['psnr_y']) |
|
ave_ssim_y = sum(test_results['ssim_y']) / len(test_results['ssim_y']) |
|
print('\n{} \n-- Average PSNR: {:.2f} dB; SSIM: {:.4f}; PSNR_Y: {:.2f} dB; SSIM_Y: {:.4f}'. |
|
format(save_dir, ave_psnr, ave_ssim, ave_psnr_y, ave_ssim_y)) |
|
|
|
|
|
def prepare_model_dataset(args): |
|
''' prepare model and dataset according to args.task. ''' |
|
|
|
|
|
if args.task == '001_VRT_videosr_bi_REDS_6frames': |
|
model = net(upscale=4, img_size=[6,64,64], window_size=[6,8,8], depths=[8,8,8,8,8,8,8, 4,4,4,4, 4,4], |
|
indep_reconsts=[11,12], embed_dims=[120,120,120,120,120,120,120, 180,180,180,180, 180,180], |
|
num_heads=[6,6,6,6,6,6,6, 6,6,6,6, 6,6], pa_frames=2, deformable_groups=12) |
|
datasets = ['REDS4'] |
|
args.scale = 4 |
|
args.window_size = [6,8,8] |
|
args.nonblind_denoising = False |
|
|
|
elif args.task == '002_VRT_videosr_bi_REDS_16frames': |
|
model = net(upscale=4, img_size=[16,64,64], window_size=[8,8,8], depths=[8,8,8,8,8,8,8, 4,4,4,4, 4,4], |
|
indep_reconsts=[11,12], embed_dims=[120,120,120,120,120,120,120, 180,180,180,180, 180,180], |
|
num_heads=[6,6,6,6,6,6,6, 6,6,6,6, 6,6], pa_frames=6, deformable_groups=24) |
|
datasets = ['REDS4'] |
|
args.scale = 4 |
|
args.window_size = [8,8,8] |
|
args.nonblind_denoising = False |
|
|
|
elif args.task in ['003_VRT_videosr_bi_Vimeo_7frames', '004_VRT_videosr_bd_Vimeo_7frames']: |
|
model = net(upscale=4, img_size=[8,64,64], window_size=[8,8,8], depths=[8,8,8,8,8,8,8, 4,4,4,4, 4,4], |
|
indep_reconsts=[11,12], embed_dims=[120,120,120,120,120,120,120, 180,180,180,180, 180,180], |
|
num_heads=[6,6,6,6,6,6,6, 6,6,6,6, 6,6], pa_frames=4, deformable_groups=16) |
|
datasets = ['Vid4'] |
|
args.scale = 4 |
|
args.window_size = [8,8,8] |
|
args.nonblind_denoising = False |
|
|
|
elif args.task in ['005_VRT_videodeblurring_DVD']: |
|
model = net(upscale=1, img_size=[6,192,192], window_size=[6,8,8], depths=[8,8,8,8,8,8,8, 4,4, 4,4], |
|
indep_reconsts=[9,10], embed_dims=[96,96,96,96,96,96,96, 120,120, 120,120], |
|
num_heads=[6,6,6,6,6,6,6, 6,6, 6,6], pa_frames=2, deformable_groups=16) |
|
datasets = ['DVD10'] |
|
args.scale = 1 |
|
args.window_size = [6,8,8] |
|
args.nonblind_denoising = False |
|
|
|
elif args.task in ['006_VRT_videodeblurring_GoPro']: |
|
model = net(upscale=1, img_size=[6,192,192], window_size=[6,8,8], depths=[8,8,8,8,8,8,8, 4,4, 4,4], |
|
indep_reconsts=[9,10], embed_dims=[96,96,96,96,96,96,96, 120,120, 120,120], |
|
num_heads=[6,6,6,6,6,6,6, 6,6, 6,6], pa_frames=2, deformable_groups=16) |
|
datasets = ['GoPro11-part1', 'GoPro11-part2'] |
|
args.scale = 1 |
|
args.window_size = [6,8,8] |
|
args.nonblind_denoising = False |
|
|
|
elif args.task in ['007_VRT_videodeblurring_REDS']: |
|
model = net(upscale=1, img_size=[6,192,192], window_size=[6,8,8], depths=[8,8,8,8,8,8,8, 4,4, 4,4], |
|
indep_reconsts=[9,10], embed_dims=[96,96,96,96,96,96,96, 120,120, 120,120], |
|
num_heads=[6,6,6,6,6,6,6, 6,6, 6,6], pa_frames=2, deformable_groups=16) |
|
datasets = ['REDS4'] |
|
args.scale = 1 |
|
args.window_size = [6,8,8] |
|
args.nonblind_denoising = False |
|
|
|
elif args.task == '008_VRT_videodenoising_DAVIS': |
|
model = net(upscale=1, img_size=[6,192,192], window_size=[6,8,8], depths=[8,8,8,8,8,8,8, 4,4, 4,4], |
|
indep_reconsts=[9,10], embed_dims=[96,96,96,96,96,96,96, 120,120, 120,120], |
|
num_heads=[6,6,6,6,6,6,6, 6,6, 6,6], pa_frames=2, deformable_groups=16, |
|
nonblind_denoising=True) |
|
datasets = ['Set8', 'DAVIS-test'] |
|
args.scale = 1 |
|
args.window_size = [6,8,8] |
|
args.nonblind_denoising = True |
|
|
|
elif args.task == '009_VRT_videofi_Vimeo_4frames': |
|
model = net(upscale=1, out_chans=3, img_size=[4,192,192], window_size=[4,8,8], depths=[8,8,8,8,8,8,8, 4,4, 4,4], |
|
indep_reconsts=[], embed_dims=[96,96,96,96,96,96,96, 120,120, 120,120], |
|
num_heads=[6,6,6,6,6,6,6, 6,6, 6,6], pa_frames=0) |
|
datasets = ['UCF101', 'DAVIS-train'] |
|
args.scale = 1 |
|
args.window_size = [4,8,8] |
|
args.nonblind_denoising = False |
|
|
|
|
|
model_path = f'model_zoo/vrt/{args.task}.pth' |
|
if os.path.exists(model_path): |
|
print(f'loading model from ./model_zoo/vrt/{model_path}') |
|
else: |
|
os.makedirs(os.path.dirname(model_path), exist_ok=True) |
|
url = 'https://github.com/JingyunLiang/VRT/releases/download/v0.0/{}'.format(os.path.basename(model_path)) |
|
r = requests.get(url, allow_redirects=True) |
|
print(f'downloading model {model_path}') |
|
open(model_path, 'wb').write(r.content) |
|
|
|
pretrained_model = torch.load(model_path) |
|
model.load_state_dict(pretrained_model['params'] if 'params' in pretrained_model.keys() else pretrained_model, strict=True) |
|
|
|
|
|
if os.path.exists(f'{args.folder_lq}'): |
|
print(f'using dataset from {args.folder_lq}') |
|
else: |
|
if 'vimeo' in args.folder_lq.lower(): |
|
print(f'Vimeo dataset is not at {args.folder_lq}! Please refer to #training of Readme.md to download it.') |
|
else: |
|
os.makedirs('testsets', exist_ok=True) |
|
for dataset in datasets: |
|
url = f'https://github.com/JingyunLiang/VRT/releases/download/v0.0/testset_{dataset}.tar.gz' |
|
r = requests.get(url, allow_redirects=True) |
|
print(f'downloading testing dataset {dataset}') |
|
open(f'testsets/{dataset}.tar.gz', 'wb').write(r.content) |
|
os.system(f'tar -xvf testsets/{dataset}.tar.gz -C testsets') |
|
os.system(f'rm testsets/{dataset}.tar.gz') |
|
|
|
return model |
|
|
|
|
|
def test_video(lq, model, args): |
|
'''test the video as a whole or as clips (divided temporally). ''' |
|
|
|
num_frame_testing = args.tile[0] |
|
if num_frame_testing: |
|
|
|
sf = args.scale |
|
num_frame_overlapping = args.tile_overlap[0] |
|
not_overlap_border = False |
|
b, d, c, h, w = lq.size() |
|
c = c - 1 if args.nonblind_denoising else c |
|
stride = num_frame_testing - num_frame_overlapping |
|
d_idx_list = list(range(0, d-num_frame_testing, stride)) + [max(0, d-num_frame_testing)] |
|
E = torch.zeros(b, d, c, h*sf, w*sf) |
|
W = torch.zeros(b, d, 1, 1, 1) |
|
|
|
for d_idx in d_idx_list: |
|
lq_clip = lq[:, d_idx:d_idx+num_frame_testing, ...] |
|
out_clip = test_clip(lq_clip, model, args) |
|
out_clip_mask = torch.ones((b, min(num_frame_testing, d), 1, 1, 1)) |
|
|
|
if not_overlap_border: |
|
if d_idx < d_idx_list[-1]: |
|
out_clip[:, -num_frame_overlapping//2:, ...] *= 0 |
|
out_clip_mask[:, -num_frame_overlapping//2:, ...] *= 0 |
|
if d_idx > d_idx_list[0]: |
|
out_clip[:, :num_frame_overlapping//2, ...] *= 0 |
|
out_clip_mask[:, :num_frame_overlapping//2, ...] *= 0 |
|
|
|
E[:, d_idx:d_idx+num_frame_testing, ...].add_(out_clip) |
|
W[:, d_idx:d_idx+num_frame_testing, ...].add_(out_clip_mask) |
|
output = E.div_(W) |
|
else: |
|
|
|
window_size = args.window_size |
|
d_old = lq.size(1) |
|
d_pad = (window_size[0] - d_old % window_size[0]) % window_size[0] |
|
lq = torch.cat([lq, torch.flip(lq[:, -d_pad:, ...], [1])], 1) if d_pad else lq |
|
output = test_clip(lq, model, args) |
|
output = output[:, :d_old, :, :, :] |
|
|
|
return output |
|
|
|
|
|
def test_clip(lq, model, args): |
|
''' test the clip as a whole or as patches. ''' |
|
|
|
sf = args.scale |
|
window_size = args.window_size |
|
size_patch_testing = args.tile[1] |
|
assert size_patch_testing % window_size[-1] == 0, 'testing patch size should be a multiple of window_size.' |
|
|
|
if size_patch_testing: |
|
|
|
overlap_size = args.tile_overlap[1] |
|
not_overlap_border = True |
|
|
|
|
|
b, d, c, h, w = lq.size() |
|
c = c - 1 if args.nonblind_denoising else c |
|
stride = size_patch_testing - overlap_size |
|
h_idx_list = list(range(0, h-size_patch_testing, stride)) + [max(0, h-size_patch_testing)] |
|
w_idx_list = list(range(0, w-size_patch_testing, stride)) + [max(0, w-size_patch_testing)] |
|
E = torch.zeros(b, d, c, h*sf, w*sf) |
|
W = torch.zeros_like(E) |
|
|
|
for h_idx in h_idx_list: |
|
for w_idx in w_idx_list: |
|
in_patch = lq[..., h_idx:h_idx+size_patch_testing, w_idx:w_idx+size_patch_testing] |
|
out_patch = model(in_patch).detach().cpu() |
|
|
|
out_patch_mask = torch.ones_like(out_patch) |
|
|
|
if not_overlap_border: |
|
if h_idx < h_idx_list[-1]: |
|
out_patch[..., -overlap_size//2:, :] *= 0 |
|
out_patch_mask[..., -overlap_size//2:, :] *= 0 |
|
if w_idx < w_idx_list[-1]: |
|
out_patch[..., :, -overlap_size//2:] *= 0 |
|
out_patch_mask[..., :, -overlap_size//2:] *= 0 |
|
if h_idx > h_idx_list[0]: |
|
out_patch[..., :overlap_size//2, :] *= 0 |
|
out_patch_mask[..., :overlap_size//2, :] *= 0 |
|
if w_idx > w_idx_list[0]: |
|
out_patch[..., :, :overlap_size//2] *= 0 |
|
out_patch_mask[..., :, :overlap_size//2] *= 0 |
|
|
|
E[..., h_idx*sf:(h_idx+size_patch_testing)*sf, w_idx*sf:(w_idx+size_patch_testing)*sf].add_(out_patch) |
|
W[..., h_idx*sf:(h_idx+size_patch_testing)*sf, w_idx*sf:(w_idx+size_patch_testing)*sf].add_(out_patch_mask) |
|
output = E.div_(W) |
|
|
|
else: |
|
_, _, _, h_old, w_old = lq.size() |
|
h_pad = (window_size[1] - h_old % window_size[1]) % window_size[1] |
|
w_pad = (window_size[2] - w_old % window_size[2]) % window_size[2] |
|
|
|
lq = torch.cat([lq, torch.flip(lq[:, :, :, -h_pad:, :], [3])], 3) if h_pad else lq |
|
lq = torch.cat([lq, torch.flip(lq[:, :, :, :, -w_pad:], [4])], 4) if w_pad else lq |
|
|
|
output = model(lq).detach().cpu() |
|
|
|
output = output[:, :, :, :h_old*sf, :w_old*sf] |
|
|
|
return output |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|