File size: 471 Bytes
3be620b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import os

import numpy as np

from .base import SequenceDataset


class KNYImage(SequenceDataset):
    def load_data(self, dataset_path: str, split: str) -> np.ndarray:
        data = np.load(os.path.join(dataset_path, "kny", "kny_images_64x128.npy"))
        if split == "train":
            data = data[:-5000]
        else:
            data = data[-5000:]

        return data

    def preprocess_data(self, data: np.ndarray) -> np.ndarray:
        return data / 255