File size: 4,351 Bytes
57d22e1
 
 
 
fb43320
57d22e1
 
 
 
 
 
 
fc6ae8e
 
 
 
 
 
57d22e1
 
fb43320
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57d22e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f5bb4a
57d22e1
 
fb43320
 
 
57d22e1
fb43320
 
 
3f5bb4a
 
57d22e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb43320
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57d22e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5fd17f4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import shutil
import uuid
import cv2
import gc
import gradio as gr
import torch
from basicsr.archs.rrdbnet_arch import RRDBNet
from gfpgan.utils import GFPGANer
from realesrgan.utils import RealESRGANer

# download weights for RealESRGAN
#if not os.path.exists('model_zoo/real/RealESRGAN_x4plus.pth'):
#    os.system("wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P model_zoo/real")
#if not os.path.exists('model_zoo/gan/GFPGANv1.4.pth'):
#    os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P model_zoo/gan")
#if not os.path.exists('model_zoo/swinir/003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.pth'):
#    os.system('wget https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.pth -P model_zoo/swinir')

def inference(img, scale):
    model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
    model_path = 'model_zoo/real/RealESRGAN_x4plus.pth'
    netscale = 4
    tile = 400 if torch.cuda.is_available() else 0
    dni_weight = None
    # restorer
    upsampler = RealESRGANer(
        scale=netscale,
        model_path=model_path,
        dni_weight=dni_weight,
        model=model,
        tile=tile,
        tile_pad=10,
        pre_pad=0,
        half=False, #Use fp32 precision during inference. Default: fp16 (half precision).
        gpu_id=None) #gpu device to use (default=None) can be 0,1,2 for multi-gpu
    # background enhancer with RealESRGAN
    os.makedirs('output', exist_ok=True)
    if scale > 4:
        scale = 4  # avoid too large scale value
    try:
        extension = os.path.splitext(os.path.basename(str(img)))[1]
        img = cv2.imread(img, cv2.IMREAD_UNCHANGED)
        if len(img.shape) == 3 and img.shape[2] == 4:
            img_mode = 'RGBA'
        elif len(img.shape) == 2:  # for gray inputs
            img_mode = None
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
        else:
            img_mode = None

        h, w = img.shape[0:2]
        if h < 300 and h > 0 and w > 0:
            img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)

        face_enhancer = GFPGANer(
            model_path='model_zoo/gan/GFPGANv1.4.pth', upscale=scale, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
        _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)

        if scale != 2:
            interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
            h, w = img.shape[0:2]
            if h > 0 and w > 0:
                output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation)

        if img_mode == 'RGBA':  # RGBA images should be saved in png format
            extension = 'png'
        else:
            extension = 'jpg'

        filename = str(uuid.uuid4())    
        save_path = f'output/out_{filename}.{extension}'
        cv2.imwrite(save_path, output)

        output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
        return output, save_path
    except Exception as error:
        print('global exception', error)
        return None, None
    finally:
        #clean_folder('output')
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

def clean_folder(folder):
    for filename in os.listdir(folder):
        file_path = os.path.join(folder, filename)
        try:
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)
        except Exception as e:
            print('Failed to delete %s. Reason: %s' % (file_path, e))

title = "Real Esrgan Restore Ai Face Restoration by appsgenz.com"
description = ""
article = "AppsGenz"
grApp = gr.Interface(
    inference, [
        gr.Image(type="filepath", label="Input"),
        gr.Number(label="Rescaling factor. Note max rescaling factor is 4", value=2),
    ], [
        gr.Image(type="numpy", label="Output (The whole image)"),
        gr.File(label="Download the output image")
    ],
    title=title,
    description=description,
    article=article)
grApp.queue(concurrency_count=2)
grApp.launch(share=False)