# Copyright 2022 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os os.environ["XLA_FLAGS"] = "--xla_gpu_force_compilation_parallelism=1" import gradio as gr import numpy as np from PIL import Image from pathlib import Path import importlib import ml_collections import tempfile import jax.numpy as jnp import flax from run_eval import ( _MODEL_FILENAME, _MODEL_VARIANT_DICT, _MODEL_CONFIGS, get_params, mod_padding_symmetric, make_shape_even, augment_image, ) def sentence_builder(image, model): params = { "Image Denoising": get_params("checkpoints/denoising-SIDD/checkpoint.npz"), "Image Deblurring (GoPro)": get_params( "checkpoints/debluring-GoPro/checkpoint.npz" ), "Image Deblurring (REDS)": get_params( "checkpoints/debluring-REDS/checkpoint.npz" ), "Image Deblurring (RealBlur_R)": get_params( "checkpoints/debluring-Real-Blur-R/checkpoint.npz" ), "Image Deblurring (RealBlur_J)": get_params( "checkpoints/debluring-Real-Blur-J/checkpoint.npz" ), "Image Deraining (Rain streak)": get_params( "checkpoints/deraining-Rain13k/checkpoint.npz" ), "Image Deraining (Rain drop)": get_params( "checkpoints/deraining-Raindrop/checkpoint.npz" ), "Image Dehazing (Indoor)": get_params( "checkpoints/dehazing-RESIDE-Indoor/checkpoint.npz" ), "Image Dehazing (Outdoor)": get_params( "checkpoints/dehazing-RESIDE-Outdoor/checkpoint.npz" ), "Image Enhancement (Low-light)": get_params( "checkpoints/enhancement-LOL/checkpoint.npz" ), "Image Enhancement (Retouching)": get_params( "checkpoints/enhancement-FiveK/checkpoint.npz" ), } model_mod = importlib.import_module(f"maxim.models.{_MODEL_FILENAME}") models = {} for task in _MODEL_VARIANT_DICT.keys(): model_configs = ml_collections.ConfigDict(_MODEL_CONFIGS) model_configs.variant = _MODEL_VARIANT_DICT[task] models[task] = model_mod.Model(**model_configs) params = params[model] task = model.split()[1] model = models[task] input_img = ( np.asarray(Image.open(str(image)).convert("RGB"), np.float32) / 255.0 ) # Padding images to have even shapes height, width = input_img.shape[0], input_img.shape[1] input_img = make_shape_even(input_img) height_even, width_even = input_img.shape[0], input_img.shape[1] # padding images to be multiplies of 64 input_img = mod_padding_symmetric(input_img, factor=64) input_img = np.expand_dims(input_img, axis=0) # handle multi-stage outputs, obtain the last scale output of last stage preds = model.apply({"params": flax.core.freeze(params)}, input_img) if isinstance(preds, list): preds = preds[-1] if isinstance(preds, list): preds = preds[-1] preds = np.array(preds[0], np.float32) # unpad images to get the original resolution new_height, new_width = preds.shape[0], preds.shape[1] h_start = new_height // 2 - height_even // 2 h_end = h_start + height w_start = new_width // 2 - width_even // 2 w_end = w_start + width preds = preds[h_start:h_end, w_start:w_end, :] # save files out_path = Path(tempfile.mkdtemp()) / "output.png" Image.fromarray( np.array((np.clip(preds, 0.0, 1.0) * 255.0).astype(jnp.uint8)) ).save(str(out_path)) return out_path title = "Maxim Multi-Axis MLP for Image Processing" description = "" article = "AppsGenz" grApp = gr.Interface( sentence_builder, [ gr.Image(type="filepath", label="Input"), gr.Radio([ "Image Denoising", "Image Deblurring (GoPro)", "Image Deblurring (REDS)", "Image Deblurring (RealBlur_R)", "Image Deblurring (RealBlur_J)", "Image Deraining (Rain streak)", "Image Deraining (Rain drop)", "Image Dehazing (Indoor)", "Image Dehazing (Outdoor)", "Image Enhancement (Low-light)", "Image Enhancement (Retouching)"], type="value", value='Image Denoising', label='Choose a model.'), ], [ gr.Image(type="filepath", label="Output"), ], title=title, description=description, article=article) grApp.queue(concurrency_count=2) grApp.launch(share=False)