File size: 3,953 Bytes
212111c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import typing
import os
import logging
from abc import ABC, abstractmethod
from pathlib import Path
import torch.nn as nn

from tensorboardX import SummaryWriter

try:
    import wandb
    WANDB_FOUND = True
except ImportError:
    WANDB_FOUND = False

logger = logging.getLogger(__name__)


class TAPEVisualizer(ABC):
    """Base class for visualization in TAPE"""

    @abstractmethod
    def __init__(self, log_dir: typing.Union[str, Path], exp_name: str, debug: bool = False):
        raise NotImplementedError

    @abstractmethod
    def log_config(self, config: typing.Dict[str, typing.Any]) -> None:
        raise NotImplementedError

    @abstractmethod
    def watch(self, model: nn.Module) -> None:
        raise NotImplementedError

    @abstractmethod
    def log_metrics(self,
                    metrics_dict: typing.Dict[str, float],
                    split: str,
                    step: int):
        raise NotImplementedError


class DummyVisualizer(TAPEVisualizer):
    """Dummy class that doesn't do anything. Used for non-master branches."""

    def __init__(self,
                 log_dir: typing.Union[str, Path] = '',
                 exp_name: str = '',
                 debug: bool = False):
        pass

    def log_config(self, config: typing.Dict[str, typing.Any]) -> None:
        pass

    def watch(self, model: nn.Module) -> None:
        pass

    def log_metrics(self,
                    metrics_dict: typing.Dict[str, float],
                    split: str,
                    step: int):
        pass


class TBVisualizer(TAPEVisualizer):

    def __init__(self, log_dir: typing.Union[str, Path], exp_name: str, debug: bool = False):
        log_dir = Path(log_dir) / exp_name
        logger.info(f"tensorboard file at: {log_dir}")
        self.logger = SummaryWriter(log_dir=str(log_dir))

    def log_config(self, config: typing.Dict[str, typing.Any]) -> None:
        logger.warn("Cannot log config when using a TBVisualizer. "
                    "Configure wandb for this functionality")

    def watch(self, model: nn.Module) -> None:
        logger.warn("Cannot watch models when using a TBVisualizer. "
                    "Configure wandb for this functionality")

    def log_metrics(self,
                    metrics_dict: typing.Dict[str, float],
                    split: str,
                    step: int):
        for name, value in metrics_dict.items():
            self.logger.add_scalar(split + "/" + name, value, step)


class WandBVisualizer(TAPEVisualizer):

    def __init__(self, log_dir: typing.Union[str, Path], exp_name: str, debug: bool = False):
        if not WANDB_FOUND:
            raise ImportError("wandb module not available")
        #if debug:
        #    os.environ['WANDB_MODE'] = 'dryrun'
        #if 'WANDB_PROJECT' not in os.environ:
        #    # Want the user to set the WANDB_PROJECT.
        #    logger.warning("WANDB_PROJECT environment variable not found, "
        #                   "not logging to app.wandb.ai")
        #    os.environ['WANDB_MODE'] = 'dryrun'
        wandb.init(dir=log_dir, name=exp_name)

    def log_config(self, config: typing.Dict[str, typing.Any]) -> None:
        wandb.config.update(config)

    def watch(self, model: nn.Module):
        wandb.watch(model)

    def log_metrics(self,
                    metrics_dict: typing.Dict[str, float],
                    split: str,
                    step: int):
        wandb.log({f"{split.capitalize()} {name.capitalize()}": value
                   for name, value in metrics_dict.items()}, step=step)


def get(log_dir: typing.Union[str, Path],
        exp_name: str,
        local_rank: int,
        debug: bool = False) -> TAPEVisualizer:
    if local_rank not in (-1, 0):
        return DummyVisualizer(log_dir, exp_name, debug)
    elif WANDB_FOUND:
        return WandBVisualizer(log_dir, exp_name, debug)
    else:
        return TBVisualizer(log_dir, exp_name, debug)