From b479769a4ae03ea80977cb52130569655ffb855d Mon Sep 17 00:00:00 2001 From: =?utf8?q?=D0=A1=D0=B5=D1=80=D0=B3=D0=B5=D0=B9=20=D0=91=D0=B0=D1=80?= =?utf8?q?=D0=B0=D0=BD=D0=BD=D0=B8=D0=BA=D0=BE=D0=B2/AI=20Tools=20Lab=20/S?= =?utf8?q?RR/Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Thu, 19 Sep 2019 18:47:42 +0300 Subject: [PATCH] [mir_onnx] Support ConvTranspose operator (#7610) Add support for `ConvTranspose` operator. Signed-off-by: Sergei Barannikov --- compiler/mir-onnx-importer/CMakeLists.txt | 2 + compiler/mir-onnx-importer/ONNXOpRegistration.h | 2 + compiler/mir-onnx-importer/Op/Conv.cpp | 13 ++- compiler/mir-onnx-importer/Op/ConvTranspose.cpp | 148 ++++++++++++++++++++++++ compiler/mir-onnx-importer/Op/ConvTranspose.h | 36 ++++++ 5 files changed, 195 insertions(+), 6 deletions(-) create mode 100644 compiler/mir-onnx-importer/Op/ConvTranspose.cpp create mode 100644 compiler/mir-onnx-importer/Op/ConvTranspose.h diff --git a/compiler/mir-onnx-importer/CMakeLists.txt b/compiler/mir-onnx-importer/CMakeLists.txt index fc03306..7af14a0 100644 --- a/compiler/mir-onnx-importer/CMakeLists.txt +++ b/compiler/mir-onnx-importer/CMakeLists.txt @@ -41,6 +41,8 @@ set(MIR_ONNX_IMPORTER_SOURCES Op/Constant.h Op/Conv.cpp Op/Conv.h + Op/ConvTranspose.cpp + Op/ConvTranspose.h Op/Dropout.cpp Op/Dropout.h Op/Flatten.cpp diff --git a/compiler/mir-onnx-importer/ONNXOpRegistration.h b/compiler/mir-onnx-importer/ONNXOpRegistration.h index 9e29511..2014fbc 100644 --- a/compiler/mir-onnx-importer/ONNXOpRegistration.h +++ b/compiler/mir-onnx-importer/ONNXOpRegistration.h @@ -25,6 +25,7 @@ #include "Op/Concat.h" #include "Op/Constant.h" #include "Op/Conv.h" +#include "Op/ConvTranspose.h" #include "Op/Dropout.h" #include "Op/Flatten.h" #include "Op/Gather.h" @@ -64,6 +65,7 @@ inline void registerSupportedOps() registry.registerConverter("Concat", stdex::make_unique()); registry.registerConverter("Constant", stdex::make_unique()); registry.registerConverter("Conv", stdex::make_unique()); + registry.registerConverter("ConvTranspose", stdex::make_unique()); registry.registerConverter("Dropout", stdex::make_unique()); registry.registerConverter("Flatten", stdex::make_unique()); registry.registerConverter("Gather", stdex::make_unique()); diff --git a/compiler/mir-onnx-importer/Op/Conv.cpp b/compiler/mir-onnx-importer/Op/Conv.cpp index 80c2920..06c31c3 100644 --- a/compiler/mir-onnx-importer/Op/Conv.cpp +++ b/compiler/mir-onnx-importer/Op/Conv.cpp @@ -92,12 +92,12 @@ void ConvNodeConverter::convertV1(const onnx::NodeProto &onnx_node, ConverterCon padding_after); } - // FIXME: It can be non-constant value. - auto *in_weights = dynamic_cast(kernel->getNode()); - assert(in_weights && "Weights could be a constant tensor only"); - const auto &in_weights_tensor = in_weights->getValue(); - // We should transpose ONNX MC(IO)HW to HWOI - auto kernel_tensor = mir::transposeTensor<2, 3, 1, 0>(in_weights_tensor); + // OIHW -> HWIO + // TODO Insert TransposeOp when mir2loco can handle it (i.e. when loco supports it). + auto kernel_op = dynamic_cast(kernel->getNode()); + if (kernel_op == nullptr) + throw std::runtime_error("Conv: non-constant kernel is not supported yet."); + auto kernel_tensor = mir::transposeTensor<2, 3, 1, 0>(kernel_op->getValue()); auto in_group_size = kernel_tensor.getShape().dim(2); auto out_channels = kernel_tensor.getShape().dim(3); @@ -120,6 +120,7 @@ void ConvNodeConverter::convertV1(const onnx::NodeProto &onnx_node, ConverterCon // first we need to convert kernel of grouped convolution to appropriate ordinary kernel if (group != 1) kernel_tensor = fixGroupedKernel(group, kernel_tensor); + // HWIO -> OHWI kernel_tensor = mir::transposeTensor<3, 0, 1, 2>(kernel_tensor); kernel = createOp(graph, kernel_tensor)->getOutput(0); result = createOp(graph, input, kernel, strides, padding_before, diff --git a/compiler/mir-onnx-importer/Op/ConvTranspose.cpp b/compiler/mir-onnx-importer/Op/ConvTranspose.cpp new file mode 100644 index 0000000..4bd47cc --- /dev/null +++ b/compiler/mir-onnx-importer/Op/ConvTranspose.cpp @@ -0,0 +1,148 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConvTranspose.h" + +#include "ONNXHelpers.h" +#include "AttributeHelpers.h" +#include "ConvPoolHelpers.h" + +#include "mir/TensorUtil.h" +#include "mir/ops/AddOp.h" +#include "mir/ops/ConstantOp.h" +#include "mir/ops/Deconv2DOp.h" +#include "mir/ops/ReshapeOp.h" + +namespace mir_onnx +{ + +void ConvTransposeNodeConverter::convert(const onnx::NodeProto &onnx_node, + ConverterContext *context) const +{ + const auto opset_version = context->getOpsetVersion(onnx_node.domain()); + if (opset_version >= 1) + convertV1(onnx_node, context); + else + throw std::runtime_error("Not supported opset version on ConvTranspose operation!"); +} + +void ConvTransposeNodeConverter::convertV1(const onnx::NodeProto &onnx_node, + ConverterContext *context) const +{ + std::vector inputs = context->getNodeInputs(onnx_node); + mir::Graph *graph = context->getGraph(); + + assert(inputs.size() >= 2); + auto input = inputs[0]; + auto kernel = inputs[1]; + + const auto group = getAttributeValue(onnx_node, "group", 1); + if (group != 1) + throw std::runtime_error("ConvTranspose: attribute 'group' has unsupported value."); + + const auto &input_shape = input->getShape(); + if (input_shape.rank() != 4) + throw std::runtime_error("ConvTranspose: only 2-D input is supported."); + + constexpr int num_spatial_dims = 2; + + const auto dilations = + getAttributeValue(onnx_node, "dilations", std::vector(num_spatial_dims, 1)); + if (dilations.size() != num_spatial_dims) + throw std::runtime_error("ConvTranspose: attribute 'dilations' has incorrect size."); + if (!std::all_of(dilations.cbegin(), dilations.cend(), [](std::int32_t x) { return x == 1; })) + throw std::runtime_error("ConvTranspose: attribute 'dilations' has unsupported value."); + + const auto strides = + getAttributeValue(onnx_node, "strides", std::vector(num_spatial_dims, 1)); + if (strides.size() != num_spatial_dims) + throw std::runtime_error("ConvTranspose: attribute 'strides' has incorrect size."); + + const auto output_padding = getAttributeValue(onnx_node, "output_padding", + std::vector(num_spatial_dims, 0)); + if (output_padding.size() != num_spatial_dims) + throw std::runtime_error("ConvTranspose: attribute 'output_padding' has incorrect size."); + if (!std::all_of(output_padding.cbegin(), output_padding.cend(), + [](std::int32_t x) { return x == 0; })) + throw std::runtime_error("ConvTranspose: attribute 'output_padding' has unsupported value."); + + // Assuming kernel has IOHW format. + assert(kernel->getShape().rank() == 4); + const auto kernel_size = getAttributeValue( + onnx_node, "kernel_shape", + std::vector{kernel->getShape().dim(2), kernel->getShape().dim(3)}); + if (kernel_size.size() != num_spatial_dims) + throw std::runtime_error("ConvTranspose: attribute 'kernel_shape' has incorrect size."); + + // ONNX IOHW -> MIR HWOI + // TODO Insert TransposeOp when mir2loco can handle it (i.e. when loco supports it). + auto kernel_op = dynamic_cast(kernel->getNode()); + if (kernel_op == nullptr) + throw std::runtime_error("ConvTranspose: non-constant kernel is not supported yet."); + auto kernel_tensor = mir::transposeTensor<2, 3, 1, 0>(kernel_op->getValue()); + kernel = createOp(graph, kernel_tensor)->getOutput(0); + + mir::Operation::Output *result; + if (const auto *output_shape_attr = findAttribute(onnx_node, "output_shape")) + { + const auto output_size = getAttributeValue>(*output_shape_attr); + if (output_size.size() != num_spatial_dims) + throw std::runtime_error("ConvTranspose: attribute 'output_shape' has incorrect size."); + const mir::Shape output_shape{input_shape.dim(0), kernel->getShape().dim(2), output_size[0], + output_size[1]}; + result = createOp(graph, input, kernel, strides, + mir::ops::PaddingType::SameUpper, output_shape, + mir::DataFormat::NCHW) + ->getOutput(0); + } + else + { + // TODO This code was not tested. + throw std::runtime_error( + "ConvTranspose: absence of attribute 'output_shape' is not supported."); + std::vector padding_before(num_spatial_dims, 0); + std::vector padding_after(num_spatial_dims, 0); + if (const auto *pads_attr = findAttribute(onnx_node, "pads")) + { + const auto pads = getAttributeValue>(*pads_attr); + if (pads.size() != num_spatial_dims * 2) + throw std::runtime_error("ConvTranspose: attribute 'pads' has incorrect size."); + padding_before.assign(pads.cbegin(), std::next(pads.cbegin(), num_spatial_dims)); + padding_after.assign(std::next(pads.cbegin(), num_spatial_dims), pads.cend()); + } + else + { + const auto auto_pad = getAttributeValue(onnx_node, "auto_pad", "NOTSET"); + inferAutoPadding(auto_pad, input_shape, dilations, strides, kernel_size, padding_before, + padding_after); + } + result = createOp(graph, input, kernel, strides, padding_before, + padding_after, mir::DataFormat::NCHW) + ->getOutput(0); + } + + if (inputs.size() > 2) + { + auto bias = inputs[2]; + bias = createOp(graph, bias, mir::Shape{1, bias->getShape().dim(0), 1, 1}) + ->getOutput(0); + result = createOp(graph, result, bias)->getOutput(0); + } + + context->setNodeOutputs(onnx_node, {result}); +} + +} // namespace mir_onnx diff --git a/compiler/mir-onnx-importer/Op/ConvTranspose.h b/compiler/mir-onnx-importer/Op/ConvTranspose.h new file mode 100644 index 0000000..25d15b7 --- /dev/null +++ b/compiler/mir-onnx-importer/Op/ConvTranspose.h @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MIR_ONNX_OP_CONV_TRANSPOSE_H +#define MIR_ONNX_OP_CONV_TRANSPOSE_H + +#include "ONNXNodeConverterRegistry.h" + +namespace mir_onnx +{ + +class ConvTransposeNodeConverter : public NodeConverter +{ +public: + void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override; + +private: + void convertV1(const onnx::NodeProto &onnx_node, ConverterContext *context) const; +}; + +} // namespace mir_onnx + +#endif // MIR_ONNX_OP_CONV_TRANSPOSE_H -- 2.7.4