import gradio as gr from pathlib import Path import os from PIL import Image import torch import torchvision.transforms as transforms import requests import numpy as np # Preprocessing from modules import PaletteModelV2 from diffusion import Diffusion_cond # Check for GPU availability, else use CPU device = 'cuda' if torch.cuda.is_available() else 'cpu' model = PaletteModelV2(c_in=2, c_out=1, num_classes=5, image_size=256, device=device, true_img_size=64).to(device) ckpt = torch.load('ema_ckpt_cond.pt', map_location=torch.device(device)) model.load_state_dict(ckpt) diffusion = Diffusion_cond(img_size=256, device=device) model.eval() transform_hmi = transforms.Compose([ transforms.ToTensor(), transforms.Resize((256, 256)), transforms.RandomVerticalFlip(p=1.0), transforms.Normalize(mean=(0.5,), std=(0.5,)) ]) def generate_image(seed_image): seed_image_tensor = transform_hmi(Image.open(seed_image)).reshape(1, 1, 256, 256).to(device) generated_image = diffusion.sample(model, y=seed_image_tensor, labels=None, n=1) # generated_image_pil = transforms.ToPILImage()(generated_image.squeeze().cpu()) img = generated_image[0].reshape(1, 256, 256).permute(1, 2, 0) # Permute dimensions to height x width x channels img = np.squeeze(img.cpu().numpy()) v = Image.fromarray(img) # Create a PIL Image from array v = v.transpose(Image.FLIP_TOP_BOTTOM) return v # Create Gradio interface iface = gr.Interface( fn=generate_image, inputs="file", outputs="image", title="Magnetogram-to-Magnetogram: Generative Forecasting of Solar Evolution", description="Upload a LoS magnetogram and predict how it is going to be in 24 hours." ) iface.launch()