[Relay][Topi][AutoTVM] Winograd support for Conv3D (#5186)
authorJosh Fromm <jwfromm@uw.edu>
Sun, 5 Apr 2020 21:59:38 +0000 (14:59 -0700)
committerGitHub <noreply@github.com>
Sun, 5 Apr 2020 21:59:38 +0000 (14:59 -0700)
* Functional conv3d winograd working.

* Formatted python code.

* registered conv3d winograd compute and started adding relay without_weight_transform operator.

* Add topi testing for conv3d winograd.

* Format file.

* small tweak to unrolling to prevent build sticking.

* Refactoring convolution ops in relay.

* Refactored relay convolutions.

* Bug fixes.

* Fixed static bug in convolution.

* Added conv3d alter op layout and related support.

* Bug fixes and testing done.

* Fix a few autotvm bugs.

* Drop silly debug print.

* Removed debug_skip_region.

* Add variant of conv3d_winograd that doesn't transform depth.

* initial infrastructure done for depthless conv.

* Fix no_depth schedule bugs.

* automatic topi switching between depth and depthless winograd.

* Fixed bug in schedule.

* lint fixes.

* Removed indents in convolution.cc

* missed a few indents oops.

* fixed flop count.

* One more small tweak.

* Change kernel pack inner axes order.

* Style changes.

* Comment fixes.

17 files changed:
docs/langref/relay_op.rst
include/tvm/relay/attrs/nn.h
python/tvm/relay/op/nn/_nn.py
python/tvm/relay/op/nn/nn.py
python/tvm/relay/op/nn/util.py
python/tvm/relay/op/op_attrs.py
python/tvm/relay/op/strategy/cuda.py
python/tvm/relay/op/strategy/generic.py
src/relay/op/nn/convolution.cc
src/relay/op/nn/convolution.h
tests/python/relay/test_op_level2.py
topi/python/topi/cuda/__init__.py
topi/python/topi/cuda/conv3d_alter_op.py [new file with mode: 0644]
topi/python/topi/cuda/conv3d_winograd.py [new file with mode: 0644]
topi/python/topi/generic/nn.py
topi/python/topi/nn/conv3d.py
topi/tests/python/test_topi_conv3d_winograd.py [new file with mode: 0644]

index ac636f8..f1d7d44 100644 (file)
@@ -82,8 +82,13 @@ This level enables typical convnet models.
    tvm.relay.nn.pad
    tvm.relay.nn.lrn
    tvm.relay.nn.l2_normalize
+   tvm.relay.nn.bitpack
+   tvm.relay.nn.bitserial_dense
+   tvm.relay.nn.bitserial_conv2d
    tvm.relay.nn.contrib_conv2d_winograd_without_weight_transform
    tvm.relay.nn.contrib_conv2d_winograd_weight_transform
+   tvm.relay.nn.contrib_conv3d_winograd_without_weight_transform
+   tvm.relay.nn.contrib_conv3d_winograd_weight_transform
 
 
 **Level 3: Additional Math And Transform Operators**
index 5794ddd..536e414 100644 (file)
@@ -156,12 +156,12 @@ struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
 };
 
 /*! \brief Attributes used in winograd weight transformation operators */
