File size: 3,122 Bytes
dfe21b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image, ImageFilter
import gradio as gr
import numpy as np
import os
import uuid
from model import model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

resize_transform = transforms.Resize((512, 512))

def load_image(image):
    image = Image.fromarray(image).convert('RGB')
    image = transform(image)
    return image.unsqueeze(0).to(device)

def interpolate_vectors(v1, v2, num_steps):
    return [v1 * (1 - alpha) + v2 * alpha for alpha in np.linspace(0, 1, num_steps)]

def infer_and_interpolate(image1, image2, num_interpolations=24):
    image1 = load_image(image1)
    image2 = load_image(image2)

    with torch.no_grad():
        mu1, logvar1 = model.encode(image1)
        mu2, logvar2 = model.encode(image2)
        interpolated_vectors = interpolate_vectors(mu1, mu2, num_interpolations)
        decoded_images = [model.decode(vec).squeeze(0) for vec in interpolated_vectors]

    return decoded_images

def create_gif(decoded_images, duration=200, apply_blur=False):
    reversed_images = decoded_images[::-1]
    all_images = decoded_images + reversed_images

    pil_images = []
    for img in all_images:
        img = (img - img.min()) / (img.max() - img.min())
        img = (img * 255).byte()
        pil_img = transforms.ToPILImage()(img.cpu()).convert("RGB")
        pil_img = resize_transform(pil_img)
        if apply_blur:
            pil_img = pil_img.filter(ImageFilter.GaussianBlur(radius=1))
        pil_images.append(pil_img)

    gif_filename = f"/tmp/morphing_{uuid.uuid4().hex}.gif"
    pil_images[0].save(gif_filename, save_all=True, append_images=pil_images[1:], duration=duration, loop=0)

    return gif_filename

def create_morphing_gif(image1, image2, num_interpolations=24, duration=200):
    decoded_images = infer_and_interpolate(image1, image2, num_interpolations)
    gif_path = create_gif(decoded_images, duration)
    
    return gif_path

examples = [
    ["example_images/image1.jpg", "example_images/image2.png", 24, 200],
    ["example_images/image3.jpg", "example_images/image4.jpg", 30, 150],
]

with gr.Blocks() as morphing:
    with gr.Column():
        with gr.Column():
            num_interpolations = gr.Slider(minimum=2, maximum=50, value=24, step=1, label="Number of interpolations")
            duration = gr.Slider(minimum=100, maximum=1000, value=200, step=50, label="Duration per frame (ms)")
            generate_button = gr.Button("Generate Morphing GIF")
        output_gif = gr.Image(label="Morphing GIF")
        with gr.Row():
            image1 = gr.Image(label="Upload first image", type="numpy")
            image2 = gr.Image(label="Upload second image", type="numpy")
            
    generate_button.click(fn=create_morphing_gif, inputs=[image1, image2, num_interpolations, duration], outputs=output_gif)

    gr.Examples(examples=examples, inputs=[image1, image2, num_interpolations, duration])