Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / tools / generate_datafile / tf_dataset_converter / datasets.py
1 '''Deal with the tensorflow dataset.'''
2
3 import tensorflow as tf
4 import tensorflow_datasets as tfds
5 from pathlib import Path
6
7 dataset_root_dir = Path(__file__).parent.absolute() / 'data'
8
9
10 class DatasetLoader():
11     '''
12     Loader of tensorflow datasets
13     '''
14
15     def load(self, dataset_name):
16         (ds_train, ds_test), ds_info = tfds.load(
17             dataset_name,
18             split=['train', 'test'],
19             data_dir=dataset_root_dir,
20             shuffle_files=True,
21             as_supervised=True,
22             with_info=True,
23         )
24
25         self.ds_info = ds_info
26
27         def _normalize_img(image, label):
28             """Normalizes images: `uint8` -> `float32`."""
29             return tf.cast(image, tf.float32) / 255., label
30
31         self.ds_train = ds_train.map(_normalize_img)
32         self.ds_test = ds_test.map(_normalize_img)
33
34         for images, labels in self.ds_train:
35             print(f'Shape of images : {images.shape}')
36             print(f'Shape of labels: {labels.shape} {labels.dtype}')
37             break
38
39     def get_dataset_names(self):
40         return tfds.list_builders()
41
42     def class_names(self):
43         '''
44         Get class names
45         '''
46         return self.ds_info.features['label'].names
47
48     def num_classes(self):
49         '''
50         Get the number of classes
51         '''
52         return self.ds_info.features['label'].num_classes
53
54     def get_num_train_examples(self):
55         '''
56         Get examples for training
57         '''
58         return self.ds_info.splits['train'].num_examples
59
60     def get_num_test_examples(self):
61         '''
62         Get examples for testing
63         '''
64         return self.ds_info.splits['test'].num_examples
65
66     def prefetched_datasets(self):
67         '''
68         get prefetched datasets for traning.
69
70         Return:
71            Datasets for training and testing.
72         '''
73
74         train_dataset = self.ds_train.cache()
75         train_dataset = train_dataset.shuffle(self.ds_info.splits['train'].num_examples)
76
77         test_dataset = self.ds_train.cache()
78
79         # return train_dataset, test_dataset
80         return self.ds_train.cache(), self.ds_test.cache()