-struct Conv2DWinogradWeightTransformAttrs :
-    public tvm::AttrsNode<Conv2DWinogradWeightTransformAttrs> {
+struct ConvWinogradWeightTransformAttrs :
+    public tvm::AttrsNode<ConvWinogradWeightTransformAttrs> {
   int tile_size;
 
-  TVM_DECLARE_ATTRS(Conv2DWinogradWeightTransformAttrs,
-      "relay.attrs.Conv2DWinogradWeightTransformAttrs") {
+  TVM_DECLARE_ATTRS(ConvWinogradWeightTransformAttrs,
+      "relay.attrs.ConvWinogradWeightTransformAttrs") {
     TVM_ATTR_FIELD(tile_size)
       .describe("Tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)");
   }
@@ -306,6 +306,69 @@ struct Conv3DAttrs : public tvm::AttrsNode<Conv3DAttrs> {
   }
 };
 
+/*! \brief Attributes used in 3d winograd convolution operators */
+struct Conv3DWinogradAttrs : public tvm::AttrsNode<Conv3DWinogradAttrs> {
+  int tile_size;
+  Array<IndexExpr> strides;
+  Array<IndexExpr> padding;
+  Array<IndexExpr> dilation;
+  int groups;
+  IndexExpr channels;
+  Array<IndexExpr> kernel_size;
+  std::string data_layout;
+  std::string kernel_layout;
+  std::string out_layout;
+  DataType out_dtype;
+
+  TVM_DECLARE_ATTRS(Conv3DWinogradAttrs, "relay.attrs.Conv3DWinogradAttrs") {
+    TVM_ATTR_FIELD(tile_size)
+      .describe("The tile size of winograd. E.g. 2 for F(2x2x2, 3x3x3) and 4 for F(4x4x4, 3x3x3)");
+    TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1, 1}))
+        .describe("Specifies the strides of the convolution.");
+    TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0, 0}))
+        .describe("If padding is non-zero, then the input is implicitly zero-padded"
+                  "Padding support both symmetric and asymmetric as"
+                  "one int : same padding used on all sides"
+                  "three int : back, bottom, right will use same padding as front, top, left"
+                  "six int : padding width in the order of (front, top, left, back, bottom,"
+                  "right)");
+    TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, 1, 1}))
+        .describe("Specifies the dilation rate to use for dilated convolution.");
+    TVM_ATTR_FIELD(groups).set_default(1)
+        .describe("Controls the connections between inputs and outputs."
+                  "At groups=1, all inputs are convolved to all outputs."
+                  "At groups=2, the operation becomes equivalent to having two convolution"
+                  "layers side by side, each seeing half the input channels, and producing"
+                  "half the output channels, and both subsequently concatenated.");
+    TVM_ATTR_FIELD(channels)
+        .describe("The number of output channels in the convolution."
+                  " If it is not set, inferred by shape of the weight.")
+        .set_default(NullValue<IndexExpr>());
+    TVM_ATTR_FIELD(kernel_size)
+        .describe("Specifies the dimensions of the convolution window.")
+        .set_default(NullValue<Array<IndexExpr> >());
+    TVM_ATTR_FIELD(data_layout).set_default("NCDHW")
+        .describe("Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc."
+                  "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
+                  "dimensions respectively. Convolution is applied on the 'D', 'H' and"
+                  "'W' dimensions.");
+    TVM_ATTR_FIELD(kernel_layout).set_default("OIDHW")
+        .describe("Dimension ordering of weight. Can be 'OIDHW', 'OIDHW16o16i', etc."
+                  "'O', 'I', 'D', 'H', 'W' stands for num_filter, input_channel, depth, height,"
+                  "and width dimensions respectively.");
+    TVM_ATTR_FIELD(out_layout).set_default("")
+        .describe("Dimension ordering of output. Can be 'NCDHW', 'NDHWC', etc."
+                  "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
+                  "dimensions respectively. Default to be same as input layout.");
+
+    // use 0 bits to indicate none.
+    TVM_ATTR_FIELD(out_dtype)
+        .set_default(NullValue<DataType>())
+        .describe("Output data type, set to explicit type under mixed precision setting");
+  }
+};
+
+
 /*! \brief Attributes used in softmax operators */
 struct SoftmaxAttrs : public tvm::AttrsNode<SoftmaxAttrs> {
   int axis;
index 65a1162..39d98c0 100644 (file)
@@ -178,6 +178,29 @@ def legalize_conv2d_transpose(attrs, inputs, types):
 reg.register_strategy("nn.conv3d", strategy.conv3d_strategy)
 reg.register_pattern("nn.conv3d", OpPattern.OUT_ELEMWISE_FUSABLE)
 
+@reg.register_alter_op_layout("nn.conv3d")
+def alter_op_layout_conv3d(attrs, inputs, tinfos, out_type):
+    """Alternate the layout of conv3d"""
+    return topi.nn.conv3d_alter_layout(attrs, inputs, tinfos, out_type)
+
+# conv3d_winograd related operators
+reg.register_strategy("nn.contrib_conv3d_winograd_without_weight_transform",
+                      strategy.conv3d_winograd_without_weight_transfrom_strategy)
+reg.register_pattern("nn.contrib_conv3d_winograd_without_weight_transform",
+                     OpPattern.OUT_ELEMWISE_FUSABLE)
+
+@reg.register_compute("nn.contrib_conv3d_winograd_weight_transform")
+def compute_contrib_conv3d_winograd_weight_transform(attrs, inputs, out_dtype):
+    """Compute definition of contrib_conv3d_winograd_weight_transform"""
+    out = topi.nn.conv3d_winograd_weight_transform(
+        inputs[0], attrs.get_int('tile_size'))
+    return [out]
+
+reg.register_schedule("nn.contrib_conv3d_winograd_weight_transform",
+                      strategy.schedule_conv3d_winograd_weight_transform)
+reg.register_pattern("nn.contrib_conv3d_winograd_weight_transform",
+                     OpPattern.OUT_ELEMWISE_FUSABLE)
+
 
 # conv1d_transpose
 reg.register_strategy("nn.conv1d_transpose", strategy.conv1d_transpose_strategy)
index 64148bb..a126e8d 100644 (file)
@@ -19,7 +19,7 @@
 from __future__ import absolute_import as _abs
 from ...expr import TupleWrapper
 from . import _make
-from .util import get_pad_tuple2d
+from .util import get_pad_tuple2d, get_pad_tuple3d
 
 
 def conv1d(data,
@@ -295,13 +295,84 @@ def conv3d(data,
         strides = (strides, strides, strides)
     if isinstance(dilation, int):
         dilation = (dilation, dilation, dilation)
-    if isinstance(padding, int):
-        padding = (padding, padding, padding)
+    padding = get_pad_tuple3d(padding)
     return _make.conv3d(data, weight, strides, padding, dilation,
                         groups, channels, kernel_size, data_layout,
                         kernel_layout, out_layout, out_dtype)
 
 
+def contrib_conv3d_winograd_without_weight_transform(data,
+                                                     weight,
+                                                     tile_size,
+                                                     strides=(1, 1, 1),
+                                                     padding=(0, 0, 0),
+                                                     dilation=(1, 1, 1),
+                                                     groups=1,
+                                                     channels=None,
+                                                     kernel_size=None,
+                                                     data_layout="NCDHW",
+                                                     kernel_layout="OIDHW",
+                                                     out_layout="",
+                                                     out_dtype=""):
+    r"""3D convolution with winograd algorithm.
+
+    The basic parameters are the same as the ones in vanilla conv3d.
+    It assumes the weight is pre-transformed by nn.contrib_conv3d_winograd_weight_transform
+
+    Parameters
+    ----------
+    data : tvm.relay.Expr
+        The input data to the operator.
+
+    weight : tvm.relay.Expr
+        The weight expressions.
+
+    tile_size : int
+        The Tile size of winograd. E.g. 2 for F(2x2x2, 3x3x3) and 4 for F(4x4x4, 3x3x3)
+
+    strides : tuple of int, optional
+        The strides of convolution.
+
+    padding : tuple of int, optional
+        The padding of convolution on both sides of inputs before convolution.
+
+    dilation : tuple of int, optional
+        Specifies the dilation rate to be used for dilated convolution.
+
+    groups : int, optional
+        Number of groups for grouped convolution.
+
+    channels : int, optional
+        Number of output channels of this convolution.
+
+    kernel_size : tuple of int, optional
+        The spatial of the convolution kernel.
+
+    data_layout : str, optional
+        Layout of the input.
+
+    kernel_layout : str, optional
+        Layout of the weight.
+
+    out_layout : str, optional
+        Layout of the output, by default, out_layout is the same as data_layout
+
+    out_dtype : str, optional
+        Specifies the output data type for mixed precision conv2d.
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The computed result.
+    """
+    # convert 3-way padding to 6-way padding
+    padding = get_pad_tuple3d(padding)
+    return _make.contrib_conv3d_winograd_without_weight_transform(
+        data, weight, tile_size, strides, padding, dilation,
+        groups, channels, kernel_size, data_layout,
+        kernel_layout, out_layout, out_dtype)
+
+
 def conv2d_transpose(data,
                      weight,
                      strides=(1, 1),
@@ -1952,6 +2023,29 @@ def contrib_conv2d_winograd_weight_transform(weight,
     return _make.contrib_conv2d_winograd_weight_transform(weight, tile_size)
 
 
+def contrib_conv3d_winograd_weight_transform(weight,
+                                             tile_size):
+    r"""Weight Transformation part for 3D convolution with winograd algorithm.
+
+    We separate this as a single op to enable pre-compute for inference.
+    Use this together with nn.contrib_conv3d_winograd_without_weight_transform
+
+    Parameters
+    ----------
+    weight : tvm.relay.Expr
+        The weight expressions.
+
+    tile_size : int
+        The Tile size of winograd. E.g. 2 for F(2x2x2, 3x3x3) and 4 for F(4x4x4, 3x3x3)
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The computed result.
+    """
+    return _make.contrib_conv3d_winograd_weight_transform(weight, tile_size)
+
+
 def contrib_conv2d_winograd_nnpack_weight_transform(weight,
                                                     convolution_algorithm,
                                                     out_dtype=""):
index 323ef7f..1fdcad7 100644 (file)
@@ -54,3 +54,46 @@ def get_pad_tuple2d(padding):
     pad_top = (pad_h + 1) // 2
     pad_left = (pad_w + 1) // 2
     return pad_top, pad_left, pad_h - pad_top, pad_w - pad_left
+
+
+def get_pad_tuple3d(padding):
+    """Common code to get the pad option
+    Parameters
+    ----------
+    padding : Union[int, Tuple[int, ...]]
+        Padding size
+    Returns
+    -------
+    pad_front : int
+        Padding size on front
+    pad_top : int
+        Padding size on top
+    pad_left : int
+        Padding size on left
+    pad_back : int
+        Padding size on back
+    pad_down : int
+        Padding size on down.
+    pad_right : int
+        Padding size on right.
+    """
+    # compute the padding size
+    if isinstance(padding, container.Array):
+        padding = list(padding)
+    if isinstance(padding, (tuple, list)):
+        if len(padding) == 3:
+            pad_d = padding[0] * 2
+            pad_h = padding[1] * 2
+            pad_w = padding[2] * 2
+        elif len(padding) == 6:
+            return padding[0], padding[1], padding[2], padding[3], padding[4], padding[5]
+        else:
+            raise ValueError("Size of padding can only be 3 or 6")
+    elif isinstance(padding, int):
+        pad_d = pad_h = pad_w = padding * 2
+    else:
+        raise ValueError("Unknown padding option %s" % padding)
+    pad_front = (pad_d + 1) // 2
+    pad_top = (pad_h + 1) // 2
+    pad_left = (pad_w + 1) // 2
+    return pad_front, pad_top, pad_left, pad_d - pad_front, pad_h - pad_top, pad_w - pad_left
index edc2160..1a07486 100644 (file)
@@ -34,9 +34,19 @@ class Conv2DWinogradAttrs(Attrs):
     """Attributes for nn.contrib_conv2d_winograd_without_weight_transform"""
 
 
-@tvm._ffi.register_object("relay.attrs.Conv2DWinogradWeightTransformAttrs")
-class Conv2DWinogradWeightTransformAttrs(Attrs):
-    """Attributes for nn.contrib_conv2d_winograd_weight_transform"""
+@tvm._ffi.register_object("relay.attrs.Conv3DAttrs")
+class Conv3DAttrs(Attrs):
+    """Attributes for nn.conv3d"""
+
+
+@tvm._ffi.register_object("relay.attrs.Conv3DWinogradAttrs")
+class Conv3DWinogradAttrs(Attrs):
+    """Attributes for nn.contrib_conv3d_winograd_without_weight_transform"""
+
+
+@tvm._ffi.register_object("relay.attrs.ConvWinogradWeightTransformAttrs")
+class ConvWinogradWeightTransformAttrs(Attrs):
+    """Attributes for nn.contrib_convNd_winograd_weight_transform"""
 
 
 @tvm._ffi.register_object("relay.attrs.Conv2DWinogradNNPACKWeightTransformAttrs")
index db03c59..45ee701 100644 (file)
@@ -233,13 +233,25 @@ def conv2d_transpose_strategy_cuda(attrs, inputs, out_type, target):
 def conv3d_strategy_cuda(attrs, inputs, out_type, target):
     """conv3d cuda strategy"""
     strategy = _op.OpStrategy()
+    _, kernel = inputs
     layout = attrs.data_layout
+    _, stride_h, stride_w = attrs.get_int_tuple("strides")
+    _, dilation_h, dilation_w = attrs.get_int_tuple("dilation")
     assert layout in ["NCDHW", "NDHWC"], "Not support this layout {} yet".format(layout)
     if layout == "NCDHW":
         strategy.add_implementation(wrap_compute_conv3d(topi.cuda.conv3d_ncdhw),
                                     wrap_topi_schedule(topi.cuda.schedule_conv3d_ncdhw),
                                     name="conv3d_ncdhw.cuda",
                                     plevel=10)
+        _, _, _, kh, kw = get_const_tuple(kernel.shape)
+        if 2 < kh < 8 and 2 < kw < 8 and kh == kw and \
+            stride_h == 1 and stride_w == 1 and \
+            dilation_h == 1 and dilation_w == 1:
+            strategy.add_implementation(
+                wrap_compute_conv3d(topi.cuda.conv3d_ncdhw_winograd),
+                wrap_topi_schedule(topi.cuda.schedule_conv3d_ncdhw_winograd),
+                name="conv3d_ncdhw_winograd.cuda",
+                plevel=5)
     else: # layout == "NDHWC":
         strategy.add_implementation(wrap_compute_conv3d(topi.cuda.conv3d_ndhwc),
                                     wrap_topi_schedule(topi.cuda.schedule_conv3d_ndhwc),
@@ -252,6 +264,26 @@ def conv3d_strategy_cuda(attrs, inputs, out_type, target):
                                     plevel=15)
     return strategy
 
+@conv3d_winograd_without_weight_transfrom_strategy.register(["cuda", "gpu"])
+def conv3d_winograd_without_weight_transfrom_strategy_cuda(attrs, inputs, out_type, target):
+    """conv3d_winograd_without_weight_transfrom cuda strategy"""
+    dilation = attrs.get_int_tuple("dilation")
+    groups = attrs.get_int("groups")
+    layout = attrs.data_layout
+    assert dilation == (1, 1, 1), "Do not support dilate now"
+    assert groups == 1, "Do not supoort arbitrary group number"
+    strategy = _op.OpStrategy()
+    if layout == "NCDHW":
+        strategy.add_implementation(
+            wrap_compute_conv3d(topi.cuda.conv3d_ncdhw_winograd_without_weight_transform),
+            wrap_topi_schedule(
+                topi.cuda.schedule_conv3d_ncdhw_winograd_without_weight_transform),
+            name="conv3d_ncdhw_winograd_without_weight_transform.cuda")
+    else:
+        raise RuntimeError("Unsupported conv3d_winograd_without_weight_transfrom layout {}".
+                           format(layout))
+    return strategy
+
 @conv1d_strategy.register(["cuda", "gpu"])
 def conv1d_strategy_cuda(attrs, inputs, out_type, target):
     """conv1d cuda strategy"""
index 573df36..388e104 100644 (file)
@@ -374,6 +374,19 @@ def conv3d_strategy(attrs, inputs, out_type, target):
         raise ValueError("Not support this layout {} yet".format(layout))
     return strategy
 
+# conv3d_winograd_without_weight_transform
+@override_native_generic_func("conv3d_winograd_without_weight_transform_strategy")
+def conv3d_winograd_without_weight_transfrom_strategy(attrs, inputs, out_type, target):
+    """conv3d_winograd_without_weight_transfrom generic strategy"""
+    raise ValueError("No generic implemenation for conv3d_winograd_without_weight_transform")
+
+# conv3d_winograd_weight_transform
+@generic_func
+def schedule_conv3d_winograd_weight_transform(attrs, outs, target):
+    """Schedule conv3d_winograd_weight_transform"""
+    with target:
+        return topi.generic.schedule_conv3d_winograd_weight_transform(outs)
+
 # conv1d
 def wrap_compute_conv1d(topi_compute):
     """wrap conv1d topi compute"""
index 547d5a6..66dab57 100644 (file)
@@ -59,10 +59,113 @@ Expr MakeConv(Expr data,
   attrs->kernel_layout = std::move(kernel_layout);
   attrs->out_layout = std::move(out_layout);
   attrs->out_dtype = std::move(out_dtype);
-  static const Op& op = Op::Get(op_name);
+  const Op& op = Op::Get(op_name);
   return Call(op, {data, weight}, Attrs(attrs), {});
 }
 
+template <typename T>
+Expr MakeConvWinograd(Expr data,
+                      Expr weight,
+                      int tile_size,
+                      Array<IndexExpr> strides,
+                      Array<IndexExpr> padding,
+                      Array<IndexExpr> dilation,
+                      int groups,
+                      IndexExpr channels,
+                      Array<IndexExpr> kernel_size,
+                      std::string data_layout,
+                      std::string kernel_layout,
+                      std::string out_layout,
+                      DataType out_dtype,
+                      std::string op_name) {
+  auto attrs = make_object<T>();
+  attrs->tile_size = tile_size;
+  attrs->strides = std::move(strides);
+  attrs->padding = std::move(padding);
+  attrs->dilation = std::move(dilation);
+  attrs->groups = groups;
+  attrs->channels = std::move(channels);
+  attrs->kernel_size = std::move(kernel_size);
+  attrs->data_layout = std::move(data_layout);
+  attrs->kernel_layout = std::move(kernel_layout);
+  attrs->out_layout = std::move(out_layout);
+  attrs->out_dtype = std::move(out_dtype);
+  const Op& op = Op::Get(op_name);
+  return Call(op, {data, weight}, Attrs(attrs), {});
+}
+
+Expr MakeConvWinogradWeightTransform(Expr weight,
+                                     int tile_size,
+                                     std::string op_name) {
+  auto attrs = make_object<ConvWinogradWeightTransformAttrs>();
+  attrs->tile_size = tile_size;
+  const Op& op = Op::Get(op_name);
+  return Call(op, {weight}, Attrs(attrs), {});
+}
+
+template <typename T>
+Expr MakeConvTranspose(Expr data,
+                       Expr weight,
+                       Array<IndexExpr> strides,
+                       Array<IndexExpr> padding,
+                       Array<IndexExpr> dilation,
+                       int groups,
+                       IndexExpr channels,
+                       Array<IndexExpr> kernel_size,
+                       std::string data_layout,
+                       std::string kernel_layout,
+                       std::string out_layout,
+                       Array<IndexExpr> output_padding,
+                       DataType out_dtype,
+                       std::string op_name) {
+  auto attrs = make_object<T>();
+  attrs->strides = std::move(strides);
+  attrs->padding = std::move(padding);
+  attrs->dilation = std::move(dilation);
+  attrs->groups = groups;
+  attrs->channels = std::move(channels);
+  attrs->kernel_size = std::move(kernel_size);
+  attrs->data_layout = std::move(data_layout);
+  attrs->kernel_layout = std::move(kernel_layout);
+  attrs->out_layout = std::move(out_layout);
+  attrs->output_padding = std::move(output_padding);
+  attrs->out_dtype = std::move(out_dtype);
+  const Op& op = Op::Get(op_name);
+  return Call(op, {data, weight}, Attrs(attrs), {});
+}
+
+template <typename T>
+Expr MakeDeformableConv(Expr data,
+                        Expr offset,
+                        Expr weight,
+                        Array<IndexExpr> strides,
+                        Array<IndexExpr> padding,
+                        Array<IndexExpr> dilation,
+                        int deformable_groups,
+                        int groups,
+                        int channels,
+                        Array<IndexExpr> kernel_size,
+                        std::string data_layout,
+                        std::string kernel_layout,
+                        std::string out_layout,
+                        DataType out_dtype,
+                        std::string op_name) {
+  auto attrs = make_object<T>();
+  attrs->strides = strides;
+  attrs->padding = padding;
+  attrs->dilation = dilation;
+  attrs->deformable_groups = deformable_groups;
+  attrs->groups = groups;
+  attrs->channels = channels;
+  attrs->kernel_size = kernel_size;
+  attrs->data_layout = data_layout;
+  attrs->kernel_layout = kernel_layout;
+  attrs->out_layout = out_layout;
+  attrs->out_dtype = out_dtype;
+  const Op& op = Op::Get(op_name);
+  return Call(op, {data, offset, weight}, Attrs{attrs}, {});
+}
+
 
 // relay.nn.conv1d
 TVM_REGISTER_NODE_TYPE(Conv1DAttrs);
@@ -153,6 +256,7 @@ with the layer input to produce a tensor of outputs.
 .add_type_rel("Conv2D", Conv2DRel<Conv2DAttrs>)
 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv2DAttrs>);
 
+
 // relay.nn.conv3d
 TVM_REGISTER_NODE_TYPE(Conv3DAttrs);
 
@@ -198,138 +302,29 @@ with the layer input to produce a tensor of outputs.
 .add_type_rel("Conv3D", Conv3DRel<Conv3DAttrs>)
 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv3DAttrs>);
 
+
 // relay.nn.conv2d_transpose
 TVM_REGISTER_NODE_TYPE(Conv2DTransposeAttrs);
 
-bool Conv2DTransposeRel(const Array<Type>& types,
-                        int num_inputs,
-                        const Attrs& attrs,
-                        const TypeReporter& reporter) {
-  CHECK_EQ(types.size(), 3);
-  const auto* data = types[0].as<TensorTypeNode>();
-  const auto* weight = types[1].as<TensorTypeNode>();
-  if (data == nullptr) return false;
-
-  static const Layout kNCHW("NCHW");
-  static const Layout kOIHW("OIHW");
-
-  const Conv2DTransposeAttrs* param = attrs.as<Conv2DTransposeAttrs>();
-  CHECK(param != nullptr);
-  const Layout in_layout(param->data_layout);
-  const Layout kernel_layout(param->kernel_layout);
-
-  const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW);
-  CHECK(trans_in_layout.defined())
-    << "Conv only support input layouts that are convertible from NCHW."
-    << " But got " << in_layout;
-
-  const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW);
-  CHECK(trans_kernel_layout.defined())
-    << "Conv only support kernel layouts that are convertible from OIHW."
-    << " But got "<< kernel_layout;
-
-  Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
-  const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW);
-  CHECK(trans_out_layout.defined())
-    << "Conv only support output layouts that are convertible from NCHW."
-    << " But got " << out_layout;
-
-  IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
-
-  auto dshape_nchw = trans_in_layout.ForwardShape(data->shape);
-
-  // infer weight if the kernel_size and channels are defined
-  if (param->kernel_size.defined() && param->channels.defined()) {
-    CHECK_EQ(param->kernel_size.size(), 2);
-    CHECK_EQ(param->dilation.size(), 2);
-
-    Array<IndexExpr> wshape({dshape_nchw[1],
-            indexdiv(param->channels, param->groups),
-            param->kernel_size[0],
-            param->kernel_size[1]});
-
-    wshape = trans_kernel_layout.BackwardShape(wshape);
-    dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
-    dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
-    channels = param->channels;
-
-    // assign result to reporter
-    reporter->Assign(types[1], TensorType(wshape, data->dtype));
-  } else {
-    // use weight to infer the conv shape.
-    if (weight == nullptr) return false;
-    auto wshape = trans_kernel_layout.ForwardShape(weight->shape);
-    if (param->kernel_size.defined()) {
-      CHECK_EQ(param->kernel_size.size(), 2);
-      // check the size
-      CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) &&
-            reporter->AssertEQ(param->kernel_size[1], wshape[3]))
-          << "Conv2D: shape of weight is inconsistent with kernel_size, "
-          << " kernel_size=" << param->kernel_size
-          << " wshape=" << Array<IndexExpr>(wshape);
-    }
-    if (param->channels.defined()) {
-      CHECK(reporter->AssertEQ(param->channels, wshape[1]))
-          << "Conv2D: shape of weight is inconsistent with channels, "
-          << " channels=" << param->channels
-          << " wshape=" << Array<IndexExpr>(wshape);
-    }
-    CHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[0]));
-    channels = wshape[1];
-    dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0];
-    dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
-  }
-  // dilation
-  Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
-  IndexExpr pad_h, pad_w;
-  GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
-  oshape.Set(2, (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y -
-                 pad_h + param->output_padding[0]));
-  oshape.Set(3, (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x -
-                 pad_w + param->output_padding[1]));
-
-  DataType out_dtype = param->out_dtype;
-  if (out_dtype.bits() == 0) {
-    out_dtype = data->dtype;
-  }
-  oshape = trans_out_layout.BackwardShape(oshape);
-  reporter->Assign(types[2], TensorType(oshape, out_dtype));
-  return true;
-}
-
-
-Expr MakeConv2DTranspose(Expr data,
-                         Expr weight,
-                         Array<IndexExpr> strides,
-                         Array<IndexExpr> padding,
-                         Array<IndexExpr> dilation,
-                         int groups,
-                         IndexExpr channels,
-                         Array<IndexExpr> kernel_size,
-                         std::string data_layout,
-                         std::string kernel_layout,
-                         std::string out_layout,
-                         Array<IndexExpr> output_padding,
-                         DataType out_dtype) {
-  auto attrs = make_object<Conv2DTransposeAttrs>();
-  attrs->channels = std::move(channels);
-  attrs->kernel_size = std::move(kernel_size);
-  attrs->strides = std::move(strides);
-  attrs->padding = std::move(padding);
-  attrs->output_padding = std::move(output_padding);
-  attrs->dilation = std::move(dilation);
-  attrs->groups = groups;
-  attrs->data_layout = std::move(data_layout);
-  attrs->kernel_layout = std::move(kernel_layout);
-  attrs->out_layout = std::move(out_layout);
-  attrs->out_dtype = std::move(out_dtype);
-  static const Op& op = Op::Get("nn.conv2d_transpose");
-  return Call(op, {data, weight}, Attrs(attrs), {});
-}
-
-
 TVM_REGISTER_GLOBAL("relay.op.nn._make.conv2d_transpose")
-.set_body_typed(MakeConv2DTranspose);
+.set_body_typed([](Expr data,
+                   Expr weight,
+                   Array<IndexExpr> strides,
+                   Array<IndexExpr> padding,
+                   Array<IndexExpr> dilation,
+                   int groups,
+                   IndexExpr channels,
+                   Array<IndexExpr> kernel_size,
+                   std::string data_layout,
+                   std::string kernel_layout,
+                   std::string out_layout,
+                   Array<IndexExpr> output_padding,
+                   DataType out_dtype) {
+  return MakeConvTranspose<Conv2DTransposeAttrs>(
+    data, weight, strides, padding, dilation,
+    groups, channels, kernel_size, data_layout,
+    kernel_layout, out_layout, output_padding, out_dtype, "nn.conv2d_transpose");
+});
 
 RELAY_REGISTER_OP("nn.conv2d_transpose")
 .describe(R"code(Transposed 2D convolution layer (sometimes called Deconvolution).
@@ -359,136 +354,31 @@ v            (batch_size, channels, out_height, out_width) if `layout` is `NCHW`
 .add_argument("weight", "Tensor", "The weight tensor.")
 .set_support_level(2)
 .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
-                               ConvInferCorrectLayout<Conv2DTransposeAttrs>)
-.add_type_rel("Conv2DTranspose", Conv2DTransposeRel);
-
+                                ConvInferCorrectLayout<Conv2DTransposeAttrs>)
+.add_type_rel("Conv2DTranspose", Conv2DTransposeRel<Conv2DTransposeAttrs>);
 
 // relay.nn.conv1d_transpose
 TVM_REGISTER_NODE_TYPE(Conv1DTransposeAttrs);
 
