fpramunno commited on
Commit
0210cac
1 Parent(s): f4dd292

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -3
app.py CHANGED
@@ -17,7 +17,7 @@ model = PaletteModelV2(c_in=2, c_out=1, num_classes=5, image_size=256, true_img_
17
  ckpt = torch.load('ema_ckpt_cond.pt')
18
  model.load_state_dict(ckpt)
19
 
20
- diffusion = Diffusion_cond(noise_steps=1000, img_size=256, device=device)
21
  model.eval()
22
 
23
  transform_hmi = transforms.Compose([
@@ -30,8 +30,13 @@ transform_hmi = transforms.Compose([
30
  def generate_image(seed_image):
31
  seed_image_tensor = transform_hmi(Image.open(seed_image)).reshape(1, 1, 256, 256).to(device)
32
  generated_image = diffusion.sample(model, y=seed_image_tensor, labels=None, n=1)
33
- generated_image_pil = transforms.ToPILImage()(generated_image.squeeze().cpu())
34
- return generated_image_pil
 
 
 
 
 
35
 
36
  # Create Gradio interface
37
  iface = gr.Interface(
 
17
  ckpt = torch.load('ema_ckpt_cond.pt')
18
  model.load_state_dict(ckpt)
19
 
20
+ diffusion = Diffusion_cond(img_size=256, device=device)
21
  model.eval()
22
 
23
  transform_hmi = transforms.Compose([
 
30
  def generate_image(seed_image):
31
  seed_image_tensor = transform_hmi(Image.open(seed_image)).reshape(1, 1, 256, 256).to(device)
32
  generated_image = diffusion.sample(model, y=seed_image_tensor, labels=None, n=1)
33
+ # generated_image_pil = transforms.ToPILImage()(generated_image.squeeze().cpu())
34
+ img = generated_image[0].permute(1, 2, 0) # Permute dimensions to height x width x channels
35
+ img = np.squeeze(img.cpu().numpy())
36
+ v = Image.fromarray(img) # Create a PIL Image from array
37
+ v = v.transpose(Image.FLIP_TOP_BOTTOM)
38
+
39
+ return v
40
 
41
  # Create Gradio interface
42
  iface = gr.Interface(