yangwu
update
1c49612
raw
history blame
2.72 kB
import gradio as gr
import os
import sys
import numpy as np
import numpy as np
import torch.backends.cudnn as cudnn
import torch.utils.data
import torch.nn.functional as F
import torchvision.transforms as transforms
from mmcv.utils import Config
sys.path.append('.')
from image_forgery_detection import build_detector
from image_forgery_detection import Compose
transform_pil = transforms.Compose([
transforms.ToPILImage(),
])
def inference_api(f_path):
print(f_path)
results = dict(img_info=dict(filename=f_path, ann=dict(seg_map='None')))
results['seg_fields'] = []
results['img_prefix'] = None
results['seg_prefix'] = None
inputs = pipelines(results)
img = inputs['img'].data
img_meta = inputs['img_metas'].data
if 'dct_vol' in inputs:
dct_vol = inputs['dct_vol'].data
qtables = inputs['qtables'].data
with torch.no_grad():
img = img.unsqueeze(dim=0)
if 'dct_vol' in inputs:
dct_vol = dct_vol.unsqueeze(dim=0)
qtables = qtables.unsqueeze(dim=0)
cls_pred, seg_pred = model(img, dct_vol, qtables, [img_meta, ], return_loss=False, rescale=True)
else:
cls_pred, seg_pred = model(img, [img_meta, ], return_loss=False, rescale=True)
cls_pred = cls_pred[0]
seg = seg_pred[0, 0]
seg = np.array(transform_pil(torch.from_numpy(seg)))
thresh_int = 255 * thresh
seg[seg>=thresh_int] = 255
seg[seg<thresh_int] = 0
return '{:.3f}'.format(cls_pred), seg
if __name__ == '__main__':
model_path = './models/latest.pth'
cfg = Config.fromfile('./models/config.py')
global model
global pipelines
global thresh
thresh = 0.5
if hasattr(cfg.model.base_model, 'backbone'):
cfg.model.base_model.backbone.pretrained = None
else:
cfg.model.base_model.pretrained = None
model = build_detector(cfg.model)
if os.path.exists(model_path):
checkpoint = torch.load(model_path, map_location='cpu')['state_dict']
model.load_state_dict(checkpoint, strict=True)
print("load %s finish" % (os.path.basename(model_path)))
else:
print("%s not exist" % model_path)
exit(1)
model.eval()
pipelines = Compose(cfg.data.val[0].pipeline)
iface = gr.Interface(
inference_api,
inputs=gr.components.Image(label="Upload image to detect", type="filepath"),
# outputs=['text', 'image'],
outputs=[gr.components.Textbox(type="text", label="image forgery score"),
gr.components.Image(type="numpy", label="predict mask")],
title="Forged? Or Not?",
)
# iface.launch(server_name='0.0.0.0', share=True)
iface.launch()