# Author: Bingxin Ke # Last modified: 2024-03-12 import logging import os import sys import wandb from tabulate import tabulate from torch.utils.tensorboard import SummaryWriter def config_logging(cfg_logging, out_dir=None): file_level = cfg_logging.get("file_level", 10) console_level = cfg_logging.get("console_level", 10) log_formatter = logging.Formatter(cfg_logging["format"]) root_logger = logging.getLogger() root_logger.handlers.clear() root_logger.setLevel(min(file_level, console_level)) if out_dir is not None: _logging_file = os.path.join( out_dir, cfg_logging.get("filename", "logging.log") ) file_handler = logging.FileHandler(_logging_file) file_handler.setFormatter(log_formatter) file_handler.setLevel(file_level) root_logger.addHandler(file_handler) console_handler = logging.StreamHandler(sys.stdout) console_handler.setFormatter(log_formatter) console_handler.setLevel(console_level) root_logger.addHandler(console_handler) # Avoid pollution by packages logging.getLogger("PIL").setLevel(logging.INFO) logging.getLogger("matplotlib").setLevel(logging.INFO) class MyTrainingLogger: """Tensorboard + wandb logger""" writer: SummaryWriter is_initialized = False def __init__(self) -> None: pass def set_dir(self, tb_log_dir): if self.is_initialized: raise ValueError("Do not initialize writer twice") self.writer = SummaryWriter(tb_log_dir) self.is_initialized = True def log_dic(self, scalar_dic, global_step, walltime=None): for k, v in scalar_dic.items(): self.writer.add_scalar(k, v, global_step=global_step, walltime=walltime) return # global instance tb_logger = MyTrainingLogger() # -------------- wandb tools -------------- def init_wandb(enable: bool, **kwargs): if enable: run = wandb.init(sync_tensorboard=True, **kwargs) else: run = wandb.init(mode="disabled") return run def log_slurm_job_id(step): global tb_logger _jobid = os.getenv("SLURM_JOB_ID") if _jobid is None: _jobid = -1 tb_logger.writer.add_scalar("job_id", int(_jobid), global_step=step) logging.debug(f"Slurm job_id: {_jobid}") def load_wandb_job_id(out_dir): with open(os.path.join(out_dir, "WANDB_ID"), "r") as f: wandb_id = f.read() return wandb_id def save_wandb_job_id(run, out_dir): with open(os.path.join(out_dir, "WANDB_ID"), "w+") as f: f.write(run.id) def eval_dic_to_text(val_metrics: dict, dataset_name: str, sample_list_path: str): eval_text = f"Evaluation metrics:\n\ on dataset: {dataset_name}\n\ over samples in: {sample_list_path}\n" eval_text += tabulate([val_metrics.keys(), val_metrics.values()]) return eval_text