import torch from typing import Dict, List, Optional, Tuple, Union import functools import fsspec import os import open_clip import torch.nn as nn from functools import partial import clip from einops import rearrange, repeat import kornia import numpy as np from inspect import isfunction from pdb import set_trace as st # from transformers import CLIPTokenizer, CLIPTextModel from ...util import (append_dims, autocast, count_params, default, disabled_train, expand_dims_like, instantiate_from_config) from ..x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test class AbstractEncoder(nn.Module): def __init__(self): super().__init__() def encode(self, *args, **kwargs): raise NotImplementedError class ClassEmbedder(nn.Module): def __init__(self, embed_dim, n_classes=1000, key='class'): super().__init__() self.key = key self.embedding = nn.Embedding(n_classes, embed_dim) def forward(self, batch, key=None): if key is None: key = self.key # this is for use in crossattn c = batch[key][:, None] c = self.embedding(c) return c class TransformerEmbedder(AbstractEncoder): """Some transformer encoder layers""" def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"): super().__init__() self.device = device self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, attn_layers=Encoder(dim=n_embed, depth=n_layer)) def forward(self, tokens): tokens = tokens.to(self.device) # meh z = self.transformer(tokens, return_embeddings=True) return z def encode(self, x): return self(x) class BERTTokenizer(AbstractEncoder): """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" def __init__(self, device="cuda", vq_interface=True, max_length=77): super().__init__() from transformers import BertTokenizerFast # TODO: add to reuquirements self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") self.device = device self.vq_interface = vq_interface self.max_length = max_length def forward(self, text): batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt") tokens = batch_encoding["input_ids"].to(self.device) return tokens @torch.no_grad() def encode(self, text): tokens = self(text) if not self.vq_interface: return tokens return None, None, [None, None, tokens] def decode(self, text): return text class BERTEmbedder(AbstractEncoder): """Uses the BERT tokenizr model and add some transformer encoder layers""" def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, device="cuda",use_tokenizer=True, embedding_dropout=0.0): super().__init__() self.use_tknz_fn = use_tokenizer if self.use_tknz_fn: self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) self.device = device self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, attn_layers=Encoder(dim=n_embed, depth=n_layer), emb_dropout=embedding_dropout) def forward(self, text): if self.use_tknz_fn: tokens = self.tknz_fn(text)#.to(self.device) else: tokens = text z = self.transformer(tokens, return_embeddings=True) return z def encode(self, text): # output of length 77 return self(text) class SpatialRescaler(nn.Module): def __init__(self, n_stages=1, method='bilinear', multiplier=0.5, in_channels=3, out_channels=None, bias=False): super().__init__() self.n_stages = n_stages assert self.n_stages >= 0 assert method in ['nearest','linear','bilinear','trilinear','bicubic','area'] self.multiplier = multiplier self.interpolator = partial(torch.nn.functional.interpolate, mode=method) self.remap_output = out_channels is not None if self.remap_output: print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias) def forward(self,x): for stage in range(self.n_stages): x = self.interpolator(x, scale_factor=self.multiplier) if self.remap_output: x = self.channel_mapper(x) return x def encode(self, x): return self(x) class FrozenCLIPEmbedder(AbstractEncoder): """Uses the CLIP transformer encoder for text (from Hugging Face)""" def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, use_eos_feature=False): super().__init__() from transformers import CLIPTokenizer, CLIPTextModel self.tokenizer = CLIPTokenizer.from_pretrained(version) self.transformer = CLIPTextModel.from_pretrained(version).to(device) self.device = device self.max_length = max_length self.freeze() self.use_eos_feature = use_eos_feature def freeze(self): self.transformer = self.transformer.eval() for param in self.parameters(): param.requires_grad = False def forward(self, text): batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt") tokens = batch_encoding["input_ids"].to(self.device) outputs = self.transformer(input_ids=tokens) if self.use_eos_feature: # for DiT z = outputs.pooler_output # N 77 C else: z = outputs.last_hidden_state # N 77 C return z def encode(self, text): return self(text) class TextEmbedder(nn.Module): """ Embeds text prompt into vector representations. Also handles text dropout for classifier-free guidance. """ def __init__(self, dropout_prob=0.1, use_eos_feature=False): super().__init__() self.text_encodder = FrozenCLIPEmbedder(use_eos_feature=use_eos_feature) # no normalization projection self.dropout_prob = dropout_prob def token_drop(self, text_prompts, force_drop_ids=None): """ Drops text to enable classifier-free guidance. """ if force_drop_ids is None: drop_ids = np.random.uniform(0, 1, len(text_prompts)) < self.dropout_prob else: drop_ids = force_drop_ids == 1 labels = list(np.where(drop_ids, "None", text_prompts)) # print(labels) return labels def forward(self, text_prompts, train, force_drop_ids=None): use_dropout = self.dropout_prob > 0 if (train and use_dropout) or (force_drop_ids is not None): text_prompts = self.token_drop(text_prompts, force_drop_ids) embeddings = self.text_encodder(text_prompts) return embeddings class FrozenCLIPTextEmbedder(nn.Module): """ Uses the CLIP transformer encoder for text. """ def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True, dropout_prob=0., scale_clip_encoding=None): super().__init__() self.model, _ = clip.load(version, jit=False, device=device) self.device = device self.max_length = max_length self.n_repeat = n_repeat self.normalize = normalize self.dropout_prob = dropout_prob self.scale_clip_encoding = scale_clip_encoding def freeze(self): self.model = self.model.eval() for param in self.parameters(): param.requires_grad = False def forward(self, text): tokens = clip.tokenize(text).to(self.device) z = self.model.encode_text(tokens) if self.normalize: z = z / torch.linalg.norm(z, dim=1, keepdim=True) if self.scale_clip_encoding is not None: z = z * self.scale_clip_encoding return z def token_drop(self, text_prompts, force_drop_ids=None): """ Drops text to enable classifier-free guidance. """ if force_drop_ids is None: drop_ids = np.random.uniform(0, 1, len(text_prompts)) < self.dropout_prob else: drop_ids = force_drop_ids == 1 labels = list(np.where(drop_ids, "None", text_prompts)) # print(labels) return labels def encode(self, text): z = self(text) if z.ndim==2: # match cross attention shape z = z[:, None, :] z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat) return z class FrozenClipImageEmbedder(nn.Module): """ Uses the CLIP image encoder. """ def __init__( self, model, jit=False, device='cuda' if torch.cuda.is_available() else 'cpu', antialias=False, n_repeat=1, dropout_prob=0.2, # follow Rodin normalize_encoding=False, scale_clip_encoding=1.0, ): super().__init__() self.model, _ = clip.load(name=model, device=device, jit=jit) self.n_repeat = n_repeat self.normalize_encoding = normalize_encoding self.scale_clip_encoding = torch.tensor(scale_clip_encoding, dtype=torch.float32, device=device) self.antialias = antialias self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) self.dropout_prob = dropout_prob def freeze(self): self.model = self.model.eval() for param in self.parameters(): param.requires_grad = False def preprocess(self, x): # normalize to [0,1] x = kornia.geometry.resize(x, (224, 224), interpolation='bicubic',align_corners=True, antialias=self.antialias) x = (x + 1.) / 2. # renormalize according to clip x = kornia.enhance.normalize(x, self.mean, self.std) # type: ignore return x def token_drop(self, z): """ zero the image encoding to enable classifier-free guidance. """ drop_ids = np.random.uniform(0, 1, z.shape[0]) < self.dropout_prob # idx token to drop drop_ids = torch.from_numpy(drop_ids).unsqueeze(1).expand_as(z).bool().to(z.device) z = torch.where(drop_ids, torch.zeros_like(z), z) return z def forward(self, x): # x is assumed to be in range [-1,1] # return self.model.encode_image(self.preprocess(x)) z = self.model.encode_image(self.preprocess(x)) # ? normalized features, seems not working? if self.normalize_encoding: z = z / torch.linalg.norm(z, dim=1, keepdim=True) if self.scale_clip_encoding: # st() z = z * self.scale_clip_encoding if self.dropout_prob>0: # for cfg z = self.token_drop(z) if z.ndim==2: # repeat 1 dim, for context shape compatability. z = z[:, None, :] z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat) return z class AbstractEmbModel(nn.Module): def __init__(self): super().__init__() self._is_trainable = None self._ucg_rate = None self._input_key = None @property def is_trainable(self) -> bool: return self._is_trainable @property def ucg_rate(self) -> Union[float, torch.Tensor]: return self._ucg_rate @property def input_key(self) -> str: return self._input_key @is_trainable.setter def is_trainable(self, value: bool): self._is_trainable = value @ucg_rate.setter def ucg_rate(self, value: Union[float, torch.Tensor]): self._ucg_rate = value @input_key.setter def input_key(self, value: str): self._input_key = value @is_trainable.deleter def is_trainable(self): del self._is_trainable @ucg_rate.deleter def ucg_rate(self): del self._ucg_rate @input_key.deleter def input_key(self): del self._input_key class FrozenOpenCLIPImageEmbedder(AbstractEmbModel): """ Uses the OpenCLIP vision transformer encoder for images """ def __init__( self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, freeze=True, antialias=True, ucg_rate=0.0, unsqueeze_dim=False, repeat_to_max_len=False, num_image_crops=0, output_tokens=False, init_device=None, ): super().__init__() model, _, _ = open_clip.create_model_and_transforms( arch, device=torch.device(default(init_device, "cpu")), pretrained=version, ) del model.transformer self.model = model self.max_crops = num_image_crops self.pad_to_max_len = self.max_crops > 0 self.repeat_to_max_len = repeat_to_max_len and (not self.pad_to_max_len) self.device = device self.max_length = max_length if freeze: self.freeze() self.antialias = antialias self.register_buffer( "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False ) self.register_buffer( "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False ) self.ucg_rate = ucg_rate self.unsqueeze_dim = unsqueeze_dim self.stored_batch = None self.model.visual.output_tokens = output_tokens self.output_tokens = output_tokens def preprocess(self, x): # normalize to [0,1] x = kornia.geometry.resize( x, (224, 224), interpolation="bicubic", align_corners=True, antialias=self.antialias, ) x = (x + 1.0) / 2.0 # renormalize according to clip x = kornia.enhance.normalize(x, self.mean, self.std) return x def freeze(self): self.model = self.model.eval() for param in self.parameters(): param.requires_grad = False @autocast def forward(self, image, no_dropout=False): z = self.encode_with_vision_transformer(image) tokens = None if self.output_tokens: z, tokens = z[0], z[1] z = z.to(image.dtype) if self.ucg_rate > 0.0 and not no_dropout and not (self.max_crops > 0): z = ( torch.bernoulli( (1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device) )[:, None] * z ) if tokens is not None: tokens = ( expand_dims_like( torch.bernoulli( (1.0 - self.ucg_rate) * torch.ones(tokens.shape[0], device=tokens.device) ), tokens, ) * tokens ) if self.unsqueeze_dim: z = z[:, None, :] if self.output_tokens: assert not self.repeat_to_max_len assert not self.pad_to_max_len return tokens, z if self.repeat_to_max_len: if z.dim() == 2: z_ = z[:, None, :] else: z_ = z return repeat(z_, "b 1 d -> b n d", n=self.max_length), z elif self.pad_to_max_len: assert z.dim() == 3 z_pad = torch.cat( ( z, torch.zeros( z.shape[0], self.max_length - z.shape[1], z.shape[2], device=z.device, ), ), 1, ) return z_pad, z_pad[:, 0, ...] return z def encode_with_vision_transformer(self, img): # if self.max_crops > 0: # img = self.preprocess_by_cropping(img) if img.dim() == 5: assert self.max_crops == img.shape[1] img = rearrange(img, "b n c h w -> (b n) c h w") img = self.preprocess(img) if not self.output_tokens: assert not self.model.visual.output_tokens x = self.model.visual(img) tokens = None else: assert self.model.visual.output_tokens x, tokens = self.model.visual(img) if self.max_crops > 0: x = rearrange(x, "(b n) d -> b n d", n=self.max_crops) # drop out between 0 and all along the sequence axis x = ( torch.bernoulli( (1.0 - self.ucg_rate) * torch.ones(x.shape[0], x.shape[1], 1, device=x.device) ) * x ) if tokens is not None: tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops) print( f"You are running very experimental token-concat in {self.__class__.__name__}. " f"Check what you are doing, and then remove this message." ) if self.output_tokens: return x, tokens return x def encode(self, text): return self(text) class FrozenOpenCLIPImagePredictionEmbedder(AbstractEmbModel): def __init__( self, # open_clip_embedding_config: Dict, n_cond_frames: int, n_copies: int, open_clip_module, ): super().__init__() self.n_cond_frames = n_cond_frames self.n_copies = n_copies # self.open_clip = instantiate_from_config(open_clip_embedding_config) self.open_clip = open_clip_module def forward(self, vid): vid = self.open_clip(vid) vid = rearrange(vid, "(b t) d -> b t d", t=self.n_cond_frames) vid = repeat(vid, "b t d -> (b s) t d", s=self.n_copies) return vid if __name__ == "__main__": from ldm.util import count_params model = FrozenCLIPEmbedder() count_params(model, verbose=True)