GANime / ganime /trainer /ganime.py
Kurokabe's picture
Upload 84 files
3be620b
raw
history blame
No virus
3.08 kB
import os
import numpy as np
from omegaconf import OmegaConf
import omegaconf
from ray.tune import Trainable
from ganime.model.base import load_model
from ganime.data.base import load_dataset
import tensorflow as tf
from ganime.utils.callbacks import TensorboardImage
class TrainableGANime(Trainable):
def setup(self, config):
strategy = tf.distribute.MirroredStrategy()
tune_config = self.load_config_file_and_replace(config)
self.batch_size = tune_config["trainer"]["batch_size"]
self.n_devices = strategy.num_replicas_in_sync
self.global_batch_size = self.batch_size * self.n_devices
self.train_dataset, self.validation_dataset, self.test_dataset = load_dataset(
dataset_name=config["dataset_name"],
dataset_path=config["dataset_path"],
batch_size=self.global_batch_size,
)
self.model = load_model(config["model"], config=tune_config, strategy=strategy)
for data in self.train_dataset.take(1):
train_sample_data = data
for data in self.validation_dataset.take(1):
validation_sample_data = data
tensorboard_image_callback = TensorboardImage(
self.logdir, train_sample_data, validation_sample_data
)
checkpointing = tf.keras.callbacks.ModelCheckpoint(
os.path.join(self.logdir, "checkpoint", "checkpoint"),
monitor="total_loss",
save_best_only=True,
save_weights_only=True,
)
self.callbacks = [tensorboard_image_callback, checkpointing]
def load_config_file_and_replace(self, config):
cfg = OmegaConf.load(config["config_file"])
hyperparameters = config["hyperparameters"]
for hp_key, hp_value in hyperparameters.items():
cfg = self.replace_item(cfg, hp_key, hp_value)
return cfg
def replace_item(self, obj, key, replace_value):
for k, v in obj.items():
if isinstance(v, dict) or isinstance(v, omegaconf.dictconfig.DictConfig):
obj[k] = self.replace_item(v, key, replace_value)
if key in obj:
obj[key] = replace_value
return obj
def step(self):
self.model.fit(
self.train_dataset,
initial_epoch=self.training_iteration,
epochs=self.training_iteration + 1,
callbacks=self.callbacks,
verbose=0,
)
scores = self.model.evaluate(self.validation_dataset, verbose=0)
if np.nan in scores:
self.stop()
return dict(zip(self.model.metrics_names, scores))
def save_checkpoint(self, tmp_checkpoint_dir):
# checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth")
# torch.save(self.model.state_dict(), checkpoint_path)
# return tmp_checkpoint_dir
pass
def load_checkpoint(self, tmp_checkpoint_dir):
# checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth")
# self.model.load_state_dict(torch.load(checkpoint_path))
pass