File size: 3,222 Bytes
7f2690b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
from glob import glob

import joblib
import numpy as np
import torch
from sklearn.cluster import MiniBatchKMeans
from torch.utils.data import DataLoader
from tqdm import tqdm
from train import instantiate_from_config


class FeatClusterStage(object):

    def __init__(self, num_clusters=None, cached_kmeans_path=None, feats_dataset_config=None, num_workers=None):
        if cached_kmeans_path is not None and os.path.exists(cached_kmeans_path):
            print(f'Precalculated Clusterer already exists, loading from {cached_kmeans_path}')
            self.clusterer = joblib.load(cached_kmeans_path)
        elif feats_dataset_config is not None:
            self.clusterer = self.load_or_precalculate_kmeans(num_clusters, feats_dataset_config, num_workers)
        else:
            raise Exception('Neither `feats_dataset_config` nor `cached_kmeans_path` are defined')

    def eval(self):
        return self

    def encode(self, c):
        # c_quant: cluster centers, c_ind: cluster index

        B, D, T = c.shape
        # (B*T, D) <- (B, T, D) <- (B, D, T)
        c_flat = c.permute(0, 2, 1).view(B*T, D).cpu().numpy()

        c_ind = self.clusterer.predict(c_flat)
        c_quant = self.clusterer.cluster_centers_[c_ind]

        c_ind = torch.from_numpy(c_ind).to(c.device)
        c_quant = torch.from_numpy(c_quant).to(c.device)

        c_ind = c_ind.long().unsqueeze(-1)
        c_quant = c_quant.view(B, T, D).permute(0, 2, 1)

        info = None, None, c_ind
        # (B, D, T), (), ((), (768, 1024), (768, 1))
        return c_quant, None, info

    def decode(self, c):
        return c

    def get_input(self, batch, k):
        x = batch[k]
        x = x.permute(0, 2, 1).to(memory_format=torch.contiguous_format)
        return x.float()

    def load_or_precalculate_kmeans(self, num_clusters, dataset_cfg, num_workers):
        print(f'Calculating clustering K={num_clusters}')
        batch_size = 64
        dataset_name = dataset_cfg.target.split('.')[-1]
        cached_path = os.path.join('./specvqgan/modules/misc/', f'kmeans_K{num_clusters}_{dataset_name}.sklearn')
        feat_depth = dataset_cfg.params.condition_dataset_cfg.feat_depth
        feat_crop_len = dataset_cfg.params.condition_dataset_cfg.feat_crop_len

        feat_loading_dset = instantiate_from_config(dataset_cfg)
        feat_loading_dset = DataLoader(feat_loading_dset, batch_size, num_workers=num_workers, shuffle=True)

        clusterer = MiniBatchKMeans(num_clusters, batch_size=batch_size*feat_crop_len, random_state=0)

        for item in tqdm(feat_loading_dset):
            batch = item['feature'].reshape(-1, feat_depth).float().numpy()
            clusterer.partial_fit(batch)

        joblib.dump(clusterer, cached_path)
        print(f'Saved the calculated Clusterer @ {cached_path}')
        return clusterer


if __name__ == '__main__':
    from omegaconf import OmegaConf

    config = OmegaConf.load('./configs/vggsound_featcluster_transformer.yaml')
    config.model.params.first_stage_config.params.ckpt_path = './logs/2021-05-19T22-16-54_vggsound_specs_vqgan/checkpoints/epoch_39.ckpt'
    model = instantiate_from_config(config.model.params.cond_stage_config)
    print(model)