mag2mag / app.py
fpramunno's picture
Update app.py
13afb1c verified
raw
history blame
No virus
3.38 kB
import gradio as gr
from pathlib import Path
import os
from PIL import Image
import torch
import torchvision.transforms as transforms
import requests
import numpy as np
from astropy.io import fits
# Preprocessing
from modules import PaletteModelV2
from diffusion import Diffusion_cond
DESCRIPTION = '''
<div style="display: flex; justify-content: center; align-items: center; flex-direction: column; font-size: 36px; margin-top: 20px;">
<h1><a href="https://github.com/fpramunno/MAG2MAG" target="_blank" style="color: black; text-decoration: none;">MAG2MAG</a></h1>
<img src="https://raw.githubusercontent.com/fpramunno/MAG2MAG/main/pred.png" alt="teaser" style="width: 100%; max-width: 800px; height: auto;">
</div>'''
# Check for GPU availability, else use CPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = PaletteModelV2(c_in=2, c_out=1, num_classes=5, image_size=256, device=device, true_img_size=64).to(device)
ckpt = torch.load('ema_ckpt_cond.pt', map_location=torch.device(device))
model.load_state_dict(ckpt)
diffusion = Diffusion_cond(img_size=256, device=device)
model.eval()
transform_hmi = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((256, 256)),
transforms.RandomVerticalFlip(p=1.0),
transforms.Normalize(mean=(0.5,), std=(0.5,))
])
def generate_image(seed_image):
_, file_ext = os.path.splitext(seed_image)
if file_ext.lower() == '.jp2':
input_img = Image.open(seed_image)
input_img_pil = transform_hmi(input_img).reshape(1, 1, 256, 256).to(device)
elif file_ext.lower() == '.fits':
with fits.open(seed_image) as hdul:
data = hdul[0].data
input_img_pil = transform_hmi(data).reshape(1, 1, 256, 256).to(device)
generated_image = diffusion.sample(model, y=seed_image_tensor, labels=None, n=1)
inp_img = seed_image_tensor.reshape(1, 256, 256).permute(1, 2, 0)
inp_img = np.squeeze(inp_img.cpu().numpy())
inp = Image.fromarray(inp_img) # Create a PIL Image from array
inp = inp.transpose(Image.FLIP_TOP_BOTTOM)
img = generated_image[0].reshape(1, 256, 256).permute(1, 2, 0) # Permute dimensions to height x width x channels
img = np.squeeze(img.cpu().numpy())
v = Image.fromarray(img) # Create a PIL Image from array
v = v.transpose(Image.FLIP_TOP_BOTTOM)
return inp, v
# Create Gradio interface
with gr.Blocks() as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
input_image = gr.File(label='Input Image')
output_image1 = gr.Image(label='Input LoS Magnetogram', type='pil', interactive=False)
output_image2 = gr.Image(label='Predicted LoS Magnetogram in 24 hours', type='pil', interactive=False)
# Buttons are placed in a nested Row inside the main Row to align them directly under the image
with gr.Row():
clear_button = gr.Button('Clear')
process_button = gr.Button('Generate')
# Binding the process button to the function
process_button.click(
fn=generate_image,
inputs=input_image,
outputs=[output_image1, output_image2]
)
# Clear button to reset the input image
clear_button.click(
fn=lambda: None, # Clears the input
inputs=None,
outputs=input_image
)
demo.launch()