Inpaint / src /util /logging_util.py
ZehanWang's picture
Upload folder using huggingface_hub
864ec44 verified
raw
history blame
2.87 kB
# Author: Bingxin Ke
# Last modified: 2024-03-12
import logging
import os
import sys
import wandb
from tabulate import tabulate
from torch.utils.tensorboard import SummaryWriter
def config_logging(cfg_logging, out_dir=None):
file_level = cfg_logging.get("file_level", 10)
console_level = cfg_logging.get("console_level", 10)
log_formatter = logging.Formatter(cfg_logging["format"])
root_logger = logging.getLogger()
root_logger.handlers.clear()
root_logger.setLevel(min(file_level, console_level))
if out_dir is not None:
_logging_file = os.path.join(
out_dir, cfg_logging.get("filename", "logging.log")
)
file_handler = logging.FileHandler(_logging_file)
file_handler.setFormatter(log_formatter)
file_handler.setLevel(file_level)
root_logger.addHandler(file_handler)
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(log_formatter)
console_handler.setLevel(console_level)
root_logger.addHandler(console_handler)
# Avoid pollution by packages
logging.getLogger("PIL").setLevel(logging.INFO)
logging.getLogger("matplotlib").setLevel(logging.INFO)
class MyTrainingLogger:
"""Tensorboard + wandb logger"""
writer: SummaryWriter
is_initialized = False
def __init__(self) -> None:
pass
def set_dir(self, tb_log_dir):
if self.is_initialized:
raise ValueError("Do not initialize writer twice")
self.writer = SummaryWriter(tb_log_dir)
self.is_initialized = True
def log_dic(self, scalar_dic, global_step, walltime=None):
for k, v in scalar_dic.items():
self.writer.add_scalar(k, v, global_step=global_step, walltime=walltime)
return
# global instance
tb_logger = MyTrainingLogger()
# -------------- wandb tools --------------
def init_wandb(enable: bool, **kwargs):
if enable:
run = wandb.init(sync_tensorboard=True, **kwargs)
else:
run = wandb.init(mode="disabled")
return run
def log_slurm_job_id(step):
global tb_logger
_jobid = os.getenv("SLURM_JOB_ID")
if _jobid is None:
_jobid = -1
tb_logger.writer.add_scalar("job_id", int(_jobid), global_step=step)
logging.debug(f"Slurm job_id: {_jobid}")
def load_wandb_job_id(out_dir):
with open(os.path.join(out_dir, "WANDB_ID"), "r") as f:
wandb_id = f.read()
return wandb_id
def save_wandb_job_id(run, out_dir):
with open(os.path.join(out_dir, "WANDB_ID"), "w+") as f:
f.write(run.id)
def eval_dic_to_text(val_metrics: dict, dataset_name: str, sample_list_path: str):
eval_text = f"Evaluation metrics:\n\
on dataset: {dataset_name}\n\
over samples in: {sample_list_path}\n"
eval_text += tabulate([val_metrics.keys(), val_metrics.values()])
return eval_text