from dataclasses import dataclass from pathlib import Path from typing import Any import torch from PIL import Image from refiners.foundationals.clip.concepts import ConceptExtender from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_upscaler import ( MultiUpscaler, UpscalerCheckpoints, ) from esrgan_model import UpscalerESRGAN @dataclass(kw_only=True) class ESRGANUpscalerCheckpoints(UpscalerCheckpoints): esrgan: Path | None = None class ESRGANUpscaler(MultiUpscaler): def __init__( self, checkpoints: ESRGANUpscalerCheckpoints, device: torch.device, dtype: torch.dtype, ) -> None: super().__init__(checkpoints=checkpoints, device=device, dtype=dtype) self.esrgan = self.load_esrgan(checkpoints.esrgan) def to(self, device: torch.device, dtype: torch.dtype): self.esrgan.to(device=device, dtype=dtype) self.sd = self.sd.to(device=device, dtype=dtype) self.device = device self.dtype = dtype def load_esrgan(self, path: Path | None) -> UpscalerESRGAN | None: if path is None: return None return UpscalerESRGAN(path, device=self.device, dtype=self.dtype) def load_negative_embedding(self, path: Path | None, key: str | None) -> str: if path is None: return "" embeddings: torch.Tensor | dict[str, Any] = torch.load( # type: ignore path, weights_only=True, map_location=self.device ) if isinstance(embeddings, dict): assert ( key is not None ), "Key must be provided to access the negative embedding." key_sequence = key.split(".") for key in key_sequence: assert ( key in embeddings ), f"Key {key} not found in the negative embedding dictionary. Available keys: {list(embeddings.keys())}" embeddings = embeddings[key] assert isinstance( embeddings, torch.Tensor ), f"The negative embedding must be a tensor, found {type(embeddings)}." assert ( embeddings.ndim == 2 ), f"The negative embedding must be a 2D tensor, found {embeddings.ndim}D tensor." extender = ConceptExtender(self.sd.clip_text_encoder) negative_embedding_token = ", " for i, embedding in enumerate(embeddings): embedding = embedding.to(device=self.device, dtype=self.dtype) extender.add_concept(token=f"<{i}>", embedding=embedding) negative_embedding_token += f"<{i}> " extender.inject() return negative_embedding_token def pre_upscale( self, image: Image.Image, upscale_factor: float, use_esrgan: bool = True, use_esrgan_tiling: bool = True, **_: Any, ) -> Image.Image: if self.esrgan is None or not use_esrgan: return super().pre_upscale(image=image, upscale_factor=upscale_factor) width, height = image.size if use_esrgan_tiling: image = self.esrgan.upscale_with_tiling(image) else: image = self.esrgan.upscale_without_tiling(image) return image.resize( size=( int(width * upscale_factor), int(height * upscale_factor), ), resample=Image.LANCZOS, )