From d430fbb586269f28bb83f8087213d68951d4ef4d Mon Sep 17 00:00:00 2001 From: Alex Gladkov <53275205+alexgl-github@users.noreply.github.com> Date: Wed, 18 Dec 2019 09:35:22 -0800 Subject: [PATCH] Implement 1d deconvolution (#4476) --- include/tvm/relay/attrs/nn.h | 58 +++++++ python/tvm/autotvm/task/relay_integration.py | 1 + python/tvm/autotvm/task/topi_integration.py | 12 ++ python/tvm/relay/_parser.py | 1 + python/tvm/relay/frontend/mxnet.py | 20 +-- python/tvm/relay/op/nn/_nn.py | 31 ++++ python/tvm/relay/op/nn/nn.py | 66 ++++++++ src/relay/op/nn/convolution.cc | 157 +++++++++++++++++ src/relay/op/op_common.h | 12 ++ tests/python/relay/test_op_level2.py | 20 +++ topi/python/topi/cuda/__init__.py | 2 +- topi/python/topi/cuda/conv1d_transpose_ncw.py | 187 +++++++++++++++++++++ topi/python/topi/generic/nn.py | 18 ++ topi/python/topi/nn/__init__.py | 1 + topi/python/topi/nn/conv1d_transpose.py | 83 +++++++++ topi/python/topi/nn/util.py | 39 +++++ topi/python/topi/testing/__init__.py | 1 + .../topi/testing/conv1d_transpose_ncw_python.py | 71 ++++++++ .../tests/python/test_topi_conv1d_transpose_ncw.py | 87 ++++++++++ 19 files changed, 853 insertions(+), 14 deletions(-) create mode 100644 topi/python/topi/cuda/conv1d_transpose_ncw.py create mode 100644 topi/python/topi/nn/conv1d_transpose.py create mode 100644 topi/python/topi/testing/conv1d_transpose_ncw_python.py create mode 100644 topi/tests/python/test_topi_conv1d_transpose_ncw.py diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 4422fce..046e043 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -315,6 +315,64 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode { } }; +/*! \brief Attributes used in 1D transposed convolution operator */ +struct Conv1DTransposeAttrs : public tvm::AttrsNode { + IndexExpr channels; + Array kernel_size; + Array strides; + Array padding; + Array output_padding; + Array dilation; + int groups; + std::string data_layout; + std::string kernel_layout; + std::string out_layout; + DataType out_dtype; + + TVM_DECLARE_ATTRS(Conv1DTransposeAttrs, "relay.attrs.Conv1DTransposeAttrs") { + TVM_ATTR_FIELD(channels) + .set_default(NullValue()) + .describe("The dimensionality of the output space" + "i.e. the number of output channels in the convolution."); + TVM_ATTR_FIELD(kernel_size) + .describe("The dimensions of the convolution window.") + .set_default(NullValue >()); + TVM_ATTR_FIELD(strides).set_default(Array({1})) + .describe("The strides of the convolution."); + TVM_ATTR_FIELD(output_padding).set_default(Array({0})) + .describe("Zero-padding added to one side of the output."); + TVM_ATTR_FIELD(padding).set_default(Array({0})) + .describe("Symmetric or asymmetric padding." + "Single value: the input is implicitly zero-padded on both sides." + "Two values: padding[0] is used for left input padding, " + "padding[1] is used for right input padding,"); + TVM_ATTR_FIELD(dilation).set_default(Array({1})) + .describe("Specifies the dilation rate to use for dilated convolution."); + TVM_ATTR_FIELD(groups).set_default(1) + .describe("Controls the connections between inputs and outputs." + "At groups=1, all inputs are convolved to all outputs." + "At groups=2, the operation becomes equivalent to having two convolution" + "layers side by side, each seeing half the input channels, and producing" + "half the output channels, and both subsequently concatenated."); + TVM_ATTR_FIELD(data_layout).set_default("NCW") + .describe("Dimension ordering of data. Can be 'NCW', 'NWC', etc." + "'N', 'C', 'W' stands for batch, channel, and width" + "dimensions respectively. Convolution is applied on the" + "'W' dimension."); + TVM_ATTR_FIELD(kernel_layout).set_default("OIW") + .describe("Dimension ordering of data and weight. Can be 'OIW', 'OIW16o16i', etc." + "'O', 'I', 'W' stands for num_filter, input_channel, and width" + "dimensions respectively."); + TVM_ATTR_FIELD(out_layout).set_default("") + .describe("Dimension ordering of output. Can be 'NCW', 'NWC', etc." + "'N', 'C', 'W' stands for batch, channel, and width" + "dimensions respectively. Default to be same as input layout."); + TVM_ATTR_FIELD(out_dtype) + .set_default(NullValue()) + .describe("Output data type, set to explicit type under mixed precision setting"); + } +}; + /*! \brief Attributes for max pool operator */ struct MaxPool2DAttrs : public tvm::AttrsNode { Array pool_size; diff --git a/python/tvm/autotvm/task/relay_integration.py b/python/tvm/autotvm/task/relay_integration.py index b65c5d4..4a40771 100644 --- a/python/tvm/autotvm/task/relay_integration.py +++ b/python/tvm/autotvm/task/relay_integration.py @@ -128,6 +128,7 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None, tvm.relay.op.nn.dense: [topi.nn.dense], tvm.relay.op.nn.batch_matmul: [topi.nn.batch_matmul], tvm.relay.op.nn.deformable_conv2d: [topi.nn.deformable_conv2d_nchw], + tvm.relay.op.nn.conv1d_transpose: [topi.nn.conv1d_transpose_ncw], } topi_funcs = [] diff --git a/python/tvm/autotvm/task/topi_integration.py b/python/tvm/autotvm/task/topi_integration.py index 7bfc313..1b446e3 100644 --- a/python/tvm/autotvm/task/topi_integration.py +++ b/python/tvm/autotvm/task/topi_integration.py @@ -92,6 +92,7 @@ class TaskExtractEnv: topi.nn.bitserial_conv2d_nhwc: "topi_nn_bitserial_conv2d_nhwc", topi.nn.bitserial_dense: "topi_nn_bitserial_dense", topi.nn.deformable_conv2d_nchw: "topi_nn_deformable_conv2d_nchw", + topi.nn.conv1d_transpose_ncw: "topi_nn_conv1d_transpose_ncw", } self.topi_to_schedule = { @@ -109,6 +110,7 @@ class TaskExtractEnv: topi.nn.bitserial_conv2d_nhwc: [topi.generic.schedule_bitserial_conv2d_nhwc], topi.nn.bitserial_dense: [topi.generic.schedule_bitserial_dense], topi.nn.deformable_conv2d_nchw: [topi.generic.schedule_deformable_conv2d_nchw], + topi.nn.conv1d_transpose_ncw: [topi.generic.schedule_conv1d_transpose_ncw], } # function reflection for tracing @@ -125,6 +127,7 @@ class TaskExtractEnv: topi.nn.bitserial_conv2d_nhwc: lambda x: setattr(topi.nn, 'bitserial_conv2d_nhwc', x), topi.nn.bitserial_dense: lambda x: setattr(topi.nn, 'bitserial_dense', x), topi.nn.deformable_conv2d_nchw: lambda x: setattr(topi.nn, 'deformable_conv2d_nchw', x), + topi.nn.conv1d_transpose_ncw: lambda x: setattr(topi.nn, 'conv1d_transpose_ncw', x), } self.allow_duplicate = allow_duplicate @@ -214,6 +217,15 @@ class TaskExtractEnv: s = topi.generic.schedule_conv2d_transpose_nchw([C]) return s, [A, W, C] + @register("topi_nn_conv1d_transpose_ncw") + def _topi_nn_conv1d_transpose_ncw(*args, **kwargs): + assert not kwargs, "Do not support kwargs in template function call" + args = deserialize_args(args) + A, W = args[:2] + C = topi.nn.conv1d_transpose_ncw(*args, **kwargs) + s = topi.generic.schedule_conv1d_transpose_ncw([C]) + return s, [A, W, C] + @register("topi_nn_dense") def _topi_nn_dense(*args, **kwargs): assert not kwargs, "Do not support kwargs in template function call" diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/_parser.py index 71e5bfa..6ec581b 100644 --- a/python/tvm/relay/_parser.py +++ b/python/tvm/relay/_parser.py @@ -141,6 +141,7 @@ FUNC_OPS = { "nn.softmax": op.nn.softmax, "reshape": op.reshape, "nn.conv2d_transpose": op.nn.conv2d_transpose, + "nn.conv1d_transpose": op.nn.conv1d_transpose, "concatenate": op.concatenate, "nn.dropout": op.nn.dropout_raw, "zeros": op.zeros, diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index abef45d..a1a3578 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -207,29 +207,23 @@ def _mx_conv1d_transpose(inputs, attrs): if data_layout != "NCW": raise tvm.error.OpAttributeInvalid( 'Only "NCW" data layout is supported for 1D Convolution') - data_layout = "NCHW" channel_axis = 1 - kernel_layout = "OIHW" - + kernel_layout = "OIW" new_attrs = {} new_attrs["channels"] = attrs.get_int("num_filter") - new_attrs["kernel_size"] = (1,) + attrs.get_int_tuple("kernel") - new_attrs["strides"] = (1,) + attrs.get_int_tuple("stride", (1,)) - new_attrs["output_padding"] = (0,) + attrs.get_int_tuple("adj", (0,)) - new_attrs["padding"] = (0,) + attrs.get_int_tuple("pad", (0,)) - new_attrs["dilation"] = (1,) + attrs.get_int_tuple("dilate", (1,)) + new_attrs["kernel_size"] = attrs.get_int_tuple("kernel") + new_attrs["strides"] = attrs.get_int_tuple("stride", (1,)) + new_attrs["output_padding"] = attrs.get_int_tuple("adj", (0,)) + new_attrs["padding"] = attrs.get_int_tuple("pad", (0,)) + new_attrs["dilation"] = attrs.get_int_tuple("dilate", (1,)) new_attrs["groups"] = attrs.get_int("num_group", 1) new_attrs["data_layout"] = data_layout new_attrs["kernel_layout"] = kernel_layout use_bias = not attrs.get_bool("no_bias", True) - data = _op.expand_dims(inputs[0], axis=2) - kernel = _op.expand_dims(inputs[1], axis=2) - res = _op.nn.conv2d_transpose(data, kernel, **new_attrs) - + res = _op.nn.conv1d_transpose(inputs[0], inputs[1], **new_attrs) if use_bias: assert len(inputs) == 3 res = _op.nn.bias_add(res, inputs[2], axis=channel_axis) - res = _op.squeeze(res, axis=[2]) return res diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index cd8a131..e1372ac 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -348,6 +348,37 @@ def legalize_conv2d_transpose(attrs, inputs, types): reg.register_pattern("nn.conv2d_transpose", OpPattern.OUT_ELEMWISE_FUSABLE) +# conv1d_transpose +@reg.register_compute("nn.conv1d_transpose") +def compute_conv1d_transpose(attrs, inputs, out_dtype, target): + """Compute definition of conv1d_transpose""" + padding = get_const_tuple(attrs.padding) + strides = get_const_tuple(attrs.strides) + dilation = get_const_tuple(attrs.dilation) + groups = attrs.groups + layout = attrs.data_layout + out_dtype = attrs.out_dtype + out_dtype = (inputs[0].dtype if out_dtype in ("same", "") + else out_dtype) + assert layout == "NCW", "conv1d_transpose ncw only supported" + assert dilation == (1,), "conv1d_transpose dilation is not supported" + assert groups == 1, "conv1d_transpose groups == 1 only supported" + out = topi.nn.conv1d_transpose_ncw( + inputs[0], inputs[1], strides, padding, out_dtype) + output_padding = get_const_tuple(attrs.output_padding) + out = topi.nn.pad(out, + [0, 0, 0], [0, 0, output_padding[0]]) + return [out] + + +@reg.register_schedule("nn.conv1d_transpose") +def schedule_conv1d_transpose(attrs, outs, target): + """Schedule definition of conv1d_transpose""" + with target: + return topi.generic.schedule_conv1d_transpose_ncw(outs) + +reg.register_pattern("nn.conv1d_transpose", OpPattern.OUT_ELEMWISE_FUSABLE) + # bias_add reg.register_schedule("nn.bias_add", schedule_injective) reg.register_pattern("nn.bias_add", OpPattern.BROADCAST) diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 5e1c6a8..fda5027 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -257,6 +257,72 @@ def conv2d_transpose(data, kernel_layout, out_layout, output_padding, out_dtype) +def conv1d_transpose(data, + weight, + strides=(1,), + padding=(0,), + dilation=(1,), + groups=1, + channels=None, + kernel_size=None, + data_layout="NCW", + kernel_layout="OIW", + out_layout="", + output_padding=(0,), + out_dtype=""): + """One dimensional transposed convolution operator. + + Parameters + ---------- + data : tvm.relay.Expr + The input data to the operator. + + weight : tvm.relay.Expr + The weight expressions. + + strides : Tuple[int], optional + The strides of convolution. + + padding : Tuple[int], optional + The padding of convolution on both sides of inputs. + + dilation : Tuple[int], optional + Specifies the dilation rate to be used for dilated convolution. + + channels : int, optional + Number of output channels of this convolution. + + kernel_size : tuple of int, optional + The spatial of the convolution kernel. + + groups : int, optional + Number of groups for grouped convolution. + + data_layout : str, optional + Layout of the input. + + kernel_layout : str, optional + Layout of the weight. + + out_layout : Optional[str] + Layout of the output, by default, out_layout is the same as data_layout + + output_padding : Tuple[int], optional + Additional zero-padding to be added to one side of the output. + + out_dtype : str, optional + Specifies the output data type for mixed precision conv2d. + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + return _make.conv1d_transpose(data, weight, strides, padding, dilation, + groups, channels, kernel_size, data_layout, + kernel_layout, out_layout, output_padding, out_dtype) + + def softmax(data, axis=-1): r"""Computes softmax. diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index 3c9bebc..534c2d1 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -28,6 +28,7 @@ #include #include "../../pass/alter_op_layout.h" +#include "../op_common.h" #include "convolution.h" namespace tvm { @@ -328,6 +329,162 @@ v (batch_size, channels, out_height, out_width) if `layout` is `NCHW` .add_type_rel("Conv2DTranspose", Conv2DTransposeRel); +// relay.nn.conv1d_transpose +TVM_REGISTER_NODE_TYPE(Conv1DTransposeAttrs); + +bool Conv1DTransposeRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + const auto* weight = types[1].as(); + if (data == nullptr) return false; + + static const Layout kNCW("NCW"); + static const Layout kOIW("OIW"); + + const Conv1DTransposeAttrs* param = attrs.as(); + CHECK(param != nullptr); + const Layout in_layout(param->data_layout); + const Layout kernel_layout(param->kernel_layout); + + const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCW); + CHECK(trans_in_layout.defined()) + << "Conv only support input layouts that are convertible from NCW." + << " But got " << in_layout; + + const auto trans_kernel_layout = BijectiveLayoutNode::make(kernel_layout, kOIW); + CHECK(trans_kernel_layout.defined()) + << "Conv only support kernel layouts that are convertible from OIW." + << " But got "<< kernel_layout; + + Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); + const auto trans_out_layout = BijectiveLayoutNode::make(out_layout, kNCW); + CHECK(trans_out_layout.defined()) + << "Conv only support output layouts that are convertible from NCW." + << " But got " << out_layout; + + IndexExpr channels, dilated_ksize_y, dilated_ksize_x; + + auto dshape_ncw = trans_in_layout.ForwardShape(data->shape); + + // infer weight if the kernel_size and channels are defined + if (param->kernel_size.defined() && param->channels.defined()) { + CHECK_EQ(param->kernel_size.size(), 1); + CHECK_EQ(param->dilation.size(), 1); + + Array wshape({dshape_ncw[1], + indexdiv(param->channels, param->groups), + param->kernel_size[0]}); + + wshape = trans_kernel_layout.BackwardShape(wshape); + dilated_ksize_x = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; + channels = param->channels; + + // assign result to reporter + reporter->Assign(types[1], TensorTypeNode::make(wshape, data->dtype)); + } else { + // use weight to infer the conv shape. + if (weight == nullptr) return false; + auto wshape = trans_kernel_layout.ForwardShape(weight->shape); + if (param->kernel_size.defined()) { + CHECK_EQ(param->kernel_size.size(), 1); + // check the size + CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2])) + << "Conv1D: shape of weight is inconsistent with kernel_size, " + << " kernel_size=" << param->kernel_size + << " wshape=" << Array(wshape); + } + if (param->channels.defined()) { + CHECK(reporter->AssertEQ(param->channels, wshape[1])) + << "Conv1D: shape of weight is inconsistent with channels, " + << " channels=" << param->channels + << " wshape=" << Array(wshape); + } + CHECK(reporter->AssertEQ(indexdiv(dshape_ncw[1], param->groups), wshape[0])); + channels = wshape[1]; + dilated_ksize_x = 1 + (wshape[2] - 1) * param->dilation[0]; + } + // dilation + IndexExpr pad_w; + GetPaddingWidth(param->padding, &pad_w); + Array oshape({dshape_ncw[0], channels, 0}); + oshape.Set(2, (param->strides[0] * (dshape_ncw[2] - 1) + dilated_ksize_x - + pad_w + param->output_padding[0])); + + DataType out_dtype = param->out_dtype; + if (out_dtype.bits() == 0) { + out_dtype = data->dtype; + } + oshape = trans_out_layout.BackwardShape(oshape); + reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype)); + return true; +} + + +Expr MakeConv1DTranspose(Expr data, + Expr weight, + Array strides, + Array padding, + Array dilation, + int groups, + IndexExpr channels, + Array kernel_size, + std::string data_layout, + std::string kernel_layout, + std::string out_layout, + Array output_padding, + DataType out_dtype) { + auto attrs = make_node(); + attrs->channels = std::move(channels); + attrs->kernel_size = std::move(kernel_size); + attrs->strides = std::move(strides); + attrs->padding = std::move(padding); + attrs->output_padding = std::move(output_padding); + attrs->dilation = std::move(dilation); + attrs->groups = groups; + attrs->data_layout = std::move(data_layout); + attrs->kernel_layout = std::move(kernel_layout); + attrs->out_layout = std::move(out_layout); + attrs->out_dtype = std::move(out_dtype); + static const Op& op = Op::Get("nn.conv1d_transpose"); + return CallNode::make(op, {data, weight}, Attrs(attrs), {}); +} + + +TVM_REGISTER_API("relay.op.nn._make.conv1d_transpose") +.set_body_typed(MakeConv1DTranspose); + +RELAY_REGISTER_OP("nn.conv1d_transpose") +.describe(R"code(Transposed 1D convolution layer (sometimes called Deconvolution). + +The need for transposed convolutions generally arises +from the desire to use a transformation going in the opposite direction +of a normal convolution, i.e., from something that has the shape of the +output of some convolution to something that has the shape of its input +while maintaining a connectivity pattern that is compatible with +said convolution. + +- **data**: This depends on the `layout` parameter. Input is 3D array of shape + (batch_size, in_channels, width) if `layout` is `NCW`. +- **weight**: (in_channels, channels, kernel_size[0]) +- **bias**: (channels,) +- **out**: This depends on the `layout` parameter. Output is 3D array of shape + (batch_size, channels, out_width) if `layout` is `NCW`. + + out_width is calculated as:: + out_width = (width-1)*strides[0]-2*padding[0]+kernel_size[0]+output_padding[0] + +)code" TVM_ADD_FILELINE) +.set_attrs_type() +.set_num_inputs(2) +.add_argument("data", "Tensor", "The input tensor.") +.add_argument("weight", "Tensor", "The weight tensor.") +.set_support_level(2) +.add_type_rel("Conv1DTranspose", Conv1DTransposeRel); + + // relay.nn.contrib_conv2d_winograd_without_weight_transform TVM_REGISTER_NODE_TYPE(Conv2DWinogradAttrs); diff --git a/src/relay/op/op_common.h b/src/relay/op/op_common.h index 5cf9851..b960c75 100644 --- a/src/relay/op/op_common.h +++ b/src/relay/op/op_common.h @@ -150,6 +150,18 @@ class OpMatch { MatchFunc default_; }; +/*! \brief A utility function to get padding width from a 1 or 2 ints tuple. */ +inline void GetPaddingWidth(const Array& padding, IndexExpr* pad_w) { + if (padding.size() == 1) { + *pad_w = padding[0] * 2; + } else if (padding.size() == 2) { + *pad_w = padding[0] + padding[1]; + } else { + CHECK_EQ(padding.size(), 4) << " Expected padding size of 1 or 2, found " + << padding.size(); + } +} + } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index bb16487..9257ef2 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -413,6 +413,25 @@ def test_conv2d_transpose_nhwc_run(): c_np = topi.testing.conv2d_transpose_nhwc_python(data, kernel, 'HWOI', 2, 1) d_np = np.zeros(shape=oshape_nhwc) d_np[:,0:c_np.shape[1],0:c_np.shape[2],:] = c_np + + +def test_conv1d_transpose_ncw_run(): + dshape = (1, 3, 18) + kshape = (3, 10, 3) + oshape = (1, 10, 37) + x = relay.var("x", shape=dshape) + w = relay.var("w") + y = relay.nn.conv1d_transpose(x, w, + channels=10, kernel_size=(3,), strides=(2,), + padding=(1,), output_padding=(2,)) + func = relay.Function([x, w], y) + dtype = "float32" + data = np.random.uniform(size=dshape).astype(dtype) + kernel = np.random.uniform(size=kshape).astype(dtype) + c_np = topi.testing.conv1d_transpose_ncw_python( + data, kernel, 2, 1) + d_np = np.zeros(shape=oshape) + d_np[:,:,0:c_np.shape[2]] = c_np ref_res = d_np for target, ctx in ctx_list(): @@ -893,6 +912,7 @@ if __name__ == "__main__": test_conv2d_transpose_infer_type() test_conv2d_transpose_nchw_run() test_conv2d_transpose_nhwc_run() + test_conv1d_transpose_ncw_run() test_conv2d_run() test_conv2d_winograd() test_conv3d_run() diff --git a/topi/python/topi/cuda/__init__.py b/topi/python/topi/cuda/__init__.py index e6a342d..55255f4 100644 --- a/topi/python/topi/cuda/__init__.py +++ b/topi/python/topi/cuda/__init__.py @@ -20,7 +20,7 @@ from __future__ import absolute_import as _abs from . import conv2d, depthwise_conv2d, conv2d_transpose_nchw, deformable_conv2d, \ - group_conv2d_nchw, dense + group_conv2d_nchw, dense, conv1d_transpose_ncw from . import conv3d from .conv2d_hwcn import schedule_conv2d_hwcn from .depthwise_conv2d import schedule_depthwise_conv2d_backward_input_nhwc diff --git a/topi/python/topi/cuda/conv1d_transpose_ncw.py b/topi/python/topi/cuda/conv1d_transpose_ncw.py new file mode 100644 index 0000000..be7824e --- /dev/null +++ b/topi/python/topi/cuda/conv1d_transpose_ncw.py @@ -0,0 +1,187 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Conv1d transpose template for cuda backend""" + +import tvm +from tvm import autotvm +from .. import nn, generic +from ..util import get_const_tuple, traverse_inline + +@autotvm.task.register_topi_compute(nn.conv1d_transpose_ncw, ['cuda', 'gpu'], "direct") +def conv1d_transpose_ncw_cuda(cfg, data, kernel, stride, padding, out_dtype): + """Transposed 1D convolution ncw forward operator. + + Parameters + ---------- + cfg: ConfigEntity + The config for this template + Input : tvm.Tensor + 3-D with shape [batch, in_channel, inp_width] + Filter : tvm.Tensor + 3-D with shape [in_channel, num_filter, kernel_size] + stride : tuple of one int + The spatial stride along width + padding : int, tuple, or string + int: padding size + tuple of 2 ints: (pad_left, pad_right) for left and right padding + string: ['VALID', 'SAME'] + out_dtype: str + The output type. This is used in mixed precision + + Returns + ------- + Output : tvm.Tensor + u 3-D with shape [batch, out_channel, out_width] + """ + if isinstance(stride, (tuple, list)): + stride = stride[0] + cfg.stride = stride + batch, inp_channels, inp_width = get_const_tuple(data.shape) + _, out_channels, kernel_size = get_const_tuple(kernel.shape) + pad_left, pad_right = nn.get_pad_tuple1d(padding, kernel_size) + out_width = (inp_width - 1) * stride + kernel_size - pad_left - pad_right + pad_left = kernel_size - 1 - pad_left + pad_right = kernel_size - 1 - pad_right + dilated_width = stride * (inp_width - 1) + 1 + data = tvm.compute( + (batch, inp_channels, pad_left + dilated_width + pad_right), + lambda n, c, x: tvm.if_then_else( + tvm.all(x >= pad_left, + x < pad_left + dilated_width, + tvm.indexmod(x - pad_left, stride).equal(0)), + data[n, c, tvm.indexdiv(x - pad_left, stride)], + tvm.const(0., "float32")), + name='data_pad') + + dc = tvm.reduce_axis((0, inp_channels), name='dc') + dw = tvm.reduce_axis((0, kernel_size), name='dw') + data_out = tvm.compute( + (batch, out_channels, out_width), + lambda b, c, w: tvm.sum( + data[b, dc, w + dw].astype(out_dtype) * + kernel[dc, c, kernel_size - 1 - dw].astype(out_dtype), + axis=[dc, dw]), tag="conv1d_transpose_ncw") + + return data_out + +@autotvm.task.register_topi_schedule(generic.schedule_conv1d_transpose_ncw, + ['cuda', 'gpu'], 'direct') +def schedule_conv1d_transpose_ncw_cuda(cfg, outs): + """TOPI Schedule callback for conv1d_transpose operator. + + Parameters + ---------- + cfg: ConfigEntity + The parameters for this template + + outs: Array of Tensor + The computation graph description of conv1d transpose + in the format of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for conv1d transpose. + """ + outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs + s = tvm.create_schedule([x.op for x in outs]) + + def _callback(op): + if op.tag == 'conv1d_transpose_ncw': + pad_data = op.input_tensors[0] + kernel = op.input_tensors[1] + conv = op.output(0) + + ##### space definition begin ##### + n, f, x = s[conv].op.axis + rc = s[conv].op.reduce_axis[0] + cfg.define_split("tile_n", cfg.axis(n), num_outputs=4) + cfg.define_split("tile_f", cfg.axis(f), num_outputs=4) + cfg.define_split("tile_x", cfg.axis(x), num_outputs=4) + cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=3) + cfg.define_knob("auto_unroll_max_step", [64, 512, 1500]) + + target = tvm.target.current_target() + if target.target_name in ['nvptx', 'rocm']: + cfg.define_knob("unroll_explicit", [1]) + else: + cfg.define_knob("unroll_explicit", [0, 1]) + + ##### space definition end ##### + + if isinstance(kernel.op, tvm.tensor.ComputeOp) and 'dilate' in kernel.op.tag: + s[kernel].compute_inline() + + if conv.op in s.outputs: + output = conv + OL = s.cache_write(conv, 'local') + else: + output = s.outputs[0].output(0) + s[conv].set_scope('local') + OL = conv + + # create cache stage + s[pad_data].set_scope('shared') + AA = pad_data + WW = s.cache_read(kernel, 'shared', [OL]) + + # tile and bind spatial axes + n, f, x = s[output].op.axis + kernel_scope, n = s[output].split(n, nparts=1) + bn, vn, tn, ni = cfg["tile_n"].apply(s, output, n) + bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f) + bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x) + + s[output].reorder(bn, bf, bx, vn, vf, vx, tn, tf, tx, ni, fi, xi) + s[output].bind(bn, tvm.thread_axis("blockIdx.z")) + s[output].bind(bf, tvm.thread_axis("blockIdx.y")) + s[output].bind(bx, tvm.thread_axis("blockIdx.x")) + s[output].bind(vn, tvm.thread_axis("vthread")) + s[output].bind(vf, tvm.thread_axis("vthread")) + s[output].bind(vx, tvm.thread_axis("vthread")) + + s[output].bind(tx, tvm.thread_axis("threadIdx.x")) + s[OL].compute_at(s[output], tx) + # number of threads + n_tz = cfg["tile_n"].size[2] * cfg["tile_f"].size[2] + n_tx = cfg["tile_x"].size[2] + + # tile reduction axes + n, f, x = s[OL].op.axis + rc, rx = s[OL].op.reduce_axis + rco, rcm, rci = cfg['tile_rc'].apply(s, OL, rc) + s[OL].reorder(rco, rcm, rx, rci, n, f, x) + + s[AA].compute_at(s[OL], rx) + s[WW].compute_at(s[OL], rx) + + # cooperative fetching + for load in [AA, WW]: + n, f, x = s[load].op.axis + fused = s[load].fuse(f, x) + tz, fused = s[load].split(fused, nparts=n_tz) + tx, fused = s[load].split(fused, nparts=n_tx) + s[load].bind(tz, tvm.thread_axis("threadIdx.y")) + s[load].bind(tx, tvm.thread_axis("threadIdx.x")) + + s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val) + s[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val) + + traverse_inline(s, outs[0].op, _callback) + + return s diff --git a/topi/python/topi/generic/nn.py b/topi/python/topi/generic/nn.py index 752cb5a..953be58 100644 --- a/topi/python/topi/generic/nn.py +++ b/topi/python/topi/generic/nn.py @@ -262,6 +262,24 @@ def schedule_conv2d_transpose_nchw(outs): @tvm.target.generic_func +def schedule_conv1d_transpose_ncw(outs): + """Schedule for conv1d_transpose_ncw + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of conv2d_transpose_ncw + in the format of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) + + +@tvm.target.generic_func def schedule_depthwise_conv2d_nchw(outs): """Schedule for depthwise_conv2d_nchw diff --git a/topi/python/topi/nn/__init__.py b/topi/python/topi/nn/__init__.py index f42cde8..3aa4a4e 100644 --- a/topi/python/topi/nn/__init__.py +++ b/topi/python/topi/nn/__init__.py @@ -31,6 +31,7 @@ from .mapping import * from .pooling import * from .softmax import * from .conv2d_transpose import * +from .conv1d_transpose import * from .bnn import * from .upsampling import * from .local_response_norm import * diff --git a/topi/python/topi/nn/conv1d_transpose.py b/topi/python/topi/nn/conv1d_transpose.py new file mode 100644 index 0000000..39918e9 --- /dev/null +++ b/topi/python/topi/nn/conv1d_transpose.py @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-variable, unused-argument +"""Transposed 1D convolution operators (sometimes called Deconvolution).""" +from __future__ import absolute_import as _abs +import tvm +from .dilate import dilate +from .pad import pad +from ..util import simplify +from .util import get_pad_tuple1d + + +@tvm.target.generic_func +def conv1d_transpose_ncw(data, kernel, stride, padding, out_dtype): + """Transposed 1D convolution ncw forward operator. + + Parameters + ---------- + data : tvm.Tensor + 3-D with shape [batch, in_channel, in_width] + + kernel : tvm.Tensor + 3-D with shape [in_channel, num_filter, filter_width] + + stride : ints + The spatial stride along width + + padding : int or str + Padding size, or ['VALID', 'SAME'] + + out_dtype : str + The output data type. This is used for mixed precision. + + Returns + ------- + output : tvm.Tensor + 3-D with shape [batch, out_channel, out_width] + """ + + # dilate and pad + if isinstance(stride, (tuple, list)): + stride = stride[0] + batch, channels_in, data_width = data.shape + _, channels_out, kernel_width = kernel.shape + channels_out = simplify(channels_out) + data = dilate(data, [1, 1, stride], name='data_dilate') + pad_left, pad_right = get_pad_tuple1d(padding, (kernel_width,)) + pad_left = kernel_width - 1 - pad_left + pad_right = kernel_width - 1 - pad_right + data = pad(data, [0, 0, pad_left], [0, 0, pad_right], name='data_pad') + + # transpose kernel, switch kernel layout to IOW + kernel = tvm.compute((channels_out, channels_in, kernel_width), \ + lambda o, i, w: kernel[i][o][kernel_width-1-w],\ + name='kernel') + + # convolution + _, _, data_width = data.shape + out_w = simplify(data_width - kernel_width + 1) + dc = tvm.reduce_axis((0, channels_in), name='dc') + dw = tvm.reduce_axis((0, kernel_width), name='dw') + output = tvm.compute( + (batch, channels_out, out_w), + lambda b, c, w: tvm.sum( + data[b, dc, w+dw].astype(out_dtype) * + kernel[c, dc, dw].astype(out_dtype), + axis=[dc, dw]), tag="conv1d_transpose_ncw") + + return output diff --git a/topi/python/topi/nn/util.py b/topi/python/topi/nn/util.py index 463edaa..1ee0862 100644 --- a/topi/python/topi/nn/util.py +++ b/topi/python/topi/nn/util.py @@ -172,3 +172,42 @@ def get_pad_tuple3d(padding, kernel): pad_left = (pad_w + 1) // 2 pad_front = (pad_d + 1) // 2 return pad_front, pad_top, pad_left, pad_d - pad_front, pad_h - pad_top, pad_w - pad_left + + +def get_pad_tuple1d(padding, kernel): + """Common code to get the pad option + + Parameters + ---------- + padding : int or str + Padding size, or ['VALID', 'SAME'] + + kernel : tuple of int + Conv kernel size + + Returns + ------- + pad_left : int + Padding size on left + + pad_right : int + Padding size on right. + """ + # compute the padding size + if isinstance(padding, (tuple, list)): + if len(padding) == 1: + pad_w = padding[0] * 2 + elif len(padding) == 2: + return padding[0], padding[1] + else: + raise ValueError("Size of padding can only be 2 or 4") + elif isinstance(padding, int): + pad_w = padding * 2 + elif padding == "VALID": + pad_w = 0 + elif padding == "SAME": + pad_w = kernel[0] - 1 + else: + raise ValueError("Unknown padding option %s" % padding) + pad_left = (pad_w + 1) // 2 + return pad_left, pad_w - pad_left diff --git a/topi/python/topi/testing/__init__.py b/topi/python/topi/testing/__init__.py index 6c5ca6b..43e9f19 100644 --- a/topi/python/topi/testing/__init__.py +++ b/topi/python/topi/testing/__init__.py @@ -26,6 +26,7 @@ from .conv2d_nchw_python import conv2d_nchw_python from .conv2d_nhwc_python import conv2d_nhwc_python from .conv3d_ncdhw_python import conv3d_ncdhw_python from .conv2d_transpose_python import conv2d_transpose_nchw_python, conv2d_transpose_nhwc_python +from .conv1d_transpose_ncw_python import conv1d_transpose_ncw_python from .deformable_conv2d_nchw_python import deformable_conv2d_nchw_python from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_conv2d_python_nhwc from .dilate_python import dilate_python diff --git a/topi/python/topi/testing/conv1d_transpose_ncw_python.py b/topi/python/topi/testing/conv1d_transpose_ncw_python.py new file mode 100644 index 0000000..cb78bbf --- /dev/null +++ b/topi/python/topi/testing/conv1d_transpose_ncw_python.py @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-variable +"""Transposed 1D convolution in python""" +import numpy as np +import scipy +import topi +from topi.nn.util import get_pad_tuple1d + +def conv1d_transpose_ncw_python(a_np, w_np, stride, padding): + """Transposed 1D convolution operator in NCW layout. + + Parameters + ---------- + a_np : numpy.ndarray + 3-D with shape [batch, in_channel, in_width] + + w_np : numpy.ndarray + 3-D with shape [in_channel, num_filter, filter_width] + + stride : int or a list/tuple of one int + Stride size, or [stride_width] + + padding : int, tuple, or str + Single int for padding size, or + tuple of 2 ints for left and right padding, or + ['VALID', 'SAME'] + + Returns + ------- + b_np : np.ndarray + 3-D with shape [batch, out_channel, out_width] + """ + batch, in_c, in_w = a_np.shape + _, out_c, filter_w = w_np.shape + if isinstance(stride, int): + stride_w = stride + else: + stride_w = stride[0] + fpad_left, fpad_right = get_pad_tuple1d(padding, filter_w) + # dilate stage + dilated_a_np = topi.testing.dilate_python(a_np, [1, 1, stride_w]) + # padding stage + bpad_left = filter_w - 1 - fpad_left + bpad_right = filter_w - 1 - fpad_right + padded_a_np = np.zeros((batch, in_c, dilated_a_np.shape[2]+bpad_left+bpad_right)) + padded_a_np[:, :, bpad_left:dilated_a_np.shape[2]+bpad_left] = dilated_a_np + # convolution stage + out_w = (in_w - 1) * stride_w - fpad_left - fpad_right + filter_w + b_np = np.zeros((batch, out_c, out_w)) + for n in range(batch): + for f in range(out_c): + for c in range(in_c): + out = scipy.signal.convolve( + padded_a_np[n, c], w_np[c, f], mode='valid') + b_np[n, f] += out + return b_np diff --git a/topi/tests/python/test_topi_conv1d_transpose_ncw.py b/topi/tests/python/test_topi_conv1d_transpose_ncw.py new file mode 100644 index 0000000..9d6e9db --- /dev/null +++ b/topi/tests/python/test_topi_conv1d_transpose_ncw.py @@ -0,0 +1,87 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test code for transposed convolution.""" +import numpy as np +import itertools +import tvm +import topi +import topi.testing +from tvm.contrib.pickle_memoize import memoize +from topi.util import get_const_tuple +from common import get_all_backend + +def verify_conv1d_transpose_ncw(batch, in_channel, in_size, num_filter, kernel, stride, padding): + in_width = in_size + A = tvm.placeholder((batch, in_channel, in_width), name='A') + W = tvm.placeholder((in_channel, num_filter, kernel), name='W') + + a_shape = get_const_tuple(A.shape) + w_shape = get_const_tuple(W.shape) + dtype = A.dtype + + @memoize("topi.tests.test_topi_conv1d_transpose.verify_conv1d_transpose_ncw") + def get_ref_data(): + a_np = np.random.uniform(size=a_shape).astype(dtype) + w_np = np.random.uniform(size=w_shape).astype(dtype) + b_np = topi.testing.conv1d_transpose_ncw_python(a_np, w_np, stride, padding) + c_np = np.maximum(b_np, 0) + return a_np, w_np, b_np, c_np + + a_np, w_np, b_np, c_np = get_ref_data() + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + with tvm.target.create(device): + B = topi.nn.conv1d_transpose_ncw(A, W, stride, padding, A.dtype) + C = topi.nn.relu(B) + s1 = topi.generic.schedule_conv1d_transpose_ncw([B]) + s2 = topi.generic.schedule_conv1d_transpose_ncw([C]) + a = tvm.nd.array(a_np, ctx) + w = tvm.nd.array(w_np, ctx) + b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) + c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) + + func1 = tvm.build(s1, [A, W, B], device) + func2 = tvm.build(s2, [A, W, C], device) + func1(a, w, b) + func2(a, w, c) + tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) + tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5) + + for device in get_all_backend(): + check_device(device) + + +def test_conv1d_transpose_ncw(): + verify_conv1d_transpose_ncw(1, 3, 224, 32, 5, 1, 0) + verify_conv1d_transpose_ncw(1, 3, 224, 32, 7, 1, 2) + verify_conv1d_transpose_ncw(1, 3, 224, 32, 5, 2, 1) + verify_conv1d_transpose_ncw(1, 3, 224, 32, 5, 2, 0) + verify_conv1d_transpose_ncw(1, 32, 32, 128, 5, 1, 0) + verify_conv1d_transpose_ncw(1, 32, 32, 128, 5, 2, 1) + verify_conv1d_transpose_ncw(1, 1, 1024, 1, 512, 1, 256) + verify_conv1d_transpose_ncw(1, 1, 1024, 1, 512, 2, 256) + verify_conv1d_transpose_ncw(1, 1, 1024, 1, 512, 5, 256) + verify_conv1d_transpose_ncw(1, 1, 10, 1, 5, 1, (0,3)) + verify_conv1d_transpose_ncw(1, 1, 10, 1, 5, 1, (1,3)) + verify_conv1d_transpose_ncw(1, 1, 10, 1, 5, 1, (2,3)) + +if __name__ == "__main__": + test_conv1d_transpose_ncw() -- 2.7.4