GANime / ganime /data /mnist.py
Kurokabe's picture
Upload 84 files
3be620b
raw
history blame
No virus
3.54 kB
import glob
import os
from typing import Literal
import numpy as np
from .base import SequenceDataset
import math
class MovingMNISTImage(SequenceDataset):
def load_data(self, dataset_path: str, split: str) -> np.ndarray:
data = np.load(os.path.join(dataset_path, "moving_mnist", "mnist_test_seq.npy"))
# Data is of shape (window, n_samples, width, height)
# But we want for keras something of shape (n_samples, window, width, height)
data = np.moveaxis(data, 0, 1)
# Also expand dimensions to have channels at the end (n_samples, window, width, height, channels)
data = np.expand_dims(data, axis=-1)
if split == "train":
data = data[:-1000]
else:
data = data[-1000:]
data = np.concatenate([data, data, data], axis=-1)
return data
def __getitem__(self, idx):
inds = self.indices[idx * self.batch_size : (idx + 1) * self.batch_size]
batch_x = self.data[inds, 0, ...]
batch_y = self.data[inds, 1, ...]
return batch_x, batch_y
def preprocess_data(self, data: np.ndarray) -> np.ndarray:
return data / 255
class MovingMNIST(SequenceDataset):
def __init__(
self,
dataset_path: str,
batch_size: int,
split: Literal["train", "validation", "test"] = "train",
):
self.batch_size = batch_size
self.split = split
root_path = os.path.join(dataset_path, "moving_mnist", split)
self.paths = glob.glob(os.path.join(root_path, "*.npy"))
# self.data = self.preprocess_data(self.data)
self.indices = np.arange(len(self.paths))
self.on_epoch_end()
# def load_data(self, dataset_path: str, split: str) -> np.ndarray:
# data = np.load(os.path.join(dataset_path, "moving_mnist", "mnist_test_seq.npy"))
# # Data is of shape (window, n_samples, width, height)
# # But we want for keras something of shape (n_samples, window, width, height)
# data = np.moveaxis(data, 0, 1)
# # Also expand dimensions to have channels at the end (n_samples, window, width, height, channels)
# data = np.expand_dims(data, axis=-1)
# if split == "train":
# data = data[:100]
# else:
# data = data[100:110]
# data = np.concatenate([data, data, data], axis=-1)
# return data
def __len__(self):
return math.ceil(len(self.paths) / self.batch_size)
def __getitem__(self, idx):
inds = self.indices[idx * self.batch_size : (idx + 1) * self.batch_size]
data = self.load_indices(inds)
batch_x = np.concatenate([data[:, 0:1, ...], data[:, -1:, ...]], axis=1)
batch_y = data[:, 1:, ...]
return batch_x, batch_y
def get_fixed_batch(self, idx):
self.fixed_indices = (
self.fixed_indices
if hasattr(self, "fixed_indices")
else self.indices[
idx * self.batch_size : (idx + 1) * self.batch_size
].copy()
)
data = self.load_indices(self.fixed_indices)
batch_x = np.concatenate([data[:, 0:1, ...], data[:, -1:, ...]], axis=1)
batch_y = data[:, 1:, ...]
return batch_x, batch_y
def load_indices(self, indices):
paths_to_load = [self.paths[index] for index in indices]
data = [np.load(path) for path in paths_to_load]
data = np.array(data)
return self.preprocess_data(data)
def preprocess_data(self, data: np.ndarray) -> np.ndarray:
return data / 255