image-enhance / maxim.py
HungHN
config maxim
3f5bb4a
# Copyright 2022 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
os.environ["XLA_FLAGS"] = "--xla_gpu_force_compilation_parallelism=1"
import gradio as gr
import numpy as np
from PIL import Image
from pathlib import Path
import importlib
import ml_collections
import tempfile
import jax.numpy as jnp
import flax
from run_eval import (
_MODEL_FILENAME,
_MODEL_VARIANT_DICT,
_MODEL_CONFIGS,
get_params,
mod_padding_symmetric,
make_shape_even,
augment_image,
)
def sentence_builder(image, model):
params = {
"Image Denoising": get_params("checkpoints/denoising-SIDD/checkpoint.npz"),
"Image Deblurring (GoPro)": get_params(
"checkpoints/debluring-GoPro/checkpoint.npz"
),
"Image Deblurring (REDS)": get_params(
"checkpoints/debluring-REDS/checkpoint.npz"
),
"Image Deblurring (RealBlur_R)": get_params(
"checkpoints/debluring-Real-Blur-R/checkpoint.npz"
),
"Image Deblurring (RealBlur_J)": get_params(
"checkpoints/debluring-Real-Blur-J/checkpoint.npz"
),
"Image Deraining (Rain streak)": get_params(
"checkpoints/deraining-Rain13k/checkpoint.npz"
),
"Image Deraining (Rain drop)": get_params(
"checkpoints/deraining-Raindrop/checkpoint.npz"
),
"Image Dehazing (Indoor)": get_params(
"checkpoints/dehazing-RESIDE-Indoor/checkpoint.npz"
),
"Image Dehazing (Outdoor)": get_params(
"checkpoints/dehazing-RESIDE-Outdoor/checkpoint.npz"
),
"Image Enhancement (Low-light)": get_params(
"checkpoints/enhancement-LOL/checkpoint.npz"
),
"Image Enhancement (Retouching)": get_params(
"checkpoints/enhancement-FiveK/checkpoint.npz"
),
}
model_mod = importlib.import_module(f"maxim.models.{_MODEL_FILENAME}")
models = {}
for task in _MODEL_VARIANT_DICT.keys():
model_configs = ml_collections.ConfigDict(_MODEL_CONFIGS)
model_configs.variant = _MODEL_VARIANT_DICT[task]
models[task] = model_mod.Model(**model_configs)
params = params[model]
task = model.split()[1]
model = models[task]
input_img = (
np.asarray(Image.open(str(image)).convert("RGB"), np.float32) / 255.0
)
# Padding images to have even shapes
height, width = input_img.shape[0], input_img.shape[1]
input_img = make_shape_even(input_img)
height_even, width_even = input_img.shape[0], input_img.shape[1]
# padding images to be multiplies of 64
input_img = mod_padding_symmetric(input_img, factor=64)
input_img = np.expand_dims(input_img, axis=0)
# handle multi-stage outputs, obtain the last scale output of last stage
preds = model.apply({"params": flax.core.freeze(params)}, input_img)
if isinstance(preds, list):
preds = preds[-1]
if isinstance(preds, list):
preds = preds[-1]
preds = np.array(preds[0], np.float32)
# unpad images to get the original resolution
new_height, new_width = preds.shape[0], preds.shape[1]
h_start = new_height // 2 - height_even // 2
h_end = h_start + height
w_start = new_width // 2 - width_even // 2
w_end = w_start + width
preds = preds[h_start:h_end, w_start:w_end, :]
# save files
out_path = Path(tempfile.mkdtemp()) / "output.png"
Image.fromarray(
np.array((np.clip(preds, 0.0, 1.0) * 255.0).astype(jnp.uint8))
).save(str(out_path))
return out_path
title = "Maxim Multi-Axis MLP for Image Processing"
description = ""
article = "AppsGenz"
grApp = gr.Interface(
sentence_builder,
[
gr.Image(type="filepath", label="Input"),
gr.Radio([
"Image Denoising",
"Image Deblurring (GoPro)",
"Image Deblurring (REDS)",
"Image Deblurring (RealBlur_R)",
"Image Deblurring (RealBlur_J)",
"Image Deraining (Rain streak)",
"Image Deraining (Rain drop)",
"Image Dehazing (Indoor)",
"Image Dehazing (Outdoor)",
"Image Enhancement (Low-light)",
"Image Enhancement (Retouching)"], type="value", value='Image Denoising', label='Choose a model.'),
], [
gr.Image(type="filepath", label="Output"),
],
title=title,
description=description,
article=article)
grApp.queue(concurrency_count=2)
grApp.launch(share=False)