File size: 3,383 Bytes
0f1af34
 
 
 
 
 
 
be9674f
13afb1c
0f1af34
 
 
 
 
ff34208
 
 
 
 
05493b2
 
 
0f1af34
a5a784e
4eeda6c
0f1af34
 
0210cac
0f1af34
 
 
 
 
 
 
 
 
 
ff34208
 
 
 
 
 
 
 
 
 
 
0f1af34
ff34208
 
 
 
 
eb240c8
0210cac
 
 
 
ff34208
0f1af34
 
cf7ac47
 
 
 
 
 
 
 
 
 
 
0f1af34
cf7ac47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import gradio as gr
from pathlib import Path
import os
from PIL import Image
import torch
import torchvision.transforms as transforms
import requests
import numpy as np
from astropy.io import fits

# Preprocessing
from modules import PaletteModelV2
from diffusion import Diffusion_cond

DESCRIPTION = '''
<div style="display: flex; justify-content: center; align-items: center; flex-direction: column; font-size: 36px; margin-top: 20px;">
    <h1><a href="https://github.com/fpramunno/MAG2MAG" target="_blank" style="color: black; text-decoration: none;">MAG2MAG</a></h1>
    <img src="https://raw.githubusercontent.com/fpramunno/MAG2MAG/main/pred.png" alt="teaser" style="width: 100%; max-width: 800px; height: auto;">
</div>'''

# Check for GPU availability, else use CPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = PaletteModelV2(c_in=2, c_out=1, num_classes=5, image_size=256, device=device, true_img_size=64).to(device)
ckpt = torch.load('ema_ckpt_cond.pt', map_location=torch.device(device))
model.load_state_dict(ckpt)

diffusion = Diffusion_cond(img_size=256, device=device)
model.eval()

transform_hmi = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((256, 256)),
    transforms.RandomVerticalFlip(p=1.0),
    transforms.Normalize(mean=(0.5,), std=(0.5,))
])

def generate_image(seed_image):
    _, file_ext = os.path.splitext(seed_image)
    
    if file_ext.lower() == '.jp2':
        input_img = Image.open(seed_image)
        input_img_pil = transform_hmi(input_img).reshape(1, 1, 256, 256).to(device)
    elif file_ext.lower() == '.fits':
        with fits.open(seed_image) as hdul:
            data = hdul[0].data
            
            input_img_pil = transform_hmi(data).reshape(1, 1, 256, 256).to(device)
            
    generated_image = diffusion.sample(model, y=seed_image_tensor, labels=None, n=1)

    inp_img = seed_image_tensor.reshape(1, 256, 256).permute(1, 2, 0)  
    inp_img = np.squeeze(inp_img.cpu().numpy())
    inp = Image.fromarray(inp_img)                # Create a PIL Image from array
    inp = inp.transpose(Image.FLIP_TOP_BOTTOM) 
    img = generated_image[0].reshape(1, 256, 256).permute(1, 2, 0)              # Permute dimensions to height x width x channels
    img = np.squeeze(img.cpu().numpy())
    v = Image.fromarray(img)                # Create a PIL Image from array
    v = v.transpose(Image.FLIP_TOP_BOTTOM)  
    
    return inp, v

# Create Gradio interface
with gr.Blocks() as demo:
    gr.Markdown(DESCRIPTION)
    with gr.Row():
        input_image = gr.File(label='Input Image')
        output_image1 = gr.Image(label='Input LoS Magnetogram', type='pil', interactive=False)
        output_image2 = gr.Image(label='Predicted LoS Magnetogram in 24 hours', type='pil', interactive=False)
        # Buttons are placed in a nested Row inside the main Row to align them directly under the image
    with gr.Row():
        clear_button = gr.Button('Clear')
        process_button = gr.Button('Generate')
        

    # Binding the process button to the function
    process_button.click(
        fn=generate_image,
        inputs=input_image,
        outputs=[output_image1, output_image2]
    )
    
    # Clear button to reset the input image
    clear_button.click(
        fn=lambda: None,  # Clears the input
        inputs=None,
        outputs=input_image
    )

demo.launch()