reedmayhew commited on
Commit
25ad706
1 Parent(s): 13a4c81

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -67
app.py CHANGED
@@ -1,90 +1,76 @@
1
- # Import necessary libraries
2
  from PIL import Image
3
  import numpy as np
4
- import torch
5
  from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution
6
  import gradio as gr
7
  import spaces
8
 
9
- # Function to resize image to max 2048x2048 while maintaining aspect ratio
10
- def resize_image(image, max_size=2048):
11
  width, height = image.size
12
- if width > max_size or height > max_size:
13
- aspect_ratio = width / height
14
- if width > height:
15
- new_width = max_size
16
- new_height = int(new_width / aspect_ratio)
17
- else:
18
- new_height = max_size
19
- new_width = int(new_height * aspect_ratio)
20
- image = image.resize((new_width, new_height), Image.LANCZOS)
21
- return image
22
 
23
- # Function to upscale an image using Swin2SR
24
- def upscale_image(image, model, processor, device):
25
- try:
26
- # Convert the image to RGB format
27
- image = image.convert("RGB")
28
- # Process the image for the model
29
- inputs = processor(image, return_tensors="pt")
30
- # Move inputs to the same device as model
31
- inputs = {k: v.to(device) for k, v in inputs.items()}
32
- # Perform inference (upscale)
33
- with torch.no_grad():
34
- outputs = model(**inputs)
35
- # Move output back to CPU for further processing
36
- output = outputs.reconstruction.data.squeeze().cpu().float().clamp_(0, 1).numpy()
37
- output = np.moveaxis(output, source=0, destination=-1)
38
- output_image = (output * 255.0).round().astype(np.uint8) # Convert from float32 to uint8
39
- # Remove 32 pixels from the bottom and right of the image
40
- output_image = output_image[:-32, :-32]
41
- return Image.fromarray(output_image), None
42
- except RuntimeError as e:
43
- return None, str(e)
44
 
45
  @spaces.GPU
46
  def main(image, model_choice, save_as_jpg=True):
47
- # Resize the input image
48
- image = resize_image(image)
49
 
50
- # Define model paths
51
  model_paths = {
52
  "Pixel Perfect": "caidas/swin2SR-classical-sr-x4-64",
53
  "PSNR Match (Recommended)": "caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr"
54
  }
55
 
56
- # Load the selected Swin2SR model and processor for 4x upscaling
57
  processor = AutoImageProcessor.from_pretrained(model_paths[model_choice])
58
- model = Swin2SRForImageSuperResolution.from_pretrained(model_paths[model_choice])
59
 
60
- # Try GPU first, fallback to CPU if there's an error
61
- for device in [torch.device("cuda" if torch.cuda.is_available() else "cpu"), torch.device("cpu")]:
62
- model.to(device)
63
- upscaled_image, error = upscale_image(image, model, processor, device)
64
-
65
- if upscaled_image is not None:
66
- if save_as_jpg:
67
- # Save the upscaled image as JPG with 98% compression
68
- upscaled_image.save("upscaled_image.jpg", quality=98)
69
- return "upscaled_image.jpg"
70
- else:
71
- # Save the upscaled image as PNG
72
- upscaled_image.save("upscaled_image.png")
73
- return "upscaled_image.png"
74
-
75
- if device.type == "cpu":
76
- return f"Error: Unable to process the image. {error}"
77
-
78
- return "Error: Unable to process the image on both GPU and CPU."
 
 
79
 
80
- # Gradio interface
81
  def gradio_interface(image, model_choice, save_as_jpg):
82
- result = main(image, model_choice, save_as_jpg)
83
- if result.startswith("Error:"):
84
- return gr.update(value=None), result
85
- return result, None
 
86
 
