From: Samuel Date: Thu, 11 Jun 2020 09:12:35 +0000 (+0530) Subject: [TOPI][RELAY][PYTORCH]Conv3d_transpose op support added (#5737) X-Git-Tag: upstream/0.7.0~587 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=e2fb5039a2472403f2f6dfd65ac6b94de6758271;p=platform%2Fupstream%2Ftvm.git [TOPI][RELAY][PYTORCH]Conv3d_transpose op support added (#5737) * [TOPI][RELAY][PYTORCH]Conv3d_transpose op support added * Test cases in topi/relay * conv3d_transpose_ncdhw_python added * Review comments fixed --- diff --git a/docs/api/python/topi.rst b/docs/api/python/topi.rst index f93f82f..960d946 100644 --- a/docs/api/python/topi.rst +++ b/docs/api/python/topi.rst @@ -214,6 +214,8 @@ topi.nn .. autofunction:: topi.nn.conv2d_hwcn .. autofunction:: topi.nn.depthwise_conv2d_nchw .. autofunction:: topi.nn.depthwise_conv2d_nhwc +.. autofunction:: topi.nn.conv3d_ncdhw +.. autofunction:: topi.nn.conv3d_transpose_ncdhw .. autofunction:: topi.nn.fifo_buffer topi.image @@ -233,6 +235,8 @@ topi.generic .. autofunction:: topi.generic.schedule_conv2d_nchw .. autofunction:: topi.generic.schedule_depthwise_conv2d_nchw +.. autofunction:: topi.generic.schedule_conv3d_ncdhw +.. autofunction:: topi.generic.schedule_conv3d_transpose_ncdhw .. autofunction:: topi.generic.schedule_reduce .. autofunction:: topi.generic.schedule_broadcast .. autofunction:: topi.generic.schedule_injective diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index 7b20b3d..b3fdf1c 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -69,6 +69,8 @@ This level enables typical convnet models. tvm.relay.nn.conv2d tvm.relay.nn.conv2d_transpose + tvm.relay.nn.conv3d + tvm.relay.nn.conv3d_transpose tvm.relay.nn.dense tvm.relay.nn.max_pool2d tvm.relay.nn.max_pool3d @@ -225,4 +227,4 @@ This level supports dialect operators. :nosignatures: tvm.relay.qnn.op.requantize - tvm.relay.qnn.op.conv2d \ No newline at end of file + tvm.relay.qnn.op.conv2d diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index dcb4cb6..abe63e5 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -348,6 +348,82 @@ struct Conv3DAttrs : public tvm::AttrsNode { } }; +/*! \brief Attributes used in transposed convolution operator */ +struct Conv3DTransposeAttrs : public tvm::AttrsNode { + IndexExpr channels; + Array kernel_size; + Array strides; + Array padding; + Array output_padding; + Array dilation; + int groups; + std::string data_layout; + std::string kernel_layout; + std::string out_layout; + DataType out_dtype; + + TVM_DECLARE_ATTRS(Conv3DTransposeAttrs, "relay.attrs.Conv3DTransposeAttrs") { + TVM_ATTR_FIELD(channels) + .set_default(NullValue()) + .describe( + "The dimensionality of the output space" + "i.e. the number of output channels in the convolution."); + TVM_ATTR_FIELD(kernel_size) + .describe("The dimensions of the convolution window.") + .set_default(NullValue >()); + TVM_ATTR_FIELD(strides) + .set_default(Array({1, 1, 1})) + .describe("The strides of the convolution."); + TVM_ATTR_FIELD(output_padding) + .set_default(Array({0, 0, 0})) + .describe( + "Zero-padding added to one side of the output." + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "three int : front, bottom, right will use same padding as back, top, left" + "six int : padding width in the order of (front, top, left, back, bottom, right)"); + TVM_ATTR_FIELD(padding) + .set_default(Array({0, 0, 0})) + .describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "three int : front, bottom, right will use same padding as back, top, left" + "six int : padding width in the order of (front, top, left, back, bottom, right)"); + TVM_ATTR_FIELD(dilation) + .set_default(Array({1, 1, 1})) + .describe("Specifies the dilation rate to use for dilated convolution."); + TVM_ATTR_FIELD(groups).set_default(1).describe( + "Controls the connections between inputs and outputs." + "At groups=1, all inputs are convolved to all outputs." + "At groups=2, the operation becomes equivalent to having two convolution" + "layers side by side, each seeing half the input channels, and producing" + "half the output channels, and both subsequently concatenated."); + TVM_ATTR_FIELD(data_layout) + .set_default("NCDHW") + .describe( + "Dimension ordering of data. Can be 'NCDHW', 'NDHWC', etc." + "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" + "dimensions respectively. Convolution is applied on the 'D', 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(kernel_layout) + .set_default("OIDHW") + .describe( + "Dimension ordering of data and weight. Can be 'OIDHW', 'OIDHW16o16i', etc." + "'O', 'I', 'D', 'H', 'W' stands for num_filter, input_channel, depth, height, and width" + "dimensions respectively."); + TVM_ATTR_FIELD(out_layout) + .set_default("") + .describe( + "Dimension ordering of output. Can be 'NCDHW', 'NDHWC', etc." + "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" + "dimensions respectively. Default to be same as input layout."); + TVM_ATTR_FIELD(out_dtype) + .set_default(NullValue()) + .describe("Output data type, set to explicit type under mixed precision setting"); + } +}; + /*! \brief Attributes used in 3d winograd convolution operators */ struct Conv3DWinogradAttrs : public tvm::AttrsNode { int tile_size; diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index e74d58e..7b96530 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -754,17 +754,24 @@ def _convolution(): if isinstance(dilation, _expr.Expr): dilation = _infer_shape(dilation) - data_layout = "NCHW" - kernel_layout = "OIHW" - conv_op = _op.nn.conv2d - if use_transpose: - assert len(kernel_size) == 2, "ConvTranspose 3D not supported" - conv_op = _op.nn.conv2d_transpose + if len(kernel_size) == 3: + conv_op = _op.nn.conv3d_transpose + else: + conv_op = _op.nn.conv2d_transpose + else: + if len(kernel_size) == 3: + conv_op = _op.nn.conv3d + else: + conv_op = _op.nn.conv2d + if len(kernel_size) == 3: - conv_op = _op.nn.conv3d data_layout = "NCDHW" kernel_layout = "OIDHW" + else: + data_layout = "NCHW" + kernel_layout = "OIHW" + conv_out = conv_op(data, weight, diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 0633451..c09b873 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -192,6 +192,31 @@ def legalize_conv2d_transpose(attrs, inputs, types): return topi.nn.conv2d_transpose_legalize(attrs, inputs, types) +# conv3d_transpose +reg.register_strategy("nn.conv3d_transpose", strategy.conv3d_transpose_strategy) +reg.register_pattern("nn.conv3d_transpose", OpPattern.OUT_ELEMWISE_FUSABLE) + +@reg.register_legalize("nn.conv3d_transpose") +def legalize_conv3d_transpose(attrs, inputs, types): + """Legalize conv3d_transpose op. + + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current Transposed convolution + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + types : list of types + List of input and output types + + Returns + ------- + result : tvm.relay.Expr + The legalized expr + """ + return topi.nn.conv3d_transpose_legalize(attrs, inputs, types) + + # conv3d reg.register_strategy("nn.conv3d", strategy.conv3d_strategy) reg.register_pattern("nn.conv3d", OpPattern.OUT_ELEMWISE_FUSABLE) diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 0f1f158..34d07dc 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -372,6 +372,76 @@ def contrib_conv3d_winograd_without_weight_transform(data, groups, channels, kernel_size, data_layout, kernel_layout, out_layout, out_dtype) +def conv3d_transpose(data, + weight, + strides=(1, 1, 1), + padding=(0, 0, 0), + dilation=(1, 1, 1), + groups=1, + channels=None, + kernel_size=None, + data_layout="NCDHW", + kernel_layout="OIDHW", + out_layout="", + output_padding=(0, 0, 0), + out_dtype=""): + r"""3D transpose convolution. + + Parameters + ---------- + data : tvm.relay.Expr + The input data to the operator. + + weight : tvm.relay.Expr + The weight expressions. + + strides : Optional[Tuple[int]] + The strides of convolution. + + padding : Optional[int, Tuple[int]] + The padding of convolution on both sides of inputs before convolution. + + dilation : Optional[int, Tuple[int]] + Specifies the dilation rate to be used for dilated convolution. + + groups : Optional[int] + Number of groups for grouped convolution. + + channels : Optional[int] + Number of output channels of this convolution. + + kernel_size : Optional[int, Tuple[int]] + The spatial of the convolution kernel. + + data_layout : Optional[str] + Layout of the input. + + kernel_layout : Optional[str] + Layout of the weight. + + out_layout : Optional[str] + Layout of the output, by default, out_layout is the same as data_layout + + out_dtype : Optional[str] + Specifies the output data type for mixed precision conv3d. + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size, kernel_size) + if isinstance(strides, int): + strides = (strides, strides, strides) + if isinstance(dilation, int): + dilation = (dilation, dilation, dilation) + padding = get_pad_tuple3d(padding) + + return _make.conv3d_transpose(data, weight, strides, padding, dilation, + groups, channels, kernel_size, data_layout, + kernel_layout, out_layout, output_padding, out_dtype) def conv2d_transpose(data, weight, diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 0686125..8a7ab48 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -348,6 +348,9 @@ class BinaryDenseAttrs(Attrs): class Conv2DTransposeAttrs(Attrs): """Attributes used in Transposed Conv2D operators""" +@tvm._ffi.register_object("relay.attrs.Conv3DTransposeAttrs") +class Conv3DTransposeAttrs(Attrs): + """Attributes used in Transposed Conv3D operators""" @tvm._ffi.register_object("relay.attrs.DilateAttrs") class DilateAttrs(Attrs): diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 59d4ec9..5ffb7f1 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -313,6 +313,24 @@ def conv2d_transpose_strategy_cuda(attrs, inputs, out_type, target): name="conv2d_transpose_nchw.cuda") return strategy + +@conv3d_transpose_strategy.register(["cuda", "gpu"]) +def conv3d_transpose_strategy_cuda(attrs, inputs, out_type, target): + """conv3d_transpose cuda strategy""" + layout = attrs.data_layout + dilation = get_const_tuple(attrs.dilation) + groups = attrs.groups + assert layout == "NCDHW", "only support ncdhw for now" + assert dilation == (1, 1, 1), "not support dilate now" + assert groups == 1, "only support groups == 1 for now" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_conv3d_transpose(topi.cuda.conv3d_transpose_ncdhw), + wrap_topi_schedule(topi.cuda.schedule_conv3d_transpose_ncdhw), + name="conv3d_transpose_ncdhw.cuda") + return strategy + + @conv3d_strategy.register(["cuda", "gpu"]) def conv3d_strategy_cuda(attrs, inputs, out_type, target): """conv3d cuda strategy""" diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index f523f66..4fa2b11 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -345,6 +345,44 @@ def conv2d_transpose_strategy(attrs, inputs, out_type, target): name="conv2d_transpose_nchw.generic") return strategy + +# conv3d_transpose +def wrap_compute_conv3d_transpose(topi_compute): + """wrap conv3d_transpose topi compute""" + def compute_conv3d_transpose(attrs, inputs, out_dtype): + """Compute definition of conv3d_transpose""" + padding = get_const_tuple(attrs.padding) + strides = get_const_tuple(attrs.strides) + out_dtype = attrs.out_dtype + out_dtype = (inputs[0].dtype if out_dtype in ("same", "") + else out_dtype) + out = topi_compute( + inputs[0], inputs[1], strides, padding, out_dtype) + output_padding = get_const_tuple(attrs.output_padding) + out = topi.nn.pad(out, + [0, 0, 0, 0, 0], + [0, 0, output_padding[0], output_padding[1], output_padding[2]]) + return [out] + return compute_conv3d_transpose + + +@override_native_generic_func("conv3d_transpose_strategy") +def conv3d_transpose_strategy(attrs, inputs, out_type, target): + """conv3d_transpose generic strategy""" + logger.warning("conv3d_transpose is not optimized for this platform.") + layout = attrs.data_layout + dilation = get_const_tuple(attrs.dilation) + groups = attrs.groups + assert layout == "NCDHW", "only support ncdhw for now" + assert dilation == (1, 1, 1), "not support dilate now" + assert groups == 1, "only support groups == 1 for now" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_conv3d_transpose(topi.nn.conv3d_transpose_ncdhw), + wrap_topi_schedule(topi.generic.schedule_conv3d_transpose_ncdhw), + name="conv3d_transpose_ncdhw.generic") + return strategy + # conv3d def wrap_compute_conv3d(topi_compute, need_layout=False): """wrap conv3d topi compute""" diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index fbc2ed2..0984e40 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -202,6 +202,24 @@ def conv2d_transpose_strategy_cpu(attrs, inputs, out_type, target): name="conv2d_transpose_nchw.x86") return strategy + +@conv3d_transpose_strategy.register("cpu") +def conv3d_transpose_strategy_cpu(attrs, inputs, out_type, target): + """conv3d_transpose x86 strategy""" + layout = attrs.data_layout + dilation = get_const_tuple(attrs.dilation) + groups = attrs.groups + assert layout == "NCDHW", "only support ncdhw for now" + assert dilation == (1, 1, 1), "not support dilate now" + assert groups == 1, "only support groups == 1 for now" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_conv3d_transpose(topi.x86.conv3d_transpose_ncdhw), + wrap_topi_schedule(topi.x86.schedule_conv3d_transpose_ncdhw), + name="conv3d_transpose_ncdhw.x86") + return strategy + + @conv3d_strategy.register("cpu") def conv3d_strategy_cpu(attrs, inputs, out_type, target): """conv3d generic strategy""" diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index 4770cd8..6c6eb1e 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -231,6 +231,51 @@ with the layer input to produce a tensor of outputs. .add_type_rel("Conv3D", Conv3DRel) .set_attr("FInferCorrectLayout", ConvInferCorrectLayout); +// relay.nn.conv3d_transpose +TVM_REGISTER_NODE_TYPE(Conv3DTransposeAttrs); + +TVM_REGISTER_GLOBAL("relay.op.nn._make.conv3d_transpose") + .set_body_typed([](Expr data, Expr weight, Array strides, Array padding, + Array dilation, int groups, IndexExpr channels, + Array kernel_size, String data_layout, String kernel_layout, + String out_layout, Array output_padding, DataType out_dtype) { + return MakeConvTranspose( + data, weight, strides, padding, dilation, groups, channels, kernel_size, data_layout, + kernel_layout, out_layout, output_padding, out_dtype, "nn.conv3d_transpose"); + }); + +RELAY_REGISTER_OP("nn.conv3d_transpose") + .describe(R"code(Transposed 3D convolution layer (sometimes called Deconvolution 3D). + +The need for transposed convolutions generally arises +from the desire to use a transformation going in the opposite direction +of a normal convolution, i.e., from something that has the shape of the +output of some convolution to something that has the shape of its input +while maintaining a connectivity pattern that is compatible with +said convolution. + +- **data**: This depends on the `layout` parameter. Input is 5D array of shape + (batch_size, in_channels, depth, height, width) if `layout` is `NCDHW`. +- **weight**: (in_channels, channels, kernel_size[0], kernel_size[1], kernel_size[2]) +- **bias**: (channels,) +- **out**: This depends on the `layout` parameter. Output is 5D array of shape + (batch_size, channels, out_depth, out_height, out_width) if `layout` is `NCDHW`. + + out_depth and out_height and out_width are calculated as:: + out_depth = (depth-1)*strides[0]-2*padding[0]+kernel_size[0]+output_padding[0] + out_height = (height-1)*strides[1]-2*padding[1]+kernel_size[1]+output_padding[1] + out_width = (width-1)*strides[2]-2*padding[2]+kernel_size[2]+output_padding[2] + +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(2) + .set_attr("FInferCorrectLayout", + ConvInferCorrectLayout) + .add_type_rel("Conv3DTranspose", Conv3DTransposeRel); + // relay.nn.conv2d_transpose TVM_REGISTER_NODE_TYPE(Conv2DTransposeAttrs); diff --git a/src/relay/op/nn/convolution.h b/src/relay/op/nn/convolution.h index 5dc649b..0c5b20a 100644 --- a/src/relay/op/nn/convolution.h +++ b/src/relay/op/nn/convolution.h @@ -690,6 +690,103 @@ bool Conv1DTransposeRel(const Array& types, int num_inputs, const Attrs& a } template +bool Conv3DTransposeRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + const auto* weight = types[1].as(); + if (data == nullptr) return false; + + static const Layout kNCDHW("NCDHW"); + static const Layout kOIDHW("OIDHW"); + + const Conv3DTransposeAttrs* param = attrs.as(); + CHECK(param != nullptr); + const Layout in_layout(param->data_layout); + const Layout kernel_layout(param->kernel_layout); + + const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCDHW); + CHECK(trans_in_layout.defined()) + << "Conv3d_transpose only support input layouts that are convertible from NCDHW." + << " But got " << in_layout; + + const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIDHW); + CHECK(trans_kernel_layout.defined()) + << "Conv3d_transpose only support kernel layouts that are convertible from OIDHW." + << " But got " << kernel_layout; + + Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); + const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCDHW); + CHECK(trans_out_layout.defined()) + << "Conv3d_transpose only support output layouts that are convertible from NCDHW." + << " But got " << out_layout; + + IndexExpr channels, dilated_ksize_d, dilated_ksize_y, dilated_ksize_x; + + auto dshape_ncdhw = trans_in_layout.ForwardShape(data->shape); + + // infer weight if the kernel_size and channels are defined + if (param->kernel_size.defined() && param->channels.defined()) { + CHECK_EQ(param->kernel_size.size(), 3); + CHECK_EQ(param->dilation.size(), 3); + + Array wshape({dshape_ncdhw[1], indexdiv(param->channels, param->groups), + param->kernel_size[0], param->kernel_size[1], param->kernel_size[2]}); + + wshape = trans_kernel_layout.BackwardShape(wshape); + dilated_ksize_d = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; + dilated_ksize_y = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; + dilated_ksize_x = 1 + (param->kernel_size[2] - 1) * param->dilation[2]; + channels = param->channels; + + // assign result to reporter + reporter->Assign(types[1], TensorType(wshape, data->dtype)); + } else { + // use weight to infer the conv shape. + if (weight == nullptr) return false; + auto wshape = trans_kernel_layout.ForwardShape(weight->shape); + if (param->kernel_size.defined()) { + CHECK_EQ(param->kernel_size.size(), 3); + // check the size + CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) && + reporter->AssertEQ(param->kernel_size[1], wshape[3]) && + reporter->AssertEQ(param->kernel_size[2], wshape[4])) + << "Conv3D: shape of weight is inconsistent with kernel_size, " + << " kernel_size=" << param->kernel_size << " wshape=" << Array(wshape); + } + if (param->channels.defined()) { + CHECK(reporter->AssertEQ(param->channels, wshape[1])) + << "Conv3D: shape of weight is inconsistent with channels, " + << " channels=" << param->channels << " wshape=" << Array(wshape); + } + CHECK(reporter->AssertEQ(indexdiv(dshape_ncdhw[1], param->groups), wshape[0])); + channels = wshape[1]; + dilated_ksize_d = 1 + (wshape[2] - 1) * param->dilation[0]; + dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1]; + dilated_ksize_y = 1 + (wshape[4] - 1) * param->dilation[2]; + } + + // dilation + Array oshape({dshape_ncdhw[0], channels, 0, 0, 0}); + IndexExpr pad_d, pad_h, pad_w; + GetPaddingDepthHeightWidth(param->padding, &pad_d, &pad_h, &pad_w); + oshape.Set(2, (param->strides[0] * (dshape_ncdhw[2] - 1) + dilated_ksize_d - pad_d + + param->output_padding[0])); + oshape.Set(3, (param->strides[1] * (dshape_ncdhw[3] - 1) + dilated_ksize_y - pad_h + + param->output_padding[1])); + oshape.Set(4, (param->strides[2] * (dshape_ncdhw[4] - 1) + dilated_ksize_x - pad_w + + param->output_padding[2])); + + DataType out_dtype = param->out_dtype; + if (out_dtype.bits() == 0) { + out_dtype = data->dtype; + } + oshape = trans_out_layout.BackwardShape(oshape); + reporter->Assign(types[2], TensorType(oshape, out_dtype)); + return true; +} + +template bool Conv2DTransposeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index e41da7e..3c7ff4f 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1186,6 +1186,33 @@ def test_conv3d(): inp) +def test_conv3d_transpose(): + for ishape in [(1, 8, 10, 5, 10), + (1, 8, 5, 8, 8), + (1, 8, 13, 7, 7)]: + inp = torch.rand(ishape) + verify_model(torch.nn.ConvTranspose3d(in_channels=8, + out_channels=33, + kernel_size=3, + stride=2).eval(), + inp), + verify_model(torch.nn.ConvTranspose3d(in_channels=8, + out_channels=20, + kernel_size=(3, 5, 2), + stride=(2, 1, 1), + padding=(0, 4, 2)).eval(), + inp), + verify_model(torch.nn.ConvTranspose3d(in_channels=8, + out_channels=20, + kernel_size=1).eval(), + inp) + verify_model(torch.nn.ConvTranspose3d(in_channels=8, + out_channels=5, + kernel_size=1, + stride=2).eval(), + inp) + + # Model tests def test_resnet18(): torch.set_grad_enabled(False) @@ -2472,6 +2499,7 @@ if __name__ == "__main__": test_forward_replication_pad3d() test_adaptive_pool3d() test_conv3d() + test_conv3d_transpose() # Model tests test_resnet18() diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 2b5e67c..d45372e 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -612,6 +612,66 @@ def test_conv3d_winograd(): padding=(0, 2, 2), channels=120, kernel_size=(1, 5, 5)) +def test_conv3d_transpose_infer_type(): + # symbolic in batch dimension + n, c, d, h, w = te.size_var("n"), 10, 224, 224, 224 + x = relay.var("x", relay.ty.TensorType((n, c, d, h, w), "float32")) + w = relay.var("w") + y = relay.nn.conv3d_transpose(x, w, + kernel_size=(3, 3, 3), + padding=(1, 1, 1), + channels=2) + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType( + (n, 2, 224, 224, 224), "float32") + + assert yy.args[1].checked_type == relay.TensorType( + (10, 2, 3, 3, 3), "float32") + + # infer by shape of w, mixed precision + n, c, d, h, w = te.size_var("n"), 10, 224, 224, 224 + x = relay.var("x", relay.TensorType((n, c, d, h, w), "int8")) + w = relay.var("w", relay.TensorType((10, 12, 3, 3, 3), "int8")) + y = relay.nn.conv3d_transpose(x, w, out_dtype="int32") + assert "out_dtype=\"int32\"" in y.astext() + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType( + (n, 12, 226, 226, 226), "int32") + + # infer shape in case of different dtypes for input and weight. + n, c, d, h, w = te.size_var("n"), 10, 224, 224, 224 + x = relay.var("x", relay.TensorType((n, c, d, h, w), "uint8")) + w = relay.var("w", relay.TensorType((10, 12, 3, 3, 3), "int8")) + y = relay.nn.conv3d_transpose(x, w, out_dtype="int32") + assert "out_dtype=\"int32\"" in y.astext() + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType( + (n, 12, 226, 226, 226), "int32") + + +def test_conv3d_transpose_ncdhw_run(): + dshape = (1, 3, 24, 24, 24) + kshape = (3, 4, 2, 2, 2) + + x = relay.var("x", shape=dshape) + w = relay.var("w") + y = relay.nn.conv3d_transpose(x, w, + channels=4, kernel_size=(2, 2, 2), strides=(1, 1, 1), + padding=(1, 1, 1)) + func = relay.Function([x, w], y) + dtype = "float32" + + data = np.random.uniform(size=dshape).astype(dtype) + kernel = np.random.uniform(size=kshape).astype(dtype) + + ref_res = topi.testing.conv3d_transpose_ncdhw_python(data, kernel, 1, 1) + + for target, ctx in ctx_list(): + intrp1 = relay.create_executor("graph", ctx=ctx, target=target) + op_res1 = intrp1.evaluate(func)(data, kernel) + tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) + + def test_conv2d_transpose_infer_type(): # symbolic in batch dimension n, c, h, w = te.size_var("n"), 10, 10, 12 @@ -1397,6 +1457,8 @@ if __name__ == "__main__": test_flatten_infer_type() test_pad_infer_type() test_pad_run() + test_conv3d_transpose_infer_type() + test_conv3d_transpose_ncdhw_run() test_conv2d_transpose_infer_type() test_conv2d_transpose_nchw_run() test_conv2d_transpose_nhwc_run() diff --git a/topi/python/topi/cuda/__init__.py b/topi/python/topi/cuda/__init__.py index 78e3680..90f4e60 100644 --- a/topi/python/topi/cuda/__init__.py +++ b/topi/python/topi/cuda/__init__.py @@ -30,6 +30,7 @@ from .depthwise_conv2d import * from .group_conv2d_nchw import * from . import conv2d_alter_op from .conv2d_transpose_nchw import * +from .conv3d_transpose_ncdhw import * from .deformable_conv2d import * from .conv3d import * from .conv3d_winograd import * diff --git a/topi/python/topi/cuda/conv3d.py b/topi/python/topi/cuda/conv3d.py index cc13aa5..f244c65 100644 --- a/topi/python/topi/cuda/conv3d.py +++ b/topi/python/topi/cuda/conv3d.py @@ -129,7 +129,7 @@ def schedule_conv3d_ndhwc(cfg, outs): The config for this template outs: Array of Tensor - The computation graph description of conv2d + The computation graph description of conv3d in the format of an array of tensors. Returns diff --git a/topi/python/topi/cuda/conv3d_transpose_ncdhw.py b/topi/python/topi/cuda/conv3d_transpose_ncdhw.py new file mode 100644 index 0000000..bcad3e4 --- /dev/null +++ b/topi/python/topi/cuda/conv3d_transpose_ncdhw.py @@ -0,0 +1,145 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Conv3d transpose template for cuda backend""" + +import tvm +from tvm import te +from tvm import autotvm +from .. import nn +from ..util import get_const_tuple, traverse_inline +from .conv3d_direct import schedule_direct_conv3d_cuda + + +@autotvm.register_topi_compute("conv3d_transpose_ncdhw.cuda") +def conv3d_transpose_ncdhw(cfg, data, kernel, stride, padding, out_dtype): + """Transposed 3D convolution ncdhw forward operator. + + Parameters + ---------- + cfg: ConfigEntity + The config for this template + Input : tvm.te.Tensor + 5-D with shape [batch, in_channel, in_depth, in_height, in_width] + Filter : tvm.te.Tensor + 5-D with shape [in_channel, num_filter, filter_depth, filter_height, filter_width] + strides : int or a list/tuple of three ints + The spatial stride along height and width + padding : int or str + Padding size, or ['VALID', 'SAME'] + out_dtype: str + The output type. This is used in mixed precision + + Returns + ------- + Output : tvm.te.Tensor + 5-D with shape [batch, out_channel, out_depth, out_height, out_width] + """ + batch, inp_channels, inp_depth, inp_height, inp_width = get_const_tuple(data.shape) + _, out_channels, kernel_depth, kernel_height, kernel_width = get_const_tuple(kernel.shape) + stride_depth, stride_height, stride_width = stride + cfg.stride = stride + pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right = nn.get_pad_tuple3d( + padding, (kernel_depth, kernel_height, kernel_width)) + + out_depth = (inp_depth - 1) * stride_depth + \ + kernel_depth - pad_front - pad_back + pad_front = kernel_depth - 1 - pad_front + pad_back = kernel_depth - 1 - pad_back + dilated_depth = stride_depth * (inp_depth - 1) + 1 + + out_width = (inp_width - 1) * stride_width + \ + kernel_width - pad_left - pad_right + pad_left = kernel_width - 1 - pad_left + pad_right = kernel_width - 1 - pad_right + dilated_width = stride_width * (inp_width - 1) + 1 + + out_height = (inp_height - 1) * stride_height + \ + kernel_height - pad_top - pad_bottom + pad_top = kernel_height - 1 - pad_top + pad_bottom = kernel_height - 1 - pad_bottom + dilated_height = stride_height * (inp_height - 1) + 1 + + # compute pad + data = te.compute( + (batch, inp_channels, + pad_front + dilated_depth + pad_back, + pad_top + dilated_height + pad_bottom, + pad_left + dilated_width + pad_right), + lambda n, c, d, y, x: tvm.tir.if_then_else( + tvm.tir.all(x >= pad_left, + x < pad_left + dilated_width, + tvm.tir.indexmod(x - pad_left, stride_width).equal(0), + y >= pad_top, + y < pad_top + dilated_height, + tvm.tir.indexmod(y - pad_top, stride_height).equal(0), + d >= pad_front, + d < pad_front + dilated_depth, + tvm.tir.indexmod(d - pad_front, stride_depth).equal(0)), + data[n, c, + tvm.tir.indexdiv(d - pad_front, stride_depth), + tvm.tir.indexdiv(y - pad_top, stride_height), + tvm.tir.indexdiv(x - pad_left, stride_width)], + tvm.tir.const(0., "float32")), + name='data_pad') + + # compute transposed conv + dc = te.reduce_axis((0, inp_channels), name='dc') + dd = te.reduce_axis((0, kernel_depth), name='dd') + dh = te.reduce_axis((0, kernel_height), name='dh') + dw = te.reduce_axis((0, kernel_width), name='dw') + data_out = te.compute( + (batch, out_channels, out_depth, out_height, out_width), + lambda b, c, d, h, w: te.sum( + data[b, dc, d + dd, h + dh, w + dw].astype(out_dtype) * + kernel[dc, + c, + kernel_depth - 1 - dd, + kernel_height - 1 - dh, + kernel_width - 1 - dw].astype(out_dtype), + axis=[dc, dd, dh, dw]), tag="conv3d_transpose_ncdhw") + + return data_out + +@autotvm.register_topi_schedule("conv3d_transpose_ncdhw.cuda") +def schedule_conv3d_transpose_ncdhw(cfg, outs): + """TOPI Schedule callback for conv3d transpose operator. + + Parameters + ---------- + cfg: ConfigEntity + The parameters for this template + + outs: Array of Tensor + The computation graph description of conv3d transpose + in the format of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for conv3d transpose. + """ + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if op.tag == 'conv3d_transpose_ncdhw': + schedule_direct_conv3d_cuda(cfg, s, op.output(0), "NCDHW", + "conv3d_transpose_ncdhw.cuda") + + traverse_inline(s, outs[0].op, _callback) + return s diff --git a/topi/python/topi/generic/nn.py b/topi/python/topi/generic/nn.py index d0c165d..767087b 100644 --- a/topi/python/topi/generic/nn.py +++ b/topi/python/topi/generic/nn.py @@ -290,6 +290,24 @@ def schedule_conv3d_ndhwc(outs): """ return _default_schedule(outs, False) + +def schedule_conv3d_transpose_ncdhw(outs): + """Schedule for conv3d_transpose_ncdhw + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of conv3d_transpose_ncdhw + in the format of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) + + def schedule_conv2d_transpose_nchw(outs): """Schedule for conv2d_transpose_nchw diff --git a/topi/python/topi/nn/__init__.py b/topi/python/topi/nn/__init__.py index 3830bd0..a035f67 100644 --- a/topi/python/topi/nn/__init__.py +++ b/topi/python/topi/nn/__init__.py @@ -32,6 +32,7 @@ from .dense import * from .mapping import * from .pooling import * from .softmax import * +from .conv3d_transpose import * from .conv2d_transpose import * from .conv1d_transpose import * from .bnn import * diff --git a/topi/python/topi/nn/conv3d_transpose.py b/topi/python/topi/nn/conv3d_transpose.py new file mode 100644 index 0000000..29b9e53 --- /dev/null +++ b/topi/python/topi/nn/conv3d_transpose.py @@ -0,0 +1,169 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-variable, unused-argument +"""Transposed 3D convolution operators (sometimes called Deconvolution).""" +import tvm +from tvm import te +from tvm import relay +from .dilate import dilate +from .pad import pad +from .util import get_pad_tuple3d +from ..util import simplify + + +def conv3d_transpose_ncdhw(Input, Filter, strides, padding, out_dtype): + """Transposed 3D convolution ncdhw forward operator. + + Parameters + ---------- + Input : tvm.te.Tensor + 5-D with shape [batch, in_channel, in_depth, in_height, in_width] + + Filter : tvm.te.Tensor + 5-D with shape [in_channel, num_filter, filter_depth, filter_height, filter_width] + + strides : int or a list/tuple of three ints + The spatial stride along depth,height and width + + padding : int or str + Padding size, or ['VALID', 'SAME'] + + out_dtype : str + The output data type. This is used for mixed precision. + + Returns + ------- + Output : tvm.te.Tensor + 5-D with shape [batch, out_channel, out_depth, out_height, out_width] + """ + return declaration_conv3d_transpose_impl(Input, Filter, strides, padding, out_dtype) + + +def conv3d_transpose_ncdhw_preprocess(data, kernel, strides, padding, out_dtype): + """Preprocess data and kernel to make the compute pattern + of conv3d_transpose the same as conv3d""" + batch, in_c, in_d, in_h, in_w = data.shape + _, out_c, filter_d, filter_h, filter_w = kernel.shape + stride_d, stride_h, stride_w = strides + # dilate data + data_dilate = dilate(data, [1, 1, stride_d, stride_h, stride_w], name='data_dilate') + # pad data + fpad_front, fpad_top, fpad_left, fpad_back, fpad_bottom, fpad_right = get_pad_tuple3d( + padding, (filter_d, filter_h, filter_w)) + bpad_front = filter_d - 1 - fpad_front + bpad_back = filter_d - 1 - fpad_back + bpad_top = filter_h - 1 - fpad_top + bpad_bottom = filter_h - 1 - fpad_bottom + bpad_left = filter_w - 1 - fpad_left + bpad_right = filter_w - 1 - fpad_right + data_pad = pad(data_dilate, \ + [0, 0, bpad_front, bpad_top, bpad_left], \ + [0, 0, bpad_back, bpad_bottom, bpad_right], \ + name='data_pad') + # transform kernel layout from IODHW to OIDHW, and rotate kernel by 180 degrees + kernel_transform = te.compute((out_c, in_c, filter_d, filter_h, filter_w), \ + lambda o, i, d, h, w: kernel[i][o][filter_d-1-d] \ + [filter_h-1-h][filter_w-1-w], \ + name='kernel_transform') + return data_pad, kernel_transform + + +def declaration_conv3d_transpose_impl(data, kernel, strides, padding, out_dtype): + """Implementation of conv3d transpose""" + data_pad, kernel_transform = \ + conv3d_transpose_ncdhw_preprocess(data, kernel, strides, padding, out_dtype) + batch, in_c, in_d, in_h, in_w = data_pad.shape + out_c, _, filter_d, filter_h, filter_w = kernel_transform.shape + stride_d, stride_h, stride_w = strides + + # convolution stage + out_c = simplify(out_c) + out_d = simplify(in_d - filter_d + 1) + out_h = simplify(in_h - filter_h + 1) + out_w = simplify(in_w - filter_w + 1) + dc = te.reduce_axis((0, in_c), name='dc') + dd = te.reduce_axis((0, filter_d), name='dd') + dh = te.reduce_axis((0, filter_h), name='dh') + dw = te.reduce_axis((0, filter_w), name='dw') + + Output = te.compute( + (batch, out_c, out_d, out_h, out_w), + lambda b, c, d, h, w: te.sum( + data_pad[b, dc, d+dd, h+dh, w+dw].astype(out_dtype) * + kernel_transform[c, dc, dd, dh, dw].astype(out_dtype), + axis=[dc, dd, dh, dw]), tag="conv3d_transpose_ncdhw") + + return Output + + +@tvm.target.generic_func +def conv3d_transpose_legalize(attrs, inputs, types): + """Legalizes Transposed 3D convolution op. + + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current Transposed 3D convolution + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + types : list of types + List of input and output types + + Returns + ------- + result : tvm.relay.Expr + The legalized expr + """ + if attrs['data_layout'] == 'NDHWC': + data, kernel = inputs + kernel_layout = attrs['kernel_layout'] + # Convert Kernel layout to IODHW + # kernel_layout is different from input kernel layout - IO is swapped + if kernel_layout == 'DHWIO': + # input kernel layout is swapped to DHWOI + # output kernel layout will be IODHW + kernel = relay.transpose(kernel, axes=(4, 3, 0, 1, 2)) + elif kernel_layout == 'DHWOI': + # input kernel layout is swapped to DHWIO + # output kernel layout will be IODHW + kernel = relay.transpose(kernel, axes=(3, 4, 0, 1, 2)) + elif kernel_layout == 'IODHW': + # input kernel layout is swapped to OIDHW + # output kernel layout will be IODHW + kernel = relay.transpose(kernel, axes=(1, 0, 2, 3, 4)) + elif kernel_layout == 'OIDHW': + # input kernel layout is swapped to IODHW + # output kernel layout will be IODHW + pass + else: + # Skip legalize. Let relay.nn.conv2d_transpose to handle the case + return None + + # Set new attrs for conv3d_transpose. + new_attrs = {k: attrs[k] for k in attrs.keys()} + new_attrs['data_layout'] = 'NCDHW' + # layout of kernel should be IODHW, but kernel_layout should be swapped - OIDHW + new_attrs['kernel_layout'] = 'OIDHW' + + # Convert data to NCDHW. + data = relay.transpose(data, axes=(0, 4, 1, 2, 3)) + deconv = relay.nn.conv3d_transpose(data, kernel, **new_attrs) + # Convert back to original NDHWC layout. + out = relay.transpose(deconv, axes=(0, 2, 3, 4, 1)) + return out + + return None diff --git a/topi/python/topi/testing/__init__.py b/topi/python/topi/testing/__init__.py index e677a11..bd9825a 100644 --- a/topi/python/topi/testing/__init__.py +++ b/topi/python/topi/testing/__init__.py @@ -27,6 +27,7 @@ from .conv2d_nchw_python import conv2d_nchw_python from .conv2d_nhwc_python import conv2d_nhwc_python from .conv3d_ncdhw_python import conv3d_ncdhw_python from .conv3d_ndhwc_python import conv3d_ndhwc_python +from .conv3d_transpose_ncdhw_python import conv3d_transpose_ncdhw_python from .conv2d_transpose_python import conv2d_transpose_nchw_python, conv2d_transpose_nhwc_python from .conv1d_transpose_ncw_python import conv1d_transpose_ncw_python from .correlation_nchw_python import correlation_nchw_python diff --git a/topi/python/topi/testing/conv3d_transpose_ncdhw_python.py b/topi/python/topi/testing/conv3d_transpose_ncdhw_python.py new file mode 100644 index 0000000..8140eb7 --- /dev/null +++ b/topi/python/topi/testing/conv3d_transpose_ncdhw_python.py @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals, too-many-branches +"""Convolution 3D transpose in python""" +import numpy as np +import topi +from topi.nn.util import get_pad_tuple3d + + +def conv3d_transpose_ncdhw_python(a_np, w_np, stride, padding): + """Transposed 3d convolution operator in NCDHW layout. + + Parameters + ---------- + a_np : numpy.ndarray + 5-D with shape [batch, in_channel, in_depth, in_height, in_width] + + w_np : numpy.ndarray + 5-D with shape [in_channel, num_filter, filter_depth, filter_height, filter_width] + + stride : int or a list/tuple of two ints + Stride size, or [stride_depth, stride_height, stride_width] + + padding : int or str + Padding size + + Returns + ------- + b_np : np.ndarray + 5-D with shape [batch, out_channel, out_depth, out_height, out_width] + """ + batch, in_c, in_d, in_h, in_w = a_np.shape + _, out_c, filter_d, filter_h, filter_w = w_np.shape + if isinstance(stride, int): + stride_d = stride_h = stride_w = stride + else: + stride_d, stride_h, stride_w = stride + + # dilate stage + dilated_a_np = topi.testing.dilate_python(a_np, [1, 1, stride_d, stride_h, stride_w]) + + # padding stage + fpad_front, fpad_top, fpad_left, fpad_back, fpad_bottom, fpad_right = get_pad_tuple3d( + padding, (filter_d, filter_h, filter_w)) + + bpad_front = filter_d - 1 - fpad_front + bpad_back = filter_d - 1 - fpad_back + bpad_top = filter_h - 1 - fpad_top + bpad_bottom = filter_h - 1 - fpad_bottom + bpad_left = filter_w - 1 - fpad_left + bpad_right = filter_w - 1 - fpad_right + + padded_a_np = np.zeros((batch, + in_c, + dilated_a_np.shape[2]+bpad_front+bpad_back, + dilated_a_np.shape[3]+bpad_top+bpad_bottom, + dilated_a_np.shape[4]+bpad_left+bpad_right)) + + padded_a_np[:, :, bpad_front:dilated_a_np.shape[2]+bpad_back, + bpad_top:dilated_a_np.shape[3]+bpad_top, + bpad_left:dilated_a_np.shape[4]+bpad_left] = dilated_a_np + + + # convolution stage + out_d = (in_d - 1) * stride_d - bpad_front - bpad_back + filter_d + out_h = (in_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h + out_w = (in_w - 1) * stride_w - fpad_left - fpad_right + filter_w + + w_np = np.flip(w_np, axis=[2, 3, 4]).transpose((1, 0, 2, 3, 4)) + b_np = topi.testing.conv3d_ncdhw_python(padded_a_np, w_np, stride=(1, 1, 1), padding=(0, 0, 0)) + + return b_np diff --git a/topi/python/topi/x86/__init__.py b/topi/python/topi/x86/__init__.py index ce07c19..659668c 100644 --- a/topi/python/topi/x86/__init__.py +++ b/topi/python/topi/x86/__init__.py @@ -36,5 +36,6 @@ from .dense import * from .batch_matmul import * from .roi_align import roi_align_nchw from .conv2d_transpose import * +from .conv3d_transpose import * from .sparse import * from .conv2d_alter_op import * diff --git a/topi/python/topi/x86/conv3d_transpose.py b/topi/python/topi/x86/conv3d_transpose.py new file mode 100644 index 0000000..ad035d3 --- /dev/null +++ b/topi/python/topi/x86/conv3d_transpose.py @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-variable,unused-argument,no-member +# pylint: disable=no-value-for-parameter + +"""Conv3D Transpose schedule on x86""" +from tvm import te +from ..util import traverse_inline +from .. import nn +from .conv3d import conv3d_ncdhw, schedule_conv3d_ncdhw + +def conv3d_transpose_ncdhw(data, kernel, strides, padding, out_dtype): + data_pad, kernel_transform = \ + nn.conv3d_transpose_ncdhw_preprocess(data, kernel, strides, padding, out_dtype) + + # reuse conv3d_ncdhw implementation + return conv3d_ncdhw(data_pad, kernel_transform, (1, 1, 1), + (0, 0, 0), (1, 1, 1), out_dtype) + +def schedule_conv3d_transpose_ncdhw(outs): + """Create schedule for tensors""" + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = schedule_conv3d_ncdhw(outs) + def _callback(op): + if 'unpack_ncdhwc' in op.tag: + conv_out = op.input_tensors[0] + # retrieve data + data_vec = conv_out.op.input_tensors[0] + data_pad = data_vec.op.input_tensors[0] + data_dilate = data_pad.op.input_tensors[0] + s[data_dilate].compute_inline() + s[data_pad].compute_inline() + # retrieve kernel + kernel_vec = conv_out.op.input_tensors[1] + kernel_transform = kernel_vec.op.input_tensors[0] + s[kernel_transform].compute_inline() + + traverse_inline(s, outs[0].op, _callback) + return s diff --git a/topi/tests/python/test_topi_conv3d_transpose_ncdhw.py b/topi/tests/python/test_topi_conv3d_transpose_ncdhw.py new file mode 100644 index 0000000..8b08198 --- /dev/null +++ b/topi/tests/python/test_topi_conv3d_transpose_ncdhw.py @@ -0,0 +1,100 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test code for transposed convolution.""" +import numpy as np +import tvm +from tvm import te +import topi +import topi.testing +from tvm.contrib.pickle_memoize import memoize +from topi.util import get_const_tuple + +from common import get_all_backend + + +_conv3d_transpose_ncdhw_implement = { + "generic": (topi.nn.conv3d_transpose_ncdhw, topi.generic.schedule_conv3d_transpose_ncdhw), + "cpu": (topi.x86.conv3d_transpose_ncdhw, topi.x86.schedule_conv3d_transpose_ncdhw), + "gpu": (topi.cuda.conv3d_transpose_ncdhw, topi.cuda.schedule_conv3d_transpose_ncdhw), +} + +def verify_conv3d_transpose_ncdhw(batch, in_channel, in_size, num_filter, kernel, stride, padding): + in_depth, in_height, in_width = in_size + kernel_depth, kernel_height, kernel_width = kernel + stride_depth, stride_height, stride_width = stride + pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right = padding + + A = te.placeholder((batch, in_channel, in_depth, in_height, in_width), name='A') + W = te.placeholder((in_channel, num_filter, kernel_depth, kernel_height, kernel_width), name='W') + + a_shape = get_const_tuple(A.shape) + w_shape = get_const_tuple(W.shape) + dtype = A.dtype + + @memoize("topi.tests.test_topi_conv3d_transpose.verify_conv3d_transpose_ncdhw") + def get_ref_data(): + a_np = np.random.uniform(size=a_shape).astype(dtype) + w_np = np.random.uniform(size=w_shape).astype(dtype) + b_np = topi.testing.conv3d_transpose_ncdhw_python(a_np, w_np, stride, padding) + c_np = np.maximum(b_np, 0) + return a_np, w_np, b_np, c_np + + a_np, w_np, b_np, c_np = get_ref_data() + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + fcompute, fschedule = topi.testing.dispatch(device, _conv3d_transpose_ncdhw_implement) + B = fcompute(A, W, + [stride_depth, stride_height, stride_width], + [pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right], + A.dtype) + C = topi.nn.relu(B) + s1 = fschedule([B]) + s2 = fschedule([C]) + a = tvm.nd.array(a_np, ctx) + w = tvm.nd.array(w_np, ctx) + b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) + c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) + + func1 = tvm.build(s1, [A, W, B], device) + func2 = tvm.build(s2, [A, W, C], device) + func1(a, w, b) + func2(a, w, c) + tvm.testing.assert_allclose(b.asnumpy(), b_np, atol=1e-4, rtol=1e-4) + tvm.testing.assert_allclose(c.asnumpy(), c_np, atol=1e-4, rtol=1e-4) + for device in get_all_backend(): + check_device(device) + + +def test_conv3d_transpose_ncdhw(): + verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 1, (1, 1, 1), (1, 1, 1), (0, 0, 0, 0, 0, 0)) + verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 2, (3, 3, 3), (1, 1, 1), (0, 0, 0, 0, 0, 0)) + verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 16, (3, 3, 3), (1, 1, 1), (0, 0, 0, 0, 0, 0)) + verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 16, (3, 3, 3), (3, 3, 3), (0, 0, 0, 0, 0, 0)) + verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 16, (3, 3, 3), (1, 1, 1), (0, 0, 0, 0, 0, 0)) + verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 16, (3, 3, 3), (2, 2, 2), (1, 1, 1, 1, 1, 1)) + verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 16, (2, 2, 2), (2, 2, 2), (0, 0, 0, 0, 0, 0)) + verify_conv3d_transpose_ncdhw(1, 8, (32, 32, 32), 32, (5, 5, 5), (1, 1, 1), (0, 0, 0, 0, 0, 0)) + verify_conv3d_transpose_ncdhw(1, 8, (32, 32, 32), 64, (5, 5, 5), (2, 2, 2), (1, 1, 1, 1, 1, 1)) + +if __name__ == "__main__": + test_conv3d_transpose_ncdhw()