-bool Conv1DTransposeRel(const Array<Type>& types,
-                        int num_inputs,
-                        const Attrs& attrs,
-                        const TypeReporter& reporter) {
-  CHECK_EQ(types.size(), 3);
-  const auto* data = types[0].as<TensorTypeNode>();
-  const auto* weight = types[1].as<TensorTypeNode>();
-  if (data == nullptr) return false;
-
-  static const Layout kNCW("NCW");
-  static const Layout kOIW("OIW");
-
-  const Conv1DTransposeAttrs* param = attrs.as<Conv1DTransposeAttrs>();
-  CHECK(param != nullptr);
-  const Layout in_layout(param->data_layout);
-  const Layout kernel_layout(param->kernel_layout);
-
-  const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCW);
-  CHECK(trans_in_layout.defined())
-    << "Conv only support input layouts that are convertible from NCW."
-    << " But got " << in_layout;
-
-  const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIW);
-  CHECK(trans_kernel_layout.defined())
-    << "Conv only support kernel layouts that are convertible from OIW."
-    << " But got "<< kernel_layout;
-
-  Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
-  const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCW);
-  CHECK(trans_out_layout.defined())
-    << "Conv only support output layouts that are convertible from NCW."
-    << " But got " << out_layout;
-
-  IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
-
-  auto dshape_ncw = trans_in_layout.ForwardShape(data->shape);
-
-  // infer weight if the kernel_size and channels are defined
-  if (param->kernel_size.defined() && param->channels.defined()) {
-    CHECK_EQ(param->kernel_size.size(), 1);
-    CHECK_EQ(param->dilation.size(), 1);
-
-    Array<IndexExpr> wshape({dshape_ncw[1],
-            indexdiv(param->channels, param->groups),
-            param->kernel_size[0]});
-
-    wshape = trans_kernel_layout.BackwardShape(wshape);
-    dilated_ksize_x = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
-    channels = param->channels;
-
-    // assign result to reporter
-    reporter->Assign(types[1], TensorType(wshape, data->dtype));
-  } else {
-    // use weight to infer the conv shape.
-    if (weight == nullptr) return false;
-    auto wshape = trans_kernel_layout.ForwardShape(weight->shape);
-    if (param->kernel_size.defined()) {
-      CHECK_EQ(param->kernel_size.size(), 1);
-      // check the size
-      CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]))
-          << "Conv1D: shape of weight is inconsistent with kernel_size, "
-          << " kernel_size=" << param->kernel_size
-          << " wshape=" << Array<IndexExpr>(wshape);
-    }
-    if (param->channels.defined()) {
-      CHECK(reporter->AssertEQ(param->channels, wshape[1]))
-          << "Conv1D: shape of weight is inconsistent with channels, "
-          << " channels=" << param->channels
-          << " wshape=" << Array<IndexExpr>(wshape);
-    }
-    CHECK(reporter->AssertEQ(indexdiv(dshape_ncw[1], param->groups), wshape[0]));
-    channels = wshape[1];
-    dilated_ksize_x = 1 + (wshape[2] - 1) * param->dilation[0];
-  }
-  // dilation
-  IndexExpr pad_w;
-  GetPaddingWidth(param->padding, &pad_w);
-  Array<IndexExpr> oshape({dshape_ncw[0], channels, 0});
-  oshape.Set(2, (param->strides[0] * (dshape_ncw[2] - 1) + dilated_ksize_x -
-                 pad_w + param->output_padding[0]));
-
-  DataType out_dtype = param->out_dtype;
-  if (out_dtype.bits() == 0) {
-    out_dtype = data->dtype;
-  }
-  oshape = trans_out_layout.BackwardShape(oshape);
-  reporter->Assign(types[2], TensorType(oshape, out_dtype));
-  return true;
-}
-
-
-Expr MakeConv1DTranspose(Expr data,
-                         Expr weight,
-                         Array<IndexExpr> strides,
-                         Array<IndexExpr> padding,
-                         Array<IndexExpr> dilation,
-                         int groups,
-                         IndexExpr channels,
-                         Array<IndexExpr> kernel_size,
-                         std::string data_layout,
-                         std::string kernel_layout,
-                         std::string out_layout,
-                         Array<IndexExpr> output_padding,
-                         DataType out_dtype) {
-  auto attrs = make_object<Conv1DTransposeAttrs>();
-  attrs->channels = std::move(channels);
-  attrs->kernel_size = std::move(kernel_size);
-  attrs->strides = std::move(strides);
-  attrs->padding = std::move(padding);
-  attrs->output_padding = std::move(output_padding);
-  attrs->dilation = std::move(dilation);
-  attrs->groups = groups;
-  attrs->data_layout = std::move(data_layout);
-  attrs->kernel_layout = std::move(kernel_layout);
-  attrs->out_layout = std::move(out_layout);
-  attrs->out_dtype = std::move(out_dtype);
-  static const Op& op = Op::Get("nn.conv1d_transpose");
-  return Call(op, {data, weight}, Attrs(attrs), {});
-}
-
-
 TVM_REGISTER_GLOBAL("relay.op.nn._make.conv1d_transpose")
-.set_body_typed(MakeConv1DTranspose);
+.set_body_typed([](Expr data,
+                   Expr weight,
+                   Array<IndexExpr> strides,
+                   Array<IndexExpr> padding,
+                   Array<IndexExpr> dilation,
+                   int groups,
+                   IndexExpr channels,
+                   Array<IndexExpr> kernel_size,
+                   std::string data_layout,
+                   std::string kernel_layout,
+                   std::string out_layout,
+                   Array<IndexExpr> output_padding,
+                   DataType out_dtype) {
+  return MakeConvTranspose<Conv1DTransposeAttrs>(
+    data, weight, strides, padding, dilation,
+    groups, channels, kernel_size, data_layout,
+    kernel_layout, out_layout, output_padding, out_dtype, "nn.conv1d_transpose");
+});
 
 RELAY_REGISTER_OP("nn.conv1d_transpose")
 .describe(R"code(Transposed 1D convolution layer (sometimes called Deconvolution).
@@ -516,128 +406,30 @@ said convolution.
 .add_argument("data", "Tensor", "The input tensor.")
 .add_argument("weight", "Tensor", "The weight tensor.")
 .set_support_level(2)
-.add_type_rel("Conv1DTranspose", Conv1DTransposeRel);
-
+.add_type_rel("Conv1DTranspose", Conv1DTransposeRel<Conv1DTransposeAttrs>);
 
 // relay.nn.contrib_conv2d_winograd_without_weight_transform
 TVM_REGISTER_NODE_TYPE(Conv2DWinogradAttrs);
 
-template<class Param>
-bool Conv2DWinogradRel(const Array<Type>& types,
-                       int num_inputs,
-                       const Attrs& attrs,
-                       const TypeReporter& reporter) {
-  CHECK_EQ(types.size(), 3);
-  const auto* data = types[0].as<TensorTypeNode>();
-  if (data == nullptr) return false;
-  static const Layout kNCHW("NCHW");
-  static const Layout kOIHW("OIHW");
-
-  const Param* param = attrs.as<Param>();
-  CHECK(param != nullptr);
-  const Layout in_layout(param->data_layout);
-  const Layout kernel_layout(param->kernel_layout);
-
-  const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW);
-  CHECK(trans_in_layout.defined())
-    << "Conv only support input layouts that are convertible from NCHW."
-    << " But got " << in_layout;
-
-  const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW);
-  CHECK(trans_kernel_layout.defined())
-    << "Conv only support kernel layouts that are convertible from OIHW."
-    << " But got "<< kernel_layout;
-
-  Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
-  const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW);
-  CHECK(trans_out_layout.defined())
-      << "Conv only support output layouts that are convertible from NCHW."
-      << " But got " << out_layout;
-
-  Array<IndexExpr> dshape_nchw = trans_in_layout.ForwardShape(data->shape);
-
-  IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
-
-  CHECK(param->kernel_size.defined() && param->channels.defined())
-      << "The kernel size and channels of a Conv must be set or infered by previous pass";
-
-  CHECK_EQ(param->kernel_size.size(), 2);
-  CHECK_EQ(param->dilation.size(), 2);
-
-  channels = param->channels;
-  dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
-  dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
-
-  // NOTE: Do not check weight shape here!
-  // Different backend requires different layout to compute
-  // the batch gemm stage in winograd efficiently, but we want to
-  // make this op work for all backends.
-  // So we accept all weight shapes, and assume the TOPI developers
-  // can handle this correctly in alter_op_layout.
-
-  // dilation
-  Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
-
-  IndexExpr pad_h, pad_w;
-  GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
-  if (!dshape_nchw[2].as<tir::AnyNode>()) {
-    oshape.Set(2, (dshape_nchw[2] + pad_h
-                   - dilated_ksize_y) / param->strides[0] + 1);
-  } else {
-    oshape.Set(2, dshape_nchw[2]);
-  }
-  if (!dshape_nchw[3].as<tir::AnyNode>()) {
-    oshape.Set(3, (dshape_nchw[3] + pad_w
-                   - dilated_ksize_x) / param->strides[1] + 1);
-  } else {
-    oshape.Set(3, dshape_nchw[3]);
-  }
-
-  DataType out_dtype = param->out_dtype;
-  if (out_dtype.bits() == 0) {
-    out_dtype = data->dtype;
-  }
-  oshape = trans_out_layout.BackwardShape(oshape);
-  // assign output type
-  reporter->Assign(types[2], TensorType(oshape, out_dtype));
-  return true;
-}
-
-
-// Positional relay function to create conv2d winograd operator
-// used by frontend FFI.
-Expr MakeConv2DWinograd(Expr data,
-                        Expr weight,
-                        int tile_size,
-                        Array<IndexExpr> strides,
-                        Array<IndexExpr> padding,
-                        Array<IndexExpr> dilation,
-                        int groups,
-                        IndexExpr channels,
-                        Array<IndexExpr> kernel_size,
-                        std::string data_layout,
-                        std::string kernel_layout,
-                        std::string out_layout,
-                        DataType out_dtype) {
-  auto attrs = make_object<Conv2DWinogradAttrs>();
-  attrs->tile_size = tile_size;
-  attrs->strides = std::move(strides);
-  attrs->padding = std::move(padding);
-  attrs->dilation = std::move(dilation);
-  attrs->groups = groups;
-  attrs->channels = channels;
-  attrs->kernel_size = std::move(kernel_size);
-  attrs->data_layout = std::move(data_layout);
-  attrs->kernel_layout = std::move(kernel_layout);
-  attrs->out_layout = std::move(out_layout);
-  attrs->out_dtype = std::move(out_dtype);
-  static const Op& op = Op::Get("nn.contrib_conv2d_winograd_without_weight_transform");
-  return Call(op, {data, weight}, Attrs(attrs), {});
-}
-
-
 TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_winograd_without_weight_transform")
-.set_body_typed(MakeConv2DWinograd);
+.set_body_typed([](Expr data,
+                   Expr weight,
+                   int tile_size,
+                   Array<IndexExpr> strides,
+                   Array<IndexExpr> padding,
+                   Array<IndexExpr> dilation,
+                   int groups,
+                   IndexExpr channels,
+                   Array<IndexExpr> kernel_size,
+                   std::string data_layout,
+                   std::string kernel_layout,
+                   std::string out_layout,
+                   DataType out_dtype) {
+  return MakeConvWinograd<Conv2DWinogradAttrs>(
+    data, weight, tile_size, strides, padding, dilation,
+    groups, channels, kernel_size, data_layout,
+    kernel_layout, out_layout, out_dtype, "nn.contrib_conv2d_winograd_without_weight_transform");
+});
 
 
 RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_without_weight_transform")
@@ -662,46 +454,14 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_without_weight_transform")
         ConvInferCorrectLayout<Conv2DWinogradAttrs>);
 
 // relay.nn.contrib_conv2d_winograd_weight_transform
-TVM_REGISTER_NODE_TYPE(Conv2DWinogradWeightTransformAttrs);
-
-bool Conv2DWinogradWeightTransformRel(const Array<Type>& types,
-                                      int num_inputs,
-                                      const Attrs& attrs,
-                                      const TypeReporter& reporter) {
-  CHECK_EQ(types.size(), 2);
-  const auto* data = types[0].as<TensorTypeNode>();
-  if (data == nullptr) return false;
-
-  const Conv2DWinogradWeightTransformAttrs* param = attrs.as<Conv2DWinogradWeightTransformAttrs>();
-  CHECK(param != nullptr);
-
-  CHECK_EQ(data->shape.size(), 4) << "Only support NCHW normal kernel layout";
-
-  // each pad width element should be a pair of positive integers
-  std::vector<IndexExpr> oshape {
-      param->tile_size + data->shape[2] - 1,
-      param->tile_size + data->shape[3] - 1,
-      data->shape[0],
-      data->shape[1],
-  };
-
-  reporter->Assign(types[1], TensorType(Array<IndexExpr>(oshape),
-                                                  data->dtype));
-  return true;
-}
-
-Expr MakeConv2DWinogradWeightTransform(Expr weight,
-                                       int tile_size) {
-  auto attrs = make_object<Conv2DWinogradWeightTransformAttrs>();
-  attrs->tile_size = tile_size;
-  static const Op& op = Op::Get("nn.contrib_conv2d_winograd_weight_transform");
-  return Call(op, {weight}, Attrs(attrs), {});
-}
-
+TVM_REGISTER_NODE_TYPE(ConvWinogradWeightTransformAttrs);
 
 TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_winograd_weight_transform")
-.set_body_typed(MakeConv2DWinogradWeightTransform);
-
+.set_body_typed([](Expr weight,
+                   int tile_size) {
+  return MakeConvWinogradWeightTransform(
+    weight, tile_size, "nn.contrib_conv2d_winograd_weight_transform");
+});
 
 RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_weight_transform")
 .describe(R"code(Weight transformation of winograd fast convolution algorithm.
@@ -711,47 +471,82 @@ weight transformation in advance.
 
 - **weight**: (channels, in_channels, kernel_size[0], kernel_size[1])
 )code" TVM_ADD_FILELINE)
-.set_attrs_type<Conv2DWinogradWeightTransformAttrs>()
+.set_attrs_type<ConvWinogradWeightTransformAttrs>()
 .set_num_inputs(1)
 .add_argument("weight", "Tensor", "The weight tensor.")
 .set_support_level(10)
 .add_type_rel("Conv2DWinogradWeightTransform", Conv2DWinogradWeightTransformRel);
 
+// relay.nn.contrib_conv3d_winograd_without_weight_transform
+TVM_REGISTER_NODE_TYPE(Conv3DWinogradAttrs);
+
+TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv3d_winograd_without_weight_transform")
+.set_body_typed([](Expr data,
+                   Expr weight,
+                   int tile_size,
+                   Array<IndexExpr> strides,
+                   Array<IndexExpr> padding,
+                   Array<IndexExpr> dilation,
+                   int groups,
+                   IndexExpr channels,
+                   Array<IndexExpr> kernel_size,
+                   std::string data_layout,
+                   std::string kernel_layout,
+                   std::string out_layout,
+                   DataType out_dtype) {
+  return MakeConvWinograd<Conv3DWinogradAttrs>(
+    data, weight, tile_size, strides, padding, dilation,
+    groups, channels, kernel_size, data_layout,
+    kernel_layout, out_layout, out_dtype, "nn.contrib_conv3d_winograd_without_weight_transform");
+});
+
+RELAY_REGISTER_OP("nn.contrib_conv3d_winograd_without_weight_transform")
+.describe(R"code(Compute conv3d with winograd algorithm. Only supports NCDHW layout.
+              This operator assumes the weight tensor is already pre-transformed by
+              nn.contrib_conv3d_winograd_weight_transform.
+
+- **data**: Input is 5D array of shape  (batch_size, in_channels, depth, height, width)
+- **weight**: Any shape
+            We do not check the shape for this input tensor. Since different backend
+            has different layout strategy.
+
+- **out**:  Output is 5D array of shape (batch_size, channels, depth, out_height, out_width)
+)code" TVM_ADD_FILELINE)
+.set_attrs_type<Conv3DWinogradAttrs>()
+.set_num_inputs(2)
+.add_argument("data", "Tensor", "The input tensor.")
+.add_argument("weight", "Tensor", "The weight tensor.")
+.set_support_level(10)
+.add_type_rel("Conv3DWinograd", Conv3DWinogradRel<Conv3DWinogradAttrs>)
+.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
+                                ConvInferCorrectLayout<Conv3DWinogradAttrs>);
+
+// relay.nn.contrib_conv3d_winograd_weight_transform
+TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv3d_winograd_weight_transform")
+.set_body_typed([](Expr weight,
+                   int tile_size) {
+  return MakeConvWinogradWeightTransform(
+    weight, tile_size, "nn.contrib_conv3d_winograd_weight_transform");
+});
+
+RELAY_REGISTER_OP("nn.contrib_conv3d_winograd_weight_transform")
+    .describe(R"code(Weight transformation of winograd fast 3d convolution algorithm.
+
+Separate this into another operator in order to enable Precompute Pass to compute the
+weight transformation in advance.
+
+- **weight**: (channels, in_channels, kernel_size[0], kernel_size[1], kernel_size[2])
+)code" TVM_ADD_FILELINE)
+.set_attrs_type<ConvWinogradWeightTransformAttrs>()
+.set_num_inputs(1)
+.add_argument("weight", "Tensor", "The weight tensor.")
+.set_support_level(10)
+.add_type_rel("Conv3DWinogradWeightTransform", Conv3DWinogradWeightTransformRel);
+
 
 // relay.nn.contrib_conv2d_winograd_nnpack_weight_transform
 TVM_REGISTER_NODE_TYPE(Conv2DWinogradNNPACKWeightTransformAttrs);
 
