ClassicalPortraitsVAE / morphing.py
BioMike's picture
Update morphing.py
dfe21b4 verified
raw
history blame
No virus
3.12 kB
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image, ImageFilter
import gradio as gr
import numpy as np
import os
import uuid
from model import model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
resize_transform = transforms.Resize((512, 512))
def load_image(image):
image = Image.fromarray(image).convert('RGB')
image = transform(image)
return image.unsqueeze(0).to(device)
def interpolate_vectors(v1, v2, num_steps):
return [v1 * (1 - alpha) + v2 * alpha for alpha in np.linspace(0, 1, num_steps)]
def infer_and_interpolate(image1, image2, num_interpolations=24):
image1 = load_image(image1)
image2 = load_image(image2)
with torch.no_grad():
mu1, logvar1 = model.encode(image1)
mu2, logvar2 = model.encode(image2)
interpolated_vectors = interpolate_vectors(mu1, mu2, num_interpolations)
decoded_images = [model.decode(vec).squeeze(0) for vec in interpolated_vectors]
return decoded_images
def create_gif(decoded_images, duration=200, apply_blur=False):
reversed_images = decoded_images[::-1]
all_images = decoded_images + reversed_images
pil_images = []
for img in all_images:
img = (img - img.min()) / (img.max() - img.min())
img = (img * 255).byte()
pil_img = transforms.ToPILImage()(img.cpu()).convert("RGB")
pil_img = resize_transform(pil_img)
if apply_blur:
pil_img = pil_img.filter(ImageFilter.GaussianBlur(radius=1))
pil_images.append(pil_img)
gif_filename = f"/tmp/morphing_{uuid.uuid4().hex}.gif"
pil_images[0].save(gif_filename, save_all=True, append_images=pil_images[1:], duration=duration, loop=0)
return gif_filename
def create_morphing_gif(image1, image2, num_interpolations=24, duration=200):
decoded_images = infer_and_interpolate(image1, image2, num_interpolations)
gif_path = create_gif(decoded_images, duration)
return gif_path
examples = [
["example_images/image1.jpg", "example_images/image2.png", 24, 200],
["example_images/image3.jpg", "example_images/image4.jpg", 30, 150],
]
with gr.Blocks() as morphing:
with gr.Column():
with gr.Column():
num_interpolations = gr.Slider(minimum=2, maximum=50, value=24, step=1, label="Number of interpolations")
duration = gr.Slider(minimum=100, maximum=1000, value=200, step=50, label="Duration per frame (ms)")
generate_button = gr.Button("Generate Morphing GIF")
output_gif = gr.Image(label="Morphing GIF")
with gr.Row():
image1 = gr.Image(label="Upload first image", type="numpy")
image2 = gr.Image(label="Upload second image", type="numpy")
generate_button.click(fn=create_morphing_gif, inputs=[image1, image2, num_interpolations, duration], outputs=output_gif)
gr.Examples(examples=examples, inputs=[image1, image2, num_interpolations, duration])