File size: 471 Bytes
3be620b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import numpy as np
from tqdm.auto import tqdm
import tensorflow as tf
import tensorflow_datasets as tfds


def dataset_statistics(ds):
    if isinstance(ds, tf.data.Dataset):
        ds_numpy = tfds.as_numpy(ds)
    elif isinstance(ds, tf.keras.utils.Sequence):
        ds_numpy = ds
    data = []

    for da in tqdm(ds_numpy):
        X, y = da
        data.append(X)
    all_data = np.concatenate(data)
    return np.mean(all_data), np.var(all_data), np.std(all_data)