[x86 schedule] Fallback schedule for Int8 depthwise. (#4733)
authorAnimesh Jain <anijain@umich.edu>
Fri, 17 Jan 2020 18:21:45 +0000 (10:21 -0800)
committerYizhi Liu <liuyizhi@apache.org>
Fri, 17 Jan 2020 18:21:45 +0000 (10:21 -0800)
tests/python/relay/test_op_level2.py
topi/python/topi/x86/conv2d_int8.py

index a098b5c..27eab0e 100644 (file)
@@ -1182,6 +1182,35 @@ def test_conv2d_int8_intrinsics():
     assert "vpmulld" in asm and "vpadd" in asm
 
 
+def test_depthwise_conv2d_int8():
+    input_dtype = 'uint8'
+    weight_dtype = 'int8'
+    output_dtype = 'int32'
+
+    data_shape = (1, 64, 56, 56)
+    x = relay.var("x", relay.TensorType(data_shape, input_dtype))
+
+    kernel_shape = (64, 1, 3, 3)
+    weight = relay.var("weight", relay.TensorType(kernel_shape, weight_dtype))
+
+    y = relay.nn.conv2d(x, weight,
+                        kernel_size=(3, 3),
+                        groups=64,
+                        padding=(1, 1),
+                        dilation=(1, 1),
+                        out_dtype=output_dtype)
+    func = relay.Function([x, weight], y)
+    wdata = np.random.rand(*kernel_shape) * 10
+    parameters = {"weight": tvm.nd.array(wdata.astype(weight_dtype))}
+
+    targets = ["llvm -mcpu=skylake-avx512", "llvm -mcpu=cascadelake"]
+    llvm_version = tvm.codegen.llvm_version_major()
+    for target in targets:
+        if llvm_version >= 8:
+            with relay.build_config(opt_level=3):
+                graph, lib, params = relay.build(func, target, params=parameters)
+
+
 def test_bitserial_conv2d_infer_type():
     # Basic shape test with ambiguous batch.
     n, c, h, w = tvm.size_var("n"), 32, 224, 224
@@ -1234,3 +1263,4 @@ if __name__ == "__main__":
     test_upsampling()
     test_upsampling3d()
     test_conv2d_int8_intrinsics()
+    test_depthwise_conv2d_int8()
index cb23eec..79527a7 100644 (file)
@@ -28,6 +28,7 @@ from ..generic import conv2d as conv2d_generic
 from ..nn.util import get_pad_tuple
 from ..util import get_const_tuple
 from ..nn.conv2d import conv2d_NCHWc_int8
+from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload
 from .. import nn
 from . import conv2d_avx_1x1, conv2d_avx_common
 
@@ -36,15 +37,20 @@ def _get_default_config_int8(cfg, data, kernel, strides, padding, out_dtype, is_
     """
     Get default schedule config for the workload
     """
-    assert not is_depthwise, "Depthwise Int8 not supported"
-    wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype, layout)
-    is_kernel_1x1 = wkl.hkernel == 1 and wkl.wkernel == 1
-    if is_kernel_1x1:
-        conv2d_generic.fallback_schedule_cpu_1x1_int8(
-            cfg, wkl, int32_lanes=16, num_int8_elements=4)
+    if is_depthwise:
+        # Fallback to FP32 default config until a VNNI schedule is defined.
+        wkl = _get_depthwise_conv2d_workload(data, kernel, strides, padding, out_dtype)
+        from .depthwise_conv2d import _fallback_schedule
+        _fallback_schedule(cfg, wkl)
     else:
-        conv2d_generic.fallback_schedule_cpu_common_int8(
-            cfg, wkl, int32_lanes=16, num_int8_elements=4)
+        wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype, layout)
+        is_kernel_1x1 = wkl.hkernel == 1 and wkl.wkernel == 1
+        if is_kernel_1x1:
+            conv2d_generic.fallback_schedule_cpu_1x1_int8(
+                cfg, wkl, int32_lanes=16, num_int8_elements=4)
+        else:
+            conv2d_generic.fallback_schedule_cpu_common_int8(
+                cfg, wkl, int32_lanes=16, num_int8_elements=4)
 
 
 def _is_int8_hw_support(data_dtype, kernel_dtype):