mag2mag / app.py
fpramunno's picture
Upload 3 files
0f1af34 verified
raw
history blame
No virus
2.46 kB
import gradio as gr
from pathlib import Path
import os
from PIL import Image
import torch
import torchvision.transforms as transforms
import requests
# Function to download the model from Google Drive
def download_file_from_google_drive(id, destination):
URL = "https://drive.google.com/uc?export=download"
session = requests.Session()
response = session.get(URL, params={'id': id}, stream=True)
token = get_confirm_token(response)
if token:
params = {'id': id, 'confirm': token}
response = session.get(URL, params=params, stream=True)
save_response_content(response, destination)
def get_confirm_token(response):
for key, value in response.cookies.items():
if key.startswith('download_warning'):
return value
return None
def save_response_content(response, destination):
CHUNK_SIZE = 32768
with open(destination, "wb") as f:
for chunk in response.iter_content(CHUNK_SIZE):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
# Replace 'YOUR_FILE_ID' with your actual file ID from Google Drive
file_id = '1WJ33nys02XpPDsMO5uIZFiLqTuAT_iuV'
destination = 'ema_ckpt_cond.pt'
download_file_from_google_drive(file_id, destination)
# Preprocessing
from modules import PaletteModelV2
from diffusion import Diffusion_cond
device = 'cuda'
model = PaletteModelV2(c_in=2, c_out=1, num_classes=5, image_size=256, true_img_size=64).to(device)
ckpt = torch.load(destination, map_location=device)
model.load_state_dict(ckpt)
diffusion = Diffusion_cond(noise_steps=1000, img_size=256, device=device)
model.eval()
transform_hmi = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((256, 256)),
transforms.RandomVerticalFlip(p=1.0),
transforms.Normalize(mean=(0.5,), std=(0.5,))
])
def generate_image(seed_image):
seed_image_tensor = transform_hmi(Image.open(seed_image)).reshape(1, 1, 256, 256).to(device)
generated_image = diffusion.sample(model, y=seed_image_tensor, labels=None, n=1)
generated_image_pil = transforms.ToPILImage()(generated_image.squeeze().cpu())
return generated_image_pil
# Create Gradio interface
iface = gr.Interface(
fn=generate_image,
inputs="file",
outputs="image",
title="Magnetogram-to-Magnetogram: Generative Forecasting of Solar Evolution",
description="Upload a LoS magnetogram and predict how it is going to be in 24 hours."
)
iface.launch()