-bool Conv2DWinogradNNPACKWeightTransformRel(const Array<Type>& types,
-                                            int num_inputs,
-                                            const Attrs& attrs,
-                                            const TypeReporter& reporter) {
-  CHECK_EQ(types.size(), 2);
-  const auto* data = types[0].as<TensorTypeNode>();
-  if (data == nullptr) {
-    return false;
-  }
-
-  const Conv2DWinogradNNPACKWeightTransformAttrs* param =
-      attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>();
-  CHECK(param != nullptr);
-
-  CHECK_EQ(data->shape.size(), 4) << "Only support NCHW normal kernel layout";
-
-  std::vector<IndexExpr> oshape{
-      data->shape[0],
-      data->shape[1],
-      8,
-      8,
-  };
-
-  DataType out_dtype = param->out_dtype;
-  if (out_dtype.bits() == 0) {
-    out_dtype = data->dtype;
-  }
-  reporter->Assign(types[1], TensorType(Array<IndexExpr>(oshape), out_dtype));
-  return true;
-}
-
 Expr MakeConv2DWinogradNNPACKWeightTransform(Expr weight,
                                              int convolution_algorithm,
                                              DataType out_dtype) {
@@ -779,38 +574,27 @@ weight transformation in advance.
 .set_support_level(10)
 .add_type_rel("Conv2DWinogradNNPACKWeightTransform", Conv2DWinogradNNPACKWeightTransformRel);
 
+
 // Positional relay function to create conv2d NCHWc operator
 // used by frontend FFI.
-Expr MakeConv2DNCHWc(Expr data,
-                     Expr kernel,
-                     Array<IndexExpr> strides,
-                     Array<IndexExpr> padding,
-                     Array<IndexExpr> dilation,
-                     int groups,
-                     IndexExpr channels,
-                     Array<IndexExpr> kernel_size,
-                     std::string data_layout,
-                     std::string kernel_layout,
-                     std::string out_layout,
-                     DataType out_dtype) {
-  auto attrs = make_object<Conv2DAttrs>();
-  attrs->strides = std::move(strides);
-  attrs->padding = std::move(padding);
-  attrs->dilation = std::move(dilation);
-  attrs->groups = groups;
-  attrs->channels = channels;
-  attrs->kernel_size = std::move(kernel_size);
-  attrs->data_layout = std::move(data_layout);
-  attrs->kernel_layout = std::move(kernel_layout);
-  attrs->out_layout = std::move(out_layout);
-  attrs->out_dtype = std::move(out_dtype);
-  static const Op& op = Op::Get("nn.contrib_conv2d_NCHWc");
-  return Call(op, {data, kernel}, Attrs(attrs), {});
-}
-
 TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_NCHWc")
-.set_body_typed(MakeConv2DNCHWc);
-
+.set_body_typed([](Expr data,
+                   Expr weight,
+                   Array<IndexExpr> strides,
+                   Array<IndexExpr> padding,
+                   Array<IndexExpr> dilation,
+                   int groups,
+                   IndexExpr channels,
+                   Array<IndexExpr> kernel_size,
+                   std::string data_layout,
+                   std::string kernel_layout,
+                   std::string out_layout,
+                   DataType out_dtype) {
+  return MakeConv<Conv2DAttrs>(
+    data, weight, strides, padding, dilation,
+    groups, channels, kernel_size, data_layout,
+    kernel_layout, out_layout, out_dtype, "nn.contrib_conv2d_NCHWc");
+});
 
 RELAY_REGISTER_OP("nn.contrib_conv2d_NCHWc")
 .describe(R"code(Compute conv2d with NCHWc data layout. Only supports NCHW layout.
@@ -831,35 +615,24 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_NCHWc")
 
 // Positional relay function to create depthwise conv2d NCHWc operator
 // used by frontend FFI.
-Expr MakeDepthwiseConv2DNCHWc(Expr data,
-                              Expr kernel,
-                              Array<IndexExpr> strides,
-                              Array<IndexExpr> padding,
-                              Array<IndexExpr> dilation,
-                              int groups,
-                              IndexExpr channels,
-                              Array<IndexExpr> kernel_size,
-                              std::string data_layout,
-                              std::string kernel_layout,
-                              std::string out_layout,
-                              DataType out_dtype) {
-  auto attrs = make_object<Conv2DAttrs>();
-  attrs->strides = std::move(strides);
-  attrs->padding = std::move(padding);
-  attrs->dilation = std::move(dilation);
-  attrs->groups = groups;
-  attrs->channels = channels;
-  attrs->kernel_size = std::move(kernel_size);
-  attrs->data_layout = std::move(data_layout);
-  attrs->kernel_layout = std::move(kernel_layout);
-  attrs->out_layout = std::move(out_layout);
-  attrs->out_dtype = std::move(out_dtype);
-  static const Op& op = Op::Get("nn.contrib_depthwise_conv2d_NCHWc");
-  return Call(op, {data, kernel}, Attrs(attrs), {});
-}
-
 TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_depthwise_conv2d_NCHWc")
-.set_body_typed(MakeDepthwiseConv2DNCHWc);
+.set_body_typed([](Expr data,
+                   Expr weight,
+                   Array<IndexExpr> strides,
+                   Array<IndexExpr> padding,
+                   Array<IndexExpr> dilation,
+                   int groups,
+                   IndexExpr channels,
+                   Array<IndexExpr> kernel_size,
+                   std::string data_layout,
+                   std::string kernel_layout,
+                   std::string out_layout,
+                   DataType out_dtype) {
+  return MakeConv<Conv2DAttrs>(
+    data, weight, strides, padding, dilation,
+    groups, channels, kernel_size, data_layout,
+    kernel_layout, out_layout, out_dtype, "nn.contrib_depthwise_conv2d_NCHWc");
+});
 
 
 RELAY_REGISTER_OP("nn.contrib_depthwise_conv2d_NCHWc")
@@ -879,85 +652,6 @@ RELAY_REGISTER_OP("nn.contrib_depthwise_conv2d_NCHWc")
         ConvInferCorrectLayout<Conv2DAttrs>);
 
 
-bool DeformableConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
-                         const TypeReporter& reporter) {
-  CHECK_EQ(types.size(), 4);
-  const auto* data = types[0].as<TensorTypeNode>();
-  const auto* weight = types[2].as<TensorTypeNode>();
-
-  CHECK(data);
-  auto* param = attrs.as<DeformableConv2DAttrs>();
-  CHECK_EQ(param->data_layout, "NCHW") << "data layout not supported.";
-  CHECK_EQ(param->kernel_layout, "OIHW") << "kernel_layout not supported.";
-
-  IndexExpr channels, dilated_ksize_y, dilated_ksize_x, ksize_y, ksize_x;
-
-  // infer weight shape if kernel_size and channels are defiend
-  if (param->kernel_size.defined() && param->channels.defined()) {
-    CHECK_EQ(param->kernel_size.size(), 2);
-    CHECK_EQ(param->dilation.size(), 2);
-    Array<IndexExpr> wshape(
-       {param->channels,
-         indexdiv(data->shape[1], param->groups),
-         param->kernel_size[0],
-         param->kernel_size[1]});
-    channels = param->channels;
-    ksize_y = param->kernel_size[0];
-    ksize_x = param->kernel_size[1];
-    dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
-    dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
-    // assign result to reporter
-    reporter->Assign(types[2], TensorType(wshape, data->dtype));
-  } else {
-    // use weight to infer the conv shape.
-    if (weight == nullptr) return false;
-    auto wshape = weight->shape;
-    if (param->kernel_size.defined()) {
-      CHECK_EQ(param->kernel_size.size(), 2);
-      // check the size
-      CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) &&
-            reporter->AssertEQ(param->kernel_size[1], wshape[3]))
-          << "DeformableConv2D: shape of weight is inconsistent with kernel_size, "
-          << " kernel_size=" << param->kernel_size
-          << " wshape=" << wshape;
-    }
-    if (param->channels.defined()) {
-      CHECK(reporter->AssertEQ(param->channels, wshape[0]))
-          << "DeformableConv2D: shape of weight is inconsistent with channels, "
-          << " channels=" << param->channels
-          << " wshape=" << wshape;
-    }
-    CHECK(reporter->AssertEQ(indexdiv(data->shape[1], param->groups), wshape[1]));
-    channels = wshape[0];
-    ksize_y = wshape[2];
-    ksize_x = wshape[3];
-    dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0];
-    dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
-  }
-  // dilation
-  Array<IndexExpr> oshape({data->shape[0], channels, 0, 0});
-
-  IndexExpr pad_h, pad_w;
-  GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
-  oshape.Set(2, indexdiv(data->shape[2] + pad_h - dilated_ksize_y,
-                         param->strides[0]) + 1);
-  oshape.Set(3, indexdiv(data->shape[3] + pad_w - dilated_ksize_x,
-                         param->strides[1]) + 1);
-  DataType out_dtype = param->out_dtype;
-
-  // infer offset shape
-  Array<IndexExpr> offset_shape({data->shape[0], 2 * ksize_y * ksize_x * param->deformable_groups,
-          oshape[2], oshape[3]});
-  reporter->Assign(types[1], TensorType(offset_shape, data->dtype));
-  if (out_dtype.bits() == 0) {
-    out_dtype = data->dtype;
-  }
-
-  reporter->Assign(types[3], TensorType(oshape, out_dtype));
-  return true;
-}
-
-
 TVM_REGISTER_NODE_TYPE(DeformableConv2DAttrs);
 
 RELAY_REGISTER_OP("nn.deformable_conv2d")
@@ -986,42 +680,30 @@ by concating all the *g* results.
 .add_argument("offset", "Tensor", "The offset tensor.")
 .add_argument("weight", "Tensor", "The weight tensor.")
 .set_support_level(5)
-.add_type_rel("DeformableConv2D", DeformableConv2DRel);
+.add_type_rel("DeformableConv2D", DeformableConv2DRel<DeformableConv2DAttrs>);
 
 // Positional relay function to create deformable_conv2d operator
 // used by frontend FFI.
-Expr MakeDeformableConv2D(Expr data,
-                          Expr offset,
-                          Expr weight,
-                          Array<IndexExpr> strides,
-                          Array<IndexExpr> padding,
-                          Array<IndexExpr> dilation,
-                          int deformable_groups,
-                          int groups,
-                          int channels,
-                          Array<IndexExpr> kernel_size,
-                          std::string data_layout,
-                          std::string kernel_layout,
-                          std::string out_layout,
-                          DataType out_dtype) {
-  auto attrs = make_object<DeformableConv2DAttrs>();
-  attrs->strides = strides;
-  attrs->padding = padding;
-  attrs->dilation = dilation;
-  attrs->deformable_groups = deformable_groups;
-  attrs->groups = groups;
-  attrs->channels = channels;
-  attrs->kernel_size = kernel_size;
-  attrs->data_layout = data_layout;
-  attrs->kernel_layout = kernel_layout;
-  attrs->out_layout = out_layout;
-  attrs->out_dtype = out_dtype;
-  static const Op& op = Op::Get("nn.deformable_conv2d");
-  return Call(op, {data, offset, weight}, Attrs{attrs}, {});
-}
-
 TVM_REGISTER_GLOBAL("relay.op.nn._make.deformable_conv2d")
-.set_body_typed(MakeDeformableConv2D);
+.set_body_typed([](Expr data,
+                   Expr offset,
+                   Expr weight,
+                   Array<IndexExpr> strides,
+                   Array<IndexExpr> padding,
+                   Array<IndexExpr> dilation,
+                   int deformable_groups,
+                   int groups,
+                   int channels,
+                   Array<IndexExpr> kernel_size,
+                   std::string data_layout,
+                   std::string kernel_layout,
+                   std::string out_layout,
+                   DataType out_dtype) {
+  return MakeDeformableConv<DeformableConv2DAttrs>(
+    data, offset, weight, strides, padding, dilation,
+    deformable_groups, groups, channels, kernel_size, data_layout,
+    kernel_layout, out_layout, out_dtype, "nn.deformable_conv2d");
+});
 
 }  // namespace relay
 }  // namespace tvm
index 05c1171..6c5aebe 100644 (file)
 
 #include <string>
 #include <utility>
