from dataclasses import dataclass, field from typing import List, Union, Optional, Tuple from enum import IntEnum import os import cv2 import torch import numpy as np from PIL import Image, ImageDraw, ImageFilter, ImageOps from torchvision.transforms.functional import to_pil_image # import math from diffusers import StableDiffusionInpaintPipeline # from post_process.yoloface.face_detector import YoloDetector MASK_MERGE_INVERT = ["None", "Merge", "Merge and Invert"] def adetailer(sd_pipeline, yolodetector, images: list[Image.Image], prompt, negative_prompt, seed=42): resolution = 512 # ad_model = "post_process/yoloface/weights/yolov5n-face.pt" processed_input_imgs = [] for input_image in images: pred = ultralytics_predict(yolodetector_model=yolodetector, image=input_image) masks = pred_preprocessing(pred) for i_mask, mask in enumerate(masks): # # Only inpaint up to n faces # if i_mask == n: # break blurred_mask = mask.filter(ImageFilter.GaussianBlur(8)) crop_region = get_crop_region(np.array(blurred_mask)) crop_region = expand_crop_region(crop_region, resolution, resolution, mask.width, mask.height) x1, y1, x2, y2 = crop_region paste_to = (x1, y1, x2-x1, y2-y1) image_mask = blurred_mask.crop(crop_region) image_mask = image_mask.resize((resolution, resolution), Image.LANCZOS) image_masked = Image.new('RGBa', (input_image.width, input_image.height)) image_masked.paste(input_image.convert("RGBA"), mask=ImageOps.invert(blurred_mask.convert('L'))) overlay_image = image_masked.convert('RGBA') patch_input_img = input_image.crop(crop_region) patch_input_img = patch_input_img.resize((resolution, resolution), Image.LANCZOS) processed_input_imgs.append([patch_input_img, paste_to, overlay_image]) denoising_strength = 0.4 pipe = StableDiffusionInpaintPipeline( vae=sd_pipeline.vae, text_encoder=sd_pipeline.text_encoder, tokenizer=sd_pipeline.tokenizer, unet=sd_pipeline.unet, scheduler=sd_pipeline.scheduler, requires_safety_checker=False, safety_checker=None, feature_extractor=sd_pipeline.feature_extractor, ).to('cuda') generator = torch.Generator(device="cuda").manual_seed(seed) inpaint_images = [] for i in range(len(processed_input_imgs)): out = pipe( prompt=prompt, negative_prompt=negative_prompt, image=[processed_input_imgs[i][0]], mask_image=image_mask, num_inference_steps=30, strength=denoising_strength, controlnet_conditioning_scale=1.0, generator=generator ).images[0] paste_to = processed_input_imgs[i][1] overlay_image = processed_input_imgs[i][2] input_image = apply_overlay(out, paste_to, overlay_image) inpaint_images.append(input_image) return inpaint_images def get_crop_region(mask, pad=0): """finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle. For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)""" h, w = mask.shape crop_left = 0 for i in range(w): if not (mask[:, i] == 0).all(): break crop_left += 1 crop_right = 0 for i in reversed(range(w)): if not (mask[:, i] == 0).all(): break crop_right += 1 crop_top = 0 for i in range(h): if not (mask[i] == 0).all(): break crop_top += 1 crop_bottom = 0 for i in reversed(range(h)): if not (mask[i] == 0).all(): break crop_bottom += 1 return ( int(max(crop_left-pad, 0)), int(max(crop_top-pad, 0)), int(min(w - crop_right + pad, w)), int(min(h - crop_bottom + pad, h)) ) def expand_crop_region(crop_region, processing_width, processing_height, image_width, image_height): """expands crop region get_crop_region() to match the ratio of the image the region will processed in; returns expanded region for example, if user drew mask in a 128x32 region, and the dimensions for processing are 512x512, the region will be expanded to 128x128.""" x1, y1, x2, y2 = crop_region ratio_crop_region = (x2 - x1) / (y2 - y1) ratio_processing = processing_width / processing_height if ratio_crop_region > ratio_processing: desired_height = (x2 - x1) / ratio_processing desired_height_diff = int(desired_height - (y2-y1)) y1 -= desired_height_diff//2 y2 += desired_height_diff - desired_height_diff//2 if y2 >= image_height: diff = y2 - image_height y2 -= diff y1 -= diff if y1 < 0: y2 -= y1 y1 -= y1 if y2 >= image_height: y2 = image_height else: desired_width = (y2 - y1) * ratio_processing desired_width_diff = int(desired_width - (x2-x1)) x1 -= desired_width_diff//2 x2 += desired_width_diff - desired_width_diff//2 if x2 >= image_width: diff = x2 - image_width x2 -= diff x1 -= diff if x1 < 0: x2 -= x1 x1 -= x1 if x2 >= image_width: x2 = image_width return x1, y1, x2, y2 @dataclass class PredictOutput: bboxes: List[List[Union[int, float]]] = field(default_factory=list) masks: List[Image.Image] = field(default_factory=list) preview: Optional[Image.Image] = None def create_mask_from_bbox( bboxes: List[List[float]], shape: Tuple[int, int] ) -> List[Image.Image]: """ Parameters ---------- bboxes: List[List[float]] list of [x1, y1, x2, y2] bounding boxes shape: Tuple[int, int] shape of the image (width, height) Returns ------- masks: List[Image.Image] A list of masks """ masks = [] for bbox in bboxes: mask = Image.new("L", shape, 0) mask_draw = ImageDraw.Draw(mask) mask_draw.rectangle(bbox, fill=255) masks.append(mask) return masks def ultralytics_predict( # model_path: str, yolodector_model, image: Image.Image, confidence: float = 0.5, device: str = "cuda", ) -> PredictOutput: # model = YoloDetector(target_size=720, device=device, min_face=50) bboxes, _ = yolodector_model.predict(np.array(image), conf_thres=confidence, iou_thres=0.5) masks = create_mask_from_bbox(bboxes[0], image.size) # model = YOLO(model_path) #old # pred = model(image, conf=confidence, device=device) #old # bboxes = pred[0].boxes.xyxy.cpu().numpy() #old # if bboxes.size == 0: # return PredictOutput() # bboxes = bboxes.tolist() # if pred[0].masks is None: #old # masks = create_mask_from_bbox(bboxes, image.size) #old # else: #old # masks = mask_to_pil(pred[0].masks.data, image.size) #old # preview = pred[0].plot() #old # preview = cv2.cvtColor(preview, cv2.COLOR_BGR2RGB) #old # preview = Image.fromarray(preview) #old return PredictOutput(bboxes=bboxes[0], masks=masks, preview=image) def mask_to_pil(masks, shape: Tuple[int, int]) -> List[Image.Image]: """ Parameters ---------- masks: torch.Tensor, dtype=torch.float32, shape=(N, H, W). The device can be CUDA, but `to_pil_image` takes care of that. shape: Tuple[int, int] (width, height) of the original image """ n = masks.shape[0] return [to_pil_image(masks[i], mode="L").resize(shape) for i in range(n)] class MergeInvert(IntEnum): NONE = 0 MERGE = 1 MERGE_INVERT = 2 def offset(img: Image.Image, x: int = 0, y: int = 0) -> Image.Image: """ The offset function takes an image and offsets it by a given x(→) and y(↑) value. Parameters ---------- mask: Image.Image Pass the mask image to the function x: int → y: int ↑ Returns ------- PIL.Image.Image A new image that is offset by x and y """ return ImageChops.offset(img, x, -y) def is_all_black(img: Image.Image) -> bool: arr = np.array(img) return cv2.countNonZero(arr) == 0 def _dilate(arr: np.ndarray, value: int) -> np.ndarray: kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (value, value)) return cv2.dilate(arr, kernel, iterations=1) def _erode(arr: np.ndarray, value: int) -> np.ndarray: kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (value, value)) return cv2.erode(arr, kernel, iterations=1) def dilate_erode(img: Image.Image, value: int) -> Image.Image: """ The dilate_erode function takes an image and a value. If the value is positive, it dilates the image by that amount. If the value is negative, it erodes the image by that amount. Parameters ---------- img: PIL.Image.Image the image to be processed value: int kernel size of dilation or erosion Returns ------- PIL.Image.Image The image that has been dilated or eroded """ if value == 0: return img arr = np.array(img) arr = _dilate(arr, value) if value > 0 else _erode(arr, -value) return Image.fromarray(arr) def mask_preprocess( masks: List[Image.Image], kernel: int = 0, x_offset: int = 0, y_offset: int = 0, merge_invert: Union[int, 'MergeInvert', str] = MergeInvert.NONE, ) -> List[Image.Image]: """ The mask_preprocess function takes a list of masks and preprocesses them. It dilates and erodes the masks, and offsets them by x_offset and y_offset. Parameters ---------- masks: List[Image.Image] A list of masks kernel: int kernel size of dilation or erosion x_offset: int → y_offset: int ↑ Returns ------- List[Image.Image] A list of processed masks """ if not masks: return [] if x_offset != 0 or y_offset != 0: masks = [offset(m, x_offset, y_offset) for m in masks] if kernel != 0: masks = [dilate_erode(m, kernel) for m in masks] masks = [m for m in masks if not is_all_black(m)] return mask_merge_invert(masks, mode=merge_invert) def mask_merge_invert( masks: List[Image.Image], mode: Union[int, 'MergeInvert', str] ) -> List[Image.Image]: if isinstance(mode, str): mode = MASK_MERGE_INVERT.index(mode) if mode == MergeInvert.NONE or not masks: return masks if mode == MergeInvert.MERGE: return mask_merge(masks) if mode == MergeInvert.MERGE_INVERT: merged = mask_merge(masks) return mask_invert(merged) raise RuntimeError def bbox_area(bbox: List[float]): return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) def filter_by_ratio(pred: PredictOutput, low: float, high: float) -> PredictOutput: def is_in_ratio(bbox: List[float], low: float, high: float, orig_area: int) -> bool: area = bbox_area(bbox) return low <= area / orig_area <= high if not pred.bboxes: return pred w, h = pred.preview.size orig_area = w * h items = len(pred.bboxes) idx = [i for i in range(items) if is_in_ratio(pred.bboxes[i], low, high, orig_area)] pred.bboxes = [pred.bboxes[i] for i in idx] pred.masks = [pred.masks[i] for i in idx] return pred class SortBy(IntEnum): NONE = 0 LEFT_TO_RIGHT = 1 CENTER_TO_EDGE = 2 AREA = 3 # Bbox sorting def _key_left_to_right(bbox: List[float]) -> float: """ Left to right Parameters ---------- bbox: list[float] list of [x1, y1, x2, y2] """ return bbox[0] def _key_center_to_edge(bbox: List[float], *, center: Tuple[float, float]) -> float: """ Center to edge Parameters ---------- bbox: list[float] list of [x1, y1, x2, y2] image: Image.Image the image """ bbox_center = ((bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2) return dist(center, bbox_center) def _key_area(bbox: List[float]) -> float: """ Large to small Parameters ---------- bbox: list[float] list of [x1, y1, x2, y2] """ return -bbox_area(bbox) def sort_bboxes( pred: PredictOutput, order: Union[int, 'SortBy'] = SortBy.NONE ) -> PredictOutput: if order == SortBy.NONE or len(pred.bboxes) <= 1: return pred if order == SortBy.LEFT_TO_RIGHT: key = _key_left_to_right elif order == SortBy.CENTER_TO_EDGE: width, height = pred.preview.size center = (width / 2, height / 2) key = partial(_key_center_to_edge, center=center) elif order == SortBy.AREA: key = _key_area else: raise RuntimeError items = len(pred.bboxes) idx = sorted(range(items), key=lambda i: key(pred.bboxes[i])) pred.bboxes = [pred.bboxes[i] for i in idx] pred.masks = [pred.masks[i] for i in idx] return pred def filter_k_largest(pred: PredictOutput, k: int = 0) -> PredictOutput: if not pred.bboxes or k == 0: return pred areas = [bbox_area(bbox) for bbox in pred.bboxes] idx = np.argsort(areas)[-k:] pred.bboxes = [pred.bboxes[i] for i in idx] pred.masks = [pred.masks[i] for i in idx] return pred def pred_preprocessing(pred: PredictOutput) -> List[Image.Image]: pred = filter_by_ratio( pred, low=0.0, high=1.0 ) pred = filter_k_largest(pred, k=0) pred = sort_bboxes(pred, SortBy.AREA) return mask_preprocess( pred.masks, kernel=4, x_offset=0, y_offset=0, merge_invert="None", ) def apply_overlay(image, paste_loc, overlay): if overlay is None: return image if paste_loc is not None: x, y, w, h = paste_loc base_image = Image.new('RGBA', (overlay.width, overlay.height)) image = image.resize((w, h), Image.LANCZOS) base_image.paste(image, (x, y)) image = base_image image = image.convert('RGBA') image.alpha_composite(overlay) image = image.convert('RGB') return image