import glob import os from typing import Literal import numpy as np from .base import SequenceDataset import math class MovingMNISTImage(SequenceDataset): def load_data(self, dataset_path: str, split: str) -> np.ndarray: data = np.load(os.path.join(dataset_path, "moving_mnist", "mnist_test_seq.npy")) # Data is of shape (window, n_samples, width, height) # But we want for keras something of shape (n_samples, window, width, height) data = np.moveaxis(data, 0, 1) # Also expand dimensions to have channels at the end (n_samples, window, width, height, channels) data = np.expand_dims(data, axis=-1) if split == "train": data = data[:-1000] else: data = data[-1000:] data = np.concatenate([data, data, data], axis=-1) return data def __getitem__(self, idx): inds = self.indices[idx * self.batch_size : (idx + 1) * self.batch_size] batch_x = self.data[inds, 0, ...] batch_y = self.data[inds, 1, ...] return batch_x, batch_y def preprocess_data(self, data: np.ndarray) -> np.ndarray: return data / 255 class MovingMNIST(SequenceDataset): def __init__( self, dataset_path: str, batch_size: int, split: Literal["train", "validation", "test"] = "train", ): self.batch_size = batch_size self.split = split root_path = os.path.join(dataset_path, "moving_mnist", split) self.paths = glob.glob(os.path.join(root_path, "*.npy")) # self.data = self.preprocess_data(self.data) self.indices = np.arange(len(self.paths)) self.on_epoch_end() # def load_data(self, dataset_path: str, split: str) -> np.ndarray: # data = np.load(os.path.join(dataset_path, "moving_mnist", "mnist_test_seq.npy")) # # Data is of shape (window, n_samples, width, height) # # But we want for keras something of shape (n_samples, window, width, height) # data = np.moveaxis(data, 0, 1) # # Also expand dimensions to have channels at the end (n_samples, window, width, height, channels) # data = np.expand_dims(data, axis=-1) # if split == "train": # data = data[:100] # else: # data = data[100:110] # data = np.concatenate([data, data, data], axis=-1) # return data def __len__(self): return math.ceil(len(self.paths) / self.batch_size) def __getitem__(self, idx): inds = self.indices[idx * self.batch_size : (idx + 1) * self.batch_size] data = self.load_indices(inds) batch_x = np.concatenate([data[:, 0:1, ...], data[:, -1:, ...]], axis=1) batch_y = data[:, 1:, ...] return batch_x, batch_y def get_fixed_batch(self, idx): self.fixed_indices = ( self.fixed_indices if hasattr(self, "fixed_indices") else self.indices[ idx * self.batch_size : (idx + 1) * self.batch_size ].copy() ) data = self.load_indices(self.fixed_indices) batch_x = np.concatenate([data[:, 0:1, ...], data[:, -1:, ...]], axis=1) batch_y = data[:, 1:, ...] return batch_x, batch_y def load_indices(self, indices): paths_to_load = [self.paths[index] for index in indices] data = [np.load(path) for path in paths_to_load] data = np.array(data) return self.preprocess_data(data) def preprocess_data(self, data: np.ndarray) -> np.ndarray: return data / 255