qnn_output = get_output(qnn_func, golden_inputs)
np.testing.assert_equal(qnn_output, golden_output)
-def no_zero_point_test():
+def test_no_zero_point():
# uint8 input
data_shape = (2, 1, 2, 4)
data_dtype = 'uint8'
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)
-def kernel_zero_point_test():
+def test_kernel_zero_point():
# uint8 input
data_shape = (2, 4, 2, 4)
data_dtype = 'uint8'
kernel_shape, kernel_dtype)
-def input_zero_point_test():
+def test_input_zero_point():
# uint8 input
data_shape = (2, 4, 2, 4)
data_dtype = 'uint8'
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)
-def both_zero_point_test():
+def test_both_zero_point():
# uint8 input
data_shape = (2, 4, 2, 4)
data_dtype = 'uint8'
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)
-def layout_test():
+def test_layout():
# uint8 input
data_shape = (2, 2, 4, 4) # NHWC
data_dtype = 'uint8'
-def padding_test():
+def test_padding():
# uint8 input
data_shape = (1, 4, 2, 2)
data_dtype = 'uint8'
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)
-def dilation_test():
+def test_dilation():
# uint8 input
data_shape = (2, 4, 4, 4)
data_dtype = 'uint8'
kernel_shape, kernel_dtype)
-def const_folding_test():
+def test_const_folding():
data_shape = (2, 4, 2, 4)
data_dtype = 'uint8'
kernel_shape = (3, 4, 2, 2)
folded_func = folded_mod["main"]
assert "reshape" not in folded_func.astext()
-def kernel_size_1x1_test():
+def test_kernel_size_1x1():
# uint8 input
data_shape = (2, 4, 2, 4)
data_dtype = 'uint8'
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)
-def tflite_large_irregular_test():
+def test_tflite_large_irregular():
# uint8 input
data_shape = (1, 1024, 1, 1)
data_dtype = 'uint8'
golden_output = np.array((124, -92, 164, -132)).reshape(1, 1, 2, 2)
np.testing.assert_equal(qnn_output, golden_output)
-def broadcast_layout_test():
+def test_broadcast_layout():
# Test broadcast support for NHWC layout.
data_shape = (1, 229, 229, 3) # NHWC
data_dtype = 'uint8'
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(mod, "llvm -mcpu=skylake-avx512")
+
+def test_conv2d_int8():
+ target = "llvm -mcpu=core-avx2"
+ if not tvm.module.enabled(target):
+ print("skip because %s is not enabled..." % target)
+ return
+
+ data = relay.var("data", shape=(1, 28, 28, 128), dtype='uint8')
+ kernel = relay.var("w", shape=(3, 3, 128, 256), dtype='int8')
+ conv = relay.nn.conv2d(
+ data,
+ kernel,
+ kernel_size=(3, 3),
+ out_dtype='int32',
+ data_layout='NHWC',
+ kernel_layout='HWIO')
+ func = relay.Function([data, kernel], conv)
+
+ with relay.build_config(opt_level=0):
+ params = {"w": np.zeros((3, 3, 128, 256)).astype("int8")}
+ # -mcpu should be specified to avoid the llvm jitting error here:
+ # https://discuss.tvm.ai/t/segfault-in-llvm/3567
+ # To use VNNI, we need to specify the micro-architecture that supports
+ # it, e.g. cascadelake.
+ graph, lib, params = relay.build(func, target, params=params)
+ mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
+ mod.set_input("data", np.zeros((1, 28, 28, 128)).astype("uint8"))
+ mod.set_input(**params)
+ mod.run()
+ qnn_output = mod.get_output(0).asnumpy()
+ golden_output = np.zeros((1, 26, 26, 256)).astype("int32")
+ np.testing.assert_equal(qnn_output, golden_output)
+
+
if __name__ == "__main__":
- no_zero_point_test()
- input_zero_point_test()
- kernel_zero_point_test()
- both_zero_point_test()
- layout_test()
- padding_test()
- dilation_test()
- const_folding_test()
- kernel_size_1x1_test()
- tflite_large_irregular_test()
- tflite_output_multiplier_greater_than_one()
- tflite_anistropic_strides()
- broadcast_layout_test()
+ test_no_zero_point()
+ test_input_zero_point()
+ test_kernel_zero_point()
+ test_both_zero_point()
+ test_layout()
+ test_padding()
+ test_dilation()
+ test_const_folding()
+ test_kernel_size_1x1g()
+ test_tflite_large_irregularg()
+ test_tflite_output_multiplier_greater_than_one()
+ test_tflite_anistropic_strides()
+ test_broadcast_layoutg()
+ test_conv2d_int8()