import torch import torch.nn as nn import numpy as np from functools import partial from lib.model_zoo.common.get_model import register symbol = 'clip' class AbstractEncoder(nn.Module): def __init__(self): super().__init__() def encode(self, *args, **kwargs): raise NotImplementedError from transformers import CLIPTokenizer, CLIPTextModel def disabled_train(self, mode=True): """Overwrite model.train with this function to make sure train/eval mode does not change anymore.""" return self @register('clip_text_context_encoder_sdv1') class CLIPTextContextEncoderSDv1(AbstractEncoder): """Uses the CLIP transformer encoder for text (from huggingface)""" def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, freeze=True): # clip-vit-base-patch32 super().__init__() self.tokenizer = CLIPTokenizer.from_pretrained(version) self.transformer = CLIPTextModel.from_pretrained(version) self.device = device self.max_length = max_length if freeze: self.freeze() def freeze(self): self.transformer = self.transformer.eval() for param in self.parameters(): param.requires_grad = False def forward(self, text): with torch.no_grad(): batch_encoding = self.tokenizer( text, truncation=True, max_length=self.max_length, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt") tokens = batch_encoding["input_ids"].to(self.device) max_token_n = self.transformer.text_model.embeddings.position_ids.shape[1] positional_ids = torch.arange(max_token_n)[None].to(self.device) outputs = self.transformer( input_ids=tokens, position_ids=positional_ids, ) z = outputs.last_hidden_state return z def encode(self, text): return self(text) ############################# # copyed from justin's code # ############################# @register('clip_image_context_encoder_justin') class CLIPImageContextEncoderJustin(AbstractEncoder): """ Uses the CLIP image encoder. """ def __init__( self, model='ViT-L/14', jit=False, device='cuda' if torch.cuda.is_available() else 'cpu', antialias=False, ): super().__init__() from . import clip_justin self.model, _ = clip_justin.load(name=model, device=device, jit=jit) self.device = device self.antialias = antialias self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) # I didn't call this originally, but seems like it was frozen anyway self.freeze() def freeze(self): self.transformer = self.model.eval() for param in self.parameters(): param.requires_grad = False def preprocess(self, x): import kornia # Expects inputs in the range -1, 1 x = kornia.geometry.resize(x, (224, 224), interpolation='bicubic',align_corners=True, antialias=self.antialias) x = (x + 1.) / 2. # renormalize according to clip x = kornia.enhance.normalize(x, self.mean, self.std) return x def forward(self, x): # x is assumed to be in range [-1,1] return self.model.encode_image(self.preprocess(x)).float() def encode(self, im): return self(im).unsqueeze(1) ############### # for vd next # ############### from transformers import CLIPModel @register('clip_text_context_encoder') class CLIPTextContextEncoder(AbstractEncoder): def __init__(self, version="openai/clip-vit-large-patch14", max_length=77, fp16=False, ): super().__init__() self.tokenizer = CLIPTokenizer.from_pretrained(version) self.model = CLIPModel.from_pretrained(version) self.max_length = max_length self.fp16 = fp16 self.freeze() def get_device(self): # A trick to get device return self.model.text_projection.weight.device def freeze(self): self.model = self.model.eval() self.train = disabled_train for param in self.parameters(): param.requires_grad = False def encode(self, text): batch_encoding = self.tokenizer( text, truncation=True, max_length=self.max_length, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt") tokens = batch_encoding["input_ids"].to(self.get_device()) outputs = self.model.text_model(input_ids=tokens) z = self.model.text_projection(outputs.last_hidden_state) z_pooled = self.model.text_projection(outputs.pooler_output) z = z / torch.norm(z_pooled.unsqueeze(1), dim=-1, keepdim=True) return z from transformers import CLIPProcessor @register('clip_image_context_encoder') class CLIPImageContextEncoder(AbstractEncoder): def __init__(self, version="openai/clip-vit-large-patch14", fp16=False, ): super().__init__() self.tokenizer = CLIPTokenizer.from_pretrained(version) self.processor = CLIPProcessor.from_pretrained(version) self.model = CLIPModel.from_pretrained(version) self.fp16 = fp16 self.freeze() def get_device(self): # A trick to get device return self.model.text_projection.weight.device def freeze(self): self.model = self.model.eval() self.train = disabled_train for param in self.parameters(): param.requires_grad = False def _encode(self, images): if isinstance(images, torch.Tensor): import torchvision.transforms as tvtrans images = [tvtrans.ToPILImage()(i) for i in images] inputs = self.processor(images=images, return_tensors="pt") pixels = inputs['pixel_values'].half() if self.fp16 else inputs['pixel_values'] pixels = pixels.to(self.get_device()) outputs = self.model.vision_model(pixel_values=pixels) z = outputs.last_hidden_state z = self.model.vision_model.post_layernorm(z) z = self.model.visual_projection(z) z_pooled = z[:, 0:1] z = z / torch.norm(z_pooled, dim=-1, keepdim=True) return z @torch.no_grad() def _encode_wmask(self, images, masks): assert isinstance(masks, torch.Tensor) assert (len(masks.shape)==4) and (masks.shape[1]==1) masks = torch.clamp(masks, 0, 1) masked_images = images*masks masks = masks.float() masks = F.interpolate(masks, [224, 224], mode='bilinear') if masks.sum() == masks.numel(): return self._encode(images) device = images.device dtype = images.dtype gscale = masks.mean(axis=[1, 2, 3], keepdim=True).flatten(2) vtoken_kernel_size = self.model.vision_model.embeddings.patch_embedding.kernel_size vtoken_stride = self.model.vision_model.embeddings.patch_embedding.stride mask_kernal = torch.ones([1, 1, *vtoken_kernel_size], device=device, requires_grad=False).float() vtoken_mask = torch.nn.functional.conv2d(masks, mask_kernal, stride=vtoken_stride).flatten(2).transpose(1, 2) vtoken_mask = vtoken_mask/np.prod(vtoken_kernel_size) vtoken_mask = torch.concat([gscale, vtoken_mask], axis=1) import types def customized_embedding_forward(self, pixel_values): batch_size = pixel_values.shape[0] patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] patch_embeds = patch_embeds.flatten(2).transpose(1, 2) class_embeds = self.class_embedding.expand(batch_size, 1, -1) embeddings = torch.cat([class_embeds, patch_embeds], dim=1) embeddings = embeddings + self.position_embedding(self.position_ids) embeddings = embeddings*vtoken_mask.to(embeddings.dtype) return embeddings old_forward = self.model.vision_model.embeddings.forward self.model.vision_model.embeddings.forward = types.MethodType( customized_embedding_forward, self.model.vision_model.embeddings) z = self._encode(images) self.model.vision_model.embeddings.forward = old_forward z = z * vtoken_mask.to(dtype) return z # def _encode_wmask(self, images, masks): # assert isinstance(masks, torch.Tensor) # assert (len(masks.shape)==4) and (masks.shape[1]==1) # masks = torch.clamp(masks, 0, 1) # masks = masks.float() # masks = F.interpolate(masks, [224, 224], mode='bilinear') # if masks.sum() == masks.numel(): # return self._encode(images) # device = images.device # dtype = images.dtype # vtoken_kernel_size = self.model.vision_model.embeddings.patch_embedding.kernel_size # vtoken_stride = self.model.vision_model.embeddings.patch_embedding.stride # mask_kernal = torch.ones([1, 1, *vtoken_kernel_size], device=device, requires_grad=False).float() # vtoken_mask = torch.nn.functional.conv2d(masks, mask_kernal, stride=vtoken_stride).flatten(2).transpose(1, 2) # vtoken_mask = vtoken_mask/np.prod(vtoken_kernel_size) # z = self._encode(images) # z[:, 1:, :] = z[:, 1:, :] * vtoken_mask.to(dtype) # z[:, 0, :] = 0 # return z def encode(self, images, masks=None): if masks is None: return self._encode(images) else: return self._encode_wmask(images, masks) @register('clip_image_context_encoder_position_agnostic') class CLIPImageContextEncoderPA(CLIPImageContextEncoder): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) import types def customized_embedding_forward(self, pixel_values): batch_size = pixel_values.shape[0] patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] patch_embeds = patch_embeds.flatten(2).transpose(1, 2) class_embeds = self.class_embedding.expand(batch_size, 1, -1) embeddings = torch.cat([class_embeds, patch_embeds], dim=1) pembeddings = self.position_embedding(self.position_ids) pembeddings = torch.cat([ pembeddings[:, 0:1], pembeddings[:, 1: ].mean(dim=1, keepdim=True).repeat(1, 256, 1)], dim=1) embeddings = embeddings + pembeddings return embeddings self.model.vision_model.embeddings.forward = types.MethodType( customized_embedding_forward, self.model.vision_model.embeddings) ############## # from sd2.0 # ############## import open_clip import torch.nn.functional as F @register('openclip_text_context_encoder_sdv2') class FrozenOpenCLIPTextEmbedderSDv2(AbstractEncoder): """ Uses the OpenCLIP transformer encoder for text """ LAYERS = [ #"pooled", "last", "penultimate" ] def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, freeze=True, layer="last"): super().__init__() assert layer in self.LAYERS model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version) del model.visual self.model = model self.device = device self.max_length = max_length if freeze: self.freeze() self.layer = layer if self.layer == "last": self.layer_idx = 0 elif self.layer == "penultimate": self.layer_idx = 1 else: raise NotImplementedError() def freeze(self): self.model = self.model.eval() for param in self.parameters(): param.requires_grad = False def forward(self, text): tokens = open_clip.tokenize(text) z = self.encode_with_transformer(tokens.to(self.device)) return z def encode_with_transformer(self, text): x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] x = x + self.model.positional_embedding x = x.permute(1, 0, 2) # NLD -> LND x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) x = x.permute(1, 0, 2) # LND -> NLD x = self.model.ln_final(x) return x def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): for i, r in enumerate(self.model.transformer.resblocks): if i == len(self.model.transformer.resblocks) - self.layer_idx: break if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint(r, x, attn_mask) else: x = r(x, attn_mask=attn_mask) return x def encode(self, text): return self(text) @register('openclip_text_context_encoder') class FrozenOpenCLIPTextEmbedder(AbstractEncoder): """ Uses the OpenCLIP transformer encoder for text """ def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", max_length=77, freeze=True,): super().__init__() model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version) del model.visual self.model = model self.max_length = max_length self.device = 'cpu' if freeze: self.freeze() def to(self, device): self.device = device super().to(device) def freeze(self): self.model = self.model.eval() for param in self.parameters(): param.requires_grad = False def forward(self, text): self.device = self.model.ln_final.weight.device # urgly trick tokens = open_clip.tokenize(text) z = self.encode_with_transformer(tokens.to(self.device)) return z def encode_with_transformer(self, text): x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] x = x + self.model.positional_embedding x = x.permute(1, 0, 2) # NLD -> LND x = self.model.transformer(x, attn_mask=self.model.attn_mask) x = x.permute(1, 0, 2) # LND -> NLD x = self.model.ln_final(x) x_pool = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.model.text_projection # x_pool_debug = F.normalize(x_pool, dim=-1) x = x @ self.model.text_projection x = x / x_pool.norm(dim=1, keepdim=True).unsqueeze(1) return x def encode(self, text): return self(text) @register('openclip_image_context_encoder') class FrozenOpenCLIPImageEmbedder(AbstractEncoder): """ Uses the OpenCLIP transformer encoder for text """ def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", freeze=True,): super().__init__() model, _, preprocess = open_clip.create_model_and_transforms( arch, device=torch.device('cpu'), pretrained=version) self.model = model.visual self.device = 'cpu' import torchvision.transforms as tvtrans # we only need resize & normalization preprocess.transforms[0].size = [224, 224] # make it more precise self.preprocess = tvtrans.Compose([ preprocess.transforms[0], preprocess.transforms[4],]) if freeze: self.freeze() def to(self, device): self.device = device super().to(device) def freeze(self): self.model = self.model.eval() for param in self.parameters(): param.requires_grad = False def forward(self, image): z = self.preprocess(image) z = self.encode_with_transformer(z) return z def encode_with_transformer(self, image): x = self.model.conv1(image) x = x.reshape(x.shape[0], x.shape[1], -1) x = x.permute(0, 2, 1) x = torch.cat([ self.model.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) x = x + self.model.positional_embedding.to(x.dtype) x = self.model.ln_pre(x) x = x.permute(1, 0, 2) x = self.model.transformer(x) x = x.permute(1, 0, 2) x = self.model.ln_post(x) if self.model.proj is not None: x = x @ self.model.proj x_pool = x[:, 0, :] # x_pool_debug = self.model(image) # x_pooln_debug = F.normalize(x_pool_debug, dim=-1) x = x / x_pool.norm(dim=1, keepdim=True).unsqueeze(1) return x def _encode(self, image): return self(image) def _encode_wmask(self, images, masks): z = self._encode(images) device = z.device vtoken_kernel_size = self.model.conv1.kernel_size vtoken_stride = self.model.conv1.stride mask_kernal = torch.ones([1, 1, *vtoken_kernel_size], device=device, dtype=z.dtype, requires_grad=False) mask_kernal /= np.prod(vtoken_kernel_size) assert isinstance(masks, torch.Tensor) assert (len(masks.shape)==4) and (masks.shape[1]==1) masks = torch.clamp(masks, 0, 1) masks = F.interpolate(masks, [224, 224], mode='bilinear') vtoken_mask = torch.nn.functional.conv2d(1-masks, mask_kernal, stride=vtoken_stride).flatten(2).transpose(1, 2) z[:, 1:, :] = z[:, 1:, :] * vtoken_mask z[:, 0, :] = 0 return z def encode(self, images, masks=None): if masks is None: return self._encode(images) else: return self._encode_wmask(images, masks) ############################ # def customized tokenizer # ############################ from open_clip import SimpleTokenizer @register('openclip_text_context_encoder_sdv2_customized_tokenizer_v1') class FrozenOpenCLIPEmbedderSDv2CustomizedTokenizerV1(FrozenOpenCLIPTextEmbedderSDv2): """ Uses the OpenCLIP transformer encoder for text """ def __init__(self, customized_tokens, *args, **kwargs): super().__init__(*args, **kwargs) if isinstance(customized_tokens, str): customized_tokens = [customized_tokens] self.tokenizer = open_clip.SimpleTokenizer(special_tokens=customized_tokens) self.num_regular_tokens = self.model.token_embedding.weight.shape[0] self.embedding_dim = self.model.ln_final.weight.shape[0] self.customized_token_embedding = nn.Embedding( len(customized_tokens), embedding_dim=self.embedding_dim) nn.init.normal_(self.customized_token_embedding.weight, std=0.02) def tokenize(self, texts): if isinstance(texts, str): texts = [texts] sot_token = self.tokenizer.encoder[""] eot_token = self.tokenizer.encoder[""] all_tokens = [[sot_token] + self.tokenizer.encode(text) + [eot_token] for text in texts] maxn = self.num_regular_tokens regular_tokens = [[ti if ti < maxn else 0 for ti in tokens] for tokens in all_tokens] token_mask = [[0 if ti < maxn else 1 for ti in tokens] for tokens in all_tokens] customized_tokens = [[ti-maxn if ti >= maxn else 0 for ti in tokens] for tokens in all_tokens] return regular_tokens, customized_tokens, token_mask def pad_to_length(self, tokens, context_length=77, eot_token=None): result = torch.zeros(len(tokens), context_length, dtype=torch.long) eot_token = self.tokenizer.encoder[""] if eot_token is None else eot_token for i, tokens in enumerate(tokens): if len(tokens) > context_length: tokens = tokens[:context_length] # Truncate tokens[-1] = eot_token result[i, :len(tokens)] = torch.tensor(tokens) return result def forward(self, text): self.device = self.model.ln_final.weight.device # urgly trick regular_tokens, customized_tokens, token_mask = self.tokenize(text) regular_tokens = self.pad_to_length(regular_tokens).to(self.device) customized_tokens = self.pad_to_length(customized_tokens, eot_token=0).to(self.device) token_mask = self.pad_to_length(token_mask, eot_token=0).to(self.device) z0 = self.encode_with_transformer(regular_tokens) z1 = self.customized_token_embedding(customized_tokens) token_mask = token_mask[:, :, None].type(z0.dtype) z = z0 * (1-token_mask) + z1 * token_mask return z @register('openclip_text_context_encoder_sdv2_customized_tokenizer_v2') class FrozenOpenCLIPEmbedderSDv2CustomizedTokenizerV2(FrozenOpenCLIPTextEmbedderSDv2): """ Uses the OpenCLIP transformer encoder for text """ def __init__(self, customized_tokens, *args, **kwargs): super().__init__(*args, **kwargs) if isinstance(customized_tokens, str): customized_tokens = [customized_tokens] self.tokenizer = open_clip.SimpleTokenizer(special_tokens=customized_tokens) self.num_regular_tokens = self.model.token_embedding.weight.shape[0] self.embedding_dim = self.model.token_embedding.weight.shape[1] self.customized_token_embedding = nn.Embedding( len(customized_tokens), embedding_dim=self.embedding_dim) nn.init.normal_(self.customized_token_embedding.weight, std=0.02) def tokenize(self, texts): if isinstance(texts, str): texts = [texts] sot_token = self.tokenizer.encoder[""] eot_token = self.tokenizer.encoder[""] all_tokens = [[sot_token] + self.tokenizer.encode(text) + [eot_token] for text in texts] maxn = self.num_regular_tokens regular_tokens = [[ti if ti < maxn else 0 for ti in tokens] for tokens in all_tokens] token_mask = [[0 if ti < maxn else 1 for ti in tokens] for tokens in all_tokens] customized_tokens = [[ti-maxn if ti >= maxn else 0 for ti in tokens] for tokens in all_tokens] return regular_tokens, customized_tokens, token_mask def pad_to_length(self, tokens, context_length=77, eot_token=None): result = torch.zeros(len(tokens), context_length, dtype=torch.long) eot_token = self.tokenizer.encoder[""] if eot_token is None else eot_token for i, tokens in enumerate(tokens): if len(tokens) > context_length: tokens = tokens[:context_length] # Truncate tokens[-1] = eot_token result[i, :len(tokens)] = torch.tensor(tokens) return result def forward(self, text): self.device = self.model.token_embedding.weight.device # urgly trick regular_tokens, customized_tokens, token_mask = self.tokenize(text) regular_tokens = self.pad_to_length(regular_tokens).to(self.device) customized_tokens = self.pad_to_length(customized_tokens, eot_token=0).to(self.device) token_mask = self.pad_to_length(token_mask, eot_token=0).to(self.device) z = self.encode_with_transformer(regular_tokens, customized_tokens, token_mask) return z def encode_with_transformer(self, token, customized_token, token_mask): x0 = self.model.token_embedding(token) x1 = self.customized_token_embedding(customized_token) token_mask = token_mask[:, :, None].type(x0.dtype) x = x0 * (1-token_mask) + x1 * token_mask x = x + self.model.positional_embedding x = x.permute(1, 0, 2) # NLD -> LND x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) x = x.permute(1, 0, 2) # LND -> NLD x = self.model.ln_final(x) return x class ln_freezed_temp(nn.LayerNorm): def forward(self, x): self.weight.requires_grad = False self.bias.requires_grad = False return super().forward(x) @register('openclip_text_context_encoder_sdv2_customized_tokenizer_v3') class FrozenOpenCLIPEmbedderSDv2CustomizedTokenizerV3(FrozenOpenCLIPEmbedderSDv2CustomizedTokenizerV2): """ Uses the OpenCLIP transformer encoder for text """ def __init__(self, customized_tokens, texpand=4, lora_rank=None, lora_bias_trainable=True, *args, **kwargs): super().__init__(customized_tokens, *args, **kwargs) if isinstance(customized_tokens, str): customized_tokens = [customized_tokens] self.texpand = texpand self.customized_token_embedding = nn.Embedding( len(customized_tokens)*texpand, embedding_dim=self.embedding_dim) nn.init.normal_(self.customized_token_embedding.weight, std=0.02) if lora_rank is not None: from .lora import freeze_param, freeze_module, to_lora def convert_resattnblock(module): module.ln_1.__class__ = ln_freezed_temp # freeze_module(module.ln_1) module.attn = to_lora(module.attn, lora_rank, lora_bias_trainable) module.ln_2.__class__ = ln_freezed_temp # freeze_module(module.ln_2) module.mlp.c_fc = to_lora(module.mlp.c_fc, lora_rank, lora_bias_trainable) module.mlp.c_proj = to_lora(module.mlp.c_proj, lora_rank, lora_bias_trainable) freeze_param(self.model, 'positional_embedding') freeze_param(self.model, 'text_projection') freeze_param(self.model, 'logit_scale') for idx, resattnblock in enumerate(self.model.transformer.resblocks): convert_resattnblock(resattnblock) freeze_module(self.model.token_embedding) self.model.ln_final.__class__ = ln_freezed_temp # freeze_module(self.model.ln_final) def tokenize(self, texts): if isinstance(texts, str): texts = [texts] sot_token = self.tokenizer.encoder[""] eot_token = self.tokenizer.encoder[""] all_tokens = [[sot_token] + self.tokenizer.encode(text) + [eot_token] for text in texts] maxn = self.num_regular_tokens regular_tokens = [[[ti] if ti < maxn else [0]*self.texpand for ti in tokens] for tokens in all_tokens] token_mask = [[[ 0] if ti < maxn else [1]*self.texpand for ti in tokens] for tokens in all_tokens] custom_tokens = [[[ 0] if ti < maxn else [ (ti-maxn)*self.texpand+ii for ii in range(self.texpand)] for ti in tokens] for tokens in all_tokens] from itertools import chain regular_tokens = [[i for i in chain(*tokens)] for tokens in regular_tokens] token_mask = [[i for i in chain(*tokens)] for tokens in token_mask] custom_tokens = [[i for i in chain(*tokens)] for tokens in custom_tokens] return regular_tokens, custom_tokens, token_mask ################### # clip expandable # ################### @register('clip_text_sdv1_customized_embedding') class CLIPTextSD1CE(nn.Module): def __init__( self, replace_info="text|elon musk", version="openai/clip-vit-large-patch14", max_length=77): super().__init__() self.name = 'clip_text_sdv1_customized_embedding' self.tokenizer = CLIPTokenizer.from_pretrained(version) self.transformer = CLIPTextModel.from_pretrained(version) self.reset_replace_info(replace_info) self.max_length = max_length self.special_token = "" def reset_replace_info(self, replace_info): rtype, rpara = replace_info.split("|") self.replace_type = rtype if rtype == "token_embedding": ce_num = int(rpara) ce_dim = self.transformer.text_model.embeddings.token_embedding.weight.size(1) self.cembedding = nn.Embedding(ce_num, ce_dim) self.cembedding = self.cembedding.to(self.get_device()) elif rtype == "context_embedding": ce_num = int(rpara) ce_dim = self.transformer.text_model.encoder.layers[-1].layer_norm2.weight.size(0) self.cembedding = nn.Embedding(ce_num, ce_dim) self.cembedding = self.cembedding.to(self.get_device()) else: assert rtype=="text" self.replace_type = "text" self.replace_string = rpara self.cembedding = None def get_device(self): return self.transformer.text_model.embeddings.token_embedding.weight.device def position_to_mask(self, tokens, positions): mask = torch.zeros_like(tokens) for idxb, idxs, idxe in zip(*positions): mask[idxb, idxs:idxe] = 1 return mask def forward(self, text): tokens, positions = self.tokenize(text) mask = self.position_to_mask(tokens, positions) max_token_n = tokens.size(1) positional_ids = torch.arange(max_token_n)[None].to(self.get_device()) if self.replace_what == 'token_embedding': cembeds = self.cembedding(tokens * mask) def embedding_customized_forward( self, input_ids=None, position_ids=None, inputs_embeds=None,): seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] if position_ids is None: position_ids = self.position_ids[:, :seq_length] if inputs_embeds is None: inputs_embeds = self.token_embedding(input_ids) inputs_embeds = inputs_embeds * (1-mask.float())[:, :, None] inputs_embeds = inputs_embeds + cembeds position_embeddings = self.position_embedding(position_ids) embeddings = inputs_embeds + position_embeddings return embeddings import types self.transformer.text_model.embeddings.forward = types.MethodType( embedding_customized_forward, self.transformer.text_model.embeddings) else: # TODO: Implement assert False outputs = self.transformer( input_ids=tokens, position_ids=positional_ids, ) z = outputs.last_hidden_state return z def encode(self, text): return self(text) @torch.no_grad() def tokenize(self, text): if isinstance(text, str): text = [text] bos_special_text = "<|startoftext|>" text = [ti.replace(self.special_token, bos_special_text) for ti in text] batch_encoding = self.tokenizer( text, truncation=True, max_length=self.max_length, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt") tokens = batch_encoding["input_ids"] bosid = tokens[0, 0] eosid = tokens[0, -1] bs, maxn = tokens.shape if self.replace_what in ['token_embedding', 'context_embedding']: newtokens = [] ce_num = self.cembedding.weight.size(0) idxi = []; idxstart = []; idxend = []; for idxii, tokeni in enumerate(tokens): newtokeni = [] idxjj = 0 for ii, tokenii in enumerate(tokeni): if (tokenii == bosid) and (ii != 0): newtokeni.extend([i for i in range(ce_num)]) idxi.append(idxii); idxstart.append(idxjj); idxjj += ce_num idxjj_record = idxjj if idxjj<=maxn-1 else maxn-1 idxend.append(idxjj_record); else: newtokeni.extend([tokenii]) idxjj += 1 newtokeni = newtokeni[:maxn] newtokeni[-1] = eosid newtokens.append(newtokeni) return torch.LongTensor(newtokens).to(self.get_device()), (idxi, idxstart, idxend) else: # TODO: Implement assert False