Improve NHWC depthwise convolution for AArch64 (#6095)
authorGiuseppe Rossini <giuseppe.rossini@arm.com>
Thu, 13 Aug 2020 14:41:45 +0000 (15:41 +0100)
committerGitHub <noreply@github.com>
Thu, 13 Aug 2020 14:41:45 +0000 (22:41 +0800)
* Improve NHWC depthwise convolution for aarch64

We created a default schedule (no auto-tuning or tensorization) named
depthwise_conv2d_nhwc which does a decent job at optimizing depthwise
for NHWC layouts (on aarch64).

Change-Id: I01e32903f6c1950623f33eae18484e70244fe0af

* Add tuning knobs in depthwise schedule

Change-Id: I15080e7f12b16e6c6aba99a04e42023845eeabf1

* Introduce padding policy

Change-Id: If12a6d05dce9153861550ddef1ee5216809dd1e1

* Vectorize padding

Change-Id: I7e2062a40358bf111c0366a449945eb077fb2e30

* Legalize depthwise convolution (2x improvement) and fix tuning issue

Change-Id: I4b82c58b167e40b0b7747d28293bbb488c505dd9

* Adding assert on padding

Change-Id: Idf8eeaaface5eb7799109cd00f437e404778b9cd

* Fix python linting

Change-Id: Iac16a8daea1268f0eb331fe4ec18a62408106cf9

* Removing commented code

Change-Id: I1412f22ad9864273d77a7bf38a6768694339b7f0

* Revert test file to make CI pass

Change-Id: Ica3eff8f9f0fd4c6f32f7ae80adc922f8b16cec9

* Enabling only arm_cpu tests

Change-Id: Icbaafcb39e892a5d1a4685133c1699e4d1a8e07e

* Rebasing

Change-Id: Ibb23f1d4e0d0107e4e3b3571437161cdc2ee2909

python/tvm/relay/op/strategy/arm_cpu.py
python/tvm/relay/qnn/op/legalizations.py
python/tvm/topi/arm_cpu/depthwise_conv2d.py
tests/python/topi/python/test_topi_depthwise_conv2d.py

index 8143cc5..0c4edbb 100644 (file)
@@ -167,11 +167,10 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
                     name="depthwise_conv2d_nchw.x86")
         elif layout == "NHWC":
             assert kernel_layout == "HWOI"
-            logger.warning("depthwise_conv2d with layout NHWC is not optimized for arm cpu.")
             strategy.add_implementation(
-                wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc),
-                wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nhwc),
-                name="depthwise_conv2d_nhwc.generic")
+                wrap_compute_conv2d(topi.arm_cpu.compute_depthwise_conv2d_nhwc),
+                wrap_topi_schedule(topi.arm_cpu.schedule_depthwise_conv2d_nhwc),
+                name="depthwise_conv2d_nhwc.arm_cpu")
         else:
             raise RuntimeError("Unsupported depthwise_conv2d layout {} for arm cpu".
                                format(layout))
index af5072e..62bee30 100644 (file)
@@ -248,7 +248,13 @@ def is_aarch64_arm():
 @qnn_conv2d_legalize.register('arm_cpu')
 def _qnn_conv2d_legalize_arm_cpu(attrs, inputs, types):
     # ARM prefers the dtypes to be same.
-    if (is_aarch64_arm() and attrs["data_layout"] == "NHWC") or is_fast_int8_on_arm():
+    is_depthwise = relay.op.strategy.is_depthwise_conv2d(types[0].shape,
+                                                         attrs['data_layout'],
+                                                         types[1].shape,
+                                                         attrs['kernel_layout'],
+                                                         attrs['groups'])
+    use_int8_on_arm = (not is_depthwise) and is_aarch64_arm() and attrs["data_layout"] == "NHWC"
+    if use_int8_on_arm or is_fast_int8_on_arm():
         return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d)
     return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.conv2d)
 
index 802b3df..07749ee 100644 (file)
@@ -20,6 +20,7 @@
 import tvm
 from tvm import te
 from tvm import autotvm
