Add schedule for conv3d NDHWC layout (#4775)
authorAlex Gladkov <gladkova@lab126.com>
Sat, 1 Feb 2020 01:43:27 +0000 (17:43 -0800)
committerGitHub <noreply@github.com>
Sat, 1 Feb 2020 01:43:27 +0000 (17:43 -0800)
topi/python/topi/nn/conv3d.py
topi/python/topi/x86/__init__.py
topi/python/topi/x86/conv3d.py [new file with mode: 0644]

index 21d893f..83c16da 100644 (file)
@@ -186,15 +186,15 @@ def conv3d_ndhwc(Input, Filter, stride, padding, dilation, out_dtype='float32'):
     pad_before = [0, pad_front, pad_top, pad_left, 0]
     pad_after = [0, pad_back, pad_down, pad_right, 0]
     PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput")
+    rd = tvm.reduce_axis((0, kernel_d), name='rd')
+    rh = tvm.reduce_axis((0, kernel_h), name='rh')
+    rw = tvm.reduce_axis((0, kernel_w), name='rw')
     rc = tvm.reduce_axis((0, in_channel), name='rc')
-    rz = tvm.reduce_axis((0, kernel_d), name='rz')
-    ry = tvm.reduce_axis((0, kernel_h), name='ry')
-    rx = tvm.reduce_axis((0, kernel_w), name='rx')
     Output = tvm.compute(
         (batch, out_depth, out_height, out_width, out_channel),
-        lambda nn, zz, yy, xx, ff: tvm.sum(
-            PaddedInput[nn, zz * stride_d + rz * dilation_d, yy * stride_h + ry * dilation_h,
-                        xx * stride_w + rx * dilation_w, rc].astype(out_dtype) *
-            Filter[rz, ry, rx, rc, ff].astype(out_dtype), axis=[rz, ry, rx, rc]),
+        lambda nn, dd, hh, ww, cc: tvm.sum(
+            PaddedInput[nn, dd * stride_d + rd * dilation_d, hh * stride_h + rh * dilation_h,
+                        ww * stride_w + rw * dilation_w, rc].astype(out_dtype) *
+            Filter[rd, rh, rw, rc, cc].astype(out_dtype), axis=[rd, rh, rw, rc]),
         name="Conv3dOutput", tag="conv3d_ndhwc")
     return Output
index af7f974..d1c728d 100644 (file)
@@ -21,6 +21,7 @@ from __future__ import absolute_import as _abs
 
 from .conv1d import schedule_conv1d_nwc
 from .conv2d import schedule_conv2d, schedule_conv2d_nhwc
+from .conv3d import schedule_conv3d_ndhwc
 from .binarize_pack import schedule_binarize_pack
 from .binary_dense import schedule_binary_dense
 from .nn import *
diff --git a/topi/python/topi/x86/conv3d.py b/topi/python/topi/x86/conv3d.py
new file mode 100644 (file)
index 0000000..b7a88cb
--- /dev/null
@@ -0,0 +1,82 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-variable, too-many-locals
+# pylint: disable=unused-argument, redefined-builtin, no-else-return
+"""Conv3D operators"""
+import tvm
+from .. import generic, tag
+from ..util import traverse_inline
+
+@generic.schedule_conv3d_ndhwc.register("cpu")
+def schedule_conv3d_ndhwc(outs):
+    """TOPI schedule callback for conv3d
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+        The computation graph description of conv3d
+        in the format of an array of tensors.
+
+    Returns
+    -------
+    s: Schedule
+        The computation schedule for conv3d.
+    """
+    s = tvm.create_schedule([x.op for x in outs])
+    output_op = outs[0].op
+
+    def _traverse(op):
+        """Traverse operators from computation graph"""
+        if op in s.outputs and tag.is_broadcast(op.tag) and len(op.axis) == 5:
+            # schedule bias + bn + relu
+            n, d, h, w, c = op.axis
+            fused = s[op].fuse(n, d, h, w)
+            s[op].parallel(fused)
+            s[op].vectorize(c)
+
+        if 'conv3d_ndhwc' in op.tag:
+            conv = op.output(0)
+            kernel = op.input_tensors[1]
+            # dilation stage
+            if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
+                s[kernel].compute_inline()
+
+            # padding stage
+            data = op.input_tensors[0]
+            data_pad = None
+            if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
+                # fuse pad h and w
+                data_pad = data
+                data = data_pad.op.input_tensors[0]
+                _, _, h_pad, w_pad, _ = data_pad.op.axis
+                pad_fused = s[data_pad].fuse(h_pad, w_pad)
+                s[data_pad].parallel(pad_fused)
+
+            # compute conv
+            C = conv
+            n, d, h, w, c = s[C].op.axis
+            s[C].vectorize(c)
+            if op != output_op: # fuse bias + bn + activation
+                _, _, _, _, c_out = output_op.axis
+                s[C].compute_at(s[output_op], c_out)
+            else:
+                # fuse batch, depth, height axes
+                fused = s[C].fuse(n, d, h)
+                s[C].parallel(fused)
+
+    traverse_inline(s, output_op, _traverse)
+    return s