BioMike commited on
Commit
dfe21b4
1 Parent(s): e888cf9

Update morphing.py

Browse files
Files changed (1) hide show
  1. morphing.py +84 -83
morphing.py CHANGED
@@ -1,83 +1,84 @@
1
- import torch
2
- import torch.nn as nn
3
- from torchvision import transforms
4
- from PIL import Image, ImageFilter
5
- import gradio as gr
6
- import numpy as np
7
- import os
8
- import uuid
9
-
10
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
-
12
- transform = transforms.Compose([
13
- transforms.Resize((128, 128)),
14
- transforms.ToTensor(),
15
- transforms.Normalize((0.5,), (0.5,))
16
- ])
17
-
18
- resize_transform = transforms.Resize((512, 512))
19
-
20
- def load_image(image):
21
- image = Image.fromarray(image).convert('RGB')
22
- image = transform(image)
23
- return image.unsqueeze(0).to(device)
24
-
25
- def interpolate_vectors(v1, v2, num_steps):
26
- return [v1 * (1 - alpha) + v2 * alpha for alpha in np.linspace(0, 1, num_steps)]
27
-
28
- def infer_and_interpolate(image1, image2, num_interpolations=24):
29
- image1 = load_image(image1)
30
- image2 = load_image(image2)
31
-
32
- with torch.no_grad():
33
- mu1, logvar1 = model.encode(image1)
34
- mu2, logvar2 = model.encode(image2)
35
- interpolated_vectors = interpolate_vectors(mu1, mu2, num_interpolations)
36
- decoded_images = [model.decode(vec).squeeze(0) for vec in interpolated_vectors]
37
-
38
- return decoded_images
39
-
40
- def create_gif(decoded_images, duration=200, apply_blur=False):
41
- reversed_images = decoded_images[::-1]
42
- all_images = decoded_images + reversed_images
43
-
44
- pil_images = []
45
- for img in all_images:
46
- img = (img - img.min()) / (img.max() - img.min())
47
- img = (img * 255).byte()
48
- pil_img = transforms.ToPILImage()(img.cpu()).convert("RGB")
49
- pil_img = resize_transform(pil_img)
50
- if apply_blur:
51
- pil_img = pil_img.filter(ImageFilter.GaussianBlur(radius=1))
52
- pil_images.append(pil_img)
53
-
54
- gif_filename = f"/tmp/morphing_{uuid.uuid4().hex}.gif"
55
- pil_images[0].save(gif_filename, save_all=True, append_images=pil_images[1:], duration=duration, loop=0)
56
-
57
- return gif_filename
58
-
59
- def create_morphing_gif(image1, image2, num_interpolations=24, duration=200):
60
- decoded_images = infer_and_interpolate(image1, image2, num_interpolations)
61
- gif_path = create_gif(decoded_images, duration)
62
-
63
- return gif_path
64
-
65
- examples = [
66
- ["example_images/image1.jpg", "example_images/image2.png", 24, 200],
67
- ["example_images/image3.jpg", "example_images/image4.jpg", 30, 150],
68
- ]
69
-
70
- with gr.Blocks() as morphing:
71
- with gr.Column():
72
- with gr.Column():
73
- num_interpolations = gr.Slider(minimum=2, maximum=50, value=24, step=1, label="Number of interpolations")
74
- duration = gr.Slider(minimum=100, maximum=1000, value=200, step=50, label="Duration per frame (ms)")
75
- generate_button = gr.Button("Generate Morphing GIF")
76
- output_gif = gr.Image(label="Morphing GIF")
77
- with gr.Row():
78
- image1 = gr.Image(label="Upload first image", type="numpy")
79
- image2 = gr.Image(label="Upload second image", type="numpy")
80
-
81
- generate_button.click(fn=create_morphing_gif, inputs=[image1, image2, num_interpolations, duration], outputs=output_gif)
82
-
83
- gr.Examples(examples=examples, inputs=[image1, image2, num_interpolations, duration])
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import transforms
4
+ from PIL import Image, ImageFilter
5
+ import gradio as gr
6
+ import numpy as np
7
+ import os
8
+ import uuid
9
+ from model import model
10
+
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+
13
+ transform = transforms.Compose([
14
+ transforms.Resize((128, 128)),
15
+ transforms.ToTensor(),
16
+ transforms.Normalize((0.5,), (0.5,))
17
+ ])
18
+
19
+ resize_transform = transforms.Resize((512, 512))
20
+
21
+ def load_image(image):
22
+ image = Image.fromarray(image).convert('RGB')
23
+ image = transform(image)
24
+ return image.unsqueeze(0).to(device)
25
+
26
+ def interpolate_vectors(v1, v2, num_steps):
27
+ return [v1 * (1 - alpha) + v2 * alpha for alpha in np.linspace(0, 1, num_steps)]
28
+
29
+ def infer_and_interpolate(image1, image2, num_interpolations=24):
30
+ image1 = load_image(image1)
31
+ image2 = load_image(image2)
32
+
33
+ with torch.no_grad():
34
+ mu1, logvar1 = model.encode(image1)
35
+ mu2, logvar2 = model.encode(image2)
36
+ interpolated_vectors = interpolate_vectors(mu1, mu2, num_interpolations)
37
+ decoded_images = [model.decode(vec).squeeze(0) for vec in interpolated_vectors]
38
+
39
+ return decoded_images
40
+
41
+ def create_gif(decoded_images, duration=200, apply_blur=False):
42
+ reversed_images = decoded_images[::-1]
43
+ all_images = decoded_images + reversed_images
44
+
45
+ pil_images = []
46
+ for img in all_images:
47
+ img = (img - img.min()) / (img.max() - img.min())
48
+ img = (img * 255).byte()
49
+ pil_img = transforms.ToPILImage()(img.cpu()).convert("RGB")
50
+ pil_img = resize_transform(pil_img)
51
+ if apply_blur:
52
+ pil_img = pil_img.filter(ImageFilter.GaussianBlur(radius=1))
53
+ pil_images.append(pil_img)
54
+
55
+ gif_filename = f"/tmp/morphing_{uuid.uuid4().hex}.gif"
56
+ pil_images[0].save(gif_filename, save_all=True, append_images=pil_images[1:], duration=duration, loop=0)
57
+
58
+ return gif_filename
59
+
60
+ def create_morphing_gif(image1, image2, num_interpolations=24, duration=200):
61
+ decoded_images = infer_and_interpolate(image1, image2, num_interpolations)
62
+ gif_path = create_gif(decoded_images, duration)
63
+
64
+ return gif_path
65
+
66
+ examples = [
67
+ ["example_images/image1.jpg", "example_images/image2.png", 24, 200],
68
+ ["example_images/image3.jpg", "example_images/image4.jpg", 30, 150],
69
+ ]
70
+
71
+ with gr.Blocks() as morphing:
72
+ with gr.Column():
73
+ with gr.Column():
74
+ num_interpolations = gr.Slider(minimum=2, maximum=50, value=24, step=1, label="Number of interpolations")
75
+ duration = gr.Slider(minimum=100, maximum=1000, value=200, step=50, label="Duration per frame (ms)")
76
+ generate_button = gr.Button("Generate Morphing GIF")
77
+ output_gif = gr.Image(label="Morphing GIF")
78
+ with gr.Row():
79
+ image1 = gr.Image(label="Upload first image", type="numpy")
80
+ image2 = gr.Image(label="Upload second image", type="numpy")
81
+
82
+ generate_button.click(fn=create_morphing_gif, inputs=[image1, image2, num_interpolations, duration], outputs=output_gif)
83
+
84
+ gr.Examples(examples=examples, inputs=[image1, image2, num_interpolations, duration])