File size: 3,539 Bytes
3be620b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import glob
import os
from typing import Literal

import numpy as np

from .base import SequenceDataset
import math


class MovingMNISTImage(SequenceDataset):
    def load_data(self, dataset_path: str, split: str) -> np.ndarray:
        data = np.load(os.path.join(dataset_path, "moving_mnist", "mnist_test_seq.npy"))
        # Data is of shape (window, n_samples, width, height)
        # But we want for keras something of shape (n_samples, window, width, height)
        data = np.moveaxis(data, 0, 1)
        # Also expand dimensions to have channels at the end (n_samples, window, width, height, channels)
        data = np.expand_dims(data, axis=-1)
        if split == "train":
            data = data[:-1000]
        else:
            data = data[-1000:]

        data = np.concatenate([data, data, data], axis=-1)

        return data

    def __getitem__(self, idx):
        inds = self.indices[idx * self.batch_size : (idx + 1) * self.batch_size]
        batch_x = self.data[inds, 0, ...]
        batch_y = self.data[inds, 1, ...]

        return batch_x, batch_y

    def preprocess_data(self, data: np.ndarray) -> np.ndarray:
        return data / 255


class MovingMNIST(SequenceDataset):
    def __init__(
        self,
        dataset_path: str,
        batch_size: int,
        split: Literal["train", "validation", "test"] = "train",
    ):
        self.batch_size = batch_size
        self.split = split
        root_path = os.path.join(dataset_path, "moving_mnist", split)
        self.paths = glob.glob(os.path.join(root_path, "*.npy"))
        # self.data = self.preprocess_data(self.data)

        self.indices = np.arange(len(self.paths))
        self.on_epoch_end()

    # def load_data(self, dataset_path: str, split: str) -> np.ndarray:
    # data = np.load(os.path.join(dataset_path, "moving_mnist", "mnist_test_seq.npy"))
    # # Data is of shape (window, n_samples, width, height)
    # # But we want for keras something of shape (n_samples, window, width, height)
    # data = np.moveaxis(data, 0, 1)
    # # Also expand dimensions to have channels at the end (n_samples, window, width, height, channels)
    # data = np.expand_dims(data, axis=-1)
    # if split == "train":
    #     data = data[:100]
    # else:
    #     data = data[100:110]

    # data = np.concatenate([data, data, data], axis=-1)

    # return data

    def __len__(self):
        return math.ceil(len(self.paths) / self.batch_size)

    def __getitem__(self, idx):
        inds = self.indices[idx * self.batch_size : (idx + 1) * self.batch_size]
        data = self.load_indices(inds)
        batch_x = np.concatenate([data[:, 0:1, ...], data[:, -1:, ...]], axis=1)
        batch_y = data[:, 1:, ...]

        return batch_x, batch_y

    def get_fixed_batch(self, idx):
        self.fixed_indices = (
            self.fixed_indices
            if hasattr(self, "fixed_indices")
            else self.indices[
                idx * self.batch_size : (idx + 1) * self.batch_size
            ].copy()
        )
        data = self.load_indices(self.fixed_indices)
        batch_x = np.concatenate([data[:, 0:1, ...], data[:, -1:, ...]], axis=1)
        batch_y = data[:, 1:, ...]

        return batch_x, batch_y

    def load_indices(self, indices):
        paths_to_load = [self.paths[index] for index in indices]
        data = [np.load(path) for path in paths_to_load]
        data = np.array(data)
        return self.preprocess_data(data)

    def preprocess_data(self, data: np.ndarray) -> np.ndarray:
        return data / 255