File size: 4,616 Bytes
25ad706
29356cb
 
 
f1ee166
92c37e9
7129683
29356cb
337146c
 
 
 
 
 
 
 
 
 
 
 
 
25ad706
0782bc0
25ad706
 
 
 
 
 
0782bc0
25ad706
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29356cb
65b549e
80950c2
337146c
 
 
7129683
0782bc0
4a66938
 
 
 
 
 
25ad706
29356cb
9fbd930
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25ad706
7129683
80950c2
 
7129683
25ad706
7129683
 
25ad706
7129683
 
 
 
29356cb
7129683
25ad706
80950c2
 
25ad706
 
 
29356cb
 
 
 
e2d6adc
4a66938
 
 
 
 
f1ee166
9fbd930
29356cb
13a4c81
 
 
 
29356cb
7129683
29356cb
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch
from PIL import Image
import numpy as np
from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution
import gradio as gr
import spaces
import os

def resize_image(image, max_size=2048):
    width, height = image.size
    if width > max_size or height > max_size:
        aspect_ratio = width / height
        if width > height:
            new_width = max_size
            new_height = int(new_width / aspect_ratio)
        else:
            new_height = max_size
            new_width = int(new_height * aspect_ratio)
        image = image.resize((new_width, new_height), Image.LANCZOS)
    return image

def split_image(image, chunk_size=512):
    width, height = image.size
    chunks = []
    for y in range(0, height, chunk_size):
        for x in range(0, width, chunk_size):
            chunk = image.crop((x, y, min(x + chunk_size, width), min(y + chunk_size, height)))
            chunks.append((chunk, x, y))
    return chunks

def stitch_image(chunks, original_size):
    result = Image.new('RGB', original_size)
    for img, x, y in chunks:
        result.paste(img, (x, y))
    return result

def upscale_chunk(chunk, model, processor, device):
    inputs = processor(chunk, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        outputs = model(**inputs)
    output = outputs.reconstruction.data.squeeze().cpu().float().clamp_(0, 1).numpy()
    output = np.moveaxis(output, source=0, destination=-1)
    output_image = (output * 255.0).round().astype(np.uint8)
    return Image.fromarray(output_image)

@spaces.GPU
def main(image, original_filename, model_choice, save_as_jpg=True, use_tiling=True):
    # Resize the input image
    image = resize_image(image)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model_paths = {
        "Pixel Perfect": "caidas/swin2SR-classical-sr-x4-64",
        "PSNR Match (Recommended)": "caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr"
    }
    
    processor = AutoImageProcessor.from_pretrained(model_paths[model_choice])
    model = Swin2SRForImageSuperResolution.from_pretrained(model_paths[model_choice]).to(device)
    
    if use_tiling:
        # Split the image into chunks
        chunks = split_image(image)
        
        # Process each chunk
        upscaled_chunks = []
        for chunk, x, y in chunks:
            upscaled_chunk = upscale_chunk(chunk, model, processor, device)
            # Remove 32 pixels from bottom and right edges
            upscaled_chunk = upscaled_chunk.crop((0, 0, upscaled_chunk.width - 32, upscaled_chunk.height - 32))
            upscaled_chunks.append((upscaled_chunk, x * 4, y * 4))  # Multiply coordinates by 4 due to 4x upscaling
        
        # Stitch the chunks back together
        final_size = (image.width * 4 - 32, image.height * 4 - 32)  # Adjust for removed pixels
        upscaled_image = stitch_image(upscaled_chunks, final_size)
    else:
        # Process the entire image at once
        upscaled_image = upscale_chunk(image, model, processor, device)
    
    # Generate output filename
    original_basename = os.path.splitext(original_filename)[0] if original_filename else "image"
    output_filename = f"{original_basename}_upscaled"
    
    if save_as_jpg:
        output_filename += ".jpg"
        upscaled_image.save(output_filename, quality=95)
    else:
        output_filename += ".png"
        upscaled_image.save(output_filename)
    
    return output_filename

def gradio_interface(image, model_choice, save_as_jpg, use_tiling):
    try:
        original_filename = getattr(image, 'name', 'image')
        result = main(image, original_filename, model_choice, save_as_jpg, use_tiling)
        return result, None
    except Exception as e:
        return None, str(e)

interface = gr.Interface(
    fn=gradio_interface,
    inputs=[
        gr.Image(type="pil", label="Upload Image"),
        gr.Dropdown(
            choices=["PSNR Match (Recommended)", "Pixel Perfect"],
            label="Select Model",
            value="PSNR Match (Recommended)"
        ),
        gr.Checkbox(value=True, label="Save as JPEG"),
        gr.Checkbox(value=True, label="Use Tiling"),
    ],
    outputs=[
        gr.File(label="Download Upscaled Image"),
        gr.Textbox(label="Error Message", visible=True)
    ],
    title="Image Upscaler",
    description="Upload an image, select a model, and upscale it. Images larger than 2048x2048 will be resized while maintaining aspect ratio. Use tiling for efficient processing of large images.",
)

interface.launch()