+from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
 
 from .. import nn
 from ..util import traverse_inline, get_const_tuple, get_const_int
@@ -31,7 +32,6 @@ def depthwise_conv2d_nchw(_, data, kernel, strides, padding, dilation, out_dtype
     """Compute depthwise_conv2d with NCHW layout"""
     return nn.depthwise_conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype)
 
-
 @autotvm.register_topi_schedule("depthwise_conv2d_nchw.arm_cpu")
 def schedule_depthwise_conv2d_nchw(cfg, outs):
     """Schedule depthwise conv2d
@@ -181,6 +181,171 @@ def depthwise_conv2d_nchw_spatial_pack(cfg, data, kernel, strides, padding, dila
 
     return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile=2)
 
+@autotvm.register_topi_compute("depthwise_conv2d_nhwc.arm_cpu")
+def compute_depthwise_conv2d_nhwc(_, data, kernel, strides, padding, dilation, out_dtype):
+    """TOPI compute callback for depthwise_conv2d nhwc
+
+    Parameters
+    ----------
+    cfg: ConfigEntity
+        The config for this template
+
+    data : tvm.te.Tensor
+        4-D with shape [batch, in_height, in_width, in_channel]
+
+    kernel : tvm.te.Tensor
+        4-D with shape [filter_height, filter_width, in_channel, channel_multiplier]
+
+    strides : list of two ints
+        [stride_height, stride_width]
+
+    padding : list of two ints
+        [pad_height, pad_width]
+
+    dilation : list of two ints
+        [dilation_height, dilation_width]
+
+    out_dtype: str
+        The output type. This is used for mixed precision.
+
+    Returns
+    -------
+    output : tvm.te.Tensor
+        4-D with shape [batch, out_height, out_width, out_channel]
+    """
+
+    out_dtype = out_dtype or data.dtype
+
+    N, IH, IW, IC = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    KH, KW, IC, channel_multiplier = get_const_tuple(kernel.shape)
+
+    dilated_kernel_h = (KH - 1) * dilation_h + 1
+    dilated_kernel_w = (KW - 1) * dilation_w + 1
+
+    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
+        padding, (dilated_kernel_h, dilated_kernel_w))
+    HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
+
+    OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
+    OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
+
+    if pad_top or pad_left or pad_down or pad_right:
+        data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0],
+                          name="data_pad")
+    else:
+        data_pad = data
+
+    output_shape = (N, OH, OW, IC*channel_multiplier)
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+
+    reduce_h = te.reduce_axis((0, KH), name='reduce_h')
+    reduce_w = te.reduce_axis((0, KW), name='reduce_w')
+
+    out = te.compute(output_shape, lambda n, h, w, c:
+                     te.sum(data_pad[n,
+                                     HSTR*h+dilation_h*reduce_h,
+                                     w*WSTR+reduce_w*dilation_w,
+                                     idxdiv(c, channel_multiplier)].astype(out_dtype) *
+                            kernel[reduce_h,
+                                   reduce_w,
+                                   idxdiv(c, channel_multiplier),
+                                   idxmod(c, channel_multiplier)].astype(out_dtype),
+                            axis=[reduce_h, reduce_w]),
+                     name='depthwise_conv2d_nhwc_output')
+    return out
+
+@autotvm.register_topi_schedule("depthwise_conv2d_nhwc.arm_cpu")
+def schedule_depthwise_conv2d_nhwc(cfg, outs):
+    """Create the schedule for depthwise_conv2d_nchw_spatial_pack"""
+    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+    s = te.create_schedule([x.op for x in outs])
+    out = outs[0]
+
+    ##### space definition begin #####
+    n, h, w, c = s[out].op.axis
+    cfg.define_split('tile_c', c, num_outputs=2)
+    _, hi = cfg.define_split('tile_h', h, num_outputs=2)
+    _, wi = cfg.define_split('tile_w', w, num_outputs=2)
+    cfg.define_knob('locate_output', [0, 1])
+
+    # fallback support
+    if cfg.is_fallback:
+        cfg['tile_c'] = SplitEntity([-1, 8])
+        cfg['tile_h'] = SplitEntity([-1, 2])
+        cfg['tile_w'] = SplitEntity([-1, 2])
+        cfg['locate_output'] = OtherOptionEntity(1)
+    ##### space definition end #####
+
+    def schedule_conv(conv):
+        conv_data = conv.op.input_tensors[0]
+
+        n, w, h, c = conv.op.axis
+        r_h, r_w = conv.op.reduce_axis
+        ho, hi = cfg['tile_h'].apply(s, conv, h)
+        wo, wi = cfg['tile_w'].apply(s, conv, w)
+        co, ci = cfg['tile_c'].apply(s, conv, c)
+
+        if conv_data.name == "data_pad":
+            assert isinstance(conv_data.op, tvm.te.ComputeOp)
+            # Define a policy for padding computation
+            cfg.define_knob('data_pad_inline', [1, 2, 3])
+            if cfg.is_fallback:
+                cfg['data_pad_inline'] = OtherOptionEntity(3)
+            if cfg['data_pad_inline'].val == 1:
+                s[conv_data].vectorize(list(s[conv_data].op.axis)[-1])
+                s[conv_data].compute_at(s[conv], ho)
+            if cfg['data_pad_inline'].val == 2:
+                s[conv_data].vectorize(list(s[conv_data].op.axis)[-1])
+                s[conv_data].compute_at(s[conv], wo)
+            if cfg['data_pad_inline'].val == 3:
+                s[conv_data].compute_inline()
+
+        s[conv].reorder(n, ho, wo, co, hi, wi, r_h, r_w, ci)
+        fused_n_ho = s[conv].fuse(n, ho)
+        s[conv].vectorize(ci)
+        return fused_n_ho
+
+    def schedule_conv_out(out):
+        n, h, w, c = out.op.axis
+        co, ci = cfg['tile_c'].apply(s, out, c)
+        wo, wi = cfg['tile_w'].apply(s, out, w)
+        ho, hi = cfg['tile_h'].apply(s, out, h)
+        s[out].reorder(n, ho, wo, co, hi, wi)
+
+        if out.dtype in ['int8', 'uint8']:
+            # In case of quantized convolution further split the channel in batches of 4 elements
+            # so that we can use arm intrinsics to run fixed_point_multiplication
+            ci_outer, ci_inner = s[out].split(ci, 4)
+            s[out].vectorize(ci_inner)
+
+        fused_n_ho = s[out].fuse(n, ho)
+        return hi, wi, fused_n_ho
+
+    def _callback(op):
+        if op.name == 'depthwise_conv2d_nhwc_output':
+            conv = op.output(0)
+            if conv != out:
+                hi, wi, p_axis = schedule_conv_out(out)
+                schedule_conv(conv)
+                if cfg['locate_output'].val == 0:
+                    s[conv].compute_at(s[out], hi)
+                if cfg['locate_output'].val == 1:
+                    s[conv].compute_at(s[out], wi)
+            else:
+                p_axis = schedule_conv(out)
+
+            s[out].parallel(p_axis)
+
+    traverse_inline(s, outs[0].op, _callback)
+    return s
 
 @autotvm.register_topi_schedule("depthwise_conv2d_nchw_spatial_pack.arm_cpu")
 def schedule_depthwise_conv2d_nchw_spatial_pack(cfg, outs):
index 3978617..93a166d 100644 (file)
@@ -40,6 +40,7 @@ _depthwise_conv2d_nchw_implement = {
 
 _depthwise_conv2d_nhwc_implement = {
     "generic": (topi.nn.depthwise_conv2d_nhwc, topi.generic.schedule_depthwise_conv2d_nhwc),
+    "arm_cpu": (topi.arm_cpu.compute_depthwise_conv2d_nhwc, topi.arm_cpu.schedule_depthwise_conv2d_nhwc),
     "gpu": (topi.nn.depthwise_conv2d_nhwc, topi.cuda.schedule_depthwise_conv2d_nhwc),
 }