9863c807aa0fd9396fd58a0570598f2af74f3b43
[platform/core/ml/nnfw.git] / compiler / pota-quantization-value-test / gen_h5_explicit_inputs.py
1 #!/usr/bin/env python3
2 import h5py as h5
3 import numpy as np
4 import tensorflow as tf
5 import argparse
6 import glob
7
8 #
9 # This script generates a pack of random input data (.h5) expected by the input tflite model
10 #
11 # Basic usage:
12 #   gen_h5_explicit_inputs.py --model <path/to/model/file> --input <path/to/input/directory> --output <path/to/output/file>
13 #   ex: gen_h5_explicit_inputs.py --model Add_000.tflite --input Add_000 --output Add_000.input.h5
14 #   (This will create Add_000.input.h5)
15 #
16 # The input directory should be organized as follows
17 # <input_directory>/
18 #   -> <record_index>.txt
19 #     ...
20 # Each txt file has the explicit values of inputs
21 # Example. if the model has two inputs whose shapes are both (1, 3),
22 # the first record file name is 0.txt, and its contents is something like below
23 # 1, 2, 3
24 # 4, 5, 6
25 #
26 parser = argparse.ArgumentParser()
27 parser.add_argument('--model', type=str, required=True)
28 parser.add_argument('--input', type=str, required=True)
29 parser.add_argument('--output', type=str, required=True)
30 args = parser.parse_args()
31
32 model = args.model
33 input = args.input
34 output = args.output
35
36 # Build TFLite interpreter. (to get the information of model input)
37 interpreter = tf.lite.Interpreter(model)
38 input_details = interpreter.get_input_details()
39
40 # Create h5 file
41 h5_file = h5.File(output, 'w')
42 group = h5_file.create_group("value")
43 group.attrs['desc'] = "Input data for " + model
44
45 # Input files
46 records = sorted(glob.glob(input + "/*.txt"))
47 for i, record in enumerate(records):
48     sample = group.create_group(str(i))
49     sample.attrs['desc'] = "Input data " + str(i)
50     with open(record, 'r') as f:
51         lines = f.readlines()
52         for j, line in enumerate(lines):
53             data = np.array(line.split(','))
54             input_detail = input_details[j]
55             input_data = np.array(
56                 data.reshape(input_detail["shape"]), input_detail["dtype"])
57             sample.create_dataset(str(j), data=input_data)
58
59 h5_file.close()