from abc import ABC, abstractclassmethod, abstractmethod import glob import math import os from typing import Dict from typing_extensions import dataclass_transform import numpy as np import tensorflow as tf from tqdm.auto import tqdm def _bytes_feature(value): """Returns a bytes_list from a string / byte.""" if isinstance(value, type(tf.constant(0))): # if value ist tensor value = value.numpy() # get value of tensor return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) def _float_feature(value): """Returns a floast_list from a float / double.""" return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) def _int64_feature(value): """Returns an int64_list from a bool / enum / int / uint.""" return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) def serialize_array(array): array = tf.io.serialize_tensor(array) return array class Dataset(ABC): def __init__(self, dataset_path: str): self.dataset_path = dataset_path @classmethod def _parse_single_element(cls, element) -> tf.train.Example: features = tf.train.Features(feature=cls._get_features(element)) return tf.train.Example(features=features) @abstractclassmethod def _get_features(cls, element) -> Dict[str, tf.train.Feature]: pass @abstractclassmethod def _parse_tfr_element(cls, element): pass @classmethod def write_to_tfr(cls, data: np.ndarray, out_dir: str, filename: str): if not os.path.exists(out_dir): os.makedirs(out_dir) # Write all elements to a single tfrecord file single_file_name = cls.__write_to_single_tfr(data, out_dir, filename) # The optimal size for a single tfrecord file is around 100 MB. Get the number of files that need to be created number_splits = cls.__get_number_splits(single_file_name) if number_splits > 1: os.remove(single_file_name) cls.__write_to_multiple_tfr(data, out_dir, filename, number_splits) @classmethod def __write_to_multiple_tfr( cls, data: np.array, out_dir: str, filename: str, n_splits: int ): file_count = 0 max_files = math.ceil(data.shape[0] / n_splits) print(f"Creating {n_splits} files with {max_files} elements each.") for i in tqdm(range(n_splits)): current_shard_name = os.path.join( out_dir, f"{filename}.tfrecords-{str(i).zfill(len(str(n_splits)))}-of-{n_splits}", ) writer = tf.io.TFRecordWriter(current_shard_name) current_shard_count = 0 while current_shard_count < max_files: # as long as our shard is not full # get the index of the file that we want to parse now index = i * max_files + current_shard_count if index >= len( data ): # when we have consumed the whole data, preempt generation break current_element = data[index] # create the required Example representation out = cls._parse_single_element(element=current_element) writer.write(out.SerializeToString()) current_shard_count += 1 file_count += 1 writer.close() print(f"\nWrote {file_count} elements to TFRecord") return file_count @classmethod def __get_number_splits(cls, filename: str): target_size = 100 * 1024 * 1024 # 100mb single_file_size = os.path.getsize(filename) number_splits = math.ceil(single_file_size / target_size) return number_splits @classmethod def __write_to_single_tfr(cls, data: np.array, out_dir: str, filename: str): current_path_name = os.path.join( out_dir, f"{filename}.tfrecords-0-of-1", ) writer = tf.io.TFRecordWriter(current_path_name) for element in tqdm(data): writer.write(cls._parse_single_element(element).SerializeToString()) writer.close() return current_path_name def load(self) -> tf.data.TFRecordDataset: path = self.dataset_path dataset = None if os.path.isdir(path): dataset = self._load_folder(path) elif os.path.isfile(path): dataset = self._load_file(path) else: raise ValueError(f"Path {path} is not a valid file or folder.") dataset = dataset.map(self._parse_tfr_element) return dataset def _load_file(self, path) -> tf.data.TFRecordDataset: return tf.data.TFRecordDataset(path) def _load_folder(self, path) -> tf.data.TFRecordDataset: return tf.data.TFRecordDataset( glob.glob(os.path.join(path, "**/*.tfrecords*"), recursive=True) ) class VideoDataset(Dataset): @classmethod def _get_features(cls, element) -> Dict[str, tf.train.Feature]: return { "frames": _int64_feature(element.shape[0]), "height": _int64_feature(element.shape[1]), "width": _int64_feature(element.shape[2]), "depth": _int64_feature(element.shape[3]), "raw_video": _bytes_feature(serialize_array(element)), } @classmethod def _parse_tfr_element(cls, element): # use the same structure as above; it's kinda an outline of the structure we now want to create data = { "frames": tf.io.FixedLenFeature([], tf.int64), "height": tf.io.FixedLenFeature([], tf.int64), "width": tf.io.FixedLenFeature([], tf.int64), "raw_video": tf.io.FixedLenFeature([], tf.string), "depth": tf.io.FixedLenFeature([], tf.int64), } content = tf.io.parse_single_example(element, data) frames = content["frames"] height = content["height"] width = content["width"] depth = content["depth"] raw_video = content["raw_video"] # get our 'feature'-- our image -- and reshape it appropriately feature = tf.io.parse_tensor(raw_video, out_type=tf.uint8) feature = tf.reshape(feature, shape=[frames, height, width, depth]) return feature class ImageDataset(Dataset): @classmethod def _get_features(cls, element) -> Dict[str, tf.train.Feature]: return { "height": _int64_feature(element.shape[0]), "width": _int64_feature(element.shape[1]), "depth": _int64_feature(element.shape[2]), "raw_image": _bytes_feature(serialize_array(element)), } @classmethod def _parse_tfr_element(cls, element): # use the same structure as above; it's kinda an outline of the structure we now want to create data = { "height": tf.io.FixedLenFeature([], tf.int64), "width": tf.io.FixedLenFeature([], tf.int64), "raw_image": tf.io.FixedLenFeature([], tf.string), "depth": tf.io.FixedLenFeature([], tf.int64), } content = tf.io.parse_single_example(element, data) height = content["height"] width = content["width"] depth = content["depth"] raw_image = content["raw_image"] # get our 'feature'-- our image -- and reshape it appropriately feature = tf.io.parse_tensor(raw_image, out_type=tf.uint8) feature = tf.reshape(feature, shape=[height, width, depth]) return feature