remove_eye_glass / util /prepare_images.py
shiva46's picture
added
5393d64
raw
history blame contribute delete
No virus
4.99 kB
import copy
import glob
import os
from multiprocessing.dummy import Pool as ThreadPool
from PIL import Image
from torchvision.transforms.functional import to_tensor
from Models import *
class ImageSplitter:
# key points:
# Boarder padding and over-lapping img splitting to avoid the instability of edge value
# Thanks Waifu2x's autorh nagadomi for suggestions (https://github.com/nagadomi/waifu2x/issues/238)
def __init__(self, seg_size=48, scale_factor=2, boarder_pad_size=3):
self.seg_size = seg_size
self.scale_factor = scale_factor
self.pad_size = boarder_pad_size
self.height = 0
self.width = 0
self.upsampler = nn.Upsample(scale_factor=scale_factor, mode='bilinear')
def split_img_tensor(self, pil_img, scale_method=Image.BILINEAR, img_pad=0):
# resize image and convert them into tensor
img_tensor = to_tensor(pil_img).unsqueeze(0)
img_tensor = nn.ReplicationPad2d(self.pad_size)(img_tensor)
batch, channel, height, width = img_tensor.size()
self.height = height
self.width = width
if scale_method is not None:
img_up = pil_img.resize((2 * pil_img.size[0], 2 * pil_img.size[1]), scale_method)
img_up = to_tensor(img_up).unsqueeze(0)
img_up = nn.ReplicationPad2d(self.pad_size * self.scale_factor)(img_up)
patch_box = []
# avoid the residual part is smaller than the padded size
if height % self.seg_size < self.pad_size or width % self.seg_size < self.pad_size:
self.seg_size += self.scale_factor * self.pad_size
# split image into over-lapping pieces
for i in range(self.pad_size, height, self.seg_size):
for j in range(self.pad_size, width, self.seg_size):
part = img_tensor[:, :,
(i - self.pad_size):min(i + self.pad_size + self.seg_size, height),
(j - self.pad_size):min(j + self.pad_size + self.seg_size, width)]
if img_pad > 0:
part = nn.ZeroPad2d(img_pad)(part)
if scale_method is not None:
# part_up = self.upsampler(part)
part_up = img_up[:, :,
self.scale_factor * (i - self.pad_size):min(i + self.pad_size + self.seg_size,
height) * self.scale_factor,
self.scale_factor * (j - self.pad_size):min(j + self.pad_size + self.seg_size,
width) * self.scale_factor]
patch_box.append((part, part_up))
else:
patch_box.append(part)
return patch_box
def merge_img_tensor(self, list_img_tensor):
out = torch.zeros((1, 3, self.height * self.scale_factor, self.width * self.scale_factor))
img_tensors = copy.copy(list_img_tensor)
rem = self.pad_size * 2
pad_size = self.scale_factor * self.pad_size
seg_size = self.scale_factor * self.seg_size
height = self.scale_factor * self.height
width = self.scale_factor * self.width
for i in range(pad_size, height, seg_size):
for j in range(pad_size, width, seg_size):
part = img_tensors.pop(0)
part = part[:, :, rem:-rem, rem:-rem]
# might have error
if len(part.size()) > 3:
_, _, p_h, p_w = part.size()
out[:, :, i:i + p_h, j:j + p_w] = part
# out[:,:,
# self.scale_factor*i:self.scale_factor*i+p_h,
# self.scale_factor*j:self.scale_factor*j+p_w] = part
out = out[:, :, rem:-rem, rem:-rem]
return out
def load_single_image(img_file,
up_scale=False,
up_scale_factor=2,
up_scale_method=Image.BILINEAR,
zero_padding=False):
img = Image.open(img_file).convert("RGB")
out = to_tensor(img).unsqueeze(0)
if zero_padding:
out = nn.ZeroPad2d(zero_padding)(out)
if up_scale:
size = tuple(map(lambda x: x * up_scale_factor, img.size))
img_up = img.resize(size, up_scale_method)
img_up = to_tensor(img_up).unsqueeze(0)
out = (out, img_up)
return out
def standardize_img_format(img_folder):
def process(img_file):
img_path = os.path.dirname(img_file)
img_name, _ = os.path.basename(img_file).split(".")
out = os.path.join(img_path, img_name + ".JPEG")
os.rename(img_file, out)
list_imgs = []
for i in ['png', "jpeg", 'jpg']:
list_imgs.extend(glob.glob(img_folder + "**/*." + i, recursive=True))
print("Found {} images.".format(len(list_imgs)))
pool = ThreadPool(4)
pool.map(process, list_imgs)
pool.close()
pool.join()