+#include <vector>
 
 #include "../op_common.h"
 
 namespace tvm {
 namespace relay {
 
+
+// Standard convolution operator shape relations
 template <typename AttrType>
 bool Conv1DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
                const TypeReporter& reporter) {
@@ -363,6 +366,533 @@ bool Conv3DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
   return true;
 }
 
+
+// Winograd convolution shape relations
+inline bool Conv2DWinogradWeightTransformRel(const Array<Type>& types, int num_inputs,
+                                             const Attrs& attrs, const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 2);
+  const auto* data = types[0].as<TensorTypeNode>();
+  if (data == nullptr) return false;
+
+  const ConvWinogradWeightTransformAttrs* param = attrs.as<ConvWinogradWeightTransformAttrs>();
+  CHECK(param != nullptr);
+
+  CHECK_EQ(data->shape.size(), 4) << "Only support NCHW normal kernel layout";
+
+  std::vector<IndexExpr> oshape {
+      param->tile_size + data->shape[2] - 1,
+      param->tile_size + data->shape[3] - 1,
+      data->shape[0],
+      data->shape[1],
+  };
+
+  reporter->Assign(types[1], TensorType(Array<IndexExpr>(oshape),
+                                                  data->dtype));
+  return true;
+}
+
+inline bool Conv3DWinogradWeightTransformRel(const Array<Type>& types, int num_inputs,
+                                             const Attrs& attrs, const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 2);
+  const auto* data = types[0].as<TensorTypeNode>();
+  if (data == nullptr) return false;
+
+  const ConvWinogradWeightTransformAttrs* param = attrs.as<ConvWinogradWeightTransformAttrs>();
+  CHECK(param != nullptr);
+
+  CHECK_EQ(data->shape.size(), 5) << "Only support NCDHW normal kernel layout";
+
+  // Shape of packed weights depends on whether depth is being transformed or not.
+  Array<IndexExpr> oshape({0, 0, 0, data->shape[0], data->shape[1]});
+  auto* depth_imm = data->shape[2].as<IntImmNode>();
+  bool transform_depth = (depth_imm->value > 2)&&(depth_imm->value < 8);
+  if (transform_depth) {
+    oshape.Set(0, param->tile_size + data->shape[2] - 1);
+    oshape.Set(1, param->tile_size + data->shape[3] - 1);
+    oshape.Set(2, param->tile_size + data->shape[4] - 1);
+  } else {
+    oshape.Set(0, param->tile_size + data->shape[3] - 1);
+    oshape.Set(1, param->tile_size + data->shape[4] - 1);
+    oshape.Set(2, data->shape[2]);
+  }
+
+  reporter->Assign(types[1], TensorType(oshape, data->dtype));
+  return true;
+}
+
+inline bool Conv2DWinogradNNPACKWeightTransformRel(const Array<Type>& types, int num_inputs,
+                                                   const Attrs& attrs,
+                                                   const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 2);
+  const auto* data = types[0].as<TensorTypeNode>();
+  if (data == nullptr) {
+    return false;
+  }
+
+  const Conv2DWinogradNNPACKWeightTransformAttrs* param =
+      attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>();
+  CHECK(param != nullptr);
+
+  CHECK_EQ(data->shape.size(), 4) << "Only support NCHW normal kernel layout";
+
+  std::vector<IndexExpr> oshape{
+      data->shape[0],
+      data->shape[1],
+      8,
+      8,
+  };
+
+  DataType out_dtype = param->out_dtype;
+  if (out_dtype.bits() == 0) {
+    out_dtype = data->dtype;
+  }
+  reporter->Assign(types[1], TensorType(Array<IndexExpr>(oshape), out_dtype));
+  return true;
+}
+
+template<typename AttrType>
+bool Conv2DWinogradRel(const Array<Type>& types,
+                       int num_inputs,
+                       const Attrs& attrs,
+                       const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 3);
+  const auto* data = types[0].as<TensorTypeNode>();
+  if (data == nullptr) return false;
+  static const Layout kNCHW("NCHW");
+  static const Layout kOIHW("OIHW");
+
+  const AttrType* param = attrs.as<AttrType>();
+  CHECK(param != nullptr);
+  const Layout in_layout(param->data_layout);
+  const Layout kernel_layout(param->kernel_layout);
+
+  const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW);
+  CHECK(trans_in_layout.defined())
+    << "Conv only support input layouts that are convertible from NCHW."
+    << " But got " << in_layout;
+
+  const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW);
+  CHECK(trans_kernel_layout.defined())
+    << "Conv only support kernel layouts that are convertible from OIHW."
+    << " But got "<< kernel_layout;
+
+  Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
+  const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW);
+  CHECK(trans_out_layout.defined())
+      << "Conv only support output layouts that are convertible from NCHW."
+      << " But got " << out_layout;
+
+  Array<IndexExpr> dshape_nchw = trans_in_layout.ForwardShape(data->shape);
+
+  IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
+
+  CHECK(param->kernel_size.defined() && param->channels.defined())
+      << "The kernel size and channels of a Conv must be set or inferred by previous pass";
+
+  CHECK_EQ(param->kernel_size.size(), 2);
+  CHECK_EQ(param->dilation.size(), 2);
+
+  channels = param->channels;
+  dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
+  dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
+
+  // NOTE: Do not check weight shape here!
+  // Different backend requires different layout to compute
+  // the batch gemm stage in winograd efficiently, but we want to
+  // make this op work for all backends.
+  // So we accept all weight shapes, and assume the TOPI developers
+  // can handle this correctly in alter_op_layout.
+
+  // dilation
+  Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
+
+  IndexExpr pad_h, pad_w;
+  GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
+  if (!dshape_nchw[2].as<tir::AnyNode>()) {
+    oshape.Set(2, (dshape_nchw[2] + pad_h
+                   - dilated_ksize_y) / param->strides[0] + 1);
+  } else {
+    oshape.Set(2, dshape_nchw[2]);
+  }
+  if (!dshape_nchw[3].as<tir::AnyNode>()) {
+    oshape.Set(3, (dshape_nchw[3] + pad_w
+                   - dilated_ksize_x) / param->strides[1] + 1);
+  } else {
+    oshape.Set(3, dshape_nchw[3]);
+  }
+
+  DataType out_dtype = param->out_dtype;
+  if (out_dtype.bits() == 0) {
+    out_dtype = data->dtype;
+  }
+  oshape = trans_out_layout.BackwardShape(oshape);
+  // assign output type
+  reporter->Assign(types[2], TensorType(oshape, out_dtype));
+  return true;
+}
+
+
+template<typename AttrType>
+bool Conv3DWinogradRel(const Array<Type>& types,
+                       int num_inputs,
+                       const Attrs& attrs,
+                       const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 3);
+  const auto* data = types[0].as<TensorTypeNode>();
+  if (data == nullptr) return false;
+  static const Layout kNCDHW("NCDHW");
+  static const Layout kOIDHW("OIDHW");
+
+  const AttrType* param = attrs.as<AttrType>();
+  CHECK(param != nullptr);
+  const Layout in_layout(param->data_layout);
+  const Layout kernel_layout(param->kernel_layout);
+
+  const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCDHW);
+  CHECK(trans_in_layout.defined())
+    << "Conv only support input layouts that are convertible from NCDHW."
+    << " But got " << in_layout;
+
+  const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIDHW);
+  CHECK(trans_kernel_layout.defined())
+    << "Conv only support kernel layouts that are convertible from OIDHW."
+    << " But got "<< kernel_layout;
+
+  Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
+  const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCDHW);
+  CHECK(trans_out_layout.defined())
+      << "Conv only support output layouts that are convertible from NCDHW."
+      << " But got " << out_layout;
+
+  Array<IndexExpr> dshape_ncdhw = trans_in_layout.ForwardShape(data->shape);
+
+  IndexExpr channels, dilated_ksize_d, dilated_ksize_y, dilated_ksize_x;
+
+  CHECK(param->kernel_size.defined() && param->channels.defined())
+      << "The kernel size and channels of a Conv must be set or inferred by previous pass";
+
+  CHECK_EQ(param->kernel_size.size(), 3);
+  CHECK_EQ(param->dilation.size(), 3);
+
+  channels = param->channels;
+  dilated_ksize_d = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
+  dilated_ksize_y = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
+  dilated_ksize_x = 1 + (param->kernel_size[2] - 1) * param->dilation[2];
+
+  // NOTE: Do not check weight shape here!
+  // Different backend requires different layout to compute
+  // the batch gemm stage in winograd efficiently, but we want to
+  // make this op work for all backends.
+  // So we accept all weight shapes, and assume the TOPI developers
+  // can handle this correctly in alter_op_layout.
+
+  // dilation
+  Array<IndexExpr> oshape({dshape_ncdhw[0], channels, 0, 0, 0});
+
+  IndexExpr pad_d, pad_h, pad_w;
+  GetPaddingDepthHeightWidth(param->padding, &pad_d, &pad_h, &pad_w);
+  if (!dshape_ncdhw[2].as<tir::AnyNode>()) {
+    oshape.Set(2, (dshape_ncdhw[2] + pad_d
+                   - dilated_ksize_d) / param->strides[0] + 1);
+  } else {
+    oshape.Set(2, dshape_ncdhw[2]);
+  }
+  if (!dshape_ncdhw[2].as<tir::AnyNode>()) {
+    oshape.Set(3, (dshape_ncdhw[3] + pad_h
+                   - dilated_ksize_y) / param->strides[1] + 1);
+  } else {
+    oshape.Set(3, dshape_ncdhw[3]);
+  }
+  if (!dshape_ncdhw[4].as<tir::AnyNode>()) {
+    oshape.Set(4, (dshape_ncdhw[4] + pad_w
+                   - dilated_ksize_x) / param->strides[2] + 1);
+  } else {
+    oshape.Set(4, dshape_ncdhw[4]);
+  }
+
+  DataType out_dtype = param->out_dtype;
+  if (out_dtype.bits() == 0) {
+    out_dtype = data->dtype;
+  }
+  oshape = trans_out_layout.BackwardShape(oshape);
+  // assign output type
+  reporter->Assign(types[2], TensorType(oshape, out_dtype));
+  return true;
+}
+
+
+// Transposed convolution shape relations
+template <typename AttrType>
+bool Conv1DTransposeRel(const Array<Type>& types,
+                        int num_inputs,
+                        const Attrs& attrs,
+                        const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 3);
+  const auto* data = types[0].as<TensorTypeNode>();
+  const auto* weight = types[1].as<TensorTypeNode>();
+  if (data == nullptr) return false;
+
+  static const Layout kNCW("NCW");
+  static const Layout kOIW("OIW");
+
+  const Conv1DTransposeAttrs* param = attrs.as<AttrType>();
+  CHECK(param != nullptr);
+  const Layout in_layout(param->data_layout);
+  const Layout kernel_layout(param->kernel_layout);
+
+  const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCW);
+  CHECK(trans_in_layout.defined())
+    << "Conv only support input layouts that are convertible from NCW."
+    << " But got " << in_layout;
+
+  const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIW);
+  CHECK(trans_kernel_layout.defined())
+    << "Conv only support kernel layouts that are convertible from OIW."
+    << " But got "<< kernel_layout;
+
+  Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
+  const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCW);
+  CHECK(trans_out_layout.defined())
+    << "Conv only support output layouts that are convertible from NCW."
+    << " But got " << out_layout;
+
+  IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
+
+  auto dshape_ncw = trans_in_layout.ForwardShape(data->shape);
+
+  // infer weight if the kernel_size and channels are defined
+  if (param->kernel_size.defined() && param->channels.defined()) {
+    CHECK_EQ(param->kernel_size.size(), 1);
+    CHECK_EQ(param->dilation.size(), 1);
+
+    Array<IndexExpr> wshape({dshape_ncw[1],
+            indexdiv(param->channels, param->groups),
+            param->kernel_size[0]});
+
+    wshape = trans_kernel_layout.BackwardShape(wshape);
+    dilated_ksize_x = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
+    channels = param->channels;
+
+    // assign result to reporter
+    reporter->Assign(types[1], TensorType(wshape, data->dtype));
+  } else {
+    // use weight to infer the conv shape.
+    if (weight == nullptr) return false;
+    auto wshape = trans_kernel_layout.ForwardShape(weight->shape);
+    if (param->kernel_size.defined()) {
+      CHECK_EQ(param->kernel_size.size(), 1);
+      // check the size
+      CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]))
+          << "Conv1D: shape of weight is inconsistent with kernel_size, "
+          << " kernel_size=" << param->kernel_size
+          << " wshape=" << Array<IndexExpr>(wshape);
+    }
+    if (param->channels.defined()) {
+      CHECK(reporter->AssertEQ(param->channels, wshape[1]))
+          << "Conv1D: shape of weight is inconsistent with channels, "
+          << " channels=" << param->channels
+          << " wshape=" << Array<IndexExpr>(wshape);
+    }
+    CHECK(reporter->AssertEQ(indexdiv(dshape_ncw[1], param->groups), wshape[0]));
+    channels = wshape[1];
+    dilated_ksize_x = 1 + (wshape[2] - 1) * param->dilation[0];
+  }
+  // dilation
+  IndexExpr pad_w;
+  GetPaddingWidth(param->padding, &pad_w);
+  Array<IndexExpr> oshape({dshape_ncw[0], channels, 0});
+  oshape.Set(2, (param->strides[0] * (dshape_ncw[2] - 1) + dilated_ksize_x -
+                 pad_w + param->output_padding[0]));
+
+  DataType out_dtype = param->out_dtype;
+  if (out_dtype.bits() == 0) {
+    out_dtype = data->dtype;
+  }
+  oshape = trans_out_layout.BackwardShape(oshape);
+  reporter->Assign(types[2], TensorType(oshape, out_dtype));
+  return true;
+}
+
+
+template <typename AttrType>
+bool Conv2DTransposeRel(const Array<Type>& types,
+                        int num_inputs,
+                        const Attrs& attrs,
+                        const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 3);
+  const auto* data = types[0].as<TensorTypeNode>();
+  const auto* weight = types[1].as<TensorTypeNode>();
+  if (data == nullptr) return false;
+
+  static const Layout kNCHW("NCHW");
+  static const Layout kOIHW("OIHW");
+
+  const Conv2DTransposeAttrs* param = attrs.as<AttrType>();
+  CHECK(param != nullptr);
+  const Layout in_layout(param->data_layout);
+  const Layout kernel_layout(param->kernel_layout);
+
+  const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW);
+  CHECK(trans_in_layout.defined())
+    << "Conv only support input layouts that are convertible from NCHW."
+    << " But got " << in_layout;
+
+  const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW);
+  CHECK(trans_kernel_layout.defined())
+    << "Conv only support kernel layouts that are convertible from OIHW."
+    << " But got "<< kernel_layout;
+
+  Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
+  const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW);
+  CHECK(trans_out_layout.defined())
+    << "Conv only support output layouts that are convertible from NCHW."
+    << " But got " << out_layout;
+
+  IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
+
+  auto dshape_nchw = trans_in_layout.ForwardShape(data->shape);
+
+  // infer weight if the kernel_size and channels are defined
+  if (param->kernel_size.defined() && param->channels.defined()) {
+    CHECK_EQ(param->kernel_size.size(), 2);
+    CHECK_EQ(param->dilation.size(), 2);
+
+    Array<IndexExpr> wshape({dshape_nchw[1],
+            indexdiv(param->channels, param->groups),
+            param->kernel_size[0],
+            param->kernel_size[1]});
+
+    wshape = trans_kernel_layout.BackwardShape(wshape);
+    dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
+    dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
+    channels = param->channels;
+
+    // assign result to reporter
+    reporter->Assign(types[1], TensorType(wshape, data->dtype));
+  } else {
+    // use weight to infer the conv shape.
+    if (weight == nullptr) return false;
+    auto wshape = trans_kernel_layout.ForwardShape(weight->shape);
+    if (param->kernel_size.defined()) {
+      CHECK_EQ(param->kernel_size.size(), 2);
+      // check the size
+      CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) &&
+            reporter->AssertEQ(param->kernel_size[1], wshape[3]))
+          << "Conv2D: shape of weight is inconsistent with kernel_size, "
+          << " kernel_size=" << param->kernel_size
+          << " wshape=" << Array<IndexExpr>(wshape);
+    }
+    if (param->channels.defined()) {
+      CHECK(reporter->AssertEQ(param->channels, wshape[1]))
+          << "Conv2D: shape of weight is inconsistent with channels, "
+          << " channels=" << param->channels
+          << " wshape=" << Array<IndexExpr>(wshape);
+    }
+    CHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[0]));
+    channels = wshape[1];
+    dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0];
+    dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
+  }
+  // dilation
+  Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
+  IndexExpr pad_h, pad_w;
+  GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
+  oshape.Set(2, (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y -
+                 pad_h + param->output_padding[0]));
+  oshape.Set(3, (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x -
+                 pad_w + param->output_padding[1]));
+
+  DataType out_dtype = param->out_dtype;
+  if (out_dtype.bits() == 0) {
+    out_dtype = data->dtype;
+  }
+  oshape = trans_out_layout.BackwardShape(oshape);
+  reporter->Assign(types[2], TensorType(oshape, out_dtype));
+  return true;
+}
+
+
+// Deformable Convolution shape relations.
+template <typename AttrType>
+bool DeformableConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+                         const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 4);
+  const auto* data = types[0].as<TensorTypeNode>();
+  const auto* weight = types[2].as<TensorTypeNode>();
+
+  CHECK(data);
+  auto* param = attrs.as<AttrType>();
+  CHECK_EQ(param->data_layout, "NCHW") << "data layout not supported.";
+  CHECK_EQ(param->kernel_layout, "OIHW") << "kernel_layout not supported.";
+
+  IndexExpr channels, dilated_ksize_y, dilated_ksize_x, ksize_y, ksize_x;
+
+  // infer weight shape if kernel_size and channels are defiend
+  if (param->kernel_size.defined() && param->channels.defined()) {
+    CHECK_EQ(param->kernel_size.size(), 2);
+    CHECK_EQ(param->dilation.size(), 2);
+    Array<IndexExpr> wshape(
+       {param->channels,
+         indexdiv(data->shape[1], param->groups),
+         param->kernel_size[0],
+         param->kernel_size[1]});
+    channels = param->channels;
+    ksize_y = param->kernel_size[0];
+    ksize_x = param->kernel_size[1];
+    dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
+    dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
+    // assign result to reporter
+    reporter->Assign(types[2], TensorType(wshape, data->dtype));
+  } else {
+    // use weight to infer the conv shape.
+    if (weight == nullptr) return false;
+    auto wshape = weight->shape;
+    if (param->kernel_size.defined()) {
+      CHECK_EQ(param->kernel_size.size(), 2);
+      // check the size
+      CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) &&
+            reporter->AssertEQ(param->kernel_size[1], wshape[3]))
+          << "DeformableConv2D: shape of weight is inconsistent with kernel_size, "
+          << " kernel_size=" << param->kernel_size
+          << " wshape=" << wshape;
+    }
+    if (param->channels.defined()) {
+      CHECK(reporter->AssertEQ(param->channels, wshape[0]))
+          << "DeformableConv2D: shape of weight is inconsistent with channels, "
+          << " channels=" << param->channels
+          << " wshape=" << wshape;
+    }
+    CHECK(reporter->AssertEQ(indexdiv(data->shape[1], param->groups), wshape[1]));
+    channels = wshape[0];
+    ksize_y = wshape[2];
+    ksize_x = wshape[3];
+    dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0];
+    dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
+  }
+  // dilation
+  Array<IndexExpr> oshape({data->shape[0], channels, 0, 0});
+
+  IndexExpr pad_h, pad_w;
+  GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
+  oshape.Set(2, indexdiv(data->shape[2] + pad_h - dilated_ksize_y,
+                         param->strides[0]) + 1);
+  oshape.Set(3, indexdiv(data->shape[3] + pad_w - dilated_ksize_x,
+                         param->strides[1]) + 1);
+  DataType out_dtype = param->out_dtype;
+
+  // infer offset shape
+  Array<IndexExpr> offset_shape({data->shape[0], 2 * ksize_y * ksize_x * param->deformable_groups,
+          oshape[2], oshape[3]});
+  reporter->Assign(types[1], TensorType(offset_shape, data->dtype));
+  if (out_dtype.bits() == 0) {
+    out_dtype = data->dtype;
+  }
+
+  reporter->Assign(types[3], TensorType(oshape, out_dtype));
+  return true;
+}
+
+
 template<typename T>
 Array<Array<Layout> > ConvInferCorrectLayout(
     const Attrs& attrs,
@@ -378,6 +908,7 @@ Array<Array<Layout> > ConvInferCorrectLayout(
                                    params->data_layout : params->out_layout}};
 }
 
+
 }  // namespace relay
 }  // namespace tvm
 #endif  // TVM_RELAY_OP_NN_CONVOLUTION_H_
index 7a42fc3..771a63d 100644 (file)
@@ -25,6 +25,7 @@ from tvm.relay import transform
 from tvm.relay.testing import ctx_list, run_infer_type
 from tvm.contrib import util
 import topi.testing
