Spaces:
Build error
Build error
File size: 2,701 Bytes
0ba02e5 13182b9 9c59d94 0ba02e5 9c59d94 9230043 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import gradio as gr
import os
import sys
import numpy as np
import numpy as np
import torch.backends.cudnn as cudnn
import torch.utils.data
import torch.nn.functional as F
import torchvision.transforms as transforms
from mmcv.utils import Config
sys.path.append('.')
from image_forgery_detection import build_detector
from image_forgery_detection import Compose
transform_pil = transforms.Compose([
transforms.ToPILImage(),
])
def predict(f_path):
print(f_path)
results = dict(img_info=dict(filename=f_path, ann=dict(seg_map='None')))
results['seg_fields'] = []
results['img_prefix'] = None
results['seg_prefix'] = None
inputs = pipelines(results)
img = inputs['img'].data
img_meta = inputs['img_metas'].data
if 'dct_vol' in inputs:
dct_vol = inputs['dct_vol'].data
qtables = inputs['qtables'].data
with torch.no_grad():
img = img.unsqueeze(dim=0)
if 'dct_vol' in inputs:
dct_vol = dct_vol.unsqueeze(dim=0)
qtables = qtables.unsqueeze(dim=0)
cls_pred, seg_pred = model(img, dct_vol, qtables, [img_meta, ], return_loss=False, rescale=True)
else:
cls_pred, seg_pred = model(img, [img_meta, ], return_loss=False, rescale=True)
cls_pred = cls_pred[0]
seg = seg_pred[0, 0]
seg = np.array(transform_pil(torch.from_numpy(seg)))
thresh_int = 255 * thresh
seg[seg>=thresh_int] = 255
seg[seg<thresh_int] = 0
return '{:.3f}'.format(cls_pred), seg
if __name__ == '__main__':
model_path = './models/latest.pth'
cfg = Config.fromfile('./models/config.py')
global model
global pipelines
global thresh
thresh = 0.5
if hasattr(cfg.model.base_model, 'backbone'):
cfg.model.base_model.backbone.pretrained = None
else:
cfg.model.base_model.pretrained = None
model = build_detector(cfg.model)
if os.path.exists(model_path):
checkpoint = torch.load(model_path, map_location='cpu')['state_dict']
model.load_state_dict(checkpoint, strict=True)
print("load %s finish" % (os.path.basename(model_path)))
else:
print("%s not exist" % model_path)
exit(1)
model.eval()
pipelines = Compose(cfg.data.val[0].pipeline)
iface = gr.Interface(
predict,
inputs=gr.inputs.Image(label="Upload image to detect", type="filepath"),
# outputs=['text', 'image'],
outputs=[gr.outputs.Textbox(type="text", label="image forgery score"),
gr.outputs.Image(type="numpy", label="predict mask")],
title="Forged? Or Not?",
)
# iface.launch(server_name='0.0.0.0', share=True)
iface.launch()
|