File size: 7,552 Bytes
3be620b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
from abc import ABC, abstractclassmethod, abstractmethod
import glob
import math
import os
from typing import Dict
from typing_extensions import dataclass_transform

import numpy as np
import tensorflow as tf
from tqdm.auto import tqdm


def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):  # if value ist tensor
        value = value.numpy()  # get value of tensor
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _float_feature(value):
    """Returns a floast_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))


def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def serialize_array(array):
    array = tf.io.serialize_tensor(array)
    return array


class Dataset(ABC):
    def __init__(self, dataset_path: str):
        self.dataset_path = dataset_path

    @classmethod
    def _parse_single_element(cls, element) -> tf.train.Example:

        features = tf.train.Features(feature=cls._get_features(element))

        return tf.train.Example(features=features)

    @abstractclassmethod
    def _get_features(cls, element) -> Dict[str, tf.train.Feature]:
        pass

    @abstractclassmethod
    def _parse_tfr_element(cls, element):
        pass

    @classmethod
    def write_to_tfr(cls, data: np.ndarray, out_dir: str, filename: str):
        if not os.path.exists(out_dir):
            os.makedirs(out_dir)

        # Write all elements to a single tfrecord file
        single_file_name = cls.__write_to_single_tfr(data, out_dir, filename)

        # The optimal size for a single tfrecord file is around 100 MB. Get the number of files that need to be created
        number_splits = cls.__get_number_splits(single_file_name)

        if number_splits > 1:
            os.remove(single_file_name)
            cls.__write_to_multiple_tfr(data, out_dir, filename, number_splits)

    @classmethod
    def __write_to_multiple_tfr(
        cls, data: np.array, out_dir: str, filename: str, n_splits: int
    ):

        file_count = 0

        max_files = math.ceil(data.shape[0] / n_splits)

        print(f"Creating {n_splits} files with {max_files} elements each.")

        for i in tqdm(range(n_splits)):
            current_shard_name = os.path.join(
                out_dir,
                f"{filename}.tfrecords-{str(i).zfill(len(str(n_splits)))}-of-{n_splits}",
            )
            writer = tf.io.TFRecordWriter(current_shard_name)

            current_shard_count = 0
            while current_shard_count < max_files:  # as long as our shard is not full
                # get the index of the file that we want to parse now
                index = i * max_files + current_shard_count
                if index >= len(
                    data
                ):  # when we have consumed the whole data, preempt generation
                    break

                current_element = data[index]

                # create the required Example representation
                out = cls._parse_single_element(element=current_element)

                writer.write(out.SerializeToString())
                current_shard_count += 1
                file_count += 1

        writer.close()
        print(f"\nWrote {file_count} elements to TFRecord")
        return file_count

    @classmethod
    def __get_number_splits(cls, filename: str):
        target_size = 100 * 1024 * 1024  # 100mb

        single_file_size = os.path.getsize(filename)
        number_splits = math.ceil(single_file_size / target_size)
        return number_splits

    @classmethod
    def __write_to_single_tfr(cls, data: np.array, out_dir: str, filename: str):

        current_path_name = os.path.join(
            out_dir,
            f"{filename}.tfrecords-0-of-1",
        )

        writer = tf.io.TFRecordWriter(current_path_name)
        for element in tqdm(data):
            writer.write(cls._parse_single_element(element).SerializeToString())
        writer.close()

        return current_path_name

    def load(self) -> tf.data.TFRecordDataset:
        path = self.dataset_path
        dataset = None

        if os.path.isdir(path):
            dataset = self._load_folder(path)
        elif os.path.isfile(path):
            dataset = self._load_file(path)
        else:
            raise ValueError(f"Path {path} is not a valid file or folder.")

        dataset = dataset.map(self._parse_tfr_element)
        return dataset

    def _load_file(self, path) -> tf.data.TFRecordDataset:
        return tf.data.TFRecordDataset(path)

    def _load_folder(self, path) -> tf.data.TFRecordDataset:

        return tf.data.TFRecordDataset(
            glob.glob(os.path.join(path, "**/*.tfrecords*"), recursive=True)
        )


class VideoDataset(Dataset):
    @classmethod
    def _get_features(cls, element) -> Dict[str, tf.train.Feature]:
        return {
            "frames": _int64_feature(element.shape[0]),
            "height": _int64_feature(element.shape[1]),
            "width": _int64_feature(element.shape[2]),
            "depth": _int64_feature(element.shape[3]),
            "raw_video": _bytes_feature(serialize_array(element)),
        }

    @classmethod
    def _parse_tfr_element(cls, element):
        # use the same structure as above; it's kinda an outline of the structure we now want to create
        data = {
            "frames": tf.io.FixedLenFeature([], tf.int64),
            "height": tf.io.FixedLenFeature([], tf.int64),
            "width": tf.io.FixedLenFeature([], tf.int64),
            "raw_video": tf.io.FixedLenFeature([], tf.string),
            "depth": tf.io.FixedLenFeature([], tf.int64),
        }

        content = tf.io.parse_single_example(element, data)

        frames = content["frames"]
        height = content["height"]
        width = content["width"]
        depth = content["depth"]
        raw_video = content["raw_video"]

        # get our 'feature'-- our image -- and reshape it appropriately
        feature = tf.io.parse_tensor(raw_video, out_type=tf.uint8)
        feature = tf.reshape(feature, shape=[frames, height, width, depth])
        return feature


class ImageDataset(Dataset):
    @classmethod
    def _get_features(cls, element) -> Dict[str, tf.train.Feature]:
        return {
            "height": _int64_feature(element.shape[0]),
            "width": _int64_feature(element.shape[1]),
            "depth": _int64_feature(element.shape[2]),
            "raw_image": _bytes_feature(serialize_array(element)),
        }

    @classmethod
    def _parse_tfr_element(cls, element):
        # use the same structure as above; it's kinda an outline of the structure we now want to create
        data = {
            "height": tf.io.FixedLenFeature([], tf.int64),
            "width": tf.io.FixedLenFeature([], tf.int64),
            "raw_image": tf.io.FixedLenFeature([], tf.string),
            "depth": tf.io.FixedLenFeature([], tf.int64),
        }

        content = tf.io.parse_single_example(element, data)

        height = content["height"]
        width = content["width"]
        depth = content["depth"]
        raw_image = content["raw_image"]

        # get our 'feature'-- our image -- and reshape it appropriately
        feature = tf.io.parse_tensor(raw_image, out_type=tf.uint8)
        feature = tf.reshape(feature, shape=[height, width, depth])
        return feature