[TOPI] Move conv2d spatial pack schedule to dedicated file (#3972)
author黎明灰烬 <i@jackwish.net>
Wed, 25 Sep 2019 23:02:19 +0000 (07:02 +0800)
committerLianmin Zheng <lianminzheng@gmail.com>
Wed, 25 Sep 2019 23:02:19 +0000 (16:02 -0700)
More schedules are making the conv2d.py file too large, so
we'd like to move the spatial pack schedule to dedicated file
before introducing NHWC schedule. No logic change in this patch.

topi/python/topi/arm_cpu/conv2d.py
topi/python/topi/arm_cpu/conv2d_spatial_pack.py [new file with mode: 0644]
topi/python/topi/arm_cpu/conv2d_transpose.py
topi/python/topi/mali/conv2d.py

index 55d6bba..73a97d2 100644 (file)
@@ -35,6 +35,8 @@ from ..nn import dilate, pad, conv2d, conv2d_alter_layout, \
 from ..nn import conv2d_legalize
 from ..nn.util import get_const_int, get_pad_tuple
 from ..nn.winograd_util import winograd_transform_matrices
+from .conv2d_spatial_pack import conv2d_spatial_pack_nchw, \
+                                 schedule_conv2d_spatial_pack_nchw
 
 logger = logging.getLogger('topi')
 
@@ -75,8 +77,11 @@ def conv2d_arm_cpu(cfg, data, kernel, strides, padding, dilation, layout, out_dt
     output : tvm.Tensor
         4-D with shape [batch, out_channel, out_height, out_width]
     """
-    return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, layout, out_dtype,
-                              num_tile=2)
+    if layout == 'NCHW':
+        return conv2d_spatial_pack_nchw(cfg, data, kernel, strides, padding,
+                                        dilation, out_dtype, num_tile=2)
+    else:
+        raise ValueError("Unsupported layout {}".format(layout))
 
 
 @autotvm.register_topi_schedule(
@@ -119,7 +124,8 @@ def schedule_conv2d_nchw_arm_cpu(cfg, outs):
             if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
                 s[kernel].compute_inline()
 
-            _schedule_spatial_pack(cfg, s, data_vec, kernel_vec, conv, output, outs[0])
+            schedule_conv2d_spatial_pack_nchw(cfg, s, data_vec, kernel_vec,
+                                              conv, output, outs[0])
 
         if 'winograd_conv2d_output' in op.tag:
             output = op.output(0)
@@ -132,179 +138,6 @@ def schedule_conv2d_nchw_arm_cpu(cfg, outs):
     traverse_inline(s, outs[0].op, _callback)
     return s
 
-
-def _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, num_tile):
-    assert layout == "NCHW", "Only support NCHW"
-    # create workload according to raw arguments
-    out_dtype = out_dtype or data.dtype
-    N, CI, IH, IW = get_const_tuple(data.shape)
-
-    if isinstance(dilation, int):
-        dilation_h = dilation_w = dilation
-    else:
-        dilation_h, dilation_w = dilation
-
-    if len(kernel.shape) == 4:
-        pre_packed = False
-        CO, _, KH, KW = get_const_tuple(kernel.shape)
-    else:  # kernel tensor is pre packed
-        pre_packed = True
-        CO, _, KH, KW, VC = get_const_tuple(kernel.shape)
-        CO = CO * VC
-
-    dilated_kernel_h = (KH - 1) * dilation_h + 1
-    dilated_kernel_w = (KW - 1) * dilation_w + 1
-    pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(
-        padding, (dilated_kernel_h, dilated_kernel_w))
-    HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
-    OH = (IH + pad_top + pad_bottom - dilated_kernel_h) // HSTR + 1
-    OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
-    data_pad = pad(data, [0, 0, pad_top, pad_left], [0, 0, pad_bottom, pad_right])
-
-    # ==================== define configuration space ====================
-    n, co, oh, ow = cfg.axis(N), cfg.axis(CO), cfg.axis(OH), cfg.axis(OW)
-    ci, kh, kw = cfg.reduce_axis(CI), cfg.reduce_axis(KH), cfg.reduce_axis(KW)
-
-    if num_tile == 2:     # for arm cpu
-        co, vc = cfg.define_split('tile_co', co, num_outputs=2)
-        oh, vh = cfg.define_split('tile_oh', oh, num_outputs=2)
-        ow, vw = cfg.define_split('tile_ow', ow, num_outputs=2)
-    elif num_tile == 3:   # for mali gpu
-        co, _, vc = cfg.define_split('tile_co', co, num_outputs=3)
-        oh, _, vh = cfg.define_split('tile_oh', oh, num_outputs=3)
-        ow, _, vw = cfg.define_split('tile_ow', ow, num_outputs=3)
-    else:
-        raise RuntimeError("Invalid num_tile")
-
-    cfg.define_reorder("reorder_0",
-                       [n, co, oh, ow, ci, kh, kw, vh, vw, vc],
-                       policy='candidate', candidate=[
-                           [n, co, oh, ow, ci, kh, kw, vh, vw, vc],
-                           [n, co, oh, ow, ci, kh, kw, vc, vh, vw]])
-
-    cfg.define_annotate("ann_reduce", [kh, kw], policy='try_unroll')
-    cfg.define_annotate("ann_spatial", [vh, vw, vc], policy='try_unroll_vec')
-
-    # fallback support
-    if cfg.is_fallback:
-        if num_tile == 2:     # arm cpu
-            ref_log = autotvm.tophub.load_reference_log('arm_cpu', 'rk3399', 'conv2d', 'direct')
-            cfg.fallback_with_reference_log(ref_log)
-        elif num_tile == 3:  # mali gpu
-            ref_log = autotvm.tophub.load_reference_log('mali', 'rk3399', 'conv2d', 'direct')
-            cfg.fallback_with_reference_log(ref_log)
-    # ====================================================================
-
-    VC = cfg["tile_co"].size[-1]
-    VH = cfg["tile_oh"].size[-1]
-    VW = cfg["tile_ow"].size[-1]
-
-    kvshape = (CO // VC, CI, KH, KW, VC)
-    ovshape = (N, CO // VC, OH // VH, OW // VW, VH, VW, VC)
-    oshape = (N, CO, OH, OW)
-
-    if dilation_h != 1 or dilation_w != 1:
-        # undilate input data
-        dvshape = (N, OH // VH, OW // VW, CI, KH, KW, VH, VW)
-        data_vec = tvm.compute(dvshape, lambda n, h, w, ci, kh, kw, vh, vw:
-                               data_pad[n][ci][(h*VH+vh)*HSTR+kh*dilation_h]
-                               [(w*VW+vw)*WSTR+kw*dilation_w],
-                               name='data_vec_undilated')
-    else:
-        dvshape = (N, OH // VH, OW // VW, CI, VH*HSTR + KH-1, VW*WSTR + KW-1)
-        data_vec = tvm.compute(dvshape, lambda n, h, w, ci, vh, vw:
-                               data_pad[n][ci][h*VH*HSTR+vh][w*VW*WSTR+vw],
-                               name='data_vec')
-
-    if pre_packed:
-        kernel_vec = kernel
-    else:
-        kernel_vec = tvm.compute(kvshape, lambda co, ci, kh, kw, vc:
-                                 kernel[co*VC+vc][ci][kh][kw],
-                                 name='kernel_vec')
-
-    ci = tvm.reduce_axis((0, CI), name='ci')
-    kh = tvm.reduce_axis((0, KH), name='kh')
-    kw = tvm.reduce_axis((0, KW), name='kw')
-
-    if dilation_h != 1 or dilation_w != 1:
-        conv = tvm.compute(ovshape, lambda n, co, h, w, vh, vw, vc: \
-            tvm.sum(data_vec[n, h, w, ci, kh, kw, vh, vw].astype(out_dtype) *
-                    kernel_vec[co, ci, kh, kw, vc].astype(out_dtype),
-                    axis=[ci, kh, kw]), name='conv')
-    else:
-        conv = tvm.compute(ovshape, lambda n, co, h, w, vh, vw, vc: \
-            tvm.sum(data_vec[n, h, w, ci, vh*HSTR+kh, vw*WSTR+kw].astype(out_dtype) *
-                    kernel_vec[co, ci, kh, kw, vc].astype(out_dtype),
-                    axis=[ci, kh, kw]), name='conv')
-
-    output = tvm.compute(oshape, lambda n, co, h, w:
-                         conv[n][co//VC][h//VH][w//VW][h%VH][w%VW][co%VC],
-                         name='output_unpack', tag='spatial_conv2d_output')
-    return output
-
-def _schedule_spatial_pack(cfg, s, data_vec, kernel_vec,
-                           conv, output, last):
-    """schedule implementation"""
-    n, co, oh, ow, vh, vw, vc = s[conv].op.axis
-    ci, kh, kw = s[conv].op.reduce_axis
-
-    # schedule conv
-    cfg["reorder_0"].apply(s, conv, [n, co, oh, ow, ci, kh, kw, vh, vw, vc])
-    cfg["ann_reduce"].apply(s, conv, [kh, kw],
-                            axis_lens=[get_const_int(kh.dom.extent),
-                                       get_const_int(kw.dom.extent)],
-                            max_unroll=16,
-                            cfg=cfg)
-    cfg["ann_spatial"].apply(s, conv, [vh, vw, vc],
-                             axis_lens=[cfg['tile_oh'].size[-1],
-                                        cfg['tile_ow'].size[-1],
-                                        cfg['tile_co'].size[-1]],
-                             max_unroll=16,
-                             cfg=cfg)
-
-    # schedule fusion
-    n, co, h, w = s[last].op.axis
-    co, vc = cfg['tile_co'].apply(s, last, co)
-    oh, vh = cfg['tile_oh'].apply(s, last, h)
-    ow, vw = cfg['tile_ow'].apply(s, last, w)
-    s[last].reorder(n, co, oh, ow, vh, vw, vc)
-    if last != output:
-        s[output].compute_inline()
-        cfg["ann_spatial"].apply(s, last, [vh, vw, vc],
-                                 axis_lens=[cfg['tile_oh'].size[-1],
-                                            cfg['tile_ow'].size[-1],
-                                            cfg['tile_co'].size[-1]],
-                                 max_unroll=16,
-                                 cfg=cfg)
-    s[conv].compute_at(s[last], ow)
-
-    # mark parallel
-    p = s[last].fuse(n, co)
-    s[last].parallel(p)
-
-    if data_vec.op.name == 'data_vec_undilated':
-        n, h, _, _, _, _, _, _ = s[data_vec].op.axis
-    else:
-        n, h, _, _, _, _ = s[data_vec].op.axis
-    p = s[data_vec].fuse(n, h)
-    s[data_vec].parallel(p)
-
-    if kernel_vec.op.name == 'kernel_vec':
-        co, _, _, _, _ = s[kernel_vec].op.axis
-        if autotvm.GLOBAL_SCOPE.in_tuning:
-            # kernel packing will be pre-computed during compilation, so we skip
-            # this part to make tuning records correct
-            s[kernel_vec].pragma(co, 'debug_skip_region')
-        else:
-            s[kernel_vec].parallel(co)
-    elif kernel_vec.op.name == 'kernel_vec_conv2d_transpose':  # for conv2d transpose
-        co, _, _, _, _ = s[kernel_vec].op.axis
-        s[kernel_vec].parallel(co)
-
-    return s
-
-
 @autotvm.register_topi_compute(conv2d, 'arm_cpu', ['winograd'])
 def conv2d_arm_cpu_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
     """ TOPI compute callback. Use winograd template """
diff --git a/topi/python/topi/arm_cpu/conv2d_spatial_pack.py b/topi/python/topi/arm_cpu/conv2d_spatial_pack.py
new file mode 100644 (file)
index 0000000..2066b9a
--- /dev/null
@@ -0,0 +1,193 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-variable,no-else-return
+"""Conv2D spatial pack implementation for ARM CPU"""
+from __future__ import absolute_import as _abs
+import tvm
+from tvm import autotvm
+from .. import nn
+from ..util import get_const_tuple
+from ..nn.util import get_const_int, get_pad_tuple
+
+def conv2d_spatial_pack_nchw(cfg, data, kernel, strides, padding, dilation,
+                             out_dtype, num_tile):
+    """compute define for Conv2d Spatial Pack with NCHW layout"""
+    out_dtype = out_dtype or data.dtype
+    N, CI, IH, IW = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    if len(kernel.shape) == 4:
+        pre_packed = False
+        CO, _, KH, KW = get_const_tuple(kernel.shape)
+    else:  # kernel tensor is pre packed
+        pre_packed = True
+        CO, _, KH, KW, VC = get_const_tuple(kernel.shape)
+        CO = CO * VC
+
+    dilated_kernel_h = (KH - 1) * dilation_h + 1
+    dilated_kernel_w = (KW - 1) * dilation_w + 1
+    pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(
+        padding, (dilated_kernel_h, dilated_kernel_w))
+    HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
+    OH = (IH + pad_top + pad_bottom - dilated_kernel_h) // HSTR + 1
+    OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
+    data_pad = nn.pad(data, [0, 0, pad_top, pad_left], [0, 0, pad_bottom, pad_right])
+
+    # ==================== define configuration space ====================
+    n, co, oh, ow = cfg.axis(N), cfg.axis(CO), cfg.axis(OH), cfg.axis(OW)
+    ci, kh, kw = cfg.reduce_axis(CI), cfg.reduce_axis(KH), cfg.reduce_axis(KW)
+
+    if num_tile == 2:     # for arm cpu
+        co, vc = cfg.define_split('tile_co', co, num_outputs=2)
+        oh, vh = cfg.define_split('tile_oh', oh, num_outputs=2)
+        ow, vw = cfg.define_split('tile_ow', ow, num_outputs=2)
+    elif num_tile == 3:   # for mali gpu
+        co, _, vc = cfg.define_split('tile_co', co, num_outputs=3)
+        oh, _, vh = cfg.define_split('tile_oh', oh, num_outputs=3)
+        ow, _, vw = cfg.define_split('tile_ow', ow, num_outputs=3)
+    else:
+        raise RuntimeError("Invalid num_tile")
+
+    cfg.define_reorder("reorder_0",
+                       [n, co, oh, ow, ci, kh, kw, vh, vw, vc],
+                       policy='candidate', candidate=[
+                           [n, co, oh, ow, ci, kh, kw, vh, vw, vc],
+                           [n, co, oh, ow, ci, kh, kw, vc, vh, vw]])
+
+    cfg.define_annotate("ann_reduce", [kh, kw], policy='try_unroll')
+    cfg.define_annotate("ann_spatial", [vh, vw, vc], policy='try_unroll_vec')
+
+    # fallback support
+    if cfg.is_fallback:
+        if num_tile == 2:     # arm cpu
+            ref_log = autotvm.tophub.load_reference_log('arm_cpu', 'rk3399', 'conv2d', 'direct')
+            cfg.fallback_with_reference_log(ref_log)
+        elif num_tile == 3:  # mali gpu
+            ref_log = autotvm.tophub.load_reference_log('mali', 'rk3399', 'conv2d', 'direct')
+            cfg.fallback_with_reference_log(ref_log)
+    # ====================================================================
+
+    VC = cfg["tile_co"].size[-1]
+    VH = cfg["tile_oh"].size[-1]
+    VW = cfg["tile_ow"].size[-1]
+
+    kvshape = (CO // VC, CI, KH, KW, VC)
+    ovshape = (N, CO // VC, OH // VH, OW // VW, VH, VW, VC)
+    oshape = (N, CO, OH, OW)
+
+    if dilation_h != 1 or dilation_w != 1:
+        # undilate input data
+        dvshape = (N, OH // VH, OW // VW, CI, KH, KW, VH, VW)
+        data_vec = tvm.compute(dvshape, lambda n, h, w, ci, kh, kw, vh, vw:
+                               data_pad[n][ci][(h*VH+vh)*HSTR+kh*dilation_h]
+                               [(w*VW+vw)*WSTR+kw*dilation_w],
+                               name='data_vec_undilated')
+    else:
+        dvshape = (N, OH // VH, OW // VW, CI, VH*HSTR + KH-1, VW*WSTR + KW-1)
+        data_vec = tvm.compute(dvshape, lambda n, h, w, ci, vh, vw:
+                               data_pad[n][ci][h*VH*HSTR+vh][w*VW*WSTR+vw],
+                               name='data_vec')
+
+    if pre_packed:
+        kernel_vec = kernel
+    else:
+        kernel_vec = tvm.compute(kvshape, lambda co, ci, kh, kw, vc:
+                                 kernel[co*VC+vc][ci][kh][kw],
+                                 name='kernel_vec')
+
+    ci = tvm.reduce_axis((0, CI), name='ci')
+    kh = tvm.reduce_axis((0, KH), name='kh')
+    kw = tvm.reduce_axis((0, KW), name='kw')
+
+    if dilation_h != 1 or dilation_w != 1:
+        conv = tvm.compute(ovshape, lambda n, co, h, w, vh, vw, vc: \
+            tvm.sum(data_vec[n, h, w, ci, kh, kw, vh, vw].astype(out_dtype) *
+                    kernel_vec[co, ci, kh, kw, vc].astype(out_dtype),
+                    axis=[ci, kh, kw]), name='conv')
+    else:
+        conv = tvm.compute(ovshape, lambda n, co, h, w, vh, vw, vc: \
+            tvm.sum(data_vec[n, h, w, ci, vh*HSTR+kh, vw*WSTR+kw].astype(out_dtype) *
+                    kernel_vec[co, ci, kh, kw, vc].astype(out_dtype),
+                    axis=[ci, kh, kw]), name='conv')
+
+    output = tvm.compute(oshape, lambda n, co, h, w:
+                         conv[n][co//VC][h//VH][w//VW][h%VH][w%VW][co%VC],
+                         name='output_unpack', tag='spatial_conv2d_output')
+    return output
+
+def schedule_conv2d_spatial_pack_nchw(cfg, s, data_vec, kernel_vec,
+                                      conv, output, last):
+    """schedule implementation"""
+    n, co, oh, ow, vh, vw, vc = s[conv].op.axis
+    ci, kh, kw = s[conv].op.reduce_axis
+
+    # schedule conv
+    cfg["reorder_0"].apply(s, conv, [n, co, oh, ow, ci, kh, kw, vh, vw, vc])
+    cfg["ann_reduce"].apply(s, conv, [kh, kw],
+                            axis_lens=[get_const_int(kh.dom.extent),
+                                       get_const_int(kw.dom.extent)],
+                            max_unroll=16,
+                            cfg=cfg)
+    cfg["ann_spatial"].apply(s, conv, [vh, vw, vc],
+                             axis_lens=[cfg['tile_oh'].size[-1],
+                                        cfg['tile_ow'].size[-1],
+                                        cfg['tile_co'].size[-1]],
+                             max_unroll=16,
+                             cfg=cfg)
+
+    # schedule fusion
+    n, co, h, w = s[last].op.axis
+    co, vc = cfg['tile_co'].apply(s, last, co)
+    oh, vh = cfg['tile_oh'].apply(s, last, h)
+    ow, vw = cfg['tile_ow'].apply(s, last, w)
+    s[last].reorder(n, co, oh, ow, vh, vw, vc)
+    if last != output:
+        s[output].compute_inline()
+        cfg["ann_spatial"].apply(s, last, [vh, vw, vc],
+                                 axis_lens=[cfg['tile_oh'].size[-1],
+                                            cfg['tile_ow'].size[-1],
+                                            cfg['tile_co'].size[-1]],
+                                 max_unroll=16,
+                                 cfg=cfg)
+    s[conv].compute_at(s[last], ow)
+
+    # mark parallel
+    s[last].parallel(co)
+
+    if data_vec.op.name == 'data_vec_undilated':
+        _, h, _, _, _, _, _, _ = s[data_vec].op.axis
+    else:
+        _, h, _, _, _, _ = s[data_vec].op.axis
+    s[data_vec].parallel(h)
+
+    if kernel_vec.op.name == 'kernel_vec':
+        co, _, _, _, _ = s[kernel_vec].op.axis
+        if autotvm.GLOBAL_SCOPE.in_tuning:
+            # kernel packing will be pre-computed during compilation, so we skip
+            # this part to make tuning records correct
+            s[kernel_vec].pragma(co, 'debug_skip_region')
+        else:
+            s[kernel_vec].parallel(co)
+    elif kernel_vec.op.name == 'kernel_vec_conv2d_transpose':  # for conv2d transpose
+        co, _, _, _, _ = s[kernel_vec].op.axis
+        s[kernel_vec].parallel(co)
+
+    return s
index 99cb046..bd56ef3 100644 (file)
@@ -24,7 +24,7 @@ from tvm import autotvm
 from ..generic import schedule_conv2d_transpose_nchw
 from ..nn import conv2d_transpose_nchw, dilate, pad, get_pad_tuple
 from ..util import get_const_tuple, traverse_inline
-from .conv2d import _schedule_spatial_pack
+from .conv2d_spatial_pack import schedule_conv2d_spatial_pack_nchw
 
 @autotvm.task.register_topi_compute(conv2d_transpose_nchw, "arm_cpu", "direct")
 def conv2d_transpose_nchw_arm(cfg, Input, Filter, strides, padding, out_dtype):
@@ -154,7 +154,8 @@ def schedule_conv2d_transpose_arm(cfg, outs):
             if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
                 s[kernel].compute_inline()
 
-            _schedule_spatial_pack(cfg, s, data_vec, kernel_vec, conv, output, outs[0])
+            schedule_conv2d_spatial_pack_nchw(cfg, s, data_vec, kernel_vec,
+                                              conv, output, outs[0])
 
     traverse_inline(s, outs[0].op, _callback)
     return s
index ddb3462..2382127 100644 (file)
@@ -27,7 +27,8 @@ from ..nn import conv2d, conv2d_winograd_without_weight_transform, \
 from ..nn.winograd_util import winograd_transform_matrices
 
 # reuse some compute declarations from ARM CPU
-from ..arm_cpu.conv2d import _decl_spatial_pack, _alter_conv2d_layout_arm
+from ..arm_cpu.conv2d import _alter_conv2d_layout_arm
+from ..arm_cpu.conv2d_spatial_pack import conv2d_spatial_pack_nchw
 
 
 @autotvm.register_topi_compute(conv2d, 'mali', ['direct'])
@@ -67,8 +68,11 @@ def conv2d_mali(cfg, data, kernel, strides, padding, dilation, layout, out_dtype
     output : tvm.Tensor
         4-D with shape [batch, out_channel, out_height, out_width]
     """
-    return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, layout, out_dtype,
-                              num_tile=3)
+    if layout == 'NCHW':
+        return conv2d_spatial_pack_nchw(cfg, data, kernel, strides, padding,
+                                        dilation, out_dtype, num_tile=3)
+    else:
+        raise ValueError("Unsupported layout {}".format(layout))
 
 @autotvm.register_topi_schedule(schedule_conv2d_nchw, 'mali', ['direct', 'winograd'])
 def schedule_conv2d_nchw_mali(cfg, outs):