+from topi.cuda.conv3d_winograd import _infer_tile_size
 
 
 def test_conv1d_infer_type():
@@ -326,7 +327,7 @@ def test_conv2d_winograd():
             cfg['tile_y'] = autotvm.task.space.SplitEntity([-1, 1, 1, 1])
             cfg['tile_x'] = autotvm.task.space.SplitEntity([-1, 1, 1, 1])
             cfg['tile_rc'] = autotvm.task.space.SplitEntity([-1, 1])
-            cfg['auto_unroll_max_setp'] = autotvm.task.space.OtherOptionEntity(1500)
+            cfg['auto_unroll_max_step'] = autotvm.task.space.OtherOptionEntity(1500)
             cfg['unroll_explicit'] = autotvm.task.space.OtherOptionEntity(1)
             self.memory[key] = cfg
             return cfg
@@ -522,6 +523,94 @@ def test_conv3d_ndhwc_run():
     run_test_conv3d("float32", "float32", 1, dshape, kshape,
             padding=(1, 1, 1), channels=10, kernel_size=(3, 3 ,3), except_targets=["cuda"])
 
+def test_conv3d_winograd():
+    class WinogradFallback(autotvm.FallbackContext):
+        def _query_inside(self, target, workload):
+            key = (target, workload)
+            if key in self.memory:
+                return self.memory[key]
+            cfg = autotvm.task.space.FallbackConfigEntity()
+            cfg.is_fallback = False
+            cfg.cost = 0.1 if 'winograd' in workload[0] else 1
+            cfg['tile_b'] = autotvm.task.space.SplitEntity([-1, 1, 1, 1])
+            cfg['tile_y'] = autotvm.task.space.SplitEntity([-1, 1, 1, 1])
+            cfg['tile_x'] = autotvm.task.space.SplitEntity([-1, 1, 1, 1])
+            cfg['tile_rc'] = autotvm.task.space.SplitEntity([-1, 1])
+            cfg['auto_unroll_max_step'] = autotvm.task.space.OtherOptionEntity(0)
+            cfg['unroll_explicit'] = autotvm.task.space.OtherOptionEntity(1)
+            self.memory[key] = cfg
+            return cfg
+
+    def run_test_conv3d_cuda(dtype, out_dtype, scale, dshape, kshape,
+                             padding=(1, 1, 1),
+                             groups=1,
+                             dilation=(1, 1, 1),
+                             prepack=False,
+                             **attrs):
+
+        x = relay.var("x", shape=dshape, dtype=dtype)
+        w = relay.var("w", shape=kshape, dtype=dtype)
+        if prepack:
+            tile_size = _infer_tile_size(np.zeros(shape=dshape), np.zeros(shape=kshape))
+            w_packed = relay.nn.contrib_conv3d_winograd_weight_transform(w, tile_size)
+
+            y = relay.nn.contrib_conv3d_winograd_without_weight_transform(
+                x, w_packed, tile_size,
+                padding=padding,
+                dilation=dilation,
+                groups=groups,
+                channels=kshape[0],
+                **attrs)
+        else:
+            y = relay.nn.conv3d(x, w,
+                                padding=padding,
+                                dilation=dilation,
+                                groups=groups,
+                                **attrs)
+        func = relay.Function([x, w], y)
+        mod = tvm.IRModule()
+        mod['main'] = func
+        mod = relay.transform.InferType()(mod)
+
+        data = np.random.uniform(-scale, scale, size=dshape).astype(dtype)
+        kernel = np.random.uniform(-scale, scale, size=kshape).astype(dtype)
+        ref_res = topi.testing.conv3d_ncdhw_python(
+            data.astype(out_dtype), kernel.astype(out_dtype), 1, padding,
+            groups=groups)
+
+        with WinogradFallback(), relay.build_config(opt_level=3):
+            for target, ctx in ctx_list():
+                if target != 'cuda':
+                    continue
+                params = {'w': tvm.nd.array(kernel)}
+                graph, lib, params = relay.build_module.build(mod, target=target, params=params)
+                module = tvm.contrib.graph_runtime.create(graph, lib, ctx)
+                module.set_input('x', tvm.nd.array(data))
+                module.set_input(**params)
+                module.run()
+                op_res1 = module.get_output(0)
+                tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-3, atol=1e-3)
+
+    # normal winograd: stride 1, padding 1, kernel 3x3x3
+    dshape = (1, 32, 16, 16, 16)
+    kshape = (64, 32, 3, 3, 3)
+    run_test_conv3d_cuda("float32", "float32", 1, dshape, kshape,
+                         padding=(1, 1, 1), kernel_size=(3, 3, 3))
+    # Without depth transform using 1x3x3 kernel.
+    kshape = (64, 32, 1, 3, 3)
+    run_test_conv3d_cuda("float32", "float32", 1, dshape, kshape,
+                         padding=(0, 1, 1), kernel_size=(1, 3, 3))
+
+    # extended winograd: stride 1, padding N, kernel NxNxN
+    dshape = (1, 61, 20, 20, 20)
+    kshape = (120, 61, 5, 5, 5)
+    run_test_conv3d_cuda("float32", "float32", 1, dshape, kshape,
+                         padding=(2, 2, 2), channels=120, kernel_size=(5, 5, 5))
+    # Without depth transform
+    kshape = (120, 61, 1, 5, 5)
+    run_test_conv3d_cuda("float32", "float32", 1, dshape, kshape,
+                         padding=(0, 2, 2), channels=120, kernel_size=(1, 5, 5))
+
 
 def test_conv2d_transpose_infer_type():
     # symbolic in batch dimension
@@ -1268,6 +1357,7 @@ if __name__ == "__main__":
     test_conv2d_winograd()
     test_conv3d_run()
     test_conv3d_ndhwc_run()
+    test_conv3d_winograd()
     test_bitserial_conv2d_infer_type()
     test_batch_flatten()
     test_upsampling()
index 302171e..83ddedc 100644 (file)
@@ -31,6 +31,8 @@ from . import conv2d_alter_op
 from .conv2d_transpose_nchw import *
 from .deformable_conv2d import *
 from .conv3d import *
+from .conv3d_winograd import *
+from . import conv3d_alter_op
 from .reduction import schedule_reduce
 from .softmax import schedule_softmax
 from .injective import schedule_injective, schedule_elemwise, schedule_broadcast
