# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. import argparse import cv2 import glob import os import torch import requests import numpy as np from os import path as osp from collections import OrderedDict from torch.utils.data import DataLoader from models.network_vrt import VRT as net from utils import utils_image as util from data.dataset_video_test import VideoRecurrentTestDataset, VideoTestVimeo90KDataset, \ SingleVideoRecurrentTestDataset, VFI_DAVIS, VFI_UCF101, VFI_Vid4 def main(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='001_VRT_videosr_bi_REDS_6frames', help='tasks: 001 to 008') parser.add_argument('--sigma', type=int, default=0, help='noise level for denoising: 10, 20, 30, 40, 50') parser.add_argument('--folder_lq', type=str, default='testsets/REDS4/sharp_bicubic', help='input low-quality test video folder') parser.add_argument('--folder_gt', type=str, default=None, help='input ground-truth test video folder') parser.add_argument('--tile', type=int, nargs='+', default=[40,128,128], help='Tile size, [0,0,0] for no tile during testing (testing as a whole)') parser.add_argument('--tile_overlap', type=int, nargs='+', default=[2,20,20], help='Overlapping of different tiles') parser.add_argument('--num_workers', type=int, default=16, help='number of workers in data loading') parser.add_argument('--save_result', action='store_true', help='save resulting image') args = parser.parse_args() # define model device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = prepare_model_dataset(args) model.eval() model = model.to(device) if 'vimeo' in args.folder_lq.lower(): if 'videofi' in args.task: test_set = VideoTestVimeo90KDataset({'dataroot_gt':args.folder_gt, 'dataroot_lq':args.folder_gt, 'meta_info_file': "data/meta_info/meta_info_Vimeo90K_test_GT.txt", 'pad_sequence': False, 'num_frame': 7, 'temporal_scale': 2, 'cache_data': False}) else: test_set = VideoTestVimeo90KDataset({'dataroot_gt': args.folder_gt, 'dataroot_lq': args.folder_lq, 'meta_info_file': "data/meta_info/meta_info_Vimeo90K_test_GT.txt", 'pad_sequence': True, 'num_frame': 7, 'temporal_scale': 1, 'cache_data': False}) elif 'davis' in args.folder_lq.lower() and 'videofi' in args.task: test_set = VFI_DAVIS(data_root=args.folder_gt) elif 'ucf101' in args.folder_lq.lower() and 'videofi' in args.task: test_set = VFI_UCF101(data_root=args.folder_gt) elif 'vid4' in args.folder_lq.lower() and 'videofi' in args.task: test_set = VFI_Vid4(data_root=args.folder_gt) elif args.folder_gt is not None: test_set = VideoRecurrentTestDataset({'dataroot_gt':args.folder_gt, 'dataroot_lq':args.folder_lq, 'sigma':args.sigma, 'num_frame':-1, 'cache_data': False}) else: test_set = SingleVideoRecurrentTestDataset({'dataroot_gt':args.folder_gt, 'dataroot_lq':args.folder_lq, 'sigma':args.sigma, 'num_frame':-1, 'cache_data': False}) test_loader = DataLoader(dataset=test_set, num_workers=args.num_workers, batch_size=1, shuffle=False) save_dir = f'results/{args.task}' if args.save_result: os.makedirs(save_dir, exist_ok=True) test_results = OrderedDict() test_results['psnr'] = [] test_results['ssim'] = [] test_results['psnr_y'] = [] test_results['ssim_y'] = [] assert len(test_loader) != 0, f'No dataset found at {args.folder_lq}' for idx, batch in enumerate(test_loader): lq = batch['L'].to(device) folder = batch['folder'] gt = batch['H'] if 'H' in batch else None # inference with torch.no_grad(): output = test_video(lq, model, args) if 'videofi' in args.task: output = output[:, :1, ...] batch['lq_path'] = batch['gt_path'] elif 'videosr' in args.task and 'vimeo' in args.folder_lq.lower(): output = output[:, 3:4, :, :, :] batch['lq_path'] = batch['gt_path'] test_results_folder = OrderedDict() test_results_folder['psnr'] = [] test_results_folder['ssim'] = [] test_results_folder['psnr_y'] = [] test_results_folder['ssim_y'] = [] for i in range(output.shape[1]): # save image img = output[:, i, ...].data.squeeze().float().cpu().clamp_(0, 1).numpy() if img.ndim == 3: img = np.transpose(img[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR img = (img * 255.0).round().astype(np.uint8) # float32 to uint8 if args.save_result: seq_ = osp.basename(batch['lq_path'][i][0]).split('.')[0] os.makedirs(f'{save_dir}/{folder[0]}', exist_ok=True) cv2.imwrite(f'{save_dir}/{folder[0]}/{seq_}.png', img) # evaluate psnr/ssim if gt is not None: img_gt = gt[:, i, ...].data.squeeze().float().cpu().clamp_(0, 1).numpy() if img_gt.ndim == 3: img_gt = np.transpose(img_gt[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR img_gt = (img_gt * 255.0).round().astype(np.uint8) # float32 to uint8 img_gt = np.squeeze(img_gt) test_results_folder['psnr'].append(util.calculate_psnr(img, img_gt, border=0)) test_results_folder['ssim'].append(util.calculate_ssim(img, img_gt, border=0)) if img_gt.ndim == 3: # RGB image img = util.bgr2ycbcr(img.astype(np.float32) / 255.) * 255. img_gt = util.bgr2ycbcr(img_gt.astype(np.float32) / 255.) * 255. test_results_folder['psnr_y'].append(util.calculate_psnr(img, img_gt, border=0)) test_results_folder['ssim_y'].append(util.calculate_ssim(img, img_gt, border=0)) else: test_results_folder['psnr_y'] = test_results_folder['psnr'] test_results_folder['ssim_y'] = test_results_folder['ssim'] if gt is not None: psnr = sum(test_results_folder['psnr']) / len(test_results_folder['psnr']) ssim = sum(test_results_folder['ssim']) / len(test_results_folder['ssim']) psnr_y = sum(test_results_folder['psnr_y']) / len(test_results_folder['psnr_y']) ssim_y = sum(test_results_folder['ssim_y']) / len(test_results_folder['ssim_y']) test_results['psnr'].append(psnr) test_results['ssim'].append(ssim) test_results['psnr_y'].append(psnr_y) test_results['ssim_y'].append(ssim_y) print('Testing {:20s} ({:2d}/{}) - PSNR: {:.2f} dB; SSIM: {:.4f}; PSNR_Y: {:.2f} dB; SSIM_Y: {:.4f}'. format(folder[0], idx, len(test_loader), psnr, ssim, psnr_y, ssim_y)) else: print('Testing {:20s} ({:2d}/{})'.format(folder[0], idx, len(test_loader))) # summarize psnr/ssim if gt is not None: ave_psnr = sum(test_results['psnr']) / len(test_results['psnr']) ave_ssim = sum(test_results['ssim']) / len(test_results['ssim']) ave_psnr_y = sum(test_results['psnr_y']) / len(test_results['psnr_y']) ave_ssim_y = sum(test_results['ssim_y']) / len(test_results['ssim_y']) print('\n{} \n-- Average PSNR: {:.2f} dB; SSIM: {:.4f}; PSNR_Y: {:.2f} dB; SSIM_Y: {:.4f}'. format(save_dir, ave_psnr, ave_ssim, ave_psnr_y, ave_ssim_y)) def prepare_model_dataset(args): ''' prepare model and dataset according to args.task. ''' # define model if args.task == '001_VRT_videosr_bi_REDS_6frames': model = net(upscale=4, img_size=[6,64,64], window_size=[6,8,8], depths=[8,8,8,8,8,8,8, 4,4,4,4, 4,4], indep_reconsts=[11,12], embed_dims=[120,120,120,120,120,120,120, 180,180,180,180, 180,180], num_heads=[6,6,6,6,6,6,6, 6,6,6,6, 6,6], pa_frames=2, deformable_groups=12) datasets = ['REDS4'] args.scale = 4 args.window_size = [6,8,8] args.nonblind_denoising = False elif args.task == '002_VRT_videosr_bi_REDS_16frames': model = net(upscale=4, img_size=[16,64,64], window_size=[8,8,8], depths=[8,8,8,8,8,8,8, 4,4,4,4, 4,4], indep_reconsts=[11,12], embed_dims=[120,120,120,120,120,120,120, 180,180,180,180, 180,180], num_heads=[6,6,6,6,6,6,6, 6,6,6,6, 6,6], pa_frames=6, deformable_groups=24) datasets = ['REDS4'] args.scale = 4 args.window_size = [8,8,8] args.nonblind_denoising = False elif args.task in ['003_VRT_videosr_bi_Vimeo_7frames', '004_VRT_videosr_bd_Vimeo_7frames']: model = net(upscale=4, img_size=[8,64,64], window_size=[8,8,8], depths=[8,8,8,8,8,8,8, 4,4,4,4, 4,4], indep_reconsts=[11,12], embed_dims=[120,120,120,120,120,120,120, 180,180,180,180, 180,180], num_heads=[6,6,6,6,6,6,6, 6,6,6,6, 6,6], pa_frames=4, deformable_groups=16) datasets = ['Vid4'] # 'Vimeo'. Vimeo dataset is too large. Please refer to #training to download it. args.scale = 4 args.window_size = [8,8,8] args.nonblind_denoising = False elif args.task in ['005_VRT_videodeblurring_DVD']: model = net(upscale=1, img_size=[6,192,192], window_size=[6,8,8], depths=[8,8,8,8,8,8,8, 4,4, 4,4], indep_reconsts=[9,10], embed_dims=[96,96,96,96,96,96,96, 120,120, 120,120], num_heads=[6,6,6,6,6,6,6, 6,6, 6,6], pa_frames=2, deformable_groups=16) datasets = ['DVD10'] args.scale = 1 args.window_size = [6,8,8] args.nonblind_denoising = False elif args.task in ['006_VRT_videodeblurring_GoPro']: model = net(upscale=1, img_size=[6,192,192], window_size=[6,8,8], depths=[8,8,8,8,8,8,8, 4,4, 4,4], indep_reconsts=[9,10], embed_dims=[96,96,96,96,96,96,96, 120,120, 120,120], num_heads=[6,6,6,6,6,6,6, 6,6, 6,6], pa_frames=2, deformable_groups=16) datasets = ['GoPro11-part1', 'GoPro11-part2'] args.scale = 1 args.window_size = [6,8,8] args.nonblind_denoising = False elif args.task in ['007_VRT_videodeblurring_REDS']: model = net(upscale=1, img_size=[6,192,192], window_size=[6,8,8], depths=[8,8,8,8,8,8,8, 4,4, 4,4], indep_reconsts=[9,10], embed_dims=[96,96,96,96,96,96,96, 120,120, 120,120], num_heads=[6,6,6,6,6,6,6, 6,6, 6,6], pa_frames=2, deformable_groups=16) datasets = ['REDS4'] args.scale = 1 args.window_size = [6,8,8] args.nonblind_denoising = False elif args.task == '008_VRT_videodenoising_DAVIS': model = net(upscale=1, img_size=[6,192,192], window_size=[6,8,8], depths=[8,8,8,8,8,8,8, 4,4, 4,4], indep_reconsts=[9,10], embed_dims=[96,96,96,96,96,96,96, 120,120, 120,120], num_heads=[6,6,6,6,6,6,6, 6,6, 6,6], pa_frames=2, deformable_groups=16, nonblind_denoising=True) datasets = ['Set8', 'DAVIS-test'] args.scale = 1 args.window_size = [6,8,8] args.nonblind_denoising = True elif args.task == '009_VRT_videofi_Vimeo_4frames': model = net(upscale=1, out_chans=3, img_size=[4,192,192], window_size=[4,8,8], depths=[8,8,8,8,8,8,8, 4,4, 4,4], indep_reconsts=[], embed_dims=[96,96,96,96,96,96,96, 120,120, 120,120], num_heads=[6,6,6,6,6,6,6, 6,6, 6,6], pa_frames=0) datasets = ['UCF101', 'DAVIS-train'] # 'Vimeo'. Vimeo dataset is too large. Please refer to #training to download it. args.scale = 1 args.window_size = [4,8,8] args.nonblind_denoising = False # download model model_path = f'model_zoo/vrt/{args.task}.pth' if os.path.exists(model_path): print(f'loading model from ./model_zoo/vrt/{model_path}') else: os.makedirs(os.path.dirname(model_path), exist_ok=True) url = 'https://github.com/JingyunLiang/VRT/releases/download/v0.0/{}'.format(os.path.basename(model_path)) r = requests.get(url, allow_redirects=True) print(f'downloading model {model_path}') open(model_path, 'wb').write(r.content) pretrained_model = torch.load(model_path) model.load_state_dict(pretrained_model['params'] if 'params' in pretrained_model.keys() else pretrained_model, strict=True) # download datasets if os.path.exists(f'{args.folder_lq}'): print(f'using dataset from {args.folder_lq}') else: if 'vimeo' in args.folder_lq.lower(): print(f'Vimeo dataset is not at {args.folder_lq}! Please refer to #training of Readme.md to download it.') else: os.makedirs('testsets', exist_ok=True) for dataset in datasets: url = f'https://github.com/JingyunLiang/VRT/releases/download/v0.0/testset_{dataset}.tar.gz' r = requests.get(url, allow_redirects=True) print(f'downloading testing dataset {dataset}') open(f'testsets/{dataset}.tar.gz', 'wb').write(r.content) os.system(f'tar -xvf testsets/{dataset}.tar.gz -C testsets') os.system(f'rm testsets/{dataset}.tar.gz') return model def test_video(lq, model, args): '''test the video as a whole or as clips (divided temporally). ''' num_frame_testing = args.tile[0] if num_frame_testing: # test as multiple clips if out-of-memory sf = args.scale num_frame_overlapping = args.tile_overlap[0] not_overlap_border = False b, d, c, h, w = lq.size() c = c - 1 if args.nonblind_denoising else c stride = num_frame_testing - num_frame_overlapping d_idx_list = list(range(0, d-num_frame_testing, stride)) + [max(0, d-num_frame_testing)] E = torch.zeros(b, d, c, h*sf, w*sf) W = torch.zeros(b, d, 1, 1, 1) for d_idx in d_idx_list: lq_clip = lq[:, d_idx:d_idx+num_frame_testing, ...] out_clip = test_clip(lq_clip, model, args) out_clip_mask = torch.ones((b, min(num_frame_testing, d), 1, 1, 1)) if not_overlap_border: if d_idx < d_idx_list[-1]: out_clip[:, -num_frame_overlapping//2:, ...] *= 0 out_clip_mask[:, -num_frame_overlapping//2:, ...] *= 0 if d_idx > d_idx_list[0]: out_clip[:, :num_frame_overlapping//2, ...] *= 0 out_clip_mask[:, :num_frame_overlapping//2, ...] *= 0 E[:, d_idx:d_idx+num_frame_testing, ...].add_(out_clip) W[:, d_idx:d_idx+num_frame_testing, ...].add_(out_clip_mask) output = E.div_(W) else: # test as one clip (the whole video) if you have enough memory window_size = args.window_size d_old = lq.size(1) d_pad = (window_size[0] - d_old % window_size[0]) % window_size[0] lq = torch.cat([lq, torch.flip(lq[:, -d_pad:, ...], [1])], 1) if d_pad else lq output = test_clip(lq, model, args) output = output[:, :d_old, :, :, :] return output def test_clip(lq, model, args): ''' test the clip as a whole or as patches. ''' sf = args.scale window_size = args.window_size size_patch_testing = args.tile[1] assert size_patch_testing % window_size[-1] == 0, 'testing patch size should be a multiple of window_size.' if size_patch_testing: # divide the clip to patches (spatially only, tested patch by patch) overlap_size = args.tile_overlap[1] not_overlap_border = True # test patch by patch b, d, c, h, w = lq.size() c = c - 1 if args.nonblind_denoising else c stride = size_patch_testing - overlap_size h_idx_list = list(range(0, h-size_patch_testing, stride)) + [max(0, h-size_patch_testing)] w_idx_list = list(range(0, w-size_patch_testing, stride)) + [max(0, w-size_patch_testing)] E = torch.zeros(b, d, c, h*sf, w*sf) W = torch.zeros_like(E) for h_idx in h_idx_list: for w_idx in w_idx_list: in_patch = lq[..., h_idx:h_idx+size_patch_testing, w_idx:w_idx+size_patch_testing] out_patch = model(in_patch).detach().cpu() out_patch_mask = torch.ones_like(out_patch) if not_overlap_border: if h_idx < h_idx_list[-1]: out_patch[..., -overlap_size//2:, :] *= 0 out_patch_mask[..., -overlap_size//2:, :] *= 0 if w_idx < w_idx_list[-1]: out_patch[..., :, -overlap_size//2:] *= 0 out_patch_mask[..., :, -overlap_size//2:] *= 0 if h_idx > h_idx_list[0]: out_patch[..., :overlap_size//2, :] *= 0 out_patch_mask[..., :overlap_size//2, :] *= 0 if w_idx > w_idx_list[0]: out_patch[..., :, :overlap_size//2] *= 0 out_patch_mask[..., :, :overlap_size//2] *= 0 E[..., h_idx*sf:(h_idx+size_patch_testing)*sf, w_idx*sf:(w_idx+size_patch_testing)*sf].add_(out_patch) W[..., h_idx*sf:(h_idx+size_patch_testing)*sf, w_idx*sf:(w_idx+size_patch_testing)*sf].add_(out_patch_mask) output = E.div_(W) else: _, _, _, h_old, w_old = lq.size() h_pad = (window_size[1] - h_old % window_size[1]) % window_size[1] w_pad = (window_size[2] - w_old % window_size[2]) % window_size[2] lq = torch.cat([lq, torch.flip(lq[:, :, :, -h_pad:, :], [3])], 3) if h_pad else lq lq = torch.cat([lq, torch.flip(lq[:, :, :, :, -w_pad:], [4])], 4) if w_pad else lq output = model(lq).detach().cpu() output = output[:, :, :, :h_old*sf, :w_old*sf] return output if __name__ == '__main__': main()