From 0e76a1aeecf1638abe479842b1d1b27f81d971f7 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EC=9D=B4=EC=83=81=EA=B7=9C/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Wed, 4 Dec 2019 16:49:23 +0900 Subject: [PATCH] [gen_golden] Use u1 dtype for loading/dumping uint8 type (#9374) Previously it used float type for uint8 case. It fixes the bug. Signed-off-by: Sanggyu Lee --- tools/nnpackage_tool/gen_golden/gen_golden.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tools/nnpackage_tool/gen_golden/gen_golden.py b/tools/nnpackage_tool/gen_golden/gen_golden.py index dcd7245..b42fc7c 100755 --- a/tools/nnpackage_tool/gen_golden/gen_golden.py +++ b/tools/nnpackage_tool/gen_golden/gen_golden.py @@ -139,15 +139,17 @@ if __name__ == '__main__': # dump input and output in h5 import h5py supported_dtypes = ("float32", "uint8") + h5dtypes = {"float32": ">f4", "uint8": "u1"} with h5py.File(out_dir + "input.h5", 'w') as hf: name_grp = hf.create_group("name") val_grp = hf.create_group("value") for idx, t in enumerate(input_names): dtype = tf.compat.v1.as_dtype(input_dtypes[idx]) - if not dtype in supported_dtypes: + if not dtype.name in supported_dtypes: print("ERR: only float32 and uint8 is supported.") sys.exit(-1) - val_grp.create_dataset(str(idx), data=input_values[idx], dtype='>f4') + val_grp.create_dataset( + str(idx), data=input_values[idx], dtype=h5dtypes[dtype.name]) name_grp.attrs[str(idx)] = input_names[idx] with h5py.File(out_dir + "expected.h5", 'w') as hf: @@ -155,8 +157,9 @@ if __name__ == '__main__': val_grp = hf.create_group("value") for idx, t in enumerate(output_names): dtype = tf.compat.v1.as_dtype(output_dtypes[idx]) - if not dtype in supported_dtypes: + if not dtype.name in supported_dtypes: print("ERR: only float32 and uint8 is supported.") sys.exit(-1) - val_grp.create_dataset(str(idx), data=output_values[idx], dtype='>f4') + val_grp.create_dataset( + str(idx), data=output_values[idx], dtype=h5dtypes[dtype.name]) name_grp.attrs[str(idx)] = output_names[idx] -- 2.7.4