BioMike's picture
Upload 9 files
5a9c9b2 verified
raw
history blame
No virus
2.04 kB
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import gradio as gr
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transform1 = transforms.Compose([
transforms.Resize((128, 128)), # Resize the image to 128x128 for the model
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
transform2 = transforms.Compose([
transforms.Resize((512, 512)) # Resize the image to 512x512 for display
])
def load_image(image):
image = Image.fromarray(image).convert('RGB')
image = transform1(image)
return image.unsqueeze(0).to(device)
def infer_image(image, noise_level):
image = load_image(image)
with torch.no_grad():
mu, logvar = model.encode(image)
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std) * noise_level
z = mu + eps * std
decoded_image = model.decode(z)
decoded_image = decoded_image.squeeze().permute(1, 2, 0).cpu().numpy().astype(np.float32) * 0.5 + 0.5
decoded_image = np.clip(decoded_image, 0, 1)
decoded_image = Image.fromarray((decoded_image * 255).astype(np.uint8))
decoded_image = transform2(decoded_image)
return np.array(decoded_image)
examples = [
["example_images/image1.jpg", 0.1],
["example_images/image2.png", 0.5],
["example_images/image3.jpg", 1.0],
]
with gr.Blocks() as vae:
noise_slider = gr.Slider(0, 10, value=0.01, step=0.01, label="Noise Level")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Upload an image", type="numpy")
with gr.Column():
output_image = gr.Image(label="Reconstructed Image")
input_image.change(fn=infer_image, inputs=[input_image, noise_slider], outputs=output_image)
noise_slider.change(fn=infer_image, inputs=[input_image, noise_slider], outputs=output_image)
gr.Examples(examples=examples, inputs=[input_image, noise_slider])