File size: 7,769 Bytes
a983ebc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/20_interpret.ipynb.

# %% ../nbs/20_interpret.ipynb 2
from __future__ import annotations
from .data.all import *
from .optimizer import *
from .learner import *
from .tabular.core import *
import sklearn.metrics as skm

# %% auto 0
__all__ = ['plot_top_losses', 'Interpretation', 'ClassificationInterpretation', 'SegmentationInterpretation']

# %% ../nbs/20_interpret.ipynb 7
@typedispatch
def plot_top_losses(x, y, *args, **kwargs):
    raise Exception(f"plot_top_losses is not implemented for {type(x)},{type(y)}")

# %% ../nbs/20_interpret.ipynb 8
_all_ = ["plot_top_losses"]

# %% ../nbs/20_interpret.ipynb 9
class Interpretation():
    "Interpretation base class, can be inherited for task specific Interpretation classes"
    def __init__(self,
        learn:Learner,
        dl:DataLoader, # `DataLoader` to run inference over
        losses:TensorBase, # Losses calculated from `dl`
        act=None # Activation function for prediction
    ): 
        store_attr()

    def __getitem__(self, idxs):
        "Return inputs, preds, targs, decoded outputs, and losses at `idxs`"
        if isinstance(idxs, Tensor): idxs = idxs.tolist()
        if not is_listy(idxs): idxs = [idxs]
        items = getattr(self.dl.items, 'iloc', L(self.dl.items))[idxs]
        tmp_dl = self.learn.dls.test_dl(items, with_labels=True, process=not isinstance(self.dl, TabDataLoader))
        inps,preds,targs,decoded = self.learn.get_preds(dl=tmp_dl, with_input=True, with_loss=False, 
                                                        with_decoded=True, act=self.act, reorder=False)
        return inps, preds, targs, decoded, self.losses[idxs]

    @classmethod
    def from_learner(cls,
        learn, # Model used to create interpretation
        ds_idx:int=1, # Index of `learn.dls` when `dl` is None
        dl:DataLoader=None, # `Dataloader` used to make predictions
        act=None # Override default or set prediction activation function
    ):
        "Construct interpretation object from a learner"
        if dl is None: dl = learn.dls[ds_idx].new(shuffle=False, drop_last=False)
        _,_,losses = learn.get_preds(dl=dl, with_input=False, with_loss=True, with_decoded=False,
                                     with_preds=False, with_targs=False, act=act)
        return cls(learn, dl, losses, act)

    def top_losses(self,
        k:int|None=None, # Return `k` losses, defaults to all
        largest:bool=True, # Sort losses by largest or smallest
        items:bool=False # Whether to return input items
    ):
        "`k` largest(/smallest) losses and indexes, defaulting to all losses."
        losses, idx = self.losses.topk(ifnone(k, len(self.losses)), largest=largest)
        if items: return losses, idx, getattr(self.dl.items, 'iloc', L(self.dl.items))[idx]
        else:     return losses, idx

    def plot_top_losses(self,
        k:int|MutableSequence, # Number of losses to plot
        largest:bool=True, # Sort losses by largest or smallest
        **kwargs
    ):
        "Show `k` largest(/smallest) preds and losses. Implementation based on type dispatch"
        if is_listy(k) or isinstance(k, range):
            losses, idx = (o[k] for o in self.top_losses(None, largest))
        else: 
            losses, idx = self.top_losses(k, largest)
        inps, preds, targs, decoded, _ = self[idx]
        inps, targs, decoded = tuplify(inps), tuplify(targs), tuplify(decoded)
        x, y, its = self.dl._pre_show_batch(inps+targs, max_n=len(idx))
        x1, y1, outs = self.dl._pre_show_batch(inps+decoded, max_n=len(idx))
        if its is not None:
            plot_top_losses(x, y, its, outs.itemgot(slice(len(inps), None)), preds, losses, **kwargs)
        #TODO: figure out if this is needed
        #its None means that a batch knows how to show itself as a whole, so we pass x, x1
        #else: show_results(x, x1, its, ctxs=ctxs, max_n=max_n, **kwargs)

    def show_results(self,
        idxs:list, # Indices of predictions and targets
        **kwargs
    ):
        "Show predictions and targets of `idxs`"
        if isinstance(idxs, Tensor): idxs = idxs.tolist()
        if not is_listy(idxs): idxs = [idxs]
        inps, _, targs, decoded, _ = self[idxs]
        b = tuplify(inps)+tuplify(targs)
        self.dl.show_results(b, tuplify(decoded), max_n=len(idxs), **kwargs)

