Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / tools / generate_datafile / tf_dataset_converter / main.py
1 ################################################################################
2 # Parse arguments
3 ################################################################################
4
5 from argparser import parse_args
6
7 # You can see arguments' information in argparser.py
8 args = parse_args()
9
10 ################################################################################
11 # Load a dataset of tensorflow
12 ################################################################################
13
14 # Disable tensorflow cpp warning log
15 import os
16
17 FILTERING_WARNING = '2'
18 os.environ['TF_CPP_MIN_LOG_LEVEL'] = FILTERING_WARNING
19
20 from datasets import DatasetLoader
21 from pathlib import Path
22 import tensorflow as tf
23 import numpy as np
24
25 ds_loader = DatasetLoader()
26
27 if args.show_datasets:
28     print('Dataset list :')
29     names = ',\n'.join(ds_loader.get_dataset_names())
30     print(f'[{names}]')
31     exit(0)
32
33 ds_loader.load(args.dataset_name)
34 ds_train, ds_test = ds_loader.prefetched_datasets()
35 nums_train_ds = ds_loader.get_num_train_examples()
36 nums_test_ds = ds_loader.get_num_test_examples()
37 print(f'class names       : {ds_loader.class_names()}')
38 print(f'train dataset len : {nums_train_ds}')
39 print(f'test dataset len  : {nums_test_ds}')
40
41 ################################################################################
42 # Convert tensorlfow dataset to onert format
43 ################################################################################
44 Path(f'{args.out_dir}').mkdir(parents=True, exist_ok=True)
45 prefix_name = f'{args.out_dir}/{args.prefix_name}'
46 if args.prefix_name != '':
47     prefix_name += '.'
48
49 nums_train = args.train_length
50 if (nums_train > nums_train_ds):
51     print(
52         f'Oops! The number of data for training in the dataset is less than {nums_train}')
53     exit(1)
54
55 nums_test = args.test_length
56 if (nums_test > nums_test_ds):
57     print(f'Oops! The number of data for test in the dataset is less than {nums_test}')
58     exit(1)
59
60
61 def _only_image(image, _):
62     return image
63
64
65 def _only_label(_, label):
66     return label
67
68
69 def _label_to_array(label):
70     arr = np.zeros(ds_loader.num_classes(), dtype=float)
71     arr[label] = 1.
72     tensor = tf.convert_to_tensor(arr, tf.float32)
73     return tensor
74
75
76 file_path_list = [
77     f'{prefix_name}train.input.{nums_train}.bin',
78     f'{prefix_name}test.input.{nums_test}.bin',
79     f'{prefix_name}train.output.{nums_train}.bin',
80     f'{prefix_name}test.output.{nums_test}.bin'
81 ]
82
83 ds_list = [
84     ds_train.take(nums_train).map(_only_image),
85     ds_test.take(nums_test).map(_only_image),
86     [_label_to_array(label) for label in ds_train.take(nums_train).map(_only_label)],
87     [_label_to_array(label) for label in ds_test.take(nums_test).map(_only_label)]
88 ]
89
90 for i in range(4):
91     file_path = file_path_list[i]
92     with open(file_path, 'wb') as f:
93         ds = ds_list[i]
94         for tensor in ds:
95             f.write(tensor.numpy().tobytes())
96         f.close()
97
98 print('The data files are created!')