Kurokabe's picture
Upload 84 files
3be620b
raw
history blame
No virus
4.18 kB
from tensorflow.keras import layers
from tensorflow.keras import Model
import tensorflow as tf
from transformers import TFPreTrainedModel
valid_types = ["gpt2", "gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl"]
class Transformer(Model):
def __init__(self, config):
super().__init__()
self.config = config
self.remaining_frames_method = self.get_remaining_frames_method(config)
self.transformer_type = self.get_transformer_type(config)
self.transformer = self.load_transformer(
self.remaining_frames_method, self.transformer_type
)
def get_transformer_type(self, config):
if "transformer_type" in config:
transformer_type = config["transformer_type"]
if transformer_type not in valid_types:
raise ValueError(
f"transformer_type {transformer_type} is not valid. Valid types are {valid_types}"
)
return transformer_type
else:
return valid_types[0]
def get_remaining_frames_method(self, config) -> str:
"""Get the method to use for remaining frames.
Check if the method is set inside the configuration, otherwise use concat as the default.
"""
if "remaining_frames_method" in config:
return config["remaining_frames_method"]
else:
return "concat"
def load_transformer(self, method: str, transformer_type: str) -> TFPreTrainedModel:
print("using method ", method)
if method == "own_embeddings":
from ganime.model.vqgan_clean.experimental.gpt2_embedding import (
TFGPT2LMHeadModel,
)
transformer = TFGPT2LMHeadModel.from_pretrained(transformer_type)
else:
from transformers import TFGPT2LMHeadModel
transformer = TFGPT2LMHeadModel.from_pretrained(transformer_type)
return transformer
def concatenate_inputs(
self, remaining_frames, last_frame_indices, previous_frame_indices
) -> tf.Tensor:
if self.remaining_frames_method == "concat":
return tf.concat(
[remaining_frames, last_frame_indices, previous_frame_indices], axis=1
)
else:
return tf.concat([last_frame_indices, previous_frame_indices], axis=1)
def call_transformer(
self, transformer_input, remaining_frames, training, attention_mask
):
if self.remaining_frames_method == "concat":
return self.transformer(
transformer_input, training=training, attention_mask=attention_mask
)
elif self.remaining_frames_method == "token_type_ids":
return self.transformer(
transformer_input,
token_type_ids=remaining_frames,
training=training,
attention_mask=attention_mask,
)
elif self.remaining_frames_method == "own_embeddings":
return self.transformer(
transformer_input,
remaining_frames_ids=remaining_frames,
training=training,
attention_mask=attention_mask,
)
else:
raise ValueError(
f"Unknown remaining_frames_method {self.remaining_frames_method}"
)
def call(self, inputs, training=True, mask=None):
remaining_frames, last_frame_indices, previous_frame_indices = inputs
remaining_frames = tf.expand_dims(remaining_frames, axis=1)
shape_to_keep = tf.shape(last_frame_indices)[1]
h = self.concatenate_inputs(
remaining_frames, last_frame_indices, previous_frame_indices
)
# transformer_input = h[:, :-1]
transformer_input = h
mask = tf.ones_like(transformer_input) * tf.cast(
tf.cast(remaining_frames, dtype=tf.bool), dtype=remaining_frames.dtype
)
h = self.call_transformer(transformer_input, remaining_frames, training, mask)
h = h.logits
# h = self.transformer.transformer.wte(h, mode="linear")
h = h[:, -shape_to_keep:]
return h