# %% ../nbs/20_interpret.ipynb 22
class ClassificationInterpretation(Interpretation):
    "Interpretation methods for classification models."

    def __init__(self, 
        learn:Learner, 
        dl:DataLoader, # `DataLoader` to run inference over
        losses:TensorBase, # Losses calculated from `dl`
        act=None # Activation function for prediction
    ):
        super().__init__(learn, dl, losses, act)
        self.vocab = self.dl.vocab
        if is_listy(self.vocab): self.vocab = self.vocab[-1]

    def confusion_matrix(self):
        "Confusion matrix as an `np.ndarray`."
        x = torch.arange(0, len(self.vocab))
        _,targs,decoded = self.learn.get_preds(dl=self.dl, with_decoded=True, with_preds=True, 
                                               with_targs=True, act=self.act)
        d,t = flatten_check(decoded, targs)
        cm = ((d==x[:,None]) & (t==x[:,None,None])).long().sum(2)
        return to_np(cm)

    def plot_confusion_matrix(self, 
        normalize:bool=False, # Whether to normalize occurrences
        title:str='Confusion matrix', # Title of plot
        cmap:str="Blues", # Colormap from matplotlib
        norm_dec:int=2, # Decimal places for normalized occurrences
        plot_txt:bool=True, # Display occurrence in matrix
        **kwargs
    ):
        "Plot the confusion matrix, with `title` and using `cmap`."
        # This function is mainly copied from the sklearn docs
        cm = self.confusion_matrix()
        if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        fig = plt.figure(**kwargs)
        plt.imshow(cm, interpolation='nearest', cmap=cmap)
        plt.title(title)
        tick_marks = np.arange(len(self.vocab))
        plt.xticks(tick_marks, self.vocab, rotation=90)
        plt.yticks(tick_marks, self.vocab, rotation=0)

        if plot_txt:
            thresh = cm.max() / 2.
            for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
                coeff = f'{cm[i, j]:.{norm_dec}f}' if normalize else f'{cm[i, j]}'
                plt.text(j, i, coeff, horizontalalignment="center", verticalalignment="center", color="white"
                         if cm[i, j] > thresh else "black")

        ax = fig.gca()
        ax.set_ylim(len(self.vocab)-.5,-.5)

        plt.tight_layout()
        plt.ylabel('Actual')
        plt.xlabel('Predicted')
        plt.grid(False)

    def most_confused(self, min_val=1):
        "Sorted descending largest non-diagonal entries of confusion matrix (actual, predicted, # occurrences"
        cm = self.confusion_matrix()
        np.fill_diagonal(cm, 0)
        res = [(self.vocab[i],self.vocab[j],cm[i,j]) for i,j in zip(*np.where(cm>=min_val))]
        return sorted(res, key=itemgetter(2), reverse=True)

    def print_classification_report(self):
        "Print scikit-learn classification report"
        _,targs,decoded = self.learn.get_preds(dl=self.dl, with_decoded=True, with_preds=True, 
                                               with_targs=True, act=self.act)
        d,t = flatten_check(decoded, targs)
        names = [str(v) for v in self.vocab]
        print(skm.classification_report(t, d, labels=list(self.vocab.o2i.values()), target_names=names))

# %% ../nbs/20_interpret.ipynb 27
class SegmentationInterpretation(Interpretation):
    "Interpretation methods for segmentation models."
    pass