# dump input and output in h5
import h5py
supported_dtypes = ("float32", "uint8")
+ h5dtypes = {"float32": ">f4", "uint8": "u1"}
with h5py.File(out_dir + "input.h5", 'w') as hf:
name_grp = hf.create_group("name")
val_grp = hf.create_group("value")
for idx, t in enumerate(input_names):
dtype = tf.compat.v1.as_dtype(input_dtypes[idx])
- if not dtype in supported_dtypes:
+ if not dtype.name in supported_dtypes:
print("ERR: only float32 and uint8 is supported.")
sys.exit(-1)
- val_grp.create_dataset(str(idx), data=input_values[idx], dtype='>f4')
+ val_grp.create_dataset(
+ str(idx), data=input_values[idx], dtype=h5dtypes[dtype.name])
name_grp.attrs[str(idx)] = input_names[idx]
with h5py.File(out_dir + "expected.h5", 'w') as hf:
val_grp = hf.create_group("value")
for idx, t in enumerate(output_names):
dtype = tf.compat.v1.as_dtype(output_dtypes[idx])
- if not dtype in supported_dtypes:
+ if not dtype.name in supported_dtypes:
print("ERR: only float32 and uint8 is supported.")
sys.exit(-1)
- val_grp.create_dataset(str(idx), data=output_values[idx], dtype='>f4')
+ val_grp.create_dataset(
+ str(idx), data=output_values[idx], dtype=h5dtypes[dtype.name])
name_grp.attrs[str(idx)] = output_names[idx]