1 '''Deal with the tensorflow dataset.'''
3 import tensorflow as tf
4 import tensorflow_datasets as tfds
5 from pathlib import Path
7 dataset_root_dir = Path(__file__).parent.absolute() / 'data'
10 class DatasetLoader():
12 Loader of tensorflow datasets
15 def load(self, dataset_name):
16 (ds_train, ds_test), ds_info = tfds.load(
18 split=['train', 'test'],
19 data_dir=dataset_root_dir,
25 self.ds_info = ds_info
27 def _normalize_img(image, label):
28 """Normalizes images: `uint8` -> `float32`."""
29 return tf.cast(image, tf.float32) / 255., label
31 self.ds_train = ds_train.map(_normalize_img)
32 self.ds_test = ds_test.map(_normalize_img)
34 for images, labels in self.ds_train:
35 print(f'Shape of images : {images.shape}')
36 print(f'Shape of labels: {labels.shape} {labels.dtype}')
39 def get_dataset_names(self):
40 return tfds.list_builders()
42 def class_names(self):
46 return self.ds_info.features['label'].names
48 def num_classes(self):
50 Get the number of classes
52 return self.ds_info.features['label'].num_classes
54 def get_num_train_examples(self):
56 Get examples for training
58 return self.ds_info.splits['train'].num_examples
60 def get_num_test_examples(self):
62 Get examples for testing
64 return self.ds_info.splits['test'].num_examples
66 def prefetched_datasets(self):
68 get prefetched datasets for traning.
71 Datasets for training and testing.
74 train_dataset = self.ds_train.cache()
75 train_dataset = train_dataset.shuffle(self.ds_info.splits['train'].num_examples)
77 test_dataset = self.ds_train.cache()
79 # return train_dataset, test_dataset
80 return self.ds_train.cache(), self.ds_test.cache()