import torch import torch.nn as nn import torch.utils.checkpoint from typing import Any, Optional, Tuple, Union class Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, hidden_size, num_attention_heads, attention_head_dim, attention_dropout=0.0): super().__init__() self.embed_dim = hidden_size self.num_heads = num_attention_heads self.head_dim = attention_head_dim self.scale = self.head_dim**-0.5 self.dropout = attention_dropout self.inner_dim = self.head_dim * self.num_heads self.k_proj = nn.Linear(self.embed_dim, self.inner_dim) self.v_proj = nn.Linear(self.embed_dim, self.inner_dim) self.q_proj = nn.Linear(self.embed_dim, self.inner_dim) self.out_proj = nn.Linear(self.inner_dim, self.embed_dim) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, causal_attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" bsz, tgt_len, embed_dim = hidden_states.size() # get query proj query_states = self.q_proj(hidden_states) * self.scale key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) key_states = key_states.view(*proj_shape) value_states = value_states.view(*proj_shape) src_len = key_states.size(1) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): raise ValueError( f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" f" {attn_weights.size()}" ) # apply the causal_attention_mask first if causal_attention_mask is not None: if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" f" {causal_attention_mask.size()}" ) attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) if attention_mask is not None: if attention_mask.size() != (bsz, 1, tgt_len, src_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" ) attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = nn.functional.softmax(attn_weights, dim=-1) if output_attentions: # this operation is a bit akward, but it's required to # make sure that attn_weights keeps its gradient. # In order to do so, attn_weights have to reshaped # twice and have to be reused in the following attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) else: attn_weights_reshaped = None attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attn_output = torch.bmm(attn_probs, value_states) if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, tgt_len, self.inner_dim) attn_output = self.out_proj(attn_output) return attn_output, attn_weights_reshaped class MLP(nn.Module): def __init__(self, hidden_size, intermediate_size, mult=4): super().__init__() self.activation_fn = nn.SiLU() self.fc1 = nn.Linear(hidden_size, intermediate_size * mult) self.fc2 = nn.Linear(intermediate_size * mult, hidden_size) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.fc1(hidden_states) hidden_states = self.activation_fn(hidden_states) hidden_states = self.fc2(hidden_states) return hidden_states class Transformer(nn.Module): def __init__(self, depth=12): super().__init__() self.layers = nn.ModuleList([TransformerBlock() for _ in range(depth)]) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor=None, causal_attention_mask: torch.Tensor=None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.FloatTensor]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(config.encoder_attention_heads,)`. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. """ for layer in self.layers: hidden_states = layer( hidden_states=hidden_states, attention_mask=attention_mask, causal_attention_mask=causal_attention_mask, output_attentions=output_attentions, ) return hidden_states class TransformerBlock(nn.Module): def __init__(self, hidden_size=512, num_attention_heads=12, attention_head_dim=64, attention_dropout=0.0, dropout=0.0, eps=1e-5): super().__init__() self.embed_dim = hidden_size self.self_attn = Attention(hidden_size=hidden_size, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=eps) self.mlp = MLP(hidden_size=hidden_size, intermediate_size=hidden_size) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=eps) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor=None, causal_attention_mask: torch.Tensor=None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.FloatTensor]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(config.encoder_attention_heads,)`. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. """ residual = hidden_states hidden_states = self.layer_norm1(hidden_states) hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, causal_attention_mask=causal_attention_mask, output_attentions=output_attentions, ) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.layer_norm2(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (attn_weights,) return outputs[0] class DiffusionTransformerBlock(nn.Module): def __init__(self, hidden_size=512, num_attention_heads=12, attention_head_dim=64, attention_dropout=0.0, dropout=0.0, eps=1e-5): super().__init__() self.embed_dim = hidden_size self.self_attn = Attention(hidden_size=hidden_size, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=eps) self.mlp = MLP(hidden_size=hidden_size, intermediate_size=hidden_size) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=eps) self.output_token = nn.Parameter(torch.randn(1, hidden_size)) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor=None, causal_attention_mask: torch.Tensor=None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.FloatTensor]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(config.encoder_attention_heads,)`. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. """ output_token = self.output_token.unsqueeze(0).repeat(hidden_states.shape[0], 1, 1) hidden_states = torch.cat([output_token, hidden_states], dim=1) residual = hidden_states hidden_states = self.layer_norm1(hidden_states) hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, causal_attention_mask=causal_attention_mask, output_attentions=output_attentions, ) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.layer_norm2(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (attn_weights,) return outputs[0][:,0:1,...] class V2AMapperMLP(nn.Module): def __init__(self, input_dim=512, output_dim=512, expansion_rate=4): super().__init__() self.linear = nn.Linear(input_dim, input_dim * expansion_rate) self.silu = nn.SiLU() self.layer_norm = nn.LayerNorm(input_dim * expansion_rate) self.linear2 = nn.Linear(input_dim * expansion_rate, output_dim) def forward(self, x): x = self.linear(x) x = self.silu(x) x = self.layer_norm(x) x = self.linear2(x) return x class ImageProjModel(torch.nn.Module): """Projection Model""" def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): super().__init__() self.cross_attention_dim = cross_attention_dim self.clip_extra_context_tokens = clip_extra_context_tokens self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) self.norm = torch.nn.LayerNorm(cross_attention_dim) self.zero_initialize_last_layer() def zero_initialize_last_layer(module): last_layer = None for module_name, layer in module.named_modules(): if isinstance(layer, torch.nn.Linear): last_layer = layer if last_layer is not None: last_layer.weight.data.zero_() last_layer.bias.data.zero_() def forward(self, image_embeds): embeds = image_embeds clip_extra_context_tokens = self.proj(embeds).reshape( -1, self.clip_extra_context_tokens, self.cross_attention_dim ) clip_extra_context_tokens = self.norm(clip_extra_context_tokens) return clip_extra_context_tokens class VisionAudioAdapter(torch.nn.Module): def __init__( self, embedding_size=768, expand_dim=4, token_num=4, ): super().__init__() self.mapper = V2AMapperMLP( embedding_size, embedding_size, expansion_rate=expand_dim, ) self.proj = ImageProjModel( cross_attention_dim=embedding_size, clip_embeddings_dim=embedding_size, clip_extra_context_tokens=token_num, ) def forward(self, image_embeds): image_embeds = self.mapper(image_embeds) image_embeds = self.proj(image_embeds) return image_embeds