[mir_onnx] Support ConvTranspose operator (#7610)
authorСергей Баранников/AI Tools Lab /SRR/Engineer/삼성전자 <s.barannikov@samsung.com>
Thu, 19 Sep 2019 15:47:42 +0000 (18:47 +0300)
committerAlexander Efimov/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Thu, 19 Sep 2019 15:47:42 +0000 (18:47 +0300)
Add support for `ConvTranspose` operator.

Signed-off-by: Sergei Barannikov <s.barannikov@samsung.com>
compiler/mir-onnx-importer/CMakeLists.txt
compiler/mir-onnx-importer/ONNXOpRegistration.h
compiler/mir-onnx-importer/Op/Conv.cpp
compiler/mir-onnx-importer/Op/ConvTranspose.cpp [new file with mode: 0644]
compiler/mir-onnx-importer/Op/ConvTranspose.h [new file with mode: 0644]

index fc03306..7af14a0 100644 (file)
@@ -41,6 +41,8 @@ set(MIR_ONNX_IMPORTER_SOURCES
         Op/Constant.h
         Op/Conv.cpp
         Op/Conv.h
+        Op/ConvTranspose.cpp
+        Op/ConvTranspose.h
         Op/Dropout.cpp
         Op/Dropout.h
         Op/Flatten.cpp
index 9e29511..2014fbc 100644 (file)
@@ -25,6 +25,7 @@
 #include "Op/Concat.h"
 #include "Op/Constant.h"
 #include "Op/Conv.h"
+#include "Op/ConvTranspose.h"
 #include "Op/Dropout.h"
 #include "Op/Flatten.h"
 #include "Op/Gather.h"
@@ -64,6 +65,7 @@ inline void registerSupportedOps()
   registry.registerConverter("Concat", stdex::make_unique<ConcatNodeConverter>());
   registry.registerConverter("Constant", stdex::make_unique<ConstantNodeConverter>());
   registry.registerConverter("Conv", stdex::make_unique<ConvNodeConverter>());
+  registry.registerConverter("ConvTranspose", stdex::make_unique<ConvTransposeNodeConverter>());
   registry.registerConverter("Dropout", stdex::make_unique<DropoutNodeConverter>());
   registry.registerConverter("Flatten", stdex::make_unique<FlattenNodeConverter>());
   registry.registerConverter("Gather", stdex::make_unique<GatherNodeConverter>());
index 80c2920..06c31c3 100644 (file)
@@ -92,12 +92,12 @@ void ConvNodeConverter::convertV1(const onnx::NodeProto &onnx_node, ConverterCon
                      padding_after);
   }
 
-  // FIXME: It can be non-constant value.
-  auto *in_weights = dynamic_cast<mir::ops::ConstantOp *>(kernel->getNode());
-  assert(in_weights && "Weights could be a constant tensor only");
-  const auto &in_weights_tensor = in_weights->getValue();
-  // We should transpose ONNX MC(IO)HW to HWOI
-  auto kernel_tensor = mir::transposeTensor<2, 3, 1, 0>(in_weights_tensor);
+  // OIHW -> HWIO
+  // TODO Insert TransposeOp when mir2loco can handle it (i.e. when loco supports it).
+  auto kernel_op = dynamic_cast<mir::ops::ConstantOp *>(kernel->getNode());
+  if (kernel_op == nullptr)
+    throw std::runtime_error("Conv: non-constant kernel is not supported yet.");
+  auto kernel_tensor = mir::transposeTensor<2, 3, 1, 0>(kernel_op->getValue());
   auto in_group_size = kernel_tensor.getShape().dim(2);
   auto out_channels = kernel_tensor.getShape().dim(3);
 
@@ -120,6 +120,7 @@ void ConvNodeConverter::convertV1(const onnx::NodeProto &onnx_node, ConverterCon
     // first we need to convert kernel of grouped convolution to appropriate ordinary kernel
     if (group != 1)
       kernel_tensor = fixGroupedKernel(group, kernel_tensor);
+    // HWIO -> OHWI
     kernel_tensor = mir::transposeTensor<3, 0, 1, 2>(kernel_tensor);
     kernel = createOp<mir::ops::ConstantOp>(graph, kernel_tensor)->getOutput(0);
     result = createOp<mir::ops::Conv2DOp>(graph, input, kernel, strides, padding_before,
diff --git a/compiler/mir-onnx-importer/Op/ConvTranspose.cpp b/compiler/mir-onnx-importer/Op/ConvTranspose.cpp
new file mode 100644 (file)
index 0000000..4bd47cc
--- /dev/null
@@ -0,0 +1,148 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConvTranspose.h"
+
+#include "ONNXHelpers.h"
+#include "AttributeHelpers.h"
+#include "ConvPoolHelpers.h"
+
+#include "mir/TensorUtil.h"
+#include "mir/ops/AddOp.h"
+#include "mir/ops/ConstantOp.h"
+#include "mir/ops/Deconv2DOp.h"
+#include "mir/ops/ReshapeOp.h"
+
+namespace mir_onnx
+{
+
+void ConvTransposeNodeConverter::convert(const onnx::NodeProto &onnx_node,
+                                         ConverterContext *context) const
+{
+  const auto opset_version = context->getOpsetVersion(onnx_node.domain());
+  if (opset_version >= 1)
+    convertV1(onnx_node, context);
+  else
+    throw std::runtime_error("Not supported opset version on ConvTranspose operation!");
+}
+
+void ConvTransposeNodeConverter::convertV1(const onnx::NodeProto &onnx_node,
+                                           ConverterContext *context) const
+{
+  std::vector<mir::Operation::Output *> inputs = context->getNodeInputs(onnx_node);
+  mir::Graph *graph = context->getGraph();
+
+  assert(inputs.size() >= 2);
+  auto input = inputs[0];
+  auto kernel = inputs[1];
+
+  const auto group = getAttributeValue<std::int64_t>(onnx_node, "group", 1);
+  if (group != 1)
+    throw std::runtime_error("ConvTranspose: attribute 'group' has unsupported value.");
+
+  const auto &input_shape = input->getShape();
+  if (input_shape.rank() != 4)
+    throw std::runtime_error("ConvTranspose: only 2-D input is supported.");
+
+  constexpr int num_spatial_dims = 2;
+
+  const auto dilations =
+      getAttributeValue(onnx_node, "dilations", std::vector<std::int32_t>(num_spatial_dims, 1));
+  if (dilations.size() != num_spatial_dims)
+    throw std::runtime_error("ConvTranspose: attribute 'dilations' has incorrect size.");
+  if (!std::all_of(dilations.cbegin(), dilations.cend(), [](std::int32_t x) { return x == 1; }))
+    throw std::runtime_error("ConvTranspose: attribute 'dilations' has unsupported value.");
+
+  const auto strides =
+      getAttributeValue(onnx_node, "strides", std::vector<std::int32_t>(num_spatial_dims, 1));
+  if (strides.size() != num_spatial_dims)
+    throw std::runtime_error("ConvTranspose: attribute 'strides' has incorrect size.");
+
+  const auto output_padding = getAttributeValue(onnx_node, "output_padding",
+                                                std::vector<std::int32_t>(num_spatial_dims, 0));
+  if (output_padding.size() != num_spatial_dims)
+    throw std::runtime_error("ConvTranspose: attribute 'output_padding' has incorrect size.");
+  if (!std::all_of(output_padding.cbegin(), output_padding.cend(),
+                   [](std::int32_t x) { return x == 0; }))
+    throw std::runtime_error("ConvTranspose: attribute 'output_padding' has unsupported value.");
+
+  // Assuming kernel has IOHW format.
+  assert(kernel->getShape().rank() == 4);
+  const auto kernel_size = getAttributeValue(
+      onnx_node, "kernel_shape",
+      std::vector<std::int32_t>{kernel->getShape().dim(2), kernel->getShape().dim(3)});
+  if (kernel_size.size() != num_spatial_dims)
+    throw std::runtime_error("ConvTranspose: attribute 'kernel_shape' has incorrect size.");
+
+  // ONNX IOHW -> MIR HWOI
+  // TODO Insert TransposeOp when mir2loco can handle it (i.e. when loco supports it).
+  auto kernel_op = dynamic_cast<mir::ops::ConstantOp *>(kernel->getNode());
+  if (kernel_op == nullptr)
+    throw std::runtime_error("ConvTranspose: non-constant kernel is not supported yet.");
+  auto kernel_tensor = mir::transposeTensor<2, 3, 1, 0>(kernel_op->getValue());
+  kernel = createOp<mir::ops::ConstantOp>(graph, kernel_tensor)->getOutput(0);
+
+  mir::Operation::Output *result;
+  if (const auto *output_shape_attr = findAttribute(onnx_node, "output_shape"))
+  {
+    const auto output_size = getAttributeValue<std::vector<std::int32_t>>(*output_shape_attr);
+    if (output_size.size() != num_spatial_dims)
+      throw std::runtime_error("ConvTranspose: attribute 'output_shape' has incorrect size.");
+    const mir::Shape output_shape{input_shape.dim(0), kernel->getShape().dim(2), output_size[0],
+                                  output_size[1]};
+    result = createOp<mir::ops::DeConv2DOp>(graph, input, kernel, strides,
+                                            mir::ops::PaddingType::SameUpper, output_shape,
+                                            mir::DataFormat::NCHW)
+                 ->getOutput(0);
+  }
+  else
+  {
+    // TODO This code was not tested.
+    throw std::runtime_error(
+        "ConvTranspose: absence of attribute 'output_shape' is not supported.");
+    std::vector<std::int32_t> padding_before(num_spatial_dims, 0);
+    std::vector<std::int32_t> padding_after(num_spatial_dims, 0);
+    if (const auto *pads_attr = findAttribute(onnx_node, "pads"))
+    {
+      const auto pads = getAttributeValue<std::vector<std::int32_t>>(*pads_attr);
+      if (pads.size() != num_spatial_dims * 2)
+        throw std::runtime_error("ConvTranspose: attribute 'pads' has incorrect size.");
+      padding_before.assign(pads.cbegin(), std::next(pads.cbegin(), num_spatial_dims));
+      padding_after.assign(std::next(pads.cbegin(), num_spatial_dims), pads.cend());
+    }
+    else
+    {
+      const auto auto_pad = getAttributeValue<std::string>(onnx_node, "auto_pad", "NOTSET");
+      inferAutoPadding(auto_pad, input_shape, dilations, strides, kernel_size, padding_before,
+                       padding_after);
+    }
+    result = createOp<mir::ops::DeConv2DOp>(graph, input, kernel, strides, padding_before,
+                                            padding_after, mir::DataFormat::NCHW)
+                 ->getOutput(0);
+  }
+
+  if (inputs.size() > 2)
+  {
+    auto bias = inputs[2];
+    bias = createOp<mir::ops::ReshapeOp>(graph, bias, mir::Shape{1, bias->getShape().dim(0), 1, 1})
+               ->getOutput(0);
+    result = createOp<mir::ops::AddOp>(graph, result, bias)->getOutput(0);
+  }
+
+  context->setNodeOutputs(onnx_node, {result});
+}
+
+} // namespace mir_onnx
diff --git a/compiler/mir-onnx-importer/Op/ConvTranspose.h b/compiler/mir-onnx-importer/Op/ConvTranspose.h
new file mode 100644 (file)
index 0000000..25d15b7
--- /dev/null
@@ -0,0 +1,36 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef MIR_ONNX_OP_CONV_TRANSPOSE_H
+#define MIR_ONNX_OP_CONV_TRANSPOSE_H
+
+#include "ONNXNodeConverterRegistry.h"
+
+namespace mir_onnx
+{
+
+class ConvTransposeNodeConverter : public NodeConverter
+{
+public:
+  void convert(const onnx::NodeProto &onnx_node, ConverterContext *context) const override;
+
+private:
+  void convertV1(const onnx::NodeProto &onnx_node, ConverterContext *context) const;
+};
+
+} // namespace mir_onnx
+
+#endif // MIR_ONNX_OP_CONV_TRANSPOSE_H