10 # This script checks if the min/max values recorded in the circle model are the same with the expected values
13 # compare_tensors.py --input_h5 <path/to/iput/h5> --expect_dir <path/to/expect/dir> --mode <compare_mode>
14 # ex: compare_minmax.py --input_h5 Add_000.h5 --expect_dir expected_outputs/Add_000 --mode fake_quantization
16 parser = argparse.ArgumentParser()
17 parser.add_argument('--input_h5', type=str, required=True)
18 parser.add_argument('--expect_dir', type=str, required=True)
19 parser.add_argument('--mode', type=str, required=True)
20 args = parser.parse_args()
23 "fake_quantization", "record_minmax", "quantization", "weights_only_quantization"
27 expect_dir = args.expect_dir
32 if mode not in supported_modes:
33 raise SystemExit("Unsupported mode. --mode should be one of " + str(supported_modes))
36 def compare_fake_quantization(tensor, tensor_name, expect_dir):
38 with open(expect_dir + "/" + tensor_name + ".json", "r") as expect_file:
39 json_load = json.load(expect_file)
40 expected_weights = np.array(json_load["weights"])
41 input_weights = tensor["weights"][:]
42 if np.allclose(input_weights, expected_weights, rtol=1.e-5, atol=1.e-5) == False:
43 print("Fake-quantized weights of " + tensor_name + " (" + str(input_weights) +
44 ") do not match with expected value (" + str(expected_weights) + ").")
48 def compare_record_minmax(tensor, tensor_name, expect_dir):
50 with open(expect_dir + "/" + tensor_name + ".json", "r") as expect_file:
51 json_load = json.load(expect_file)
52 expected_min = np.array(json_load["min"])
53 expected_max = np.array(json_load["max"])
54 input_min = tensor["min"][:]
55 input_max = tensor["max"][:]
56 if np.allclose(input_min, expected_min, rtol=1.e-5, atol=1.e-5) == False:
57 print("Recorded min of " + tensor_name + " (" + str(input_min) +
58 ") does not match with expected value (" + str(expected_min) + ").")
60 if np.allclose(input_max, expected_max, rtol=1.e-5, atol=1.e-5) == False:
61 print("Recorded max of " + tensor_name + " (" + str(input_max) +
62 ") does not match with expected value (" + str(expected_max) + ").")
66 def compare_quantization(tensor, tensor_name, expect_dir):
68 with open(expect_dir + "/" + tensor_name + ".json", "r") as expect_file:
69 json_load = json.load(expect_file)
72 expected_weights = np.array(json_load["weights"])
73 input_weights = tensor["weights"][()]
75 # We use higher tolerance for int64 data (bias of int16-quantized model)
76 if tensor["weights"].dtype == 'int64':
80 input_weights, expected_weights, rtol=0, atol=abs_tolerance) == False:
81 print("Quantized weights of " + tensor_name + " (" + str(input_weights) +
82 ") do not match with expected value (" + str(expected_weights) +
87 expected_scale = np.array(json_load["scale"])
88 input_scale = tensor["scale"][:]
89 if np.allclose(input_scale, expected_scale, rtol=1.e-5, atol=1.e-5) == False:
90 print("Quantized scale of " + tensor_name + " (" + str(input_scale) +
91 ") do not match with expected value (" + str(expected_scale) + ").")
94 if key == "zero_point":
95 expected_zero_point = np.array(json_load["zero_point"])
96 input_zero_point = tensor["zero_point"][:]
98 input_zero_point, expected_zero_point, rtol=0, atol=1) == False:
99 print("Quantized zero_point of " + tensor_name + " (" +
100 str(input_zero_point) + ") do not match with expected value (" +
101 str(expected_zero_point) + ").")
105 with h5.File(model, "r") as input:
106 for tensor_name in input.keys():
107 # We only check the given golden data
108 if os.path.isfile(expect_dir + "/" + tensor_name + ".json"):
109 print("Compare " + tensor_name)
110 if mode == "fake_quantization":
111 compare_fake_quantization(input[tensor_name], tensor_name, expect_dir)
112 elif mode == "record_minmax":
113 compare_record_minmax(input[tensor_name], tensor_name, expect_dir)
114 elif mode == "quantization":
115 compare_quantization(input[tensor_name], tensor_name, expect_dir)
116 elif mode == "weights_only_quantization":
117 # Assume weights have name "ker"
118 if tensor_name == "ker":
119 compare_quantization(input[tensor_name], tensor_name, expect_dir)
121 raise SystemExit("Unsupproted mode.")
123 sys.exit(failed_cases)