from ganime.trainer.warmup.cosine import WarmUpCosine def create_warmup_scheduler(trainer_config, num_devices): len_x_train = trainer_config["len_x_train"] batch_size = trainer_config["batch_size"] n_epochs = trainer_config["n_epochs"] total_steps = int(len_x_train / batch_size * n_epochs / num_devices) warmup_epoch_percentage = trainer_config["warmup_epoch_percentage"] warmup_steps = int(total_steps * warmup_epoch_percentage) scheduled_lrs = WarmUpCosine( lr_start=trainer_config["lr_start"], lr_max=trainer_config["lr_max"], warmup_steps=warmup_steps, total_steps=total_steps, ) return scheduled_lrs