mag2mag / app.py
fpramunno's picture
Update app.py
ff34208 verified
raw
history blame
No virus
2.71 kB
import gradio as gr
from pathlib import Path
import os
from PIL import Image
import torch
import torchvision.transforms as transforms
import requests
import numpy as np
# Preprocessing
from modules import PaletteModelV2
from diffusion import Diffusion_cond
DESCRIPTION = '''
<div style="display: flex; justify-content: center; align-items: center; flex-direction: column; font-size: 36px; margin-top: 20px;">
<h1><a href="https://github.com/fpramunno/MAG2MAG" target="_blank" style="color: black; text-decoration: none;">MAG2MAG</a></h1>
<img src="https://raw.githubusercontent.com/fpramunno/MAG2MAG/main/pred.png" alt="teaser" style="width: 100%; max-width: 800px; height: auto;">
</div>'''
# Check for GPU availability, else use CPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = PaletteModelV2(c_in=2, c_out=1, num_classes=5, image_size=256, device=device, true_img_size=64).to(device)
ckpt = torch.load('ema_ckpt_cond.pt', map_location=torch.device(device))
model.load_state_dict(ckpt)
diffusion = Diffusion_cond(img_size=256, device=device)
model.eval()
transform_hmi = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((256, 256)),
transforms.RandomVerticalFlip(p=1.0),
transforms.Normalize(mean=(0.5,), std=(0.5,))
])
def generate_image(seed_image):
_, file_ext = os.path.splitext(seed_image)
if file_ext.lower() == '.jp2':
input_img = Image.open(seed_image)
input_img_pil = transform_hmi(input_img).reshape(1, 1, 256, 256).to(device)
elif file_ext.lower() == '.fits':
with fits.open(seed_image) as hdul:
data = hdul[0].data
input_img_pil = transform_hmi(data).reshape(1, 1, 256, 256).to(device)
generated_image = diffusion.sample(model, y=seed_image_tensor, labels=None, n=1)
inp_img = seed_image_tensor.reshape(1, 256, 256).permute(1, 2, 0)
inp_img = np.squeeze(inp_img.cpu().numpy())
inp = Image.fromarray(inp_img) # Create a PIL Image from array
inp = inp.transpose(Image.FLIP_TOP_BOTTOM)
img = generated_image[0].reshape(1, 256, 256).permute(1, 2, 0) # Permute dimensions to height x width x channels
img = np.squeeze(img.cpu().numpy())
v = Image.fromarray(img) # Create a PIL Image from array
v = v.transpose(Image.FLIP_TOP_BOTTOM)
return inp, v
# Create Gradio interface
iface = gr.Interface(
fn=generate_image,
inputs="file",
outputs="image",
title="Magnetogram-to-Magnetogram: Generative Forecasting of Solar Evolution",
description="Upload a LoS magnetogram and predict how it is going to be in 24 hours."
)
iface.launch()