Fixed issue #3069 by checking op tag (#3070)
authorRuizhe Zhao (Vincent) <kumasento@users.noreply.github.com>
Sat, 27 Apr 2019 02:15:21 +0000 (03:15 +0100)
committerWuwei Lin <vincentl13x@gmail.com>
Sat, 27 Apr 2019 02:15:21 +0000 (10:15 +0800)
* Fixed issue #3069 by adding in_channels

* Registerd group_conv2d_nchw as topi compute

* Improved by checking tag value

* Removed group_conv2d_nchw topi registration

* Added test for relay group_conv2d_nchw

* Added assertions to forbid small group size

* Removed hard-coded oc_block_factor

* Added explanatory comments to group_conv2d_nchw_cuda

* Updated group_conv2d_nchw_cuda schedule

Removed 'direct' CUDA tests

* Reverted an accidental change in a conv2d test

* Fixed indentation problems

* Fixed a mis-commented line

* Reverted change in group_conv2d_nchw tag

* Removed commented int8 group_conv2d test

* Fixed group size assertions in group_conv2d_nchw_cuda

python/tvm/relay/frontend/onnx.py
python/tvm/relay/op/nn/_nn.py
python/tvm/target.py
tests/python/relay/test_op_level2.py
topi/python/topi/cuda/group_conv2d_nchw.py
topi/python/topi/generic/nn.py
topi/python/topi/nn/conv2d.py

