fpramunno commited on
Commit
ff34208
1 Parent(s): edfd4bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -3
app.py CHANGED
@@ -11,6 +11,11 @@ import numpy as np
11
  from modules import PaletteModelV2
12
  from diffusion import Diffusion_cond
13
 
 
 
 
 
 
14
 
15
  # Check for GPU availability, else use CPU
16
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
@@ -30,15 +35,29 @@ transform_hmi = transforms.Compose([
30
  ])
31
 
32
  def generate_image(seed_image):
33
- seed_image_tensor = transform_hmi(Image.open(seed_image)).reshape(1, 1, 256, 256).to(device)
 
 
 
 
 
 
 
 
 
 
34
  generated_image = diffusion.sample(model, y=seed_image_tensor, labels=None, n=1)
35
- # generated_image_pil = transforms.ToPILImage()(generated_image.squeeze().cpu())
 
 
 
 
36
  img = generated_image[0].reshape(1, 256, 256).permute(1, 2, 0) # Permute dimensions to height x width x channels
37
  img = np.squeeze(img.cpu().numpy())
38
  v = Image.fromarray(img) # Create a PIL Image from array
39
  v = v.transpose(Image.FLIP_TOP_BOTTOM)
40
 
41
- return v
42
 
43
  # Create Gradio interface
44
  iface = gr.Interface(
 
11
  from modules import PaletteModelV2
12
  from diffusion import Diffusion_cond
13
 
14
+ DESCRIPTION = '''
15
+ <div style="display: flex; justify-content: center; align-items: center; flex-direction: column; font-size: 36px; margin-top: 20px;">
16
+ <h1><a href="https://github.com/fpramunno/MAG2MAG" target="_blank" style="color: black; text-decoration: none;">MAG2MAG</a></h1>
17
+ <img src="https://raw.githubusercontent.com/fpramunno/MAG2MAG/main/pred.png" alt="teaser" style="width: 100%; max-width: 800px; height: auto;">
18
+ </div>'''
19
 
20
  # Check for GPU availability, else use CPU
21
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
35
  ])
36
 
37
  def generate_image(seed_image):
38
+ _, file_ext = os.path.splitext(seed_image)
39
+
40
+ if file_ext.lower() == '.jp2':
41
+ input_img = Image.open(seed_image)
42
+ input_img_pil = transform_hmi(input_img).reshape(1, 1, 256, 256).to(device)
43
+ elif file_ext.lower() == '.fits':
44
+ with fits.open(seed_image) as hdul:
45
+ data = hdul[0].data
46
+
47
+ input_img_pil = transform_hmi(data).reshape(1, 1, 256, 256).to(device)
48
+
49
  generated_image = diffusion.sample(model, y=seed_image_tensor, labels=None, n=1)
50
+
51
+ inp_img = seed_image_tensor.reshape(1, 256, 256).permute(1, 2, 0)
52
+ inp_img = np.squeeze(inp_img.cpu().numpy())
53
+ inp = Image.fromarray(inp_img) # Create a PIL Image from array
54
+ inp = inp.transpose(Image.FLIP_TOP_BOTTOM)
55
  img = generated_image[0].reshape(1, 256, 256).permute(1, 2, 0) # Permute dimensions to height x width x channels
56
  img = np.squeeze(img.cpu().numpy())
57
  v = Image.fromarray(img) # Create a PIL Image from array
58
  v = v.transpose(Image.FLIP_TOP_BOTTOM)
59
 
60
+ return inp, v
61
 
62
  # Create Gradio interface
63
  iface = gr.Interface(