|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
os.environ["XLA_FLAGS"] = "--xla_gpu_force_compilation_parallelism=1" |
|
|
|
import gradio as gr |
|
import numpy as np |
|
from PIL import Image |
|
from pathlib import Path |
|
import importlib |
|
import ml_collections |
|
import tempfile |
|
import jax.numpy as jnp |
|
import flax |
|
|
|
from run_eval import ( |
|
_MODEL_FILENAME, |
|
_MODEL_VARIANT_DICT, |
|
_MODEL_CONFIGS, |
|
get_params, |
|
mod_padding_symmetric, |
|
make_shape_even, |
|
augment_image, |
|
) |
|
|
|
|
|
def sentence_builder(image, model): |
|
params = { |
|
"Image Denoising": get_params("checkpoints/denoising-SIDD/checkpoint.npz"), |
|
"Image Deblurring (GoPro)": get_params( |
|
"checkpoints/debluring-GoPro/checkpoint.npz" |
|
), |
|
"Image Deblurring (REDS)": get_params( |
|
"checkpoints/debluring-REDS/checkpoint.npz" |
|
), |
|
"Image Deblurring (RealBlur_R)": get_params( |
|
"checkpoints/debluring-Real-Blur-R/checkpoint.npz" |
|
), |
|
"Image Deblurring (RealBlur_J)": get_params( |
|
"checkpoints/debluring-Real-Blur-J/checkpoint.npz" |
|
), |
|
"Image Deraining (Rain streak)": get_params( |
|
"checkpoints/deraining-Rain13k/checkpoint.npz" |
|
), |
|
"Image Deraining (Rain drop)": get_params( |
|
"checkpoints/deraining-Raindrop/checkpoint.npz" |
|
), |
|
"Image Dehazing (Indoor)": get_params( |
|
"checkpoints/dehazing-RESIDE-Indoor/checkpoint.npz" |
|
), |
|
"Image Dehazing (Outdoor)": get_params( |
|
"checkpoints/dehazing-RESIDE-Outdoor/checkpoint.npz" |
|
), |
|
"Image Enhancement (Low-light)": get_params( |
|
"checkpoints/enhancement-LOL/checkpoint.npz" |
|
), |
|
"Image Enhancement (Retouching)": get_params( |
|
"checkpoints/enhancement-FiveK/checkpoint.npz" |
|
), |
|
} |
|
|
|
model_mod = importlib.import_module(f"maxim.models.{_MODEL_FILENAME}") |
|
models = {} |
|
for task in _MODEL_VARIANT_DICT.keys(): |
|
model_configs = ml_collections.ConfigDict(_MODEL_CONFIGS) |
|
model_configs.variant = _MODEL_VARIANT_DICT[task] |
|
models[task] = model_mod.Model(**model_configs) |
|
|
|
params = params[model] |
|
task = model.split()[1] |
|
model = models[task] |
|
|
|
input_img = ( |
|
np.asarray(Image.open(str(image)).convert("RGB"), np.float32) / 255.0 |
|
) |
|
|
|
|
|
height, width = input_img.shape[0], input_img.shape[1] |
|
input_img = make_shape_even(input_img) |
|
height_even, width_even = input_img.shape[0], input_img.shape[1] |
|
|
|
|
|
input_img = mod_padding_symmetric(input_img, factor=64) |
|
input_img = np.expand_dims(input_img, axis=0) |
|
|
|
|
|
preds = model.apply({"params": flax.core.freeze(params)}, input_img) |
|
if isinstance(preds, list): |
|
preds = preds[-1] |
|
if isinstance(preds, list): |
|
preds = preds[-1] |
|
|
|
preds = np.array(preds[0], np.float32) |
|
|
|
|
|
new_height, new_width = preds.shape[0], preds.shape[1] |
|
h_start = new_height // 2 - height_even // 2 |
|
h_end = h_start + height |
|
w_start = new_width // 2 - width_even // 2 |
|
w_end = w_start + width |
|
preds = preds[h_start:h_end, w_start:w_end, :] |
|
|
|
|
|
out_path = Path(tempfile.mkdtemp()) / "output.png" |
|
Image.fromarray( |
|
np.array((np.clip(preds, 0.0, 1.0) * 255.0).astype(jnp.uint8)) |
|
).save(str(out_path)) |
|
|
|
return out_path |
|
|
|
title = "Maxim Multi-Axis MLP for Image Processing" |
|
description = "" |
|
article = "AppsGenz" |
|
grApp = gr.Interface( |
|
sentence_builder, |
|
[ |
|
gr.Image(type="filepath", label="Input"), |
|
gr.Radio([ |
|
"Image Denoising", |
|
"Image Deblurring (GoPro)", |
|
"Image Deblurring (REDS)", |
|
"Image Deblurring (RealBlur_R)", |
|
"Image Deblurring (RealBlur_J)", |
|
"Image Deraining (Rain streak)", |
|
"Image Deraining (Rain drop)", |
|
"Image Dehazing (Indoor)", |
|
"Image Dehazing (Outdoor)", |
|
"Image Enhancement (Low-light)", |
|
"Image Enhancement (Retouching)"], type="value", value='Image Denoising', label='Choose a model.'), |
|
], [ |
|
gr.Image(type="filepath", label="Output"), |
|
], |
|
title=title, |
|
description=description, |
|
article=article) |
|
grApp.queue(concurrency_count=2) |
|
grApp.launch(share=False) |