[gen_golden] Use u1 dtype for loading/dumping uint8 type (#9374)
author이상규/On-Device Lab(SR)/Principal Engineer/삼성전자 <sg5.lee@samsung.com>
Wed, 4 Dec 2019 07:49:23 +0000 (16:49 +0900)
committer이춘석/On-Device Lab(SR)/Staff Engineer/삼성전자 <chunseok.lee@samsung.com>
Wed, 4 Dec 2019 07:49:23 +0000 (16:49 +0900)
Previously it used float type for uint8 case. It fixes the bug.

Signed-off-by: Sanggyu Lee <sg5.lee@samsung.com>
tools/nnpackage_tool/gen_golden/gen_golden.py

index dcd7245..b42fc7c 100755 (executable)
@@ -139,15 +139,17 @@ if __name__ == '__main__':
     # dump input and output in h5
     import h5py
     supported_dtypes = ("float32", "uint8")
+    h5dtypes = {"float32": ">f4", "uint8": "u1"}
     with h5py.File(out_dir + "input.h5", 'w') as hf:
         name_grp = hf.create_group("name")
         val_grp = hf.create_group("value")
         for idx, t in enumerate(input_names):
             dtype = tf.compat.v1.as_dtype(input_dtypes[idx])
-            if not dtype in supported_dtypes:
+            if not dtype.name in supported_dtypes:
                 print("ERR: only float32 and uint8 is supported.")
                 sys.exit(-1)
-            val_grp.create_dataset(str(idx), data=input_values[idx], dtype='>f4')
+            val_grp.create_dataset(
+                str(idx), data=input_values[idx], dtype=h5dtypes[dtype.name])
             name_grp.attrs[str(idx)] = input_names[idx]
 
     with h5py.File(out_dir + "expected.h5", 'w') as hf:
@@ -155,8 +157,9 @@ if __name__ == '__main__':
         val_grp = hf.create_group("value")
         for idx, t in enumerate(output_names):
             dtype = tf.compat.v1.as_dtype(output_dtypes[idx])
-            if not dtype in supported_dtypes:
+            if not dtype.name in supported_dtypes:
                 print("ERR: only float32 and uint8 is supported.")
                 sys.exit(-1)
-            val_grp.create_dataset(str(idx), data=output_values[idx], dtype='>f4')
+            val_grp.create_dataset(
+                str(idx), data=output_values[idx], dtype=h5dtypes[dtype.name])
             name_grp.attrs[str(idx)] = output_names[idx]