File size: 3,953 Bytes
212111c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import typing
import os
import logging
from abc import ABC, abstractmethod
from pathlib import Path
import torch.nn as nn
from tensorboardX import SummaryWriter
try:
import wandb
WANDB_FOUND = True
except ImportError:
WANDB_FOUND = False
logger = logging.getLogger(__name__)
class TAPEVisualizer(ABC):
"""Base class for visualization in TAPE"""
@abstractmethod
def __init__(self, log_dir: typing.Union[str, Path], exp_name: str, debug: bool = False):
raise NotImplementedError
@abstractmethod
def log_config(self, config: typing.Dict[str, typing.Any]) -> None:
raise NotImplementedError
@abstractmethod
def watch(self, model: nn.Module) -> None:
raise NotImplementedError
@abstractmethod
def log_metrics(self,
metrics_dict: typing.Dict[str, float],
split: str,
step: int):
raise NotImplementedError
class DummyVisualizer(TAPEVisualizer):
"""Dummy class that doesn't do anything. Used for non-master branches."""
def __init__(self,
log_dir: typing.Union[str, Path] = '',
exp_name: str = '',
debug: bool = False):
pass
def log_config(self, config: typing.Dict[str, typing.Any]) -> None:
pass
def watch(self, model: nn.Module) -> None:
pass
def log_metrics(self,
metrics_dict: typing.Dict[str, float],
split: str,
step: int):
pass
class TBVisualizer(TAPEVisualizer):
def __init__(self, log_dir: typing.Union[str, Path], exp_name: str, debug: bool = False):
log_dir = Path(log_dir) / exp_name
logger.info(f"tensorboard file at: {log_dir}")
self.logger = SummaryWriter(log_dir=str(log_dir))
def log_config(self, config: typing.Dict[str, typing.Any]) -> None:
logger.warn("Cannot log config when using a TBVisualizer. "
"Configure wandb for this functionality")
def watch(self, model: nn.Module) -> None:
logger.warn("Cannot watch models when using a TBVisualizer. "
"Configure wandb for this functionality")
def log_metrics(self,
metrics_dict: typing.Dict[str, float],
split: str,
step: int):
for name, value in metrics_dict.items():
self.logger.add_scalar(split + "/" + name, value, step)
class WandBVisualizer(TAPEVisualizer):
def __init__(self, log_dir: typing.Union[str, Path], exp_name: str, debug: bool = False):
if not WANDB_FOUND:
raise ImportError("wandb module not available")
#if debug:
# os.environ['WANDB_MODE'] = 'dryrun'
#if 'WANDB_PROJECT' not in os.environ:
# # Want the user to set the WANDB_PROJECT.
# logger.warning("WANDB_PROJECT environment variable not found, "
# "not logging to app.wandb.ai")
# os.environ['WANDB_MODE'] = 'dryrun'
wandb.init(dir=log_dir, name=exp_name)
def log_config(self, config: typing.Dict[str, typing.Any]) -> None:
wandb.config.update(config)
def watch(self, model: nn.Module):
wandb.watch(model)
def log_metrics(self,
metrics_dict: typing.Dict[str, float],
split: str,
step: int):
wandb.log({f"{split.capitalize()} {name.capitalize()}": value
for name, value in metrics_dict.items()}, step=step)
def get(log_dir: typing.Union[str, Path],
exp_name: str,
local_rank: int,
debug: bool = False) -> TAPEVisualizer:
if local_rank not in (-1, 0):
return DummyVisualizer(log_dir, exp_name, debug)
elif WANDB_FOUND:
return WandBVisualizer(log_dir, exp_name, debug)
else:
return TBVisualizer(log_dir, exp_name, debug)
|