diff --git a/topi/python/topi/cuda/conv3d_alter_op.py b/topi/python/topi/cuda/conv3d_alter_op.py
new file mode 100644 (file)
index 0000000..fbda456
--- /dev/null
@@ -0,0 +1,95 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-variable,unused-argument
+"""Conv3D alter op and legalize functions for cuda backend"""
+
+import logging
+import tvm
+from tvm import te
+from tvm import relay
+from tvm import autotvm
+
+from .. import nn
+from ..util import get_const_tuple
+from .conv3d_winograd import _infer_tile_size
+
+logger = logging.getLogger('topi')
+
+@nn.conv3d_alter_layout.register(["cuda", "gpu"])
+def _alter_conv3d_layout(attrs, inputs, tinfos, out_type):
+    target = tvm.target.Target.current(allow_none=False)
+    dispatch_ctx = autotvm.task.DispatchContext.current
+
+    _, outs = relay.backend.compile_engine.select_implementation(
+        relay.op.get("nn.conv3d"), attrs, tinfos, out_type, target)
+    workload = autotvm.task.get_workload(outs)
+    if workload is None:
+        # The best implementation is not an AutoTVM template,
+        # we then assume it's not necessary to alter this op.
+        return None
+    cfg = dispatch_ctx.query(target, workload)
+    if cfg.is_fallback:  # if is fallback, clear query cache and return None
+        autotvm.task.clear_fallback_cache(target, workload)
+        return None
+
+    topi_tmpl = workload[0]
+    new_attrs = {k: attrs[k] for k in attrs.keys()}
+
+    strides = attrs.get_int_tuple("strides")
+    padding = attrs.get_int_tuple("padding")
+    dilation = attrs.get_int_tuple("dilation")
+    groups = attrs.get_int('groups')
+    data_layout = attrs["data_layout"]
+    kernel_layout = attrs["kernel_layout"]
+    data, kernel = tinfos
+    out_dtype = out_type.dtype
+
+    if topi_tmpl == "conv3d_ncdhw_winograd.cuda":
+        if dilation != (1, 1, 1):
+            logger.warning("Does not support weight pre-transform for dilated 3D convolution.")
+            return None
+
+        assert data_layout == "NCDHW" and kernel_layout == "OIDHW"
+        N, CI, D, H, W = get_const_tuple(data.shape)
+        CO, _, KD, KH, KW = get_const_tuple(kernel.shape)
+
+        # Pre-compute weight transformation in winograd
+        tile_size = _infer_tile_size(tinfos[0], tinfos[1])
+
+        weight = relay.nn.contrib_conv3d_winograd_weight_transform(inputs[1], tile_size=tile_size)
+        new_attrs['tile_size'] = tile_size
+        new_attrs['channels'] = CO
+
+        # Store the same config for the altered operators (workload)
+        new_data = data
+        # Check if depth is transformed or not
+        if 2 < KD < 8 and KD == KH:
+            new_weight = te.placeholder(
+                (KD + tile_size - 1, KH + tile_size - 1, KW + tile_size - 1, CO, CI),
+                dtype=kernel.dtype)
+        else:
+            new_weight = te.placeholder(
+                (KH + tile_size - 1, KW + tile_size - 1, KD, CO, CI),
+                dtype=kernel.dtype)
+        new_workload = autotvm.task.args_to_workload(
+            [new_data, new_weight, strides, padding, dilation, out_dtype],
+            "conv3d_ncdhw_winograd_without_weight_transform.cuda")
+        dispatch_ctx.update(target, new_workload, cfg)
+        return relay.nn.contrib_conv3d_winograd_without_weight_transform(
+            inputs[0], weight, **new_attrs)
+
+    return None
diff --git a/topi/python/topi/cuda/conv3d_winograd.py b/topi/python/topi/cuda/conv3d_winograd.py
new file mode 100644 (file)
index 0000000..c9e8446
--- /dev/null
@@ -0,0 +1,627 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-variable,unused-argument
+"""Winograd template for cuda backend"""
+
+import logging
+import tvm
+from tvm import te
+from tvm import autotvm
+
+from .. import nn
+from ..util import get_const_int, get_const_tuple, traverse_inline, simplify
+from ..nn.winograd_util import winograd_transform_matrices
+
+logger = logging.getLogger('conv3d_winograd')
+
+
+def _infer_tile_size(data, kernel):
+    N, CI, D, H, W = get_const_tuple(data.shape)
+
+    if H % 8 == 0:
+        return 4
+    return 2
+
+
+def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, pre_computed):
+    """Compute declaration for winograd"""
+    tile_size = _infer_tile_size(data, kernel)
+
+    N, CI, D, H, W = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_d = dilation_h = dilation_w = dilation
+    else:
+        dilation_d, dilation_h, dilation_w = dilation
+    DSTR, HSTR, WSTR = (strides, strides, strides) if isinstance(strides, int) else strides
+
+    if not pre_computed:  # kernel tensor is raw tensor, do strict check
+        if dilation_d != 1 or dilation_h != 1 or dilation_w != 1:
+            kernel = nn.dilate(kernel, (1, 1, dilation_d, dilation_h, dilation_w))
+        CO, CI, KD, KH, KW = get_const_tuple(kernel.shape)
+        alpha = KW + tile_size - 1
+        assert DSTR == 1 and HSTR == 1 and WSTR == 1 and KD == KH and KH == KW
+    else:
+        # kernel tensor is pre-transformed. this op is created by alter op layout.
+        # dilation is not supported
+        alpha, _, _, CO, CI = get_const_tuple(kernel.shape)
+        KD = KH = KW = alpha + 1 - tile_size
+        assert DSTR == 1 and HSTR == 1 and WSTR == 1 and \
+               dilation_d == 1 and dilation_h == 1 and dilation_w == 1
+
+    pf, pt, pl, pb, pd, pr = nn.get_pad_tuple3d(padding, (KD, KH, KW))
+    data_pad = nn.pad(data, (0, 0, pf, pt, pl), (0, 0, pb, pd, pr), name="data_pad")
+
+    r = KW
+    m = tile_size
+    A, B, G = winograd_transform_matrices(m, r, out_dtype)
+
+    D = (D + pf + pb - KD) // DSTR + 1
+    H = (H + pt + pd - KH) // HSTR + 1
+    W = (W + pl + pr - KW) // WSTR + 1
+    nD, nH, nW = (D + m - 1) // m, (H + m - 1) // m, (W + m - 1) // m
+    P = N * nD * nH * nW
+
+    # transform kernel
+    if not pre_computed:
+        # Check if we are currently tuning, if so we want to avoid counting
+        # prepacking in time costs. Just use a placeholder with the packed shape instead.
+        if autotvm.GLOBAL_SCOPE.in_tuning:
+            kernel_pack = te.placeholder((alpha, alpha, alpha, CO, CI),
+                                         dtype=kernel.dtype,
+                                         name='kernel_pack')
+        else:
+            r_kd = te.reduce_axis((0, KD), name='r_kd')
+            r_kh = te.reduce_axis((0, KH), name='r_kh')
+            r_kw = te.reduce_axis((0, KW), name='r_kw')
+            kernel_pack = te.compute(
+                (alpha, alpha, alpha, CO, CI),
+                lambda omg, eps, nu, co, ci: te.sum(
+                    kernel[co][ci][r_kd][r_kh][r_kw] * G[omg][r_kd] * G[eps][r_kh] * G[nu][r_kw],
+                    axis=[r_kd, r_kh, r_kw]),
+                name='kernel_pack')
+    else:
+        kernel_pack = kernel
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+    # pack input tile
+    input_tile = te.compute((CI, P, alpha, alpha, alpha),
+                            lambda c, p, omg, eps, nu: data_pad[idxdiv(p, (nD * nH * nW))]
+                            [c]
+                            [idxmod(idxdiv(p, nH * nW), nD) * m + omg]
+                            [idxmod(idxdiv(p, nW), nH) * m + eps]
+                            [idxmod(p, nW) * m + nu],
+                            name='d')
+
+    # transform data
+    r_a = te.reduce_axis((0, alpha), 'r_a')
+    r_b = te.reduce_axis((0, alpha), 'r_b')
+    r_c = te.reduce_axis((0, alpha), 'r_c')
+    data_pack = te.compute(
+        (alpha, alpha, alpha, CI, P),
+        lambda omg, eps, nu, ci, p: te.sum(
+            input_tile[ci][p][r_a][r_b][r_c] * B[r_a][omg] * B[r_b][eps] * B[r_c][nu],
+            axis=[r_a, r_b, r_c]),
+        name='data_pack')
+
+    # do batch gemm
+    ci = te.reduce_axis((0, CI), name='ci')
+    bgemm = te.compute(
+        (alpha, alpha, alpha, CO, P),
+        lambda omg, eps, nu, co, p: te.sum(
+            kernel_pack[omg][eps][nu][co][ci] * data_pack[omg][eps][nu][ci][p], axis=[ci]),
+        name='bgemm')
+
+    # inverse transform
+    r_a = te.reduce_axis((0, alpha), 'r_a')
+    r_b = te.reduce_axis((0, alpha), 'r_b')
+    r_c = te.reduce_axis((0, alpha), 'r_c')
+    inverse = te.compute((CO, P, m, m, m),
+                         lambda co, p, vd, vh, vw: te.sum(
+                             bgemm[r_a][r_b][r_c][co][p] * A[r_a][vd] * A[r_b][vh] * A[r_c][vw],
+                             axis=[r_a, r_b, r_c]),
+                         name='inverse')
+
+    # output
+    output = te.compute((N, CO, D, H, W),
+                        lambda n, co, d, h, w: inverse[co, n * nD * nH * nW + idxdiv(d, m) * nH * nW
+                                                       + idxdiv(h, m) * nW + idxdiv(w, m),
+                                                       idxmod(d, m),
+                                                       idxmod(h, m),
+                                                       idxmod(w, m)],
+                        name='output',
+                        tag='conv3d_ncdhw_winograd')
+    cfg.add_flop(2 * N * CO * D * H * W * CI * KD * KH * KW)
+
+    return output
+
+
+def winograd_without_depth_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype,
+                                pre_computed):
+    """Compute declaration for winograd without transforming depth"""
+    tile_size = _infer_tile_size(data, kernel)
+
+    N, CI, D, H, W = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_d = dilation_h = dilation_w = dilation
+    else:
+        dilation_d, dilation_h, dilation_w = dilation
+    DSTR, HSTR, WSTR = (strides, strides, strides) if isinstance(strides, int) else strides
+
+    if not pre_computed:  # kernel tensor is raw tensor, do strict check
+        if dilation_d != 1 or dilation_h != 1 or dilation_w != 1:
+            kernel = nn.dilate(kernel, (1, 1, dilation_d, dilation_h, dilation_w))
+        CO, CI, KD, KH, KW = get_const_tuple(kernel.shape)
+        alpha = KW + tile_size - 1
+        assert HSTR == 1 and WSTR == 1 and KH == KW
+    else:
+        # kernel tensor is pre-transfomred. this op is created by alter op layout.
+        # dilation is not supported
+        alpha, _, KD, CO, CI = get_const_tuple(kernel.shape)
+        KH = KW = alpha + 1 - tile_size
+        assert HSTR == 1 and WSTR == 1 and dilation_h == 1 and dilation_w == 1
+
+    pf, pt, pl, pb, pd, pr = nn.get_pad_tuple3d(padding, (KD, KH, KW))
+    data_pad = nn.pad(data, (0, 0, pf, pt, pl), (0, 0, pb, pd, pr), name="data_pad")
+    out_depth = simplify((D - KD + pf + pb) // DSTR + 1)
+    D += pf + pb
+
+    r = KW
+    m = tile_size
+    A, B, G = winograd_transform_matrices(m, r, out_dtype)
+
+    H = (H + pt + pd - KH) // HSTR + 1
+    W = (W + pl + pr - KW) // WSTR + 1
+    nH, nW = (H + m-1) // m, (W + m-1) // m
+    P = N * nH * nW
+
+    # transform kernel
+    if not pre_computed:
+        # During autotuning dont count kernel packing as a time cost
+        # as it will later be removed via alter_op_layout.
+        if autotvm.GLOBAL_SCOPE.in_tuning:
+            kernel_pack = te.placeholder((alpha, alpha, KD, CO, CI),
+                                         dtype=kernel.dtype,
+                                         name='kernel_pack')
+        else:
+            r_kh = te.reduce_axis((0, KH), name='r_kh')
+            r_kw = te.reduce_axis((0, KW), name='r_kw')
+            kernel_pack = te.compute(
+                (alpha, alpha, KD, CO, CI),
+                lambda eps, nu, d, co, ci: te.sum(
+                    kernel[co][ci][d][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]),
+                name='kernel_pack')
+    else:
+        kernel_pack = kernel
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+    # pack input tile
+    input_tile = te.compute((CI, D, P, alpha, alpha), lambda c, d, p, eps, nu:
+                            data_pad[idxdiv(p, (nH * nW))][c][d]
+                            [idxmod(idxdiv(p, nW), nH) * m + eps]
+                            [idxmod(p, nW) * m + nu], name='d')
+
+    # transform data
+    r_a = te.reduce_axis((0, alpha), 'r_a')
+    r_b = te.reduce_axis((0, alpha), 'r_b')
+    data_pack = te.compute((alpha, alpha, CI, D, P), lambda eps, nu, ci, d, p:
+                           te.sum(input_tile[ci][d][p][r_a][r_b] * B[r_a][eps] * B[r_b][nu],
+                                  axis=[r_a, r_b]), name='data_pack')
+
+    # do batch gemm
+    ci = te.reduce_axis((0, CI), name='ci')
+    rz = te.reduce_axis((0, KD), name='rz')
+    bgemm = te.compute((alpha, alpha, CO, out_depth, P), lambda eps, nu, co, d, p:
+                       te.sum(kernel_pack[eps][nu][rz][co][ci] *
+                              data_pack[eps][nu][ci][d * DSTR + rz][p],
+                              axis=[ci, rz]), name='bgemm')
+
+    # inverse transform
+    r_a = te.reduce_axis((0, alpha), 'r_a')
+    r_b = te.reduce_axis((0, alpha), 'r_b')
+    inverse = te.compute((CO, out_depth, P, m, m), lambda co, d, p, vh, vw:
+                         te.sum(bgemm[r_a][r_b][co][d][p] * A[r_a][vh] * A[r_b][vw],
+                                axis=[r_a, r_b]), name='inverse')
+
+    # output
+    output = te.compute((N, CO, out_depth, H, W), lambda n, co, d, h, w:
+                        inverse[co,
+                                d,
+                                n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m),
+                                idxmod(h, m),
+                                idxmod(w, m)],
+                        name='output', tag='conv3d_ncdhw_winograd_without_depth')
+    cfg.add_flop(2 * N * CO * D * H * W * CI * KD * KH * KW)
+
+    return output
+
+
+def schedule_winograd_cuda(cfg, s, output, pre_computed):
+    """Schedule winograd template"""
+    # get stages
+    inverse = s[output].op.input_tensors[0]
+    bgemm, A = s[inverse].op.input_tensors
+    kernel_pack, data_pack = s[bgemm].op.input_tensors
+    input_tile, B = s[data_pack].op.input_tensors
+    pad_data = s[input_tile].op.input_tensors[0]
+
+    # data transform
+    s[B].compute_inline()
+
+    data_l = s.cache_write(data_pack, 'local')
+    omg, eps, nu, c, p = s[data_l].op.axis
+    r_a, r_b, r_c = s[data_l].op.reduce_axis
+    # TODO unrolling by omg, eps, nu may improve performance but
+    # in some cases causes extremely long build times due to imperfect tiling.
+    for axis in [r_a, r_b, r_c]:
+        s[data_l].unroll(axis)
+
+    omg, eps, nu, c, p = s[data_pack].op.axis
+    p, pi = s[data_pack].split(p, 1)
+    fused = s[data_pack].fuse(c, p)
+    bb, tt = s[data_pack].split(fused, 128)
+    s[data_pack].reorder(bb, tt, pi, omg, eps, nu)
+    s[data_pack].bind(bb, te.thread_axis("blockIdx.x"))
+    s[data_pack].bind(tt, te.thread_axis("threadIdx.x"))
+
+    s[data_l].compute_at(s[data_pack], pi)
+    s[input_tile].compute_at(s[data_pack], pi)
+    s[pad_data].compute_inline()
+
+    # transform kernel
+    if not pre_computed and not autotvm.GLOBAL_SCOPE.in_tuning:
+        kernel, G = s[kernel_pack].op.input_tensors
+        omg, eps, nu, co, ci = s[kernel_pack].op.axis
+        s[G].compute_inline()
+        r_a, r_b, r_c = s[kernel_pack].op.reduce_axis
+        # Could add additional unrolling by omg, eps, nu in the future.
+        for axis in [r_a, r_b, r_c]:
+            s[kernel_pack].unroll(axis)
+
+        fused = s[kernel_pack].fuse(co, ci)
+        bb, tt = s[kernel_pack].split(fused, 128)
+        s[kernel_pack].reorder(bb, tt, omg, eps, nu, r_a, r_b, r_c)
+        s[kernel_pack].bind(bb, te.thread_axis("blockIdx.x"))
+        s[kernel_pack].bind(tt, te.thread_axis("threadIdx.x"))
+    else:
+        kernel = kernel_pack
+
+    if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag:
+        s[kernel].compute_inline()
+
+    ##### space definition begin #####
+    b1, b2, b3, y, x = s[bgemm].op.axis
+    rc = s[bgemm].op.reduce_axis[0]
+    alpha = get_const_int(b1.dom.extent)
+
+    cfg.define_split(
+        "tile_b",
+        cfg.axis(alpha * alpha * alpha),
+        num_outputs=4,
+        filter=lambda x: x.size[-3:] == [1, 1, 1])
+    cfg.define_split("tile_y", y, num_outputs=4)
+    cfg.define_split("tile_x", x, num_outputs=4)
+    cfg.define_split("tile_rc", rc, num_outputs=2)
+    cfg.define_knob("auto_unroll_max_step", [0, 128, 1500])
+    target = tvm.target.Target.current()
+    if target.target_name in ['nvptx', 'rocm']:
+        cfg.define_knob("unroll_explicit", [1])
+    else:
+        cfg.define_knob("unroll_explicit", [0, 1])
+    ##### space definition end #####
+
+    # batch gemm
+    C = bgemm
+    A0, B0 = kernel_pack, data_pack
+
+    OL = s.cache_write(C, 'local')
+    AA = s.cache_read(A0, 'shared', [OL])
+    BB = s.cache_read(B0, 'shared', [OL])
+
+    b = s[bgemm].fuse(b1, b2, b3)
+
+    # tile and bind spatial axes
+    bgemm_scope, b = s[bgemm].split(b, nparts=1)
+    bz, vz, tz, zi = cfg["tile_b"].apply(s, C, b)
+    by, vy, ty, yi = cfg["tile_y"].apply(s, C, y)
+    bx, vx, tx, xi = cfg["tile_x"].apply(s, C, x)
+    s[C].bind(bz, te.thread_axis("blockIdx.z"))
+    s[C].bind(by, te.thread_axis("blockIdx.y"))
+    s[C].bind(bx, te.thread_axis("blockIdx.x"))
+    s[C].bind(vz, te.thread_axis("vthread"))
+    s[C].bind(vy, te.thread_axis("vthread"))
+    s[C].bind(vx, te.thread_axis("vthread"))
+    s[C].bind(tz, te.thread_axis("threadIdx.z"))
+    s[C].bind(ty, te.thread_axis("threadIdx.y"))
+    s[C].bind(tx, te.thread_axis("threadIdx.x"))
+    s[C].reorder(bgemm_scope, bz, by, bx, vz, vy, vx, tz, ty, tx, zi, yi, xi)
+
+    # tile reduction axes
+    s[OL].compute_at(s[C], tx)
+    b1, b2, b3, y, x = s[OL].op.axis
+    b = s[OL].fuse(b1, b2, b3)
+    rc, = s[OL].op.reduce_axis
+    rco, rci = cfg['tile_rc'].apply(s, OL, rc)
+    s[OL].reorder(rco, rci, b, y, x)
+
+    s[AA].compute_at(s[OL], rco)
+    s[BB].compute_at(s[OL], rco)
+
+    # cooperative fetching
+    for load in [AA, BB]:
+        fused = s[load].fuse(*list(s[load].op.axis))
+        fused, tx = s[load].split(fused, cfg["tile_x"].size[2])
+        fused, ty = s[load].split(fused, cfg["tile_y"].size[2])
+        fused, tz = s[load].split(fused, cfg["tile_b"].size[2])
+        s[load].bind(tz, te.thread_axis("threadIdx.z"))
+        s[load].bind(ty, te.thread_axis("threadIdx.y"))
+        s[load].bind(tx, te.thread_axis("threadIdx.x"))
+
+    s[C].pragma(bgemm_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
+    s[C].pragma(bgemm_scope, 'unroll_explicit', cfg['unroll_explicit'].val)
+
+    # schedule inverse, output and fusion
+    if output.op in s.outputs:
+        OL = None
+    else:
+        OL = output
+        s[OL].set_scope('local')
+        output = s.outputs[0]
+
+    m = alpha - 3 + 1
+    n, co, d, h, w = s[output].op.axis
+    do, di = s[output].split(d, m)
+    ho, hi = s[output].split(h, m)
+    wo, wi = s[output].split(w, m)
+    s[output].reorder(n, co, do, ho, wo, di, hi, wi)
+    inverse_scope, n = s[output].split(n, nparts=1)
+
+    fused = s[output].fuse(n, co, do, ho, wo)
+    bb, tt = s[output].split(fused, 128)
+
+    s[output].bind(bb, te.thread_axis("blockIdx.x"))
+    s[output].bind(tt, te.thread_axis("threadIdx.x"))
+
+    if OL is not None:
+        s[OL].compute_at(s[output], tt)
+
+    s[A].compute_inline()
+    co, p, vd, vh, vw = s[inverse].op.axis
+    r_a, r_b, r_c = s[inverse].op.reduce_axis
+    # Could add additional unrolling of vd, vh, vw, in the future
+    for axis in [r_a, r_b, r_c]:
+        s[inverse].unroll(axis)
+    s[inverse].compute_at(s[output], tt)
+
+    return s
+
+
+def schedule_winograd_no_depth_cuda(cfg, s, output, pre_computed):
+    """Schedule winograd template"""
+    # get stages
+    inverse = s[output].op.input_tensors[0]
+    bgemm, A = s[inverse].op.input_tensors
+    kernel_pack, data_pack = s[bgemm].op.input_tensors
+    input_tile, B = s[data_pack].op.input_tensors
+    pad_data = s[input_tile].op.input_tensors[0]
+
+    # data transform
+    s[B].compute_inline()
+
+    data_l = s.cache_write(data_pack, 'local')
+    eps, nu, c, d, p = s[data_l].op.axis
+    r_a, r_b = s[data_l].op.reduce_axis
+    for axis in [eps, nu, r_a, r_b]:
+        s[data_l].unroll(axis)
+
+    eps, nu, c, d, p = s[data_pack].op.axis
+    p, pi = s[data_pack].split(p, 1)
+    fused = s[data_pack].fuse(c, d, p)
+    bb, tt = s[data_pack].split(fused, 128)
+    s[data_pack].reorder(bb, tt, pi, eps, nu)
+    s[data_pack].bind(bb, te.thread_axis("blockIdx.x"))
+    s[data_pack].bind(tt, te.thread_axis("threadIdx.x"))
+
+    s[data_l].compute_at(s[data_pack], pi)
+    s[input_tile].compute_at(s[data_pack], pi)
+    s[pad_data].compute_inline()
+
+    # transform kernel
+    if not pre_computed and not autotvm.GLOBAL_SCOPE.in_tuning:
+        kernel, G = s[kernel_pack].op.input_tensors
+        eps, nu, kd, co, ci = s[kernel_pack].op.axis
+        s[G].compute_inline()
+        r_a, r_b = s[kernel_pack].op.reduce_axis
+        for axis in [eps, nu, r_a, r_b]:
+            s[kernel_pack].unroll(axis)
+
+        fused = s[kernel_pack].fuse(kd, co, ci)
+        bb, tt = s[kernel_pack].split(fused, 128)
+        s[kernel_pack].reorder(bb, tt, eps, nu, r_a, r_b)
+        s[kernel_pack].bind(bb, te.thread_axis("blockIdx.x"))
+        s[kernel_pack].bind(tt, te.thread_axis("threadIdx.x"))
+    else:
+        kernel = kernel_pack
+
+    if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag:
+        s[kernel].compute_inline()
+
+    ##### space definition begin #####
+    b1, b2, z, y, x = s[bgemm].op.axis
+    # Combine channel and depth axes.
+    rc = s[bgemm].op.reduce_axis[0]
+    rz = s[bgemm].op.reduce_axis[1]
+    alpha = get_const_int(b1.dom.extent)
+
+    cfg.define_split("tile_b", cfg.axis(alpha * alpha), num_outputs=4,
+                     filter=lambda x: x.size[-3:] == [1, 1, 1])
+    cfg.define_split("tile_y", y, num_outputs=4)
+    cfg.define_split("tile_x", x, num_outputs=4)
+    cfg.define_split("tile_rc", rc, num_outputs=2)
+    cfg.define_split("tile_rz", rz, num_outputs=2)
+    cfg.define_knob("auto_unroll_max_step", [0, 128, 1500])
+    target = tvm.target.Target.current()
+    if target.target_name in ['nvptx', 'rocm']:
+        cfg.define_knob("unroll_explicit", [1])
+    else:
+        cfg.define_knob("unroll_explicit", [0, 1])
+    ##### space definition end #####
+
+    # batch gemm
+    C = bgemm
+    A0, B0 = kernel_pack, data_pack
+
+    OL = s.cache_write(C, 'local')
+    AA = s.cache_read(A0, 'shared', [OL])
+    BB = s.cache_read(B0, 'shared', [OL])
+
+    b = s[bgemm].fuse(b1, b2)
+    y = s[bgemm].fuse(z, y)
+
+    # tile and bind spatial axes
+    bgemm_scope, b = s[bgemm].split(b, nparts=1)
+    bz, vz, tz, zi = cfg["tile_b"].apply(s, C, b)
+    by, vy, ty, yi = cfg["tile_y"].apply(s, C, y)
+    bx, vx, tx, xi = cfg["tile_x"].apply(s, C, x)
+    s[C].bind(bz, te.thread_axis("blockIdx.z"))
+    s[C].bind(by, te.thread_axis("blockIdx.y"))
+    s[C].bind(bx, te.thread_axis("blockIdx.x"))
+    s[C].bind(vz, te.thread_axis("vthread"))
+    s[C].bind(vy, te.thread_axis("vthread"))
+    s[C].bind(vx, te.thread_axis("vthread"))
+    s[C].bind(tz, te.thread_axis("threadIdx.z"))
+    s[C].bind(ty, te.thread_axis("threadIdx.y"))
+    s[C].bind(tx, te.thread_axis("threadIdx.x"))
+    s[C].reorder(bgemm_scope, bz, by, bx, vz, vy, vx, tz, ty, tx, zi, yi, xi)
+
+    # tile reduction axes
+    s[OL].compute_at(s[C], tx)
+    b1, b2, y1, y2, x = s[OL].op.axis
+    y = s[OL].fuse(y1, y2)
+    b = s[OL].fuse(b1, b2)
+    rc, rz = s[OL].op.reduce_axis
+    rco, rci = cfg['tile_rc'].apply(s, OL, rc)
+    rzo, rzi = cfg['tile_rz'].apply(s, OL, rz)
+    s[OL].reorder(rco, rzo, rci, rzi, b, y, x)
+
+    s[AA].compute_at(s[OL], rco)
+    s[BB].compute_at(s[OL], rco)
+
+    # cooperative fetching
+    for load in [AA, BB]:
+        fused = s[load].fuse(*list(s[load].op.axis))
+        fused, tx = s[load].split(fused, cfg["tile_x"].size[2])
+        fused, ty = s[load].split(fused, cfg["tile_y"].size[2])
+        fused, tz = s[load].split(fused, cfg["tile_b"].size[2])
+        s[load].bind(tz, te.thread_axis("threadIdx.z"))
+        s[load].bind(ty, te.thread_axis("threadIdx.y"))
+        s[load].bind(tx, te.thread_axis("threadIdx.x"))
+
+    s[C].pragma(bgemm_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
+    s[C].pragma(bgemm_scope, 'unroll_explicit', cfg['unroll_explicit'].val)
+
+    # schedule inverse, output and fusion
+    if output.op in s.outputs:
+        OL = None
+    else:
+        OL = output
+        s[OL].set_scope('local')
+        output = s.outputs[0]
+
+    m = alpha - 3 + 1
+    n, co, d, h, w = s[output].op.axis
+    do, di = s[output].split(d, m)
+    ho, hi = s[output].split(h, m)
+    wo, wi = s[output].split(w, m)
+    s[output].reorder(n, co, do, ho, wo, di, hi, wi)
+    inverse_scope, n = s[output].split(n, nparts=1)
+
+    fused = s[output].fuse(n, co, do, ho, wo)
+    bb, tt = s[output].split(fused, 128)
+
+    s[output].bind(bb, te.thread_axis("blockIdx.x"))
+    s[output].bind(tt, te.thread_axis("threadIdx.x"))
+
+    if OL is not None:
+        s[OL].compute_at(s[output], tt)
+
+    s[A].compute_inline()
+    co, d, p, vh, vw = s[inverse].op.axis
+    r_a, r_b = s[inverse].op.reduce_axis
+    for axis in [vh, vw, r_a, r_b]:
+        s[inverse].unroll(axis)
+    s[inverse].compute_at(s[output], tt)
+
+    return s
+
+
+@autotvm.register_topi_compute("conv3d_ncdhw_winograd.cuda")
+def conv3d_ncdhw_winograd(cfg, data, kernel, strides, padding, dilation, out_dtype):
+    CO, CI, KD, KH, KW = get_const_tuple(kernel.shape)
+    # Check if we can transform depth.
+    if 2 < KD < 8 and KD == KH:
+        return winograd_cuda(
+            cfg, data, kernel, strides, padding, dilation, out_dtype, pre_computed=False)
+
+    return winograd_without_depth_cuda(
+        cfg, data, kernel, strides, padding, dilation, out_dtype, pre_computed=False)
+
+
+@autotvm.register_topi_schedule("conv3d_ncdhw_winograd.cuda")
+def schedule_conv3d_ncdhw_winograd(cfg, outs):
+    """Dispatch to schedule approriate for conv3d winograd algorithm used."""
+    s = te.create_schedule([x.op for x in outs])
+
+    def _callback(op):
+        if 'conv3d_ncdhw_winograd_without_depth' in op.tag:
+            schedule_winograd_no_depth_cuda(cfg, s, op.output(0), pre_computed=False)
+        elif 'conv3d_ncdhw_winograd' in op.tag:
+            schedule_winograd_cuda(cfg, s, op.output(0), pre_computed=False)
+
+    traverse_inline(s, outs[0].op, _callback)
+    return s
+
+
+@autotvm.register_topi_compute("conv3d_ncdhw_winograd_without_weight_transform.cuda")
+def conv3d_ncdhw_winograd_without_weight_transform(cfg, data, kernel, strides, padding, dilation,
+                                                   out_dtype):
+    A, B, C, _, _ = get_const_tuple(kernel.shape)
+    # Check if we can transform depth.
+    if A == B == C:
+        return winograd_cuda(
+            cfg, data, kernel, strides, padding, dilation, out_dtype, pre_computed=True)
+
+    return winograd_without_depth_cuda(
+        cfg, data, kernel, strides, padding, dilation, out_dtype, pre_computed=True)
+
+
+@autotvm.register_topi_schedule("conv3d_ncdhw_winograd_without_weight_transform.cuda")
+def schedule_conv3d_ncdhw_winograd_without_weight_transform(cfg, outs):
+    """TOPI schedule callback"""
+    s = te.create_schedule([x.op for x in outs])
+
+    def _callback(op):
+        if 'conv3d_ncdhw_winograd_without_depth' in op.tag:
+            schedule_winograd_no_depth_cuda(cfg, s, op.output(0), pre_computed=True)
+        elif 'conv3d_ncdhw_winograd' in op.tag:
+            schedule_winograd_cuda(cfg, s, op.output(0), pre_computed=True)
+
+    traverse_inline(s, outs[0].op, _callback)
+    return s
index 43b1282..2be4bbb 100644 (file)
@@ -187,6 +187,43 @@ def schedule_conv2d_winograd_weight_transform(outs):
     return s
 
 
+def schedule_conv3d_winograd_weight_transform(outs):
+    """Schedule for weight transformation of 3D winograd
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+          The computation graph description of this operator
+          in the format of an array of tensors.
+
+    Returns
+    -------
+    sch: Schedule
+        The computation schedule for the op.
+    """
+    # Typically this is computed in PreCompute pass
+    # so we make a schedule here for cpu llvm
+    s = te.create_schedule([x.op for x in outs])
+    output = outs[0]
+    _, G = s[output].op.input_tensors
+    s[G].compute_inline()
+    transform_depth = len(s[output].op.reduce_axis) == 3
+    if transform_depth:
+        omg, eps, nu, ci, co = s[output].op.axis
+        r_kd, r_kh, r_kw = s[output].op.reduce_axis
+        s[output].reorder(co, ci, omg, eps, nu, r_kd, r_kh, r_kw)
+        for axis in [r_kd, r_kh, r_kw]:
+            s[output].unroll(axis)
+    else:
+        eps, nu, d, ci, co = s[output].op.axis
+        r_kh, r_kw = s[output].op.reduce_axis
+        s[output].reorder(co, ci, d, eps, nu, r_kh, r_kw)
+        for axis in [r_kh, r_kw]:
+            s[output].unroll(axis)
+    s[output].parallel(co)
+    return s
+
+
 def schedule_conv2d_winograd_without_weight_transform(outs):
     """Schedule for winograd without weight transformation
 
index d6bd642..2bac284 100644 (file)
 # pylint: disable=invalid-name, unused-variable, too-many-locals
 # pylint: disable=unused-argument, redefined-builtin, no-else-return
 """Conv3D operators"""
+import tvm
 from tvm import te
 
 from .pad import pad
 from .util import get_pad_tuple3d
-from ..util import simplify
+from ..util import simplify, get_const_tuple
+from .winograd_util import winograd_transform_matrices
 
 
 def conv3d_ncdhw(Input, Filter, stride, padding, dilation, out_dtype=None):
@@ -159,3 +161,74 @@ def conv3d_ndhwc(Input, Filter, stride, padding, dilation, out_dtype='float32'):
             Filter[rd, rh, rw, rc, cc].astype(out_dtype), axis=[rd, rh, rw, rc]),
         name="Conv3dOutput", tag="conv3d_ndhwc")
     return Output
+
+
+def conv3d_winograd_weight_transform(kernel, tile_size):
+    """Weight transformation for 3D winograd
+
+    Parameters
+    ----------
+    kernel: Tensor
+        The raw kernel tensor with layout "NCDHW".
+    tile_size: int
+        Tile size of winograd transform. e.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)
+
+    Returns
+    -------
+    output : tvm.te.Tensor
+        5-D with shape [alpha, alpha, alpha, CO, CI]
+    """
+    CO, CI, KD, KH, KW = get_const_tuple(kernel.shape)
+
+    depth_transform = 2 < KD < 8 and KD == KH
+
+    if depth_transform:
+        assert KD == KH == KW, "Only support NxNxN kernel"
+    else:
+        assert KH == KW, "Only supports DxNxN kernel"
+
+    r = tile_size + KH - 1
+
+    r_kh = te.reduce_axis((0, KH), name='r_kh')
+    r_kw = te.reduce_axis((0, KW), name='r_kw')
+    _, _, G = winograd_transform_matrices(tile_size, KH, kernel.dtype)
+    if depth_transform:
+        shape = (r, r, r, CO, CI)
+        r_kd = te.reduce_axis((0, KD), name='r_kd')
+        return te.compute(
+            shape,
+            lambda omg, eps, nu, co, ci: te.sum(
+                kernel[co][ci][r_kd][r_kh][r_kw] * G[omg][r_kd] * G[eps][r_kh] * G[nu][r_kw],
+                axis=[r_kd, r_kh, r_kw]),
+            name='transform_weight')
+    else:
+        shape = (r, r, KD, CO, CI)
+        return te.compute(
+            shape,
+            lambda eps, nu, d, co, ci: te.sum(
+                kernel[co][ci][d][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]),
+            name='transform_weight')
+
+
+
+@tvm.target.generic_func
+def conv3d_alter_layout(attrs, inputs, tinfos, out_type):
+    """Change Conv3D layout.
+
+    Parameters
+    ----------
+    attrs : tvm.ir.Attrs
+        Attributes of current convolution
+    inputs : tvm.relay.Expr
+        Grouped input symbols
+    tinfos : list
+        Input shape and dtype
+    out_type: type
+        The output type
+
+    Note
+    ----
+    Unlike other TOPI functions, this function operates on both graph level and operator level.
+    """
+    # not to change by default
+    return None
diff --git a/topi/tests/python/test_topi_conv3d_winograd.py b/topi/tests/python/test_topi_conv3d_winograd.py
new file mode 100644 (file)
index 0000000..6d0d99d
--- /dev/null
@@ -0,0 +1,151 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test code for 3d convolution with winograd."""
+
+import numpy as np
+import tvm
+from tvm import te
+from tvm import autotvm
+import topi
+import topi.testing
+from tvm.contrib.pickle_memoize import memoize
+from topi.nn.util import get_pad_tuple3d
+from topi.util import get_const_tuple
+
+from common import get_all_backend
+
+_conv3d_ncdhw_implement = {
+    "gpu": (topi.cuda.conv3d_ncdhw_winograd, topi.cuda.schedule_conv3d_ncdhw_winograd),
+}
+
+
+def verify_conv3d_ncdhw(batch,
+                        in_channel,
+                        in_size,
+                        num_filter,
+                        depth_kernel,
+                        space_kernel,
+                        stride,
+                        padding,
+                        dilation=1,
+                        add_bias=False,
+                        add_relu=False):
+    pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right = get_pad_tuple3d(
+        padding, (depth_kernel, space_kernel, space_kernel))
+    padding_sum = pad_front + pad_back + pad_top + pad_left + pad_bottom + pad_right
+    print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" %
+          (batch, in_channel, in_size, num_filter, space_kernel, stride, padding_sum, dilation))
+
+    in_depth = in_height = in_width = in_size
+
+    A = te.placeholder((batch, in_channel, in_depth, in_height, in_width), name='A')
+    W = te.placeholder((num_filter, in_channel, depth_kernel, space_kernel, space_kernel), name='W')
+    bias = te.placeholder((num_filter, 1, 1, 1), name='bias')
+
+    a_shape = get_const_tuple(A.shape)
+    w_shape = get_const_tuple(W.shape)
+    bias_shape = get_const_tuple(bias.shape)
+    dtype = A.dtype
+
+    @memoize("topi.tests.test_topi_conv3d_ncdhw.verify_conv3d_ncdhw")
+    def get_ref_data():
+        a_np = np.random.uniform(size=a_shape).astype(dtype)
+        w_np = np.random.uniform(size=w_shape).astype(dtype)
+        b_np = np.random.uniform(size=bias_shape).astype(dtype)
+        dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation, dilation))
+        c_np = topi.testing.conv3d_ncdhw_python(a_np, dw_np, stride, padding)
+        if add_bias:
+            c_np += b_np
+        if add_relu:
+            c_np = np.maximum(c_np, 0)
+        return a_np, w_np, b_np, c_np
+
+    a_np, w_np, b_np, c_np = get_ref_data()
+
+    def check_device(device):
+        ctx = tvm.context(device, 0)
+        if not ctx.exist:
+            print("Skip because %s is not enabled" % device)
+            return
+        print("Running on target: %s" % device)
+        fcompute, fschedule = topi.testing.dispatch(device, _conv3d_ncdhw_implement)
+        with tvm.target.create(device):
+            C = fcompute(A, W, (stride, stride, stride), padding, (dilation, dilation, dilation),
+                         dtype)
+            if add_bias:
+                C = topi.add(C, bias)
+            if add_relu:
+                C = topi.nn.relu(C)
+            s = fschedule([C])
+
+        a = tvm.nd.array(a_np, ctx)
+        w = tvm.nd.array(w_np, ctx)
+        b = tvm.nd.array(b_np, ctx)
+        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
+        if add_bias:
+            func = tvm.build(
+                s, [A, W, bias, C],
+                device,
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
+                (batch, in_channel, in_size, num_filter, space_kernel, stride, padding_sum, dilation))
+            func(a, w, b, c)
+        else:
+            func = tvm.build(
+                s, [A, W, C],
+                device,
+                name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
+                (batch, in_channel, in_size, num_filter, space_kernel, stride, padding_sum, dilation))
+            func(a, w, c)
+        tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-4)
+
+    for device in ["cuda"]:
+        with autotvm.tophub.context(device):  # load tophub pre-tuned parameters
+            check_device(device)
+
+
+def test_conv3d_ncdhw():
+    # Try without depth transformation
+    #3DCNN  workloads
+    verify_conv3d_ncdhw(1, 61, 20, 120, 3, 3, 1, 0)
+    verify_conv3d_ncdhw(1, 61, 20, 120, 1, 3, 1, 0)
+    verify_conv3d_ncdhw(1, 61, 20, 120, 5, 3, 1, 0)
+    verify_conv3d_ncdhw(1, 61, 20, 120, 5, 5, 1, 2)
+    verify_conv3d_ncdhw(1, 61, 20, 120, 1, 5, 1, 2)
+    verify_conv3d_ncdhw(1, 61, 20, 120, 7, 7, 1, 3)
+    verify_conv3d_ncdhw(1, 128, 12, 256, 3, 3, 1, 1)
+    verify_conv3d_ncdhw(1, 64, 12, 128, 3, 3, 1, 1)
+
+    # bias, relu
+    verify_conv3d_ncdhw(1, 64, 12, 128, 3, 3, 1, 1, add_relu=True)
+    verify_conv3d_ncdhw(1, 64, 12, 128, 3, 3, 1, 1, add_relu=True, add_bias=True)
+    verify_conv3d_ncdhw(1, 64, 12, 128, 1, 3, 1, 1, add_relu=True, add_bias=True)
+
+    # dilation = 2
+    verify_conv3d_ncdhw(1, 16, 12, 16, 3, 3, 1, "VALID", dilation=2)
+    verify_conv3d_ncdhw(1, 16, 12, 16, 1, 3, 1, "VALID", dilation=2)
+
+    # batch size
+    verify_conv3d_ncdhw(4, 32, 12, 64, 3, 3, 1, 1)
+    verify_conv3d_ncdhw(4, 32, 12, 64, 1, 3, 1, 1)
+
+    # weird workloads
+    verify_conv3d_ncdhw(2, 2, 2, 2, 3, 3, 1, 2)
+    verify_conv3d_ncdhw(3, 3, 3, 3, 3, 3, 1, 3)
+
+
+if __name__ == "__main__":
+    test_conv3d_ncdhw()