index ebedc20..53f104c 100644 (file)
@@ -169,7 +169,6 @@ class Conv(OnnxOpConverter):
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        # get number of channels
         out = AttrCvt(op_name=dimension_picker('conv'),
                       transforms={
                           'kernel_shape': 'kernel_size',
index 272b751..5e9d5d7 100644 (file)
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-#pylint: disable=invalid-name, unused-argument
+# pylint: disable=invalid-name, unused-argument
 """Backend compiler related feature registration"""
 from __future__ import absolute_import
 
@@ -34,16 +34,19 @@ def schedule_softmax(_, outputs, target):
     with target:
         return topi.generic.schedule_softmax(outputs)
 
+
 reg.register_pattern("nn.softmax", OpPattern.OPAQUE)
 
 schedule_broadcast = schedule_injective
 
+
 @reg.register_schedule("nn.log_softmax")
 def schedule_log_softmax(_, outputs, target):
     """Schedule definition of log_softmax"""
     with target:
         return topi.generic.schedule_softmax(outputs)
 
+
 reg.register_pattern("nn.log_softmax", OpPattern.OPAQUE)
 
 
@@ -55,12 +58,14 @@ def compute_dense(attrs, inputs, out_type, target):
     out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
     return [topi.nn.dense(inputs[0], inputs[1], out_dtype=out_dtype)]
 
+
 @reg.register_schedule("nn.dense")
 def schedule_dense(attrs, outputs, target):
     """Schedule definition of dense"""
     with target:
         return topi.generic.schedule_dense(outputs)
 
+
 reg.register_pattern("nn.dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
 
 
@@ -70,16 +75,29 @@ def compute_batch_matmul(attrs, inputs, out_type, target):
     """Compute definition of batch_matmul"""
     return [topi.nn.batch_matmul(inputs[0], inputs[1])]
 
+
 @reg.register_schedule("nn.batch_matmul")
 def schedule_batch_matmul(attrs, outputs, target):
     """Schedule definition of batch_matmul"""
     with target:
         return topi.generic.schedule_batch_matmul(outputs)
 
+
 reg.register_pattern("nn.batch_matmul", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
 
 
 # conv2d
+def _find_conv2d_op(op):
+    """Find the op with conv2d in its tag by traversing."""
+    if 'conv2d' in op.tag:
+        return op
+    for tensor in op.input_tensors:
+        op_ = _find_conv2d_op(tensor.op)
+        if op_ is not None:
+            return op_
+    return None
+
+
 @reg.register_compute("nn.conv2d")
 def compute_conv2d(attrs, inputs, out_type, target):
     """Compute definition of conv2d"""
@@ -103,14 +121,14 @@ def compute_conv2d(attrs, inputs, out_type, target):
             inputs[0], inputs[1], strides, padding,
             dilation, layout, out_dtype=out_dtype)
     elif layout == "NCHW" and \
-         get_const_int(inputs[1].shape[0]) == groups and \
-         get_const_int(inputs[1].shape[1]) == 1:
+            get_const_int(inputs[1].shape[0]) == groups and \
+            get_const_int(inputs[1].shape[1]) == 1:
         out = topi.nn.depthwise_conv2d_nchw(
             inputs[0], inputs[1], strides, padding, dilation, out_dtype=out_dtype)
     elif layout == "NHWC" and \
-         kernel_layout == "HWOI" and\
-         get_const_int(inputs[1].shape[2]) == groups and \
-         get_const_int(inputs[1].shape[3]) == 1:
+            kernel_layout == "HWOI" and\
+            get_const_int(inputs[1].shape[2]) == groups and \
+            get_const_int(inputs[1].shape[3]) == 1:
         out = topi.nn.depthwise_conv2d_nhwc(
             inputs[0], inputs[1], strides, padding, dilation, out_dtype=out_dtype)
     elif layout in ['NCHW', 'NCHW4c']:
@@ -127,6 +145,7 @@ def schedule_conv2d(attrs, outs, target):
     groups = attrs.groups
     layout = attrs.data_layout
     kernel_layout = attrs.kernel_layout
+
     with target:
         if groups == 1 and layout == "NCHW":
             return topi.generic.schedule_conv2d_nchw(outs)
@@ -135,13 +154,20 @@ def schedule_conv2d(attrs, outs, target):
         if groups == 1 and layout == "NHWC":
             return topi.generic.schedule_conv2d_nhwc(outs)
         if groups != 1:
-            if layout == "NCHW":
-                # TODO(leyuan, merrymercy, Huyuwei): fold depthwise topi into conv2d.
-                return topi.generic.schedule_depthwise_conv2d_nchw(outs)
-            if layout == "NHWC" and kernel_layout == "HWOI":
-                return topi.generic.schedule_depthwise_conv2d_nhwc(outs)
-            if layout == "NCHW4c":
-                return topi.generic.schedule_group_conv2d_nchw(outs)
+            # collect in_channels to distinguish depthwise and group conv2d
+            op = _find_conv2d_op(outs[0].op)
+            assert op is not None
+
+            is_depthwise = 'depthwise' in op.tag
+            if is_depthwise:
+                if layout == "NCHW":
+                    # TODO(leyuan, merrymercy, Huyuwei): fold depthwise topi into conv2d.
+                    return topi.generic.schedule_depthwise_conv2d_nchw(outs)
+                if layout == "NHWC" and kernel_layout == "HWOI":
+                    return topi.generic.schedule_depthwise_conv2d_nhwc(outs)
+            else:
+                if layout in ["NCHW", "NCHW4c"]:
+                    return topi.generic.schedule_group_conv2d_nchw(outs)
     raise ValueError("No compatible schedule")
 
 
@@ -151,6 +177,7 @@ def alter_op_layout_conv2d(attrs, inputs, tinfos):
     from ... import op
     return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, op)
 
+
 reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
 
 
@@ -169,18 +196,21 @@ def compute_conv2d_transpose(attrs, inputs, out_dtype, target):
     assert layout == "NCHW", "only support nchw for now"
     assert dilation == (1, 1), "not support dilate now"
     assert groups == 1, "only support groups == 1 for now"
-    out = topi.nn.conv2d_transpose_nchw(inputs[0], inputs[1], strides, padding, out_dtype)
+    out = topi.nn.conv2d_transpose_nchw(
+        inputs[0], inputs[1], strides, padding, out_dtype)
     output_padding = get_const_tuple(attrs.output_padding)
     out = topi.nn.pad(out,
                       [0, 0, 0, 0], [0, 0, output_padding[0], output_padding[1]])
     return [out]
 
+
 @reg.register_schedule("nn.conv2d_transpose")
 def schedule_conv2d_transpose(attrs, outs, target):
     """Schedule definition of conv2d_transpose"""
     with target:
         return topi.generic.schedule_conv2d_transpose_nchw(outs)
 
+
 reg.register_pattern("nn.conv2d_transpose", OpPattern.OUT_ELEMWISE_FUSABLE)
 
 # bias_add
@@ -196,6 +226,7 @@ def schedule_max_pool2d(attrs, outs, target):
     with target:
         return topi.generic.schedule_pool(outs, layout)
 
+
 reg.register_pattern("nn.max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
 
 
@@ -207,6 +238,7 @@ def schedule_avg_pool2d(attrs, outs, target):
     with target:
         return topi.generic.schedule_pool(outs, layout)
 
+
 reg.register_pattern("nn.avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
 
 
@@ -217,6 +249,7 @@ def schedule_global_max_pool2d(_, outs, target):
     with target:
         return topi.generic.schedule_global_pool(outs)
 
+
 reg.register_pattern("nn.global_max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
 
 
@@ -227,6 +260,7 @@ def schedule_global_avg_pool2d(_, outs, target):
     with target:
         return topi.generic.schedule_global_pool(outs)
 
+
 reg.register_pattern("nn.global_avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
 
 # leaky_relu
@@ -250,12 +284,14 @@ def compute_lrn(attrs, inputs, out_dtype, target):
     return [topi.nn.lrn(inputs[0], attrs.size, attrs.axis,
                         attrs.alpha, attrs.beta, attrs.bias)]
 
+
 @reg.register_schedule("nn.lrn")
 def schedule_lrn(attrs, outs, target):
     """Schedule definition of lrn"""
     with target:
         return topi.generic.schedule_lrn(outs)
 
+
 reg.register_pattern("nn.lrn", OpPattern.OPAQUE)
 
 
@@ -265,20 +301,26 @@ def compute_l2_normalize(attrs, inputs, out_dtype, target):
     """Compute definition of l2 normalize"""
     return [topi.nn.l2_normalize(inputs[0], attrs.eps, attrs.axis)]
 
+
 @reg.register_schedule("nn.l2_normalize")
 def schedule_l2_normalize(attrs, outs, target):
     """Schedule definition of l2 normalize"""
     with target:
         return topi.generic.schedule_l2_normalize(outs)
 
+
 reg.register_pattern("nn.l2_normalize", OpPattern.OUT_ELEMWISE_FUSABLE)
 
 # upsampling
 reg.register_schedule("nn.upsampling", reg.schedule_injective)
+
+
 def schedule_upsampling(_, outs, target):
     """Schedule definition of upsampling"""
     with target:
         return topi.generic.schedule_injective(outs)
+
+
 # pad
 reg.register_schedule("nn.pad", schedule_broadcast)
 
@@ -304,12 +346,14 @@ def compute_contrib_conv2d_winograd_without_weight_transform(attrs, inputs, out_
 
     return [out]
 
+
 @reg.register_schedule("nn.contrib_conv2d_winograd_without_weight_transform")
 def schedule_contrib_conv2d_winograd_without_weight_transform(attrs, outs, target):
     """Schedule definition of conv2d_winograd_without_weight_transform"""
     with target:
         return topi.generic.schedule_conv2d_winograd_without_weight_transform(outs)
 
+
 reg.register_pattern("nn.contrib_conv2d_winograd_without_weight_transform",
                      OpPattern.OUT_ELEMWISE_FUSABLE)
 
@@ -317,15 +361,18 @@ reg.register_pattern("nn.contrib_conv2d_winograd_without_weight_transform",
 @reg.register_compute("nn.contrib_conv2d_winograd_weight_transform")
 def compute_contrib_conv2d_winograd_weight_transform(attrs, inputs, out_dtype, target):
     """Compute definition of contrib_conv2d_winograd_weight_transform"""
-    out = topi.nn.conv2d_winograd_weight_transform(inputs[0], attrs.get_int('tile_size'))
+    out = topi.nn.conv2d_winograd_weight_transform(
+        inputs[0], attrs.get_int('tile_size'))
     return [out]
 
+
 @reg.register_schedule("nn.contrib_conv2d_winograd_weight_transform")
 def schedule_contrib_conv2d_winograd_weight_transform(attrs, outs, target):
     """Schedule definition of contrib_conv2d_winograd_weight_transform"""
     with target:
         return topi.generic.schedule_conv2d_winograd_weight_transform(outs)
 
+
 reg.register_pattern("nn.contrib_conv2d_winograd_weight_transform",
                      OpPattern.OUT_ELEMWISE_FUSABLE)
 
@@ -353,12 +400,14 @@ def compute_contrib_conv2d_winograd_nnpack_without_weight_transform(
 
     return [out]
 
+
 @reg.register_schedule("nn.contrib_conv2d_winograd_nnpack_without_weight_transform")
 def schedule_contrib_conv2d_winograd_nnpack_without_weight_transform(attrs, outs, target):
     """Schedule definition of conv2d_winograd_nnpack_without_weight_transform"""
     with target:
         return topi.generic.schedule_conv2d_winograd_nnpack_without_weight_transform(outs)
 
+
 reg.register_pattern("nn.contrib_conv2d_winograd_nnpack_without_weight_transform",
                      OpPattern.OPAQUE)
 
@@ -371,12 +420,14 @@ def compute_contrib_conv2d_winograd_nnpack_weight_transform(attrs, inputs, out_d
         inputs[0], convolution_algorithm, out_dtype)
     return [out]
 
+
 @reg.register_schedule("nn.contrib_conv2d_winograd_nnpack_weight_transform")
 def schedule_contrib_conv2d_winograd_nnpack_weight_transform(attrs, outs, target):
     """Schedule definition of contrib_conv2d_winograd_nnpack_weight_transform"""
     with target:
         return topi.generic.schedule_conv2d_winograd_nnpack_weight_transform(outs)
 
+
 reg.register_pattern("nn.contrib_conv2d_winograd_nnpack_weight_transform",
                      OpPattern.OPAQUE)
 
@@ -397,15 +448,18 @@ def compute_contrib_conv2d_NCHWc(attrs, inputs, out_dtype, target):
                                data_layout, out_layout, out_dtype)
     return [out]
 
+
 @reg.register_schedule("nn.contrib_conv2d_NCHWc")
 def schedule_contrib_conv2d_NCHWc(attrs, outs, target):
     """Schedule definition of contrib_conv2d_NCHWc"""
     with target:
         return topi.generic.schedule_conv2d_NCHWc(outs)
 
+
 reg.register_pattern("nn.contrib_conv2d_NCHWc",
                      OpPattern.OUT_ELEMWISE_FUSABLE)
 
+
 @reg.register_compute("nn.contrib_depthwise_conv2d_NCHWc")
 def compute_contrib_depthwise_conv2d_NCHWc(attrs, inputs, out_dtype, target):
     """Compute definition of depthwise conv2d NCHWc"""
@@ -422,15 +476,18 @@ def compute_contrib_depthwise_conv2d_NCHWc(attrs, inputs, out_dtype, target):
                                          data_layout, out_layout, out_dtype)
     return [out]
 
+
 @reg.register_schedule("nn.contrib_depthwise_conv2d_NCHWc")
 def schedule_contrib_depthwise_conv2d_NCHWc(attrs, outs, target):
     """Schedule definition of contrib_conv2d_NCHWc"""
     with target:
         return topi.generic.schedule_depthwise_conv2d_NCHWc(outs)
 
+
 reg.register_pattern("nn.contrib_depthwise_conv2d_NCHWc",
                      OpPattern.OUT_ELEMWISE_FUSABLE)
 
+
 @reg.register_compute("nn.deformable_conv2d")
 def compute_deformable_conv2d(attrs, inputs, out_dtype, target):
     """Compute definition of deformable_conv2d"""
@@ -446,10 +503,12 @@ def compute_deformable_conv2d(attrs, inputs, out_dtype, target):
                                              dilation, deformable_groups, groups, out_dtype)
     return [out]
 
+
 @reg.register_schedule("nn.deformable_conv2d")
 def schedule_deformable_conv2d(attrs, outs, target):
     """Schedule definition of deformable_conv2d"""
     with target:
         return topi.generic.schedule_deformable_conv2d_nchw(outs)
 
+
 reg.register_pattern("nn.deformable_conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
index d3df3d7..eff0088 100644 (file)
@@ -296,7 +296,7 @@ def override_native_generic_func(func_name):
 def generic_func(fdefault):
     """Wrap a target generic function.
 
-    Generic function allows registeration of further functions
+    Generic function allows registration of further functions
     that can be dispatched on current target context.
     If no registered dispatch is matched, the fdefault will be called.
 
index a6efd8c..88963a6 100644 (file)
@@ -86,9 +86,13 @@ def test_conv2d_run():
                         fref=None,
                         groups=1,
                         dilation=(1, 1),
+                        except_targets=None,
                         **attrs):
-        x = relay.var("x", shape=dshape)
-        w = relay.var("w")
+        if except_targets is None:
+          except_targets = []
+          
+        x = relay.var("x", shape=dshape, dtype=dtype)
+        w = relay.var("w", dtype=dtype)
         y = relay.nn.conv2d(x, w,
                             padding=padding,
                             dilation=dilation,
@@ -100,11 +104,15 @@ def test_conv2d_run():
         dkernel = topi.testing.dilate_python(kernel, (1, 1) + dilation)
         if fref is None:
             ref_res = topi.testing.conv2d_nchw_python(
-                data.astype(out_dtype), dkernel.astype(out_dtype), 1, padding)
+                data.astype(out_dtype), dkernel.astype(out_dtype), 1, padding,
+                groups=groups)
         else:
             ref_res = fref(data.astype(out_dtype), dkernel.astype(out_dtype))
 
+
         for target, ctx in ctx_list():
+            if target in except_targets:
+                continue
             intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
             op_res1 = intrp1.evaluate(func)(data, kernel)
             tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)
@@ -117,6 +125,21 @@ def test_conv2d_run():
                     fref=lambda x, w: topi.testing.depthwise_conv2d_python_nchw(
                         x, w, (1, 1), "SAME"))
 
+    # CUDA is disabled for 'direct' schedule:
+    # https://github.com/dmlc/tvm/pull/3070#issuecomment-486597553
+    # group conv2d
+    dshape = (1, 32, 18, 18)
+    kshape = (32, 4, 3, 3)
+    run_test_conv2d("float32", "float32", 1, dshape, kshape,
+                    padding=(1, 1), channels=32, groups=8, kernel_size=(3 ,3),
+                    except_targets=['cuda'])
+    # also group conv2d
+    dshape = (1, 32, 18, 18)
+    kshape = (64, 1, 3, 3)
+    run_test_conv2d("float32", "float32", 1, dshape, kshape,
+                    padding=(1, 1), channels=64, groups=32, kernel_size=(3 ,3),
+                    except_targets=['cuda'])
+
     # normal conv2d
     dshape = (1, 3, 224, 224)
     kshape = (10, 3, 3, 3)
index 601b9b6..be4ae35 100644 (file)
@@ -27,10 +27,13 @@ from ..util import traverse_inline, get_const_tuple, get_const_int
 from .. import nn, generic
 
 
-@autotvm.register_topi_compute(nn.group_conv2d_nchw, ['cuda', 'gpu'], ['direct', 'int8'])
+autotvm.register_topi_compute(nn.group_conv2d_nchw, ['cuda', 'gpu'], 'direct',
+                              nn.group_conv2d_nchw.fdefault)
+
+@autotvm.register_topi_compute(nn.group_conv2d_nchw, ['cuda', 'gpu'], ['int8'])
 def group_conv2d_nchw_cuda(cfg, data, kernel, stride, padding, dilation, groups,
                            out_dtype='float32'):
-    """Group convolution operator in NCHW layout.
+    """Group convolution operator for 'group_conv2d_NCHWc_int8'.
 
     Parameters
     ----------
@@ -76,7 +79,7 @@ def group_conv2d_nchw_cuda(cfg, data, kernel, stride, padding, dilation, groups,
         assert out_channels % groups == 0, "output channels must divide group size"
         assert channels % ic_block_factor == 0, \
             "Number of input channels per group must divide {}".format(ic_block_factor)
-        assert out_channels % 4 == 0, \
+        assert out_channels % oc_block_factor == 0, \
             "Number of output channels per group must divide {}".format(oc_block_factor)
 
         packed_data = tvm.compute((batch, channels // ic_block_factor, height, width,
@@ -99,6 +102,17 @@ def group_conv2d_nchw_cuda(cfg, data, kernel, stride, padding, dilation, groups,
     oc_chunk, _, kernel_h, kernel_w, oc_block, ic_block = get_const_tuple(
         packed_kernel.shape)
 
+    # TODO(kumasento): these assertions ensure that the number of groups
+    # should be smaller or equal to the number of blocks, so that each
+    # group will have at least one block.
+    # Shall we pad the channels to avoid raising assertions?
+    assert groups <= oc_chunk, \
+        ('Number of groups {} should be less than '
+         'output channel chunk size {}'.format(groups, oc_chunk))
+    assert groups <= ic_chunk, \
+        ('Number of groups {} should be less than '
+         'input channel chunk size {}'.format(groups, ic_chunk))
+
     if isinstance(stride, int):
         stride_h = stride_w = stride
     else:
@@ -109,9 +123,9 @@ def group_conv2d_nchw_cuda(cfg, data, kernel, stride, padding, dilation, groups,
     else:
         dilation_h, dilation_w = dilation
 
+    # pad the input data
     pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
         padding, (kernel_h, kernel_w))
-    # compute graph
     pad_before = [0, 0, pad_top, pad_left, 0]
     pad_after = [0, 0, pad_down, pad_right, 0]
     pad_data = pad(packed_data, pad_before, pad_after, name="pad_data")
@@ -129,6 +143,17 @@ def group_conv2d_nchw_cuda(cfg, data, kernel, stride, padding, dilation, groups,
     kh = tvm.reduce_axis((0, kernel_h), name='kh')
     kw = tvm.reduce_axis((0, kernel_w), name='kw')
 
+    # NOTE(kumasento): explanation of this snippet -
+    # oc_chunk//groups and ic_chunk//groups give you the number of blocks,
+    # i.e., chunk, per group.
+    # occ is the ID of the output channel block, so that occ//(oc_chunk//groups)
+    # produces the ID of the group.
+    # Multiplying that result with ic_chunk//groups resulting in the ID
+    # of the beginning block of the corresponding input group.
+    # Adding the block offset (icc) will give you the exact block ID.
+    #
+    # Compared with a normal convolution, group convolution only sums
+    # input channels from the group that an output channel resides in.
     conv = tvm.compute(oshape, lambda n, occ, oh, ow, ocb:
                        tvm.sum(pad_data[n, occ//(oc_chunk//groups)*(ic_chunk//groups)+icc,
                                         oh*stride_h+kh*dilation_h, ow*stride_w+kw*dilation_w, icb]
@@ -138,8 +163,10 @@ def group_conv2d_nchw_cuda(cfg, data, kernel, stride, padding, dilation, groups,
                                .astype('int32'),
                                axis=[icc, kh, kw, icb]))
 
+    # Type conversion
     output = tvm.compute(oshape, lambda *index: conv(*index).astype(out_dtype),
                          tag='group_conv2d_NCHWc_int8')
+
     num_flop = batch * oc_chunk * oc_block * out_height * out_width * \
         ic_chunk * ic_block * kernel_h * kernel_w * 2 // groups
     cfg.add_flop(num_flop)
@@ -295,7 +322,7 @@ def schedule_group_conv2d_NCHWc_int8(cfg, s, output):
 
 
 @autotvm.register_topi_schedule(generic.schedule_group_conv2d_nchw,
-                                ["cuda", "gpu"], ["direct", "int8"])
+                                ["cuda", "gpu"], ["int8"])
 def schedule_conv2d_nchw_cuda(cfg, outs):
     """TOPI schedule callback of group conv2d for cuda gpu
 
index 70ce579..7bd9568 100644 (file)
@@ -242,7 +242,7 @@ def schedule_depthwise_conv2d_NCHWc(outs):
 
 @tvm.target.generic_func
 def schedule_group_conv2d_nchw(outs):
-    """Schedule for conv2d_nchw
+    """Schedule for group_conv2d_nchw
 
     Parameters
     ----------
index 49c0bd7..06d4074 100644 (file)
@@ -603,4 +603,4 @@ def group_conv2d_nchw(Input, Filter, stride, padding, dilation, groups, out_dtyp
                  yy * stride_h + ry * dilation_h,
                  xx * stride_w + rx * dilation_w].astype(out_dtype) *
             Filter[ff, rc, ry, rx].astype(out_dtype),
-            axis=[rc, ry, rx]), tag="conv2d_nchw")
+            axis=[rc, ry, rx]), tag='group_conv2d_nchw')