87
- # Create a Gradio interface
88
  interface = gr.Interface(
89
  fn=gradio_interface,
90
  inputs=[
@@ -101,8 +87,7 @@ interface = gr.Interface(
101
  gr.Textbox(label="Error Message", visible=True)
102
  ],
103
  title="Image Upscaler",
104
- description="Upload an image, select a model, upscale it, and download the new image. Images larger than 2048x2048 will be resized while maintaining aspect ratio. If GPU processing fails, it will attempt to process on CPU.",
105
  )
106
 
107
- # Launch the interface
108
  interface.launch()
 
1
+ import torch
2
  from PIL import Image
3
  import numpy as np
 
4
  from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution
5
  import gradio as gr
6
  import spaces
7
 
8
+ def split_image(image, chunk_size=512):
 
9
  width, height = image.size
10
+ chunks = []
11
+ for y in range(0, height, chunk_size):
12
+ for x in range(0, width, chunk_size):
13
+ chunk = image.crop((x, y, min(x + chunk_size, width), min(y + chunk_size, height)))
14
+ chunks.append((chunk, x, y))
15
+ return chunks
 
 
 
 
16
 
17
+ def stitch_image(chunks, original_size):
18
+ result = Image.new('RGB', original_size)
19
+ for img, x, y in chunks:
20
+ result.paste(img, (x, y))
21
+ return result
22
+
23
+ def upscale_chunk(chunk, model, processor, device):
24
+ inputs = processor(chunk, return_tensors="pt")
25
+ inputs = {k: v.to(device) for k, v in inputs.items()}
26
+ with torch.no_grad():
27
+ outputs = model(**inputs)
28
+ output = outputs.reconstruction.data.squeeze().cpu().float().clamp_(0, 1).numpy()
29
+ output = np.moveaxis(output, source=0, destination=-1)
30
+ output_image = (output * 255.0).round().astype(np.uint8)
31
+ return Image.fromarray(output_image)
 
 
 
 
 
 
32
 
33
  @spaces.GPU
34
  def main(image, model_choice, save_as_jpg=True):
35
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
36
 
 
37
  model_paths = {
38
  "Pixel Perfect": "caidas/swin2SR-classical-sr-x4-64",
39
  "PSNR Match (Recommended)": "caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr"
40
  }
41
 
 
42
  processor = AutoImageProcessor.from_pretrained(model_paths[model_choice])
43
+ model = Swin2SRForImageSuperResolution.from_pretrained(model_paths[model_choice]).to(device)
44
 
45
+ # Split the image into chunks
46
+ chunks = split_image(image)
47
+
48
+ # Process each chunk
49
+ upscaled_chunks = []
50
+ for chunk, x, y in chunks:
51
+ upscaled_chunk = upscale_chunk(chunk, model, processor, device)
52
+ # Remove 32 pixels from bottom and right edges
53
+ upscaled_chunk = upscaled_chunk.crop((0, 0, upscaled_chunk.width - 32, upscaled_chunk.height - 32))
54
+ upscaled_chunks.append((upscaled_chunk, x * 4, y * 4)) # Multiply coordinates by 4 due to 4x upscaling
55
+
56
+ # Stitch the chunks back together
57
+ final_size = (image.width * 4 - 32, image.height * 4 - 32) # Adjust for removed pixels
58
+ upscaled_image = stitch_image(upscaled_chunks, final_size)
59
+
60
+ if save_as_jpg:
61
+ upscaled_image.save("upscaled_image.jpg", quality=95)
62
+ return "upscaled_image.jpg"
63
+ else:
64
+ upscaled_image.save("upscaled_image.png")
65
+ return "upscaled_image.png"
66
 
 
67
  def gradio_interface(image, model_choice, save_as_jpg):
68
+ try:
69
+ result = main(image, model_choice, save_as_jpg)
70
+ return result, None
71
+ except Exception as e:
72
+ return None, str(e)
73
 
 
74
  interface = gr.Interface(
75
  fn=gradio_interface,
76
  inputs=[
 
87
  gr.Textbox(label="Error Message", visible=True)
88
  ],
89
  title="Image Upscaler",
90
+ description="Upload an image, select a model, and upscale it. The image will be processed in 512x512 pixel chunks to handle large images efficiently.",
91
  )
92
 
 
93
  interface.launch()