# Copyright 2024 Vchitect/Latte # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License.# Modified from Latte # # This file is adapted from the Latte project. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # -------------------------------------------------------- # References: # Latte: https://github.com/Vchitect/Latte # DiT: https://github.com/facebookresearch/DiT/tree/main # -------------------------------------------------------- import torch import torch.nn as nn import transformers from transformers import CLIPTextModel, CLIPTokenizer from opensora.registry import MODELS transformers.logging.set_verbosity_error() class AbstractEncoder(nn.Module): def __init__(self): super().__init__() def encode(self, *args, **kwargs): raise NotImplementedError class FrozenCLIPEmbedder(AbstractEncoder): """Uses the CLIP transformer encoder for text (from Hugging Face)""" def __init__(self, path="openai/clip-vit-huge-patch14", device="cuda", max_length=77): super().__init__() self.tokenizer = CLIPTokenizer.from_pretrained(path) self.transformer = CLIPTextModel.from_pretrained(path) self.device = device self.max_length = max_length self._freeze() def _freeze(self): self.transformer = self.transformer.eval() for param in self.parameters(): param.requires_grad = False def forward(self, text): batch_encoding = self.tokenizer( text, truncation=True, max_length=self.max_length, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt", ) tokens = batch_encoding["input_ids"].to(self.device) outputs = self.transformer(input_ids=tokens) z = outputs.last_hidden_state pooled_z = outputs.pooler_output return z, pooled_z def encode(self, text): return self(text) @MODELS.register_module("clip") class ClipEncoder: """ Embeds text prompt into vector representations. Also handles text dropout for classifier-free guidance. """ def __init__( self, from_pretrained, model_max_length=77, device="cuda", dtype=torch.float, ): super().__init__() assert from_pretrained is not None, "Please specify the path to the T5 model" self.text_encoder = FrozenCLIPEmbedder(path=from_pretrained, max_length=model_max_length).to(device, dtype) self.y_embedder = None self.model_max_length = model_max_length self.output_dim = self.text_encoder.transformer.config.hidden_size def encode(self, text): _, pooled_embeddings = self.text_encoder.encode(text) y = pooled_embeddings.unsqueeze(1).unsqueeze(1) return dict(y=y) def null(self, n): null_y = self.y_embedder.y_embedding[None].repeat(n, 1, 1)[:, None] return null_y def to(self, dtype): self.text_encoder = self.text_encoder.to(dtype) return self