import gradio as gr from pathlib import Path import os from PIL import Image import torch import torchvision.transforms as transforms import requests import numpy as np # Preprocessing from modules import PaletteModelV2 from diffusion import Diffusion_cond DESCRIPTION = '''

MAG2MAG

teaser
''' # Check for GPU availability, else use CPU device = 'cuda' if torch.cuda.is_available() else 'cpu' model = PaletteModelV2(c_in=2, c_out=1, num_classes=5, image_size=256, device=device, true_img_size=64).to(device) ckpt = torch.load('ema_ckpt_cond.pt', map_location=torch.device(device)) model.load_state_dict(ckpt) diffusion = Diffusion_cond(img_size=256, device=device) model.eval() transform_hmi = transforms.Compose([ transforms.ToTensor(), transforms.Resize((256, 256)), transforms.RandomVerticalFlip(p=1.0), transforms.Normalize(mean=(0.5,), std=(0.5,)) ]) def generate_image(seed_image): _, file_ext = os.path.splitext(seed_image) if file_ext.lower() == '.jp2': input_img = Image.open(seed_image) input_img_pil = transform_hmi(input_img).reshape(1, 1, 256, 256).to(device) elif file_ext.lower() == '.fits': with fits.open(seed_image) as hdul: data = hdul[0].data input_img_pil = transform_hmi(data).reshape(1, 1, 256, 256).to(device) generated_image = diffusion.sample(model, y=seed_image_tensor, labels=None, n=1) inp_img = seed_image_tensor.reshape(1, 256, 256).permute(1, 2, 0) inp_img = np.squeeze(inp_img.cpu().numpy()) inp = Image.fromarray(inp_img) # Create a PIL Image from array inp = inp.transpose(Image.FLIP_TOP_BOTTOM) img = generated_image[0].reshape(1, 256, 256).permute(1, 2, 0) # Permute dimensions to height x width x channels img = np.squeeze(img.cpu().numpy()) v = Image.fromarray(img) # Create a PIL Image from array v = v.transpose(Image.FLIP_TOP_BOTTOM) return inp, v # Create Gradio interface iface = gr.Interface( fn=generate_image, inputs="file", outputs="image", title="Magnetogram-to-Magnetogram: Generative Forecasting of Solar Evolution", description="Upload a LoS magnetogram and predict how it is going to be in 24 hours." ) iface.launch()