mag2mag / app_backup.py
fpramunno's picture
Create app_backup.py
edfd4bd verified
raw
history blame
No virus
1.75 kB
import gradio as gr
from pathlib import Path
import os
from PIL import Image
import torch
import torchvision.transforms as transforms
import requests
import numpy as np
# Preprocessing
from modules import PaletteModelV2
from diffusion import Diffusion_cond
# Check for GPU availability, else use CPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = PaletteModelV2(c_in=2, c_out=1, num_classes=5, image_size=256, device=device, true_img_size=64).to(device)
ckpt = torch.load('ema_ckpt_cond.pt', map_location=torch.device(device))
model.load_state_dict(ckpt)
diffusion = Diffusion_cond(img_size=256, device=device)
model.eval()
transform_hmi = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((256, 256)),
transforms.RandomVerticalFlip(p=1.0),
transforms.Normalize(mean=(0.5,), std=(0.5,))
])
def generate_image(seed_image):
seed_image_tensor = transform_hmi(Image.open(seed_image)).reshape(1, 1, 256, 256).to(device)
generated_image = diffusion.sample(model, y=seed_image_tensor, labels=None, n=1)
# generated_image_pil = transforms.ToPILImage()(generated_image.squeeze().cpu())
img = generated_image[0].reshape(1, 256, 256).permute(1, 2, 0) # Permute dimensions to height x width x channels
img = np.squeeze(img.cpu().numpy())
v = Image.fromarray(img) # Create a PIL Image from array
v = v.transpose(Image.FLIP_TOP_BOTTOM)
return v
# Create Gradio interface
iface = gr.Interface(
fn=generate_image,
inputs="file",
outputs="image",
title="Magnetogram-to-Magnetogram: Generative Forecasting of Solar Evolution",
description="Upload a LoS magnetogram and predict how it is going to be in 24 hours."
)
iface.launch()