Spaces:
Runtime error
Runtime error
import cv2 | |
import torch | |
import numpy as np | |
from PIL import Image | |
from torchvision import transforms | |
from util.prepare_images import * | |
from torchvision.utils import save_image | |
import os | |
os.environ["LRU_CACHE_CAPACITY"] = "1" | |
def get_potrait(test_image, interpreter,input_details,output_details): | |
# get the potrait mask output | |
im = np.asarray(test_image) | |
h, w, _ = im.shape | |
face_rgba = cv2.cvtColor(im, cv2.COLOR_RGB2BGR) | |
# resize | |
image = cv2.resize(face_rgba, (512, 512), interpolation=cv2.INTER_AREA) | |
# Preprocess the input image | |
test_image = image / 255.0 | |
test_image = np.expand_dims(test_image, axis=0).astype(input_details["dtype"]) | |
# Run the interpreter and get the output | |
interpreter.set_tensor(input_details["index"], test_image) | |
interpreter.invoke() | |
output = interpreter.get_tensor(output_details["index"])[0] | |
# Compute mask from segmentaion output | |
mask = np.reshape(output, (512, 512)) > 0.5 | |
mask = (mask * 255).astype(np.uint8) | |
# resize the mask output | |
bin_mask = cv2.resize(mask, (w, h)) | |
# extract the potrait | |
image = np.dstack((im, bin_mask)) | |
# make background white | |
face = image[:, :, :3].copy() | |
mask = image[:, :, 3].copy()[:, :, np.newaxis] / 255.0 | |
face_white_bg = (face * mask + (1 - mask) * 255).astype(np.uint8) | |
# convert image to PIL format | |
mask = Image.fromarray(bin_mask) | |
im = Image.fromarray(face_white_bg) | |
return im, mask | |
def upscale(img, model_cran_v2): | |
# convert pil image to tensor | |
img_t = transforms.ToTensor()(img).unsqueeze(0) | |
# used to compare the origin | |
img = img.resize((img.size[0] // 2, img.size[1] // 2), Image.BICUBIC) | |
img_splitter = ImageSplitter(seg_size=64, scale_factor=2, boarder_pad_size=3) | |
img_patches = img_splitter.split_img_tensor(img, scale_method=None, img_pad=0) | |
with torch.no_grad(): | |
out = [model_cran_v2(i) for i in img_patches] | |
img_upscale = img_splitter.merge_img_tensor(out) | |
save_image(img_upscale, "app/removal.png") | |
return Image.open("app/removal.png") | |