File size: 2,723 Bytes
0ba02e5
13182b9
9c59d94
 
 
 
 
 
 
 
 
 
 
0ba02e5
 
9c59d94
 
 
 
 
93cb1ad
9c59d94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93cb1ad
1c49612
9c59d94
1c49612
 
9c59d94
 
 
 
9230043
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import gradio as gr
import os
import sys
import numpy as np
import numpy as np
import torch.backends.cudnn as cudnn
import torch.utils.data
import torch.nn.functional as F
import torchvision.transforms as transforms
from mmcv.utils import Config
sys.path.append('.')
from image_forgery_detection import build_detector
from image_forgery_detection import Compose


transform_pil = transforms.Compose([
    transforms.ToPILImage(),
])


def inference_api(f_path):
    print(f_path)

    results = dict(img_info=dict(filename=f_path, ann=dict(seg_map='None')))
    results['seg_fields'] = []
    results['img_prefix'] = None
    results['seg_prefix'] = None

    inputs = pipelines(results)

    img = inputs['img'].data
    img_meta = inputs['img_metas'].data
    if 'dct_vol' in inputs:
        dct_vol = inputs['dct_vol'].data
        qtables = inputs['qtables'].data

    with torch.no_grad():
        img = img.unsqueeze(dim=0)
        if 'dct_vol' in inputs:
            dct_vol = dct_vol.unsqueeze(dim=0)
            qtables = qtables.unsqueeze(dim=0)
            cls_pred, seg_pred = model(img, dct_vol, qtables, [img_meta, ], return_loss=False, rescale=True)
        else:
            cls_pred, seg_pred = model(img, [img_meta, ], return_loss=False, rescale=True)
        cls_pred = cls_pred[0]
        seg = seg_pred[0, 0]
    seg = np.array(transform_pil(torch.from_numpy(seg)))
    thresh_int = 255 * thresh 
    seg[seg>=thresh_int] = 255
    seg[seg<thresh_int] = 0

    return '{:.3f}'.format(cls_pred), seg


if __name__ == '__main__':
    model_path = './models/latest.pth'
    cfg = Config.fromfile('./models/config.py')

    global model
    global pipelines
    global thresh

    thresh = 0.5
    if hasattr(cfg.model.base_model, 'backbone'):
        cfg.model.base_model.backbone.pretrained = None
    else:
        cfg.model.base_model.pretrained = None
    model = build_detector(cfg.model)
    if os.path.exists(model_path):
        checkpoint = torch.load(model_path, map_location='cpu')['state_dict']
        model.load_state_dict(checkpoint, strict=True)
        print("load %s finish" % (os.path.basename(model_path)))
    else:
        print("%s not exist" % model_path)
        exit(1)
    model.eval()

    pipelines = Compose(cfg.data.val[0].pipeline)

    iface = gr.Interface(
        inference_api,
        inputs=gr.components.Image(label="Upload image to detect", type="filepath"),
        # outputs=['text', 'image'],
        outputs=[gr.components.Textbox(type="text", label="image forgery score"),
                gr.components.Image(type="numpy", label="predict mask")],
        title="Forged? Or Not?",
    )
    # iface.launch(server_name='0.0.0.0', share=True)
    iface.launch()