4 import tensorflow as tf
9 # This script generates a pack of random input data (.h5) expected by the input tflite model
12 # gen_h5_explicit_inputs.py --model <path/to/model/file> --input <path/to/input/directory> --output <path/to/output/file>
13 # ex: gen_h5_explicit_inputs.py --model Add_000.tflite --input Add_000 --output Add_000.input.h5
14 # (This will create Add_000.input.h5)
16 # The input directory should be organized as follows
18 # -> <record_index>.txt
20 # Each txt file has the explicit values of inputs
21 # Example. if the model has two inputs whose shapes are both (1, 3),
22 # the first record file name is 0.txt, and its contents is something like below
26 parser = argparse.ArgumentParser()
27 parser.add_argument('--model', type=str, required=True)
28 parser.add_argument('--input', type=str, required=True)
29 parser.add_argument('--output', type=str, required=True)
30 args = parser.parse_args()
36 # Build TFLite interpreter. (to get the information of model input)
37 interpreter = tf.lite.Interpreter(model)
38 input_details = interpreter.get_input_details()
41 h5_file = h5.File(output, 'w')
42 group = h5_file.create_group("value")
43 group.attrs['desc'] = "Input data for " + model
46 records = sorted(glob.glob(input + "/*.txt"))
47 for i, record in enumerate(records):
48 sample = group.create_group(str(i))
49 sample.attrs['desc'] = "Input data " + str(i)
50 with open(record, 'r') as f:
52 for j, line in enumerate(lines):
53 data = np.array(line.split(','))
54 input_detail = input_details[j]
55 input_data = np.array(
56 data.reshape(input_detail["shape"]), input_detail["dtype"])
57 sample.create_dataset(str(j), data=input_data)