reedmayhew's picture
Update app.py
25ad706 verified
raw
history blame
No virus
3.44 kB
import torch
from PIL import Image
import numpy as np
from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution
import gradio as gr
import spaces
def split_image(image, chunk_size=512):
width, height = image.size
chunks = []
for y in range(0, height, chunk_size):
for x in range(0, width, chunk_size):
chunk = image.crop((x, y, min(x + chunk_size, width), min(y + chunk_size, height)))
chunks.append((chunk, x, y))
return chunks
def stitch_image(chunks, original_size):
result = Image.new('RGB', original_size)
for img, x, y in chunks:
result.paste(img, (x, y))
return result
def upscale_chunk(chunk, model, processor, device):
inputs = processor(chunk, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
output = outputs.reconstruction.data.squeeze().cpu().float().clamp_(0, 1).numpy()
output = np.moveaxis(output, source=0, destination=-1)
output_image = (output * 255.0).round().astype(np.uint8)
return Image.fromarray(output_image)
@spaces.GPU
def main(image, model_choice, save_as_jpg=True):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_paths = {
"Pixel Perfect": "caidas/swin2SR-classical-sr-x4-64",
"PSNR Match (Recommended)": "caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr"
}
processor = AutoImageProcessor.from_pretrained(model_paths[model_choice])
model = Swin2SRForImageSuperResolution.from_pretrained(model_paths[model_choice]).to(device)
# Split the image into chunks
chunks = split_image(image)
# Process each chunk
upscaled_chunks = []
for chunk, x, y in chunks:
upscaled_chunk = upscale_chunk(chunk, model, processor, device)
# Remove 32 pixels from bottom and right edges
upscaled_chunk = upscaled_chunk.crop((0, 0, upscaled_chunk.width - 32, upscaled_chunk.height - 32))
upscaled_chunks.append((upscaled_chunk, x * 4, y * 4)) # Multiply coordinates by 4 due to 4x upscaling
# Stitch the chunks back together
final_size = (image.width * 4 - 32, image.height * 4 - 32) # Adjust for removed pixels
upscaled_image = stitch_image(upscaled_chunks, final_size)
if save_as_jpg:
upscaled_image.save("upscaled_image.jpg", quality=95)
return "upscaled_image.jpg"
else:
upscaled_image.save("upscaled_image.png")
return "upscaled_image.png"
def gradio_interface(image, model_choice, save_as_jpg):
try:
result = main(image, model_choice, save_as_jpg)
return result, None
except Exception as e:
return None, str(e)
interface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Image(type="pil", label="Upload Image"),
gr.Dropdown(
choices=["PSNR Match (Recommended)", "Pixel Perfect"],
label="Select Model",
value="PSNR Match (Recommended)"
),
gr.Checkbox(value=True, label="Save as JPEG"),
],
outputs=[
gr.File(label="Download Upscaled Image"),
gr.Textbox(label="Error Message", visible=True)
],
title="Image Upscaler",
description="Upload an image, select a model, and upscale it. The image will be processed in 512x512 pixel chunks to handle large images efficiently.",
)
interface.launch()