import gradio as gr import os import sys import numpy as np import numpy as np import torch.backends.cudnn as cudnn import torch.utils.data import torch.nn.functional as F import torchvision.transforms as transforms from mmcv.utils import Config sys.path.append('.') from image_forgery_detection import build_detector from image_forgery_detection import Compose transform_pil = transforms.Compose([ transforms.ToPILImage(), ]) def predict(f_path): print(f_path) results = dict(img_info=dict(filename=f_path, ann=dict(seg_map='None'))) results['seg_fields'] = [] results['img_prefix'] = None results['seg_prefix'] = None inputs = pipelines(results) img = inputs['img'].data img_meta = inputs['img_metas'].data if 'dct_vol' in inputs: dct_vol = inputs['dct_vol'].data qtables = inputs['qtables'].data with torch.no_grad(): img = img.unsqueeze(dim=0) if 'dct_vol' in inputs: dct_vol = dct_vol.unsqueeze(dim=0) qtables = qtables.unsqueeze(dim=0) cls_pred, seg_pred = model(img, dct_vol, qtables, [img_meta, ], return_loss=False, rescale=True) else: cls_pred, seg_pred = model(img, [img_meta, ], return_loss=False, rescale=True) cls_pred = cls_pred[0] seg = seg_pred[0, 0] seg = np.array(transform_pil(torch.from_numpy(seg))) thresh_int = 255 * thresh seg[seg>=thresh_int] = 255 seg[seg