fpramunno commited on
Commit
8576ba9
1 Parent(s): 78f112e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -3
app.py CHANGED
@@ -28,8 +28,27 @@ model.load_state_dict(ckpt)
28
  diffusion = Diffusion_cond(img_size=256, device=device)
29
  model.eval()
30
 
31
- transform_hmi = transforms.Compose([
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  transforms.ToTensor(),
 
33
  transforms.Resize((256, 256)),
34
  transforms.RandomVerticalFlip(p=1.0),
35
  transforms.Normalize(mean=(0.5,), std=(0.5,))
@@ -39,12 +58,12 @@ def generate_image(seed_image):
39
  _, file_ext = os.path.splitext(seed_image)
40
 
41
  if file_ext.lower() == '.jp2':
42
- input_img = Image.open(seed_image)
43
  input_img_pil = transform_hmi(input_img).reshape(1, 1, 256, 256).to(device)
44
  elif file_ext.lower() == '.fits':
45
  with fits.open(seed_image) as hdul:
46
  data = hdul[0].data
47
- input_img_pil = transform_hmi(data).reshape(1, 1, 256, 256).to(device)
48
  else:
49
  print(f'Format {file_ext.lower()} not supported')
50
 
 
28
  diffusion = Diffusion_cond(img_size=256, device=device)
29
  model.eval()
30
 
31
+ from torchvision import transforms
32
+
33
+ # Define a custom transform to clamp data
34
+ class ClampTransform(object):
35
+ def __init__(self, min_value=-250, max_value=250):
36
+ self.min_value = min_value
37
+ self.max_value = max_value
38
+
39
+ def __call__(self, tensor):
40
+ return torch.clamp(tensor, self.min_value, self.max_value)
41
+
42
+ transform_hmi_jp2 = transforms.Compose([
43
+ transforms.ToTensor(),
44
+ transforms.Resize((256, 256)),
45
+ transforms.RandomVerticalFlip(p=1.0),
46
+ transforms.Normalize(mean=(0.5,), std=(0.5,))
47
+ ])
48
+
49
+ transform_hmi_fits = transforms.Compose([
50
  transforms.ToTensor(),
51
+ ClampTransform(-250, 250),
52
  transforms.Resize((256, 256)),
53
  transforms.RandomVerticalFlip(p=1.0),
54
  transforms.Normalize(mean=(0.5,), std=(0.5,))
 
58
  _, file_ext = os.path.splitext(seed_image)
59
 
60
  if file_ext.lower() == '.jp2':
61
+ input_img = Image.transform_hmi_jp2(seed_image)
62
  input_img_pil = transform_hmi(input_img).reshape(1, 1, 256, 256).to(device)
63
  elif file_ext.lower() == '.fits':
64
  with fits.open(seed_image) as hdul:
65
  data = hdul[0].data
66
+ input_img_pil = transform_hmi_fits(data).reshape(1, 1, 256, 256).to(device)
67
  else:
68
  print(f'Format {file_ext.lower()} not supported')
69