[nnc] Make convolution operations treat the second tensor as ordinary argument (...
authorСергей Баранников/AI Tools Lab /SRR/Engineer/삼성전자 <s.barannikov@samsung.com>
Thu, 17 Jan 2019 08:36:29 +0000 (11:36 +0300)
committerEfimov Alexander/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Thu, 17 Jan 2019 08:36:29 +0000 (11:36 +0300)
* Change the signatures of Conv2DOp, DepthwiseConv2DOp and DeConv2DOp to identically handle both input parameters.
* Refactor uses of Conv2DOp, DepthwiseConv2DOp and DeConv2DOp.
* Rename corresponding files in interpreter backend according to coding style.

Signed-off-by: Sergei Barannikov <s.barannikov@samsung.com>
34 files changed:
contrib/nnc/core/modelIR/IrDotDumper.cpp
contrib/nnc/core/modelIR/ir_dot_node_info.cpp
contrib/nnc/core/modelIR/operations/Conv2DOp.cpp
contrib/nnc/core/modelIR/operations/DeConv2DOp.cpp
contrib/nnc/core/modelIR/operations/DepthwiseConv2DOp.cpp
contrib/nnc/include/core/modelIR/ir_dot_node_info.h
contrib/nnc/include/core/modelIR/operations/Conv2DOp.h
contrib/nnc/include/core/modelIR/operations/Deconv2DOp.h
contrib/nnc/include/core/modelIR/operations/DepthwiseConv2DOp.h
contrib/nnc/passes/acl_soft_backend/AclCppOpGenerator.cpp
contrib/nnc/passes/caffe2_frontend/caffe2_op_creator.cpp
contrib/nnc/passes/caffe_frontend/caffe_op_creator.cpp
contrib/nnc/passes/interpreter/Interpreter.cpp
contrib/nnc/passes/interpreter/ops/Conv2D.cpp [moved from contrib/nnc/passes/interpreter/ops/conv_2D.cpp with 92% similarity]
contrib/nnc/passes/interpreter/ops/Conv2D.h [moved from contrib/nnc/passes/interpreter/ops/conv_2D.h with 81% similarity]
contrib/nnc/passes/interpreter/ops/DeConv2D.cpp
contrib/nnc/passes/interpreter/ops/DeConv2D.h
contrib/nnc/passes/interpreter/ops/DepthwiseConv2D.cpp [moved from contrib/nnc/passes/interpreter/ops/Depthwise_conv_2D.cpp with 86% similarity]
contrib/nnc/passes/interpreter/ops/DepthwiseConv2D.h [moved from contrib/nnc/passes/interpreter/ops/Depthwise_conv_2D.h with 72% similarity]
contrib/nnc/passes/interpreter/ops/conv_FFT.cpp
contrib/nnc/passes/interpreter/ops/conv_FFT.h
contrib/nnc/passes/onnx_frontend/ONNXImporterImpl.cpp
contrib/nnc/passes/onnx_frontend/ONNXOpCreator.cpp
contrib/nnc/passes/soft_backend/ModelAnalyzer.cpp
contrib/nnc/passes/soft_backend/SBSerializer.cpp
contrib/nnc/passes/soft_backend/code_snippets/cpp_operations.def
contrib/nnc/passes/tflite_frontend/tflite_op_creator.cpp
contrib/nnc/tests/interpreter/gen/gen_test_data.py
contrib/nnc/tests/interpreter/graph_creator.cpp
contrib/nnc/tests/interpreter/op_info_util.cpp
contrib/nnc/tests/interpreter/op_info_util.h
contrib/nnc/tests/interpreter/test_data/test_description.txt
contrib/nnc/unittests/acl_backend/MIRToDOM.cpp
contrib/nnc/unittests/soft_backend/CPPOperations.cpp

index d469c69..f319e3d 100644 (file)
@@ -100,7 +100,6 @@ void IrDotDumper::visit(ops::Conv2DOp& op) {
   auto nodeInfo = DotIrNodeInfo().withType("Conv2D", op.getName())
                                  .withInShapes(getInputShapes(op))
                                  .withOutShapes(getOutputShapes(op))
-                                 .withKernelShape(op.getKernel().getShape())
                                  .withStride(op.getStrides())
                                  .withShape("Padding before", Shape(op.getPaddingBefore()))
                                  .withShape("Padding after", Shape(op.getPaddingAfter()));
@@ -112,7 +111,6 @@ void IrDotDumper::visit(ops::DepthwiseConv2DOp& op) {
   auto nodeInfo = DotIrNodeInfo().withType("DepthwiseConv2D", op.getName())
                                  .withInShapes(getInputShapes(op))
                                  .withOutShapes(getOutputShapes(op))
-                                 .withKernelShape(op.getKernel().getShape())
                                  .withStride(op.getStrides())
                                  .withShape("Padding before", Shape(op.getPaddingBefore()))
                                  .withShape("Padding after", Shape(op.getPaddingAfter()));
@@ -228,7 +226,6 @@ void IrDotDumper::visit(ops::DeConv2DOp& op) {
   auto node_info = DotIrNodeInfo().withType("DeConv2D", op.getName())
           .withInShapes(getInputShapes(op))
           .withOutShapes(getOutputShapes(op))
-          .withKernelShape(op.getKernel().getShape())
           .withPadType(op.getPaddingType())
           .withStride(op.getStrides());
 
index f49dba1..401df5a 100644 (file)
@@ -49,12 +49,6 @@ DotIrNodeInfo &DotIrNodeInfo::withOutShapes(DotIrNodeInfo::Shapes &&outShapes)
   return *this;
 }
 
-DotIrNodeInfo &DotIrNodeInfo::withKernelShape(const Shape &kernelShape)
-{
-  this->kernelShape = kernelShape;
-  return *this;
-}
-
 DotIrNodeInfo &DotIrNodeInfo::withStride(const Shape &strideShape)
 {
   this->strideShape = strideShape;
index 36a5435..e2966a5 100644 (file)
@@ -21,8 +21,8 @@ namespace mir {
 namespace ops {
 
 void Conv2DOp::inferOutputShapes() {
-  auto& input_shape = getInputShape(0);
-  auto& kernel_shape = _kernel.getShape();
+  const auto& input_shape = getInputShape(0);
+  const auto& kernel_shape = getInputShape(1);
 
   assert(input_shape.rank() == 4);
   assert(kernel_shape.rank() == 4);
index 95df7c3..ff9dc1a 100644 (file)
@@ -21,14 +21,14 @@ namespace mir {
 namespace ops {
 
 void DeConv2DOp::inferPaddings() {
-  auto& input_shape = getInputShape(0);
-  auto& kernel_shape = _kernel.getShape();
-  auto output_shape = getOutputShape(0);
+  const auto& input_shape = getInputShape(0);
+  const auto& kernel_shape = getInputShape(1);
+  const auto& output_shape = getOutputShape(0);
 
-  // As stupid as it sounds but it seems like there is no difference in padding calculation
-  // between SAME and VALID padding types ( at least for tflite )
+  // It seems like there is no difference in padding calculation
+  // between SAME and VALID padding types (at least for tflite).
   for (int d = 0; d < 2; ++d) {
-    //See `ComputePadding` in tflite sources
+    // See `ComputePadding` in tflite sources.
     int pad = (input_shape.dim(d + 1) - 1) * _strides.dim(d)
               + kernel_shape.dim(d) - output_shape.dim(d + 1);
 
@@ -44,8 +44,8 @@ void DeConv2DOp::inferOutputShapes() {
   // Input shape: [N, Hi, Wi, Ci]
   // Kernel shape: [Hk, Wk, Co, Ci]
   // Output shape: [N, Ho, Wo, Co]
-  auto& input_shape = getInputShape(0);
-  auto& kernel_shape = _kernel.getShape();
+  const auto& input_shape = getInputShape(0);
+  const auto& kernel_shape = getInputShape(1);
 
   assert(input_shape.rank() == 4);
   assert(kernel_shape.rank() == 4);
index 324260b..81c585c 100644 (file)
@@ -21,8 +21,8 @@ namespace mir {
 namespace ops {
 
 void DepthwiseConv2DOp::inferOutputShapes() {
-  auto& input_shape = getInputShape(0);
-  auto& kernel_shape = getKernel().getShape();
+  const auto& input_shape = getInputShape(0);
+  const auto& kernel_shape = getInputShape(1);
 
   assert(input_shape.rank() == 4);
   assert(kernel_shape.rank() == 4);
@@ -43,7 +43,7 @@ void DepthwiseConv2DOp::inferOutputShapes() {
     // out_size = ceil((in_size - kernel_size + 1) / stride) =
     //   (in_size - kernel_size + 1 + stride - 1) / stride =
     //   (in_size - kernel_size) / stride + 1
-    output_shape.dim(1 + i) = (padded_input - kernel_shape.dim(i)) /_strides.dim(i) + 1;
+    output_shape.dim(1 + i) = (padded_input - kernel_shape.dim(i)) / _strides.dim(i) + 1;
   }
 
   setOutputShape(0, output_shape);
index ec0fc75..64f0e2f 100644 (file)
@@ -61,7 +61,7 @@ public:
   DotIrNodeInfo &withType(const std::string &typeName, const std::string &nodeName);
   DotIrNodeInfo &withInShapes(Shapes &&inShapes);
   DotIrNodeInfo &withOutShapes(Shapes &&outShapes);
-  DotIrNodeInfo &withKernelShape(const Shape &kernelShape);
+
   DotIrNodeInfo &withStride(const Shape &strideShape);
   DotIrNodeInfo &withShape(const std::string &shapeName, const Shape &shape);
   DotIrNodeInfo &withPadType(PadType padType);
index f17e0de..bd9a9da 100644 (file)
 #ifndef _NNC_CORE_IR_MODEL_CONV_2D_H_
 #define _NNC_CORE_IR_MODEL_CONV_2D_H_
 
-#include <vector>
-
 #include "core/modelIR/Operation.h"
-#include "core/modelIR/operations/CommonProps.h"
-#include "core/modelIR/TensorVariant.h"
+#include <vector>
 
 namespace nnc {
 namespace mir {
@@ -29,21 +26,18 @@ namespace ops {
 
 class Conv2DOp : public Operation {
 public:
-  Conv2DOp(const IODescriptor& arg,
-           const TensorVariant& kernel,
+  Conv2DOp(const IODescriptor& input,
+           const IODescriptor& kernel,
            const Shape& strides,
            const std::vector<int32_t>& padding_before,
            const std::vector<int32_t>& padding_after)
-      : Operation(Type::conv2D, {arg}),
-        _kernel(kernel),
+      : Operation(Type::conv2D, {input, kernel}),
         _strides(strides),
         _paddingBefore(padding_before),
         _paddingAfter(padding_after) {
     inferOutputShapes();
   }
 
-  const TensorVariant& getKernel() const { return _kernel; }
-
   const Shape& getStrides() const { return _strides; }
 
   const std::vector<int32_t>& getPaddingBefore() const { return _paddingBefore; }
@@ -53,7 +47,6 @@ public:
 private:
   void inferOutputShapes();
 
-  TensorVariant _kernel;
   Shape _strides;
   std::vector<int32_t> _paddingBefore;
   std::vector<int32_t> _paddingAfter;
index 3187fa9..9a403b2 100644 (file)
@@ -19,7 +19,7 @@
 
 #include "core/modelIR/Operation.h"
 #include "core/modelIR/operations/CommonProps.h"
-#include "core/modelIR/TensorVariant.h"
+#include <vector>
 
 namespace nnc {
 namespace mir {
@@ -27,12 +27,11 @@ namespace ops {
 
 class DeConv2DOp : public Operation {
 public:
-  DeConv2DOp(const IODescriptor& arg,
-             const TensorVariant& kernel,
+  DeConv2DOp(const IODescriptor& input,
+             const IODescriptor& kernel,
              const Shape& strides,
              const std::vector<int32_t>& paddings)
-      : Operation(Type::deConv2D, {arg}),
-        _kernel(kernel),
+      : Operation(Type::deConv2D, {input, kernel}),
         _strides(strides),
         _paddingType(PaddingType::Custom),
         _paddingBefore(paddings),
@@ -40,12 +39,11 @@ public:
     inferOutputShapes();
   }
 
-  DeConv2DOp(const IODescriptor& arg,
-             const TensorVariant& kernel,
+  DeConv2DOp(const IODescriptor& input,
+             const IODescriptor& kernel,
              const Shape& strides,
              PaddingType padding_type)
-      : Operation(Type::deConv2D, {arg}),
-        _kernel(kernel),
+      : Operation(Type::deConv2D, {input, kernel}),
         _strides(strides),
         _paddingType(padding_type),
         _paddingBefore(2),
@@ -54,24 +52,21 @@ public:
     inferOutputShapes();
   }
 
-  DeConv2DOp(const IODescriptor& arg,
-             const TensorVariant& kernel,
+  DeConv2DOp(const IODescriptor& input,
+             const IODescriptor& kernel,
              const Shape& strides,
              PaddingType padding_type,
              const Shape& output_shape)
-    : Operation(Type::deConv2D, {arg}),
-      _kernel(kernel),
-      _strides(strides),
-      _paddingType(padding_type),
-      _paddingBefore(2),
-      _paddingAfter(2) {
+      : Operation(Type::deConv2D, {input, kernel}),
+        _strides(strides),
+        _paddingType(padding_type),
+        _paddingBefore(2),
+        _paddingAfter(2) {
     assert(_paddingType != PaddingType::Custom);
     setOutputShape(0, output_shape);
     inferPaddings();
   }
 
-  const TensorVariant& getKernel() const { return _kernel; }
-
   const Shape& getStrides() const { return _strides; }
 
   PaddingType getPaddingType() const { return _paddingType; }
@@ -88,7 +83,6 @@ private:
    */
   void inferPaddings();
 
-  const TensorVariant _kernel;
   Shape _strides;
   PaddingType _paddingType;
   std::vector<int32_t> _paddingBefore;
index dd99400..ca71211 100644 (file)
 #ifndef _NNC_CORE_IR_MODEL_DEPTHWISE_CONV_2D_H_
 #define _NNC_CORE_IR_MODEL_DEPTHWISE_CONV_2D_H_
 
-#include <vector>
-
 #include "core/modelIR/Operation.h"
-#include "core/modelIR/TensorVariant.h"
-#include "core/modelIR/operations/CommonProps.h"
+#include <vector>
 
 namespace nnc {
 namespace mir {
@@ -29,21 +26,18 @@ namespace ops {
 
 class DepthwiseConv2DOp : public Operation {
 public:
-  DepthwiseConv2DOp(const IODescriptor& arg,
-                    const TensorVariant& kernel,
+  DepthwiseConv2DOp(const IODescriptor& input,
+                    const IODescriptor& kernel,
                     const Shape& strides,
                     const std::vector<int32_t>& padding_before,
                     const std::vector<int32_t>& padding_after)
-      : Operation(Type::depthwiseConv, {arg}),
-        _kernel(kernel),
+      : Operation(Type::depthwiseConv, {input, kernel}),
         _strides(strides),
         _paddingBefore(padding_before),
         _paddingAfter(padding_after) {
     inferOutputShapes();
   }
 
-  const TensorVariant& getKernel() const { return _kernel; }
-
   const Shape& getStrides() const { return _strides; }
 
   const std::vector<int32_t>& getPaddingBefore() const { return _paddingBefore; }
@@ -53,7 +47,6 @@ public:
 private:
   void inferOutputShapes();
 
-  TensorVariant _kernel;
   Shape _strides;
   std::vector<int32_t> _paddingBefore;
   std::vector<int32_t> _paddingAfter;
index 51a9542..3d74d66 100644 (file)
@@ -684,12 +684,16 @@ void AclCppOpGenerator::visit(ops::PadOp& op) {
 
 template <typename Op>
 void AclCppOpGenerator::genConvolution(Op& op, const string& acl_func_name, const string& suffix) {
-  auto ir_weights = transposeTensor<3, 2, 0, 1>(op.getKernel());
-  const Shape& ir_weights_shape = ir_weights.getShape();
+  const auto& prev_nodes = op.getPrevNodes();
+  assert(prev_nodes.size() == 2);
 
-  auto& prev_nodes = op.getPrevNodes();
-  assert(prev_nodes.size() == 1);
   auto in_op = prev_nodes[0].op;
+  auto ir_weights_op = dynamic_cast<ops::ConstantOp*>(prev_nodes[1].op);
+  if (ir_weights_op == nullptr)
+    throw AclCppException("Unsupported operation type");
+
+  auto ir_weights = transposeTensor<3, 2, 0, 1>(ir_weights_op->getValue());
+  const Shape& ir_weights_shape = ir_weights.getShape();
 
   // get output tensor name that is used as base for other names
   const string output_tensor_name = tensorName(&op);
index 53bc3e1..4f7c46a 100644 (file)
@@ -288,14 +288,16 @@ std::vector<IODescriptor> Caffe2OpCreator::convertConv(const std::vector<IODescr
   if (is_depthwise) {
     // TODO handle properly kernel with layer multiplier
     auto transposed_tensor = mir::transposeTensor<0, 1, 3, 2>(kernel_tensor);
-    result = createOp<ops::DepthwiseConv2DOp>("Depthwise_Conv2D", convertCaffeToMIR(inputs[0]), transposed_tensor,
-                                              stride_shape, pad_before, pad_after);
+    auto kernel = createOp<ops::ConstantOp>("Constant", transposed_tensor)->getOutput(0);
+    result = createOp<ops::DepthwiseConv2DOp>("Depthwise_Conv2D", convertCaffeToMIR(inputs[0]),
+                                              kernel, stride_shape, pad_before, pad_after);
   } else {
     // first we need to convert kernel of grouped convolution to appropriate ordinary kernel
     if (num_groups != 1)
       kernel_tensor = fixGroupedKernel(num_groups, kernel_tensor);
 
-    result = createOp<ops::Conv2DOp>("Conv2D", convertCaffeToMIR(inputs[0]), kernel_tensor,
+    auto kernel = createOp<ops::ConstantOp>("Constant", kernel_tensor)->getOutput(0);
+    result = createOp<ops::Conv2DOp>("Conv2D", convertCaffeToMIR(inputs[0]), kernel,
                                      stride_shape, pad_before, pad_after);
   }
 
index 2369b72..423c8e4 100644 (file)
@@ -254,14 +254,16 @@ CaffeOpCreator::convertConvolution(const caffe::LayerParameter& layer,
     // This is depthwise convolution
     // TODO handle properly kernel with layer multiplier
     auto transposed_tensor = transposeTensor<0, 1, 3, 2>(kernel_weights);
-    result = createOp<ops::DepthwiseConv2DOp>(layer.name(), convertCaffeToMIR(inputs[0]),
-                                              transposed_tensor, strides, padding, padding);
+    auto kernel = createOp<ops::ConstantOp>("", transposed_tensor)->getOutput(0);
+    result = createOp<ops::DepthwiseConv2DOp>(layer.name(), convertCaffeToMIR(inputs[0]), kernel,
+                                              strides, padding, padding);
   } else {
     if (num_groups != 1) {
       // first we need to convert kernel of grouped convolution to appropriate ordinary kernel
       kernel_weights = fixGroupedKernel(params.group(), kernel_weights);
     }
-    result = createOp<ops::Conv2DOp>(layer.name(), convertCaffeToMIR(inputs[0]), kernel_weights,
+    auto kernel = createOp<ops::ConstantOp>("", kernel_weights)->getOutput(0);
+    result = createOp<ops::Conv2DOp>(layer.name(), convertCaffeToMIR(inputs[0]), kernel,
                                      strides, padding, padding);
   }
 
@@ -290,8 +292,9 @@ CaffeOpCreator::convertDeconvolution(const caffe::LayerParameter& layer,
     // first we need to convert kernel of grouped convolution to appropriate ordinary kernel
     kernel_weights = fixGroupedKernel(opts.group(), kernel_weights);
   }
-  auto result = createOp<ops::DeConv2DOp>(layer.name(), convertCaffeToMIR(inputs[0]),
-                                            kernel_weights, strides, padding);
+  auto kernel = createOp<ops::ConstantOp>("", kernel_weights)->getOutput(0);
+  auto result = createOp<ops::DeConv2DOp>(layer.name(), convertCaffeToMIR(inputs[0]), kernel,
+                                          strides, padding);
 
   // bias_term is optional (so might not be present) and defaults to true
   if (opts.bias_term()) {
index 4946c10..a54674d 100644 (file)
@@ -49,9 +49,9 @@
 #include "ops/BatchNorm.h"
 #include "ops/Bias.h"
 #include "ops/Concat.h"
-#include "ops/conv_2D.h"
+#include "ops/Conv2D.h"
 #include "ops/DeConv2D.h"
-#include "ops/Depthwise_conv_2D.h"
+#include "ops/DepthwiseConv2D.h"
 #include "ops/Dropout.h"
 #include "ops/FullyConnected.h"
 #include "ops/Gather.h"
@@ -165,8 +165,11 @@ void NNInterpreter::visit(ops::ConcatOp& op) {
 }
 
 void NNInterpreter::visit(ops::Conv2DOp& op) {
-  auto operand = op.getPrevNodes()[0];
-  var(op.getId()) = Conv2D(var(operand.op->getId())[operand.index], op)();
+  auto input = op.getPrevNodes()[0];
+  auto kernel = op.getPrevNodes()[1];
+  auto input_tensor = var(input.op->getId())[input.index];
+  auto kernel_tensor = var(kernel.op->getId())[kernel.index];
+  var(op.getId()) = Conv2D(input_tensor, kernel_tensor, op)();
   DUMP(op, true);
 }
 
@@ -234,9 +237,11 @@ void NNInterpreter::visit(ops::CappedReluOp& op) {
 }
 
 void NNInterpreter::visit(ops::DepthwiseConv2DOp& op){
-  auto operand = op.getPrevNodes()[0];
-  TensorVariant input(var(operand.op->getId())[operand.index]);
-  var(op.getId()) = DepthwiseConv2D(input, op)();
+  auto input = op.getPrevNodes()[0];
+  auto kernel = op.getPrevNodes()[1];
+  auto input_tensor(var(input.op->getId())[input.index]);
+  auto kernel_tensor(var(kernel.op->getId())[kernel.index]);
+  var(op.getId()) = DepthwiseConv2D(input_tensor, kernel_tensor, op)();
   DUMP(op, true);
 }
 
@@ -340,8 +345,11 @@ void NNInterpreter::visit(ops::ElementwiseOp& op) {
 }
 
 void NNInterpreter::visit(ops::DeConv2DOp& op) {
-  auto operand = op.getPrevNodes()[0];
-  var(op.getId()) = DeConv2D(var(operand.op->getId())[operand.index], op)();
+  auto input = op.getPrevNodes()[0];
+  auto kernel = op.getPrevNodes()[1];
+  auto input_tensor = var(input.op->getId())[input.index];
+  auto kernel_tensor = var(kernel.op->getId())[kernel.index];
+  var(op.getId()) = DeConv2D(input_tensor, kernel_tensor, op)();
   DUMP(op, false);
 }
 
  * limitations under the License.
  */
 
-#include <cmath>
-
-#include "core/modelIR/ShapeRange.h"
-
-#include "conv_2D.h"
+#include "Conv2D.h"
 #include "common.h"
+#include "core/modelIR/ShapeRange.h"
+#include <cmath>
 
 namespace nnc
 {
@@ -40,7 +38,7 @@ Index reduce(const Index &idx)
 // Refer to https://www.tensorflow.org/api_docs/python/tf/nn/conv2d for info
 std::vector<TensorVariant> Conv2D::operator()()
 {
-  auto res = allocate_tensor(_out_shape);
+  auto res = allocate_tensor(_op.getOutputShape(0));
   Tensor<float> resAccesor(res);
   Shape strides{_op.getStrides().dim(0), _op.getStrides().dim(1), 1};
   Index pads{_op.getPaddingBefore().at(0), _op.getPaddingBefore().at(1), 0};
@@ -99,9 +97,10 @@ std::vector<TensorVariant> Conv2D::operator()()
   return {res};
 }
 
-Conv2D::Conv2D(const TensorVariant &input, const Conv2DOp &op)
-    : _input(input), _kernel(op.getKernel()), _strides(op.getStrides()),
-      _out_shape(op.getOutputShape(0)), _op(op) {
+Conv2D::Conv2D(const TensorVariant& input,
+               const TensorVariant& kernel,
+               const Conv2DOp& op)
+    : _input(input), _kernel(kernel), _op(op) {
   assert(_op.getInputShape(0).rank() == 4);
   assert(_input.getShape().rank() == 4);
   assert(_kernel.getShape().rank() == 4);
similarity index 81%
rename from contrib/nnc/passes/interpreter/ops/conv_2D.h
rename to contrib/nnc/passes/interpreter/ops/Conv2D.h
index 79d3015..01d0cea 100644 (file)
 #include "OperationImpl.h"
 #include "core/modelIR/operations/Conv2DOp.h"
 
-namespace nnc
-{
+namespace nnc {
 
-class Conv2D : public OperationImpl<float>
-{
+class Conv2D : public OperationImpl<float> {
 public:
-  explicit Conv2D(const mir::TensorVariant &input, const mir::ops::Conv2DOp &op);
+  Conv2D(const mir::TensorVariant& input,
+         const mir::TensorVariant& kernel,
+         const mir::ops::Conv2DOp& op);
+
   std::vector<mir::TensorVariant> operator()() override;
 
 private:
   const mir::Tensor<float> _input;
   mir::Tensor<float> _kernel;
-  const mir::Shape _strides;
-  const mir::Shape &_out_shape;
-  const mir::ops::Conv2DOp &_op;
+  const mir::ops::Conv2DOp& _op;
 };
 
 } // namespace nnc
index e19e9e1..5b0afd8 100644 (file)
@@ -29,7 +29,8 @@ using namespace mir;
 using namespace mir::ops;
 
 std::vector<nnc::mir::TensorVariant> nnc::DeConv2D::operator()() {
-  Shape out_shape = _out_shape;
+  const auto& strides = _op.getStrides();
+  Shape out_shape = _op.getOutputShape(0);
   auto res = allocate_tensor(out_shape);
   Tensor<float> res_accesor(res);
   Index pads({_op.getPaddingBefore().at(0), _op.getPaddingBefore().at(1), 0});
@@ -66,8 +67,8 @@ std::vector<nnc::mir::TensorVariant> nnc::DeConv2D::operator()() {
       bool is_from_input = true;
       for (int32_t d = 1; d < input_idx.rank() - 1; ++d) {
         const auto num = (out_idx.at(d) + pads.at(d - 1) - kernel_idx.at(d - 1));
-        const auto div_res = num / _strides.dim(d - 1);
-        const auto rem = num % _strides.dim(d - 1);
+        const auto div_res = num / strides.dim(d - 1);
+        const auto rem = num % strides.dim(d - 1);
         is_from_input = is_from_input && rem == 0;
         if (rem != 0) break;
         input_idx.at(d) = div_res;
@@ -94,9 +95,8 @@ std::vector<nnc::mir::TensorVariant> nnc::DeConv2D::operator()() {
   return {res};
 }
 
-DeConv2D::DeConv2D(const TensorVariant& input, const DeConv2DOp& op)
-    : _input(input), _kernel(op.getKernel()), _strides(op.getStrides()),
-      _padding(op.getPaddingType()), _out_shape(op.getOutputShape(0)), _op(op) {
+DeConv2D::DeConv2D(const TensorVariant& input, const TensorVariant& kernel, const DeConv2DOp& op)
+    : _input(input), _kernel(kernel), _op(op) {
   // Input shape: [N, Hi, Wi, Ci]
   // Kernel shape: [Hk, Wk, Co, Ci]
   assert(_op.getInputShape(0).rank() == 4);
index d15a25e..ea2861c 100644 (file)
@@ -20,8 +20,8 @@
 #include "OperationImpl.h"
 #include "core/modelIR/operations/Deconv2DOp.h"
 
-namespace nnc
-{
+namespace nnc {
+
 /**
  * @brief Transposed convolution (or Deconvolution)
  * @param input The Input tensor
@@ -31,20 +31,18 @@ namespace nnc
  * hence all the indexing can be deducted by expressing the input index
  * of Conv in terms of it's output index.
  */
-class DeConv2D : public OperationImpl<float>
-{
+class DeConv2D : public OperationImpl<float> {
 public:
-  explicit DeConv2D(const mir::TensorVariant &input, const mir::ops::DeConv2DOp &op);
+  DeConv2D(const mir::TensorVariant& input,
+           const mir::TensorVariant& kernel,
+           const mir::ops::DeConv2DOp& op);
 
   std::vector<mir::TensorVariant> operator()() override;
 
 private:
   const mir::Tensor<float> _input;
   const mir::TensorVariant _kernel;
-  const mir::Shape _strides;
-  const mir::ops::PaddingType _padding;
-  const mir::Shape &_out_shape;
-  const mir::ops::DeConv2DOp &_op;
+  const mir::ops::DeConv2DOp& _op;
 };
 
 } // namespace nnc
  * limitations under the License.
  */
 
-#include "core/modelIR/ShapeRange.h"
-
-#include "Depthwise_conv_2D.h"
+#include "DepthwiseConv2D.h"
 #include "common.h"
+#include "core/modelIR/ShapeRange.h"
 
 namespace nnc
 {
@@ -27,10 +26,10 @@ using namespace mir::ops;
 
 std::vector<TensorVariant> DepthwiseConv2D::operator()()
 {
-  TensorVariant res = allocate_tensor(_out_shape);
+  TensorVariant res = allocate_tensor(_op.getOutputShape(0));
   Tensor<float> resAccessor(res);
 
-  Shape strides({_strides.dim(0), _strides.dim(1), 1});
+  Shape strides({_op.getStrides().dim(0), _op.getStrides().dim(1), 1});
   Index pads({_op.getPaddingBefore().at(0), _op.getPaddingBefore().at(1), 0});
 
   Shape outShape = res.getShape();
@@ -80,11 +79,10 @@ std::vector<TensorVariant> DepthwiseConv2D::operator()()
   return {res};
 }
 
-DepthwiseConv2D::DepthwiseConv2D(const TensorVariant &input, const DepthwiseConv2DOp &op)
-    : _input(input), _kernel(op.getKernel()), _strides(op.getStrides()),
-      _out_shape(op.getOutputShape(0)), _op(op)
-{
-
+DepthwiseConv2D::DepthwiseConv2D(const TensorVariant& input,
+                                 const TensorVariant& kernel,
+                                 const DepthwiseConv2DOp& op)
+    : _input(input), _kernel(kernel), _op(op) {
   assert(_op.getInputShape(0).rank() == 4);
   assert(_input.getShape().rank() == 4);
   assert(_kernel.getShape().rank() == 4);
 #define _NNC_CORE_BACKEND_INTERPRETER_DEPTHWISE_CONV2D_IMPL_
 
 #include "OperationImpl.h"
-
-#include "core/modelIR/operations/CommonProps.h"
 #include "core/modelIR/operations/DepthwiseConv2DOp.h"
 
-namespace nnc
-{
+namespace nnc {
 
-class DepthwiseConv2D : public OperationImpl<float>
-{
+class DepthwiseConv2D : public OperationImpl<float> {
 public:
-  explicit DepthwiseConv2D(const mir::TensorVariant &input, const mir::ops::DepthwiseConv2DOp &op);
-  virtual std::vector<mir::TensorVariant> operator()() override;
+  DepthwiseConv2D(const mir::TensorVariant& input,
+                  const mir::TensorVariant& kernel,
+                  const mir::ops::DepthwiseConv2DOp& op);
+
+  std::vector<mir::TensorVariant> operator()() override;
 
 private:
   const mir::Tensor<float> _input;
   const mir::Tensor<float> _kernel;
-  const mir::Shape _strides;
-  const mir::Shape &_out_shape;
-  const mir::ops::DepthwiseConv2DOp &_op;
+  const mir::ops::DepthwiseConv2DOp& _op;
 };
 
 } // namespace nnc
index 9737346..9c30760 100644 (file)
@@ -76,8 +76,10 @@ std::vector<TensorVariant> Conv2D_FFT::operator()()
   return {res};
 }
 
-Conv2D_FFT::Conv2D_FFT(const TensorVariant &input, const Conv2DOp &op)
-    : _input(input), _kernel(op.getKernel()), _strides(op.getStrides()),
+Conv2D_FFT::Conv2D_FFT(const TensorVariant& input,
+                       const TensorVariant& kernel,
+                       const Conv2DOp& op)
+    : _input(input), _kernel(kernel), _strides(op.getStrides()),
       _out_shape(op.getOutputShape(0)), _op(op)
 {
   // Same assertions as in Conv2D
index 2ebe0c8..7aa3894 100644 (file)
@@ -53,7 +53,10 @@ typedef std::complex<float> FFT_complex;
 class Conv2D_FFT : public OperationImpl<float>
 {
 public:
-  explicit Conv2D_FFT(const mir::TensorVariant &input, const mir::ops::Conv2DOp &op);
+  Conv2D_FFT(const mir::TensorVariant& input,
+             const mir::TensorVariant& kernel,
+             const mir::ops::Conv2DOp& op);
+
   std::vector<mir::TensorVariant> operator()() override;
 
 protected:
index ba13ddb..3ddc20a 100644 (file)
@@ -215,14 +215,14 @@ void ONNXImporterImpl::dump(const std::vector<mir::IODescriptor>& input_descrs,
       case ONNXOpCode::opConv: {
         assert(dynamic_cast<mir::ops::TransposeOp*>(op) != nullptr);
         if (auto* conv = dynamic_cast<mir::ops::Conv2DOp*>(op->getPrevNodes()[0].op)) {
-          std::cout << " (Conv2D)Weights" << conv->getKernel().getShape() << " Strides" <<
+          std::cout << " (Conv2D)Weights" << conv->getInputShape(1) << " Strides" <<
                     conv->getStrides() << " Padding(" << conv->getPaddingBefore()[0] <<
                     " " << conv->getPaddingBefore()[1] << ")" << ":(" <<
                     conv->getPaddingAfter()[0] << " " << conv->getPaddingAfter()[1] << ")";
         } else {
           auto* dept = dynamic_cast<mir::ops::DepthwiseConv2DOp*>(op->getPrevNodes()[0].op);
           assert(dept);
-          std::cout << " (DepthwiseConv2D)Weights" << dept->getKernel().getShape() << " Strides" <<
+          std::cout << " (DepthwiseConv2D)Weights" << dept->getInputShape(1) << " Strides" <<
                     dept->getStrides() << " Padding(" << dept->getPaddingBefore()[0] <<
                     " " << dept->getPaddingBefore()[1] << ")" << ":(" <<
                     dept->getPaddingAfter()[0] << " " << dept->getPaddingAfter()[1] << ")";
index d118454..32b9d25 100644 (file)
@@ -175,16 +175,16 @@ ONNXOpCreator::convertConv2D(const std::vector<mir::IODescriptor>& inputs,
   if (is_depthwise) {
     // TODO handle properly kernel with layer multiplier
     auto transposed_tensor = mir::transposeTensor<0, 1, 3, 2>(kernel_tensor);
-    result = createOp<ops::DepthwiseConv2DOp>(transposed_input,
-                                              transposed_tensor, cdata.strides_shape,
+    auto kernel = createOp<ops::ConstantOp>(transposed_tensor)->getOutput(0);
+    result = createOp<ops::DepthwiseConv2DOp>(transposed_input, kernel, cdata.strides_shape,
                                               cdata.padding_before, cdata.padding_after);
   } else {
     // first we need to convert kernel of grouped convolution to appropriate ordinary kernel
     if (num_groups != 1)
       kernel_tensor = fixGroupedKernel(num_groups, kernel_tensor);
-    result = createOp<ops::Conv2DOp>(transposed_input, kernel_tensor,
-                                     cdata.strides_shape, cdata.padding_before,
-                                     cdata.padding_after);
+    auto kernel = createOp<ops::ConstantOp>(kernel_tensor)->getOutput(0);
+    result = createOp<ops::Conv2DOp>(transposed_input, kernel, cdata.strides_shape,
+                                     cdata.padding_before, cdata.padding_after);
   }
 
   if (inputs.size() > 2)
index ae35ab4..6eef9ad 100644 (file)
@@ -204,7 +204,7 @@ void ModelAnalyzer::visit(ops::ConcatOp& op) {
 }
 
 void ModelAnalyzer::visit(ops::Conv2DOp& op) {
-  const auto& kernel_shape = op.getKernel().getShape();
+  const auto& kernel_shape = op.getInputShape(1);
   const auto& out_shape = op.getOutputShape(0);
   const int32_t tmp_size = kernel_shape.dim(0) * kernel_shape.dim(1) * kernel_shape.dim(2)
                            * out_shape.dim(0) * out_shape.dim(1) * out_shape.dim(2);
@@ -332,7 +332,7 @@ void ModelAnalyzer::visit(mir::ops::EluOp& op) {
 }
 
 void ModelAnalyzer::visit(mir::ops::DeConv2DOp& op) {
-  const auto& kernel_shape = op.getKernel().getShape();
+  const auto& kernel_shape = op.getInputShape(1);
   const auto& out_shape = op.getOutputShape(0);
   const int32_t tmp_size = kernel_shape.dim(0) * kernel_shape.dim(1) * kernel_shape.dim(3) *
                            out_shape.dim(0) * out_shape.dim(1) * out_shape.dim(2);
index 0b3dadd..afdfd5e 100644 (file)
@@ -63,7 +63,6 @@ using namespace std;
 using nnc::mir::Shape;
 using nnc::mir::Index;
 using nnc::mir::ShapeRange;
-using nnc::mir::transposeTensor;
 using nnc::mir::TensorVariant;
 
 namespace ops = nnc::mir::ops;
@@ -155,9 +154,6 @@ void Serializer::visit(ops::ConcatOp& op) {
 
 void Serializer::visit(ops::Conv2DOp& op) {
   _curOp->_paramStartOffset = _buffer.size();
-  // serialize kernel
-  // HWCN -> NHWC
-  serializeTensor(transposeTensor<3, 0, 1, 2>(op.getKernel()));
   // serialize strides
   serializeShape(op.getStrides());
   // serialize pads
@@ -169,9 +165,6 @@ void Serializer::visit(ops::Conv2DOp& op) {
 
 void Serializer::visit(ops::DepthwiseConv2DOp& op) {
   _curOp->_paramStartOffset = _buffer.size();
-  // serialize kernel
-  const TensorVariant& kernel = op.getKernel();
-  serializeTensor(kernel);
   // serialize strides
   serializeShape(op.getStrides());
   // serialize pads
@@ -303,9 +296,6 @@ void Serializer::visit(mir::ops::EluOp& op) {
 
 void Serializer::visit(mir::ops::DeConv2DOp& op) {
   _curOp->_paramStartOffset = _buffer.size();
-  // serialize kernel
-  // HWCN -> "IN"HW"OUT"
-  serializeTensor(transposeTensor<2, 0, 1, 3>(op.getKernel()));
   // serialize strides
   serializeShape(op.getStrides());
   // serialize pads
index 2aaabca..f5bef77 100644 (file)
@@ -102,13 +102,6 @@ inline RuntimeShape shapeToRuntimeShape(const Shape &s) {
   return sh;
 }
 
-inline RuntimeShape shapeToRuntimeShapePad4(const Shape &s) {
-  assert(s.getDims()==3);
-  RuntimeShape sh({1,(int)s[0],(int)s[1],(int)s[2]});
-  return sh;
-}
-
-
 Dims<4> shapeToDims(const Shape &s)
 {
   Dims<4> dims;
@@ -151,18 +144,6 @@ static inline Shape deserializeShape(const char *&buf)
   return s;
 }
 
-struct Kernel
-{
-  const float *data;
-  Dims<4> dims;
-};
-
-struct KernelRT
-{
-  RuntimeShape shape;
-  const float *data;
-};
-
 __attribute__((unused))
 static bool isAddrAligned(const void *data, int alignment)
 {
@@ -177,46 +158,12 @@ static inline Tensor deserializeTensor(const char*& buf)
   assert(element_size == 4 && "Unsupported element size");
   Shape shape = deserializeShape(buf);
   const float* data = reinterpret_cast<const float*>(buf);
-
+  assert(isAddrAligned(data, 4));
   Tensor tensor(shape, const_cast<float*>(data));
   buf += element_size * shape.getNumElems();
   return tensor;
 }
 
-static inline Kernel deserializeKernel(const char *&buf)
-{
-  int32_t dType = deserializeT<int32_t>(buf);
-  assert(dType == 1 && "Unknown data type");
-  UNUSED(dType);
-  int32_t eSize = deserializeT<int32_t>(buf);
-  assert(eSize == 4 && "Unsupported element size");
-  UNUSED(eSize);
-  Kernel k;
-  k.dims = shapeToDims(deserializeShape(buf));
-  k.data = reinterpret_cast<const float *>(buf);
-  assert(isAddrAligned(buf, 4) && "data should be aligned to 4 bytes to use arm vector instructions");
-  buf += volume(k.dims) * eSize;
-  return k;
-}
-
-static inline KernelRT deserializeKernelRT(const char *&buf)
-{
-  int32_t dType = deserializeT<int32_t>(buf);
-  assert(dType == 1 && "Unknown data type");
-  UNUSED(dType);
-  int32_t eSize = deserializeT<int32_t>(buf);
-  assert(eSize == 4 && "Unsupported element size");
-  UNUSED(eSize);
-  KernelRT k={
-    shapeToRuntimeShape(deserializeShape(buf)),
-    reinterpret_cast<const float *>(buf)
-  };
-
-  assert(isAddrAligned(buf, 4) && "data should be aligned to 4 bytes to use arm vector instructions");
-  buf += k.shape.FlatSize() * eSize;
-  return k;
-}
-
 // This operation takes as input multiple tensors, at least 2, likely less then 7
 // parameter pack provides generalization for all possible number of inputs
 template <class ...Args>
@@ -237,125 +184,137 @@ void concat(Tensor &out, const char *params, const Args &...inputs)
                 out.getData(), shapeToDims(out.getShape()));
 }
 
-void conv2d(Tensor& out, const char* params, const Tensor& in, Tensor& temporary) {
-  const float *input = in.getData();
-  Dims<4> input_d = shapeToDims(in.getShape());
-  Kernel kernel = deserializeKernel(params);
+void conv2d(Tensor& out, const char* params, const Tensor& input, const Tensor& kernel,
+            Tensor& temporary) {
   Shape strides = deserializeShape(params);
   Shape pads = deserializeShape(params);
-  Shape out_s = deserializeShape(params);
+  Shape out_shape = deserializeShape(params);
+  out.reShape(out_shape);
 
-  out.reShape(out_s);
+  assert(strides.getDims() == 2);
+  const auto stride_h = static_cast<int>(strides[0]);
+  const auto stride_w = static_cast<int>(strides[1]);
 
-  Dims<4> out_d = shapeToDims(out_s);
+  assert(pads.getDims() == 2);
+  const auto pad_h = static_cast<int>(pads[0]);
+  const auto pad_w = static_cast<int>(pads[1]);
+
+  // Transpose the kernel from HWIO to OHWI format.
+  Shape kernel_shape = kernel.getShape();
+  kernel_shape = {kernel_shape[3], kernel_shape[0], kernel_shape[1], kernel_shape[2]};
+  Dims<4> kernel_dims = shapeToDims(kernel_shape);
+  unique_ptr<float[]> kernel_data(new float[volume(kernel_dims)]);
+  TransposeParams transpose_params{4, {3, 0, 1, 2}};
+  Transpose(transpose_params,
+            shapeToRuntimeShape(kernel.getShape()), kernel.getData(),
+            shapeToRuntimeShape(kernel_shape), kernel_data.get());
+
+  Dims<4> out_dims = shapeToDims(out_shape);
+  Dims<4> im2col_dims{{kernel_dims.sizes[0] * kernel_dims.sizes[1] * kernel_dims.sizes[2],
+                          out_dims.sizes[1],
+                          out_dims.sizes[2],
+                          out_dims.sizes[3]},
+                      {}};
 
-  const int im2col_d0 = static_cast<int>(kernel.dims.sizes[0] * kernel.dims.sizes[1] * kernel.dims.sizes[2]);
-  const int im2col_d1 = out_d.sizes[1];
-  const int im2col_d2 = out_d.sizes[2];
-  const int im2col_d3 = out_d.sizes[3];
-  Dims<4> im2col_d{{im2col_d0, im2col_d1, im2col_d2, im2col_d3},{}};
   int stride = 1;
-  for (int i = 0; i < 4; ++i)
-  {
-    im2col_d.strides[i] = stride;
-    stride *= im2col_d.sizes[i];
+  for (int i = 0; i < 4; ++i) {
+    im2col_dims.strides[i] = stride;
+    stride *= im2col_dims.sizes[i];
   }
 
-  assert(strides.getDims() == 2);
-  const int stride_w = strides[1];
-  const int stride_h = strides[0];
-  assert(pads.getDims() == 2);
-  const int pad_w = pads[1];
-  const int pad_h = pads[0];
-
   float* im2col_data = nullptr;
-  if (stride_w != 1 || stride_h != 1 || kernel.dims.sizes[1] != 1 || kernel.dims.sizes[2] != 1)
-  {
+  if (stride_w != 1 || stride_h != 1 || kernel_dims.sizes[1] != 1 || kernel_dims.sizes[2] != 1) {
     im2col_data = temporary.getData();
   }
 
-  Conv(input, input_d,
-       kernel.data, kernel.dims,
+  Conv(input.getData(), shapeToDims(input.getShape()),
+       kernel_data.get(), kernel_dims,
        stride_w, stride_h,
        pad_w, pad_h,
-       out.getData(), out_d,
-       im2col_data, im2col_d);
+       out.getData(), out_dims,
+       im2col_data, im2col_dims);
 }
 
-void convTransposed2d(Tensor& out, const char* params, const Tensor& in, Tensor& temporary) {
-  const float *input = in.getData();
-  RuntimeShape input_shape = shapeToRuntimeShape(in.getShape());
-  KernelRT kernel = deserializeKernelRT(params);
+void convTransposed2d(Tensor& out, const char* params, const Tensor& input, const Tensor& kernel,
+                      Tensor& temporary) {
   Shape strides = deserializeShape(params);
   Shape pads = deserializeShape(params);
-  Shape out_s = deserializeShape(params);
-
-  out.reShape(out_s);
-
-  RuntimeShape out_shape = shapeToRuntimeShape(out_s);
+  Shape out_shape = deserializeShape(params);
+  out.reShape(out_shape);
 
   assert(strides.getDims() == 2);
-  const short stride_w = strides[1];
-  const short stride_h = strides[0];
+  const auto stride_h = static_cast<int16>(strides[0]);
+  const auto stride_w = static_cast<int16>(strides[1]);
+
   assert(pads.getDims() == 2);
-  const short pad_w = pads[1];
-  const short pad_h = pads[0];
+  const auto pad_h = static_cast<int16>(pads[0]);
+  const auto pad_w = static_cast<int16>(pads[1]);
+
+  // Transpose the kernel from HWOI to OHWI format.
+  Shape kernel_shape = kernel.getShape();
+  kernel_shape = {kernel_shape[2], kernel_shape[0], kernel_shape[1], kernel_shape[3]};
+  Dims<4> kernel_dims = shapeToDims(kernel_shape);
+  unique_ptr<float[]> kernel_data(new float[volume(kernel_dims)]);
+  TransposeParams transpose_params{4, {2, 0, 1, 3}};
+  Transpose(transpose_params,
+      shapeToRuntimeShape(kernel.getShape()), kernel.getData(),
+      shapeToRuntimeShape(kernel_shape), kernel_data.get());
+
+  RuntimeShape input_rt_shape = shapeToRuntimeShape(input.getShape());
+  RuntimeShape out_rt_shape = shapeToRuntimeShape(out_shape);
+  RuntimeShape kernel_rt_shape = shapeToRuntimeShape(kernel_shape);
 
-  const int ker_width = kernel.shape.Dims(2);
-  const int ker_height = kernel.shape.Dims(1);
+  const int32 kernel_height = kernel_rt_shape.Dims(1);
+  const int32 kernel_width = kernel_rt_shape.Dims(2);
 
-  RuntimeShape im2col_shape = RuntimeShape({
-                                             (int)out_s[0],
-                                             (int)out_s[1],
-                                             (int)out_s[2],
-                                             // in depth
-                                             input_shape.Dims(3) * ker_width * ker_height
-                                           });
+  RuntimeShape im2col_shape{out_rt_shape.Dims(0),
+                            out_rt_shape.Dims(1),
+                            out_rt_shape.Dims(2),
+                            input_rt_shape.Dims(3) * kernel_width * kernel_height};
 
-  const auto convPara = ConvParams({PaddingType::kSame,
-                                    PaddingValues({pad_w,pad_h}), stride_w, stride_h});
+  ConvParams conv_params{PaddingType::kSame, {pad_w, pad_h}, stride_w, stride_h};
 
-  TransposeConv(
-    convPara, input_shape, input, kernel.shape, kernel.data,
-    out_shape, out.getData(), im2col_shape, temporary.getData());
+  TransposeConv(conv_params,
+      input_rt_shape, input.getData(),
+      kernel_rt_shape, kernel_data.get(),
+      out_rt_shape, out.getData(),
+      im2col_shape, temporary.getData());
 }
 
-void depthwiseConv2d(Tensor &out, const char *params, const Tensor &in)
-{
-  const float *input = in.getData();
-  Dims<4> input_d = shapeToDims(in.getShape());
-  Kernel kernel = deserializeKernel(params);
+void depthwiseConv2d(Tensor& out, const char* params, const Tensor& input, const Tensor& kernel) {
   Shape strides = deserializeShape(params);
   Shape pads = deserializeShape(params);
-  Shape out_s = deserializeShape(params);
+  Shape out_shape = deserializeShape(params);
+  out.reShape(out_shape);
 
   assert(strides.getDims() == 2);
-  const int stride_w = strides[1];
-  const int stride_h = strides[0];
-  assert(pads.getDims() == 2);
-  const int pad_w = pads[1];
-  const int pad_h = pads[0];
+  const auto stride_h = static_cast<int>(strides[0]);
+  const auto stride_w = static_cast<int>(strides[1]);
 
-  out.reShape(out_s);
+  assert(pads.getDims() == 2);
+  const auto pad_h = static_cast<int>(pads[0]);
+  const auto pad_w = static_cast<int>(pads[1]);
 
-  Dims<4> out_d = shapeToDims(out_s);
+  Dims<4> input_dims = shapeToDims(input.getShape());
+  Dims<4> kernel_dims = shapeToDims(kernel.getShape());
+  Dims<4> out_dims = shapeToDims(out_shape);
 
-  int depth_multiplier = out_d.sizes[0] / input_d.sizes[0];
-  assert(out_d.sizes[0] % input_d.sizes[0] == 0);
+  int depth_multiplier = out_dims.sizes[0] / input_dims.sizes[0];
+  assert(out_dims.sizes[0] % input_dims.sizes[0] == 0);
 
-  //reshape kernel: squash zero and first dimensions
-  const int kernel_w = kernel.dims.sizes[2];
-  const int kernel_h = kernel.dims.sizes[3];
-  const int output_channels = kernel.dims.sizes[0] * kernel.dims.sizes[1];
-  assert(output_channels == out_d.sizes[0]);
-  kernel.dims = shapeToDims({kernel_h, kernel_w, output_channels});
+  // Reshape kernel -- squash zeroth and first dimensions.
+  const int output_channels = kernel_dims.sizes[0] * kernel_dims.sizes[1];
+  assert(output_channels == out_dims.sizes[0]);
+  const int kernel_w = kernel_dims.sizes[2];
+  const int kernel_h = kernel_dims.sizes[3];
+  kernel_dims = shapeToDims({kernel_h, kernel_w, output_channels});
 
-  DepthwiseConv(input, input_d,
-                kernel.data, kernel.dims,
+  DepthwiseConv(input.getData(), input_dims,
+                kernel.getData(), kernel_dims,
                 stride_w, stride_h,
                 pad_w, pad_h,
                 depth_multiplier,
-                out.getData(), shapeToDims(out.getShape()));
+                out.getData(), out_dims);
 }
 
 void softmax(Tensor &out, const char *params, const Tensor &in)
index 32007c9..ef221d4 100644 (file)
@@ -96,8 +96,8 @@ TFLiteOpCreator::convertConv2D(const std::vector<mir::IODescriptor>& inputs,
   calculatePadding(opts->padding(), input_shape, kernel_shape, strides, padding_before,
                    padding_after);
 
-  auto result = createOp<ops::Conv2DOp>(inputs[0], params[0],
-                                        strides, padding_before, padding_after);
+  auto kernel = createOp<ops::ConstantOp>(params[0])->getOutput(0);
+  auto result = createOp<ops::Conv2DOp>(inputs[0], kernel, strides, padding_before, padding_after);
   auto bias = createOp<ops::ConstantOp>(params[1]);
   result = createOp<ops::BiasAddOp>(result->getOutput(0), bias->getOutput(0));
   return {addFusedActivation(result->getOutput(0), opts->fused_activation_function())};
@@ -121,7 +121,8 @@ TFLiteOpCreator::convertDepthwiseConv2D(const std::vector<mir::IODescriptor>& in
   calculatePadding(opts->padding(), input_shape, kernel_shape, strides, padding_before,
                    padding_after);
 
-  auto result = createOp<ops::DepthwiseConv2DOp>(inputs[0], params[0],
+  auto kernel = createOp<ops::ConstantOp>(params[0])->getOutput(0);
+  auto result = createOp<ops::DepthwiseConv2DOp>(inputs[0], kernel,
                                                  strides, padding_before, padding_after);
   auto bias = createOp<ops::ConstantOp>(params[1]);
   result = createOp<ops::BiasAddOp>(result->getOutput(0), bias->getOutput(0));
@@ -231,7 +232,8 @@ TFLiteOpCreator::convertTransposeConv(const std::vector<mir::IODescriptor>& inpu
                                       const Shape& output_shape) {
   Shape strides{opts->stride_h(), opts->stride_w()};
 
-  auto result = createOp<ops::DeConv2DOp>(inputs[0], params[1],
+  auto kernel = createOp<ops::ConstantOp>(params[1])->getOutput(0);
+  auto result = createOp<ops::DeConv2DOp>(inputs[0], kernel,
                                           strides, paddingMap[opts->padding()], output_shape);
   return {result->getOutput(0)};
 }
index 68c6455..29b51ca 100644 (file)
@@ -18,8 +18,8 @@ from opinfo import PoolType
 # 'axis' for CAPPED_RELU is not an error, it just denotes a numeric parameter.
 OP_FORMATS = {
     'FULLY_CONNECTED': (),
-    'CONV_2D': ('kernels', 'padType', 'shapes'),
-    'DEPTHWISE_CONV_2D': ('kernels', 'padType', 'shapes'),
+    'CONV_2D': ('padType', 'shapes'),
+    'DEPTHWISE_CONV_2D': ('padType', 'shapes'),
     'POOL_2D': ('padType', 'poolType', 'shapes'),
     'CONCATENATION': ('axis',),
     'RESHAPE': ('shapes',),
index d2deaf5..f1fd0e0 100644 (file)
@@ -48,14 +48,14 @@ static Operation* createFullyConnected(std::unique_ptr<Graph>& g,
 static Operation* createConv2D(std::unique_ptr<Graph>& g,
                                const std::vector<IODescriptor>& inputs,
                                const opinfo::OperatorInfo* opInfo) {
-  return g->create<ops::Conv2DOp>("y", inputs[0], *getKernel(opInfo), getShapeParam(opInfo, 0),
+  return g->create<ops::Conv2DOp>("y", inputs[0], inputs[1], getShapeParam(opInfo, 0),
                                   std::vector<int32_t>{0, 0}, std::vector<int32_t>{0, 0});
 }
 
 static Operation* createDepthwiseConv2D(std::unique_ptr<Graph>& g,
                                         const std::vector<IODescriptor>& inputs,
                                         const opinfo::OperatorInfo* opInfo) {
-  return g->create<ops::DepthwiseConv2DOp>("y", inputs[0], *getKernel(opInfo),
+  return g->create<ops::DepthwiseConv2DOp>("y", inputs[0], inputs[1],
                                            getShapeParam(opInfo, 0), std::vector<int32_t>{0, 0},
                                            std::vector<int32_t>{0, 0});
 }
index f662f0e..2774437 100644 (file)
@@ -41,11 +41,6 @@ std::shared_ptr<TensorVariant> getTensor(const opinfo::Tensor* t)
   return std::make_shared<TensorVariant>(tensorShape, tensorBufferCopy, type, elementSize);
 }
 
-std::shared_ptr<TensorVariant> getKernel(const opinfo::OperatorInfo* opInfo)
-{
-  return getTensor(opInfo->kernels()->Get(0));
-}
-
 ops::PoolOp::PoolingType getPoolingType(const opinfo::OperatorInfo* opInfo)
 {
   switch (opInfo->poolType())
index 5531616..bdadf95 100644 (file)
@@ -30,7 +30,6 @@
 
 
 std::shared_ptr<nnc::mir::TensorVariant> getTensor(const opinfo::Tensor* t);
-std::shared_ptr<nnc::mir::TensorVariant> getKernel(const opinfo::OperatorInfo* opInfo);
 nnc::mir::ops::PoolOp::PoolingType getPoolingType(const opinfo::OperatorInfo* opInfo);
 nnc::mir::Shape getShapeParam(const opinfo::OperatorInfo* opInfo, unsigned int n);
 int getAxis(const opinfo::OperatorInfo* opInfo);
index de85ccd..d2f2b2d 100644 (file)
@@ -11,12 +11,12 @@ CONV_2D
 # kernel shape: [height, width, out_channels]
 # padding type: (VALID | SAME)
 # strides: [h_stride, w_stride]
-[5, 5, 1] [3, 3, 1, 1] VALID [1, 1]
-[64, 64, 4] [3, 3, 4, 2] VALID [1, 1]
+[[5, 5, 1] [3, 3, 1, 1]] VALID [1, 1]
+[[64, 64, 4] [3, 3, 4, 2]] VALID [1, 1]
 
 DEPTHWISE_CONV_2D
-[5, 5, 10] [3, 3, 10, 1] VALID [1, 1]
-[20, 20, 8] [3, 1, 8, 2] SAME [2, 2]
+[[5, 5, 10] [3, 3, 10, 1]] VALID [1, 1]
+[[20, 20, 8] [3, 1, 8, 2]] SAME [2, 2]
 
 POOL_2D
 # input shape: [height, width, in_channels]
index 0dfdae2..a164827 100644 (file)
@@ -312,13 +312,14 @@ TEST(acl_backend_mir_to_dom, conv2d) {
   const int32_t channels = 3;
   mir::Shape kernel_shape{3, 3, channels, 1}; // Height, Width, input Channels, output Channel
   mir::Shape strides{1, 1};
-  mir::TensorVariant kernel = createTensorVariant(kernel_shape);
+  mir::TensorVariant kernel_tensor = createTensorVariant(kernel_shape);
 
   Graph g;
   OpConstructor op_generator =
-      [kernel, strides](mir::Graph& g,
-                        const std::vector<mir::IODescriptor>& inputs) {
+      [kernel_tensor, strides](mir::Graph& g,
+                               const std::vector<mir::IODescriptor>& inputs) {
         std::vector<int32_t> padding{0, 0};
+        auto kernel = g.create<mir::ops::ConstantOp>("", kernel_tensor)->getOutput(0);
         return g.create<mir::ops::Conv2DOp>("conv2d", inputs[0], kernel, strides, padding, padding);
       };
 
@@ -338,16 +339,16 @@ TEST(acl_backend_mir_to_dom, depthwise_conv) {
   const int32_t channels = 3;
   mir::Shape kernel_shape{3, 3, channels, 1}; // Height, Width, Channels, Channel multiplier
   mir::Shape strides{1, 1};
-  mir::TensorVariant kernel = createTensorVariant(kernel_shape);
+  mir::TensorVariant kernel_tensor = createTensorVariant(kernel_shape);
 
   Graph g;
   OpConstructor op_generator =
-      [kernel, strides](mir::Graph& g,
-                        const std::vector<mir::IODescriptor>& inputs) {
-          std::vector<int32_t> padding{0, 0};
-          return g.create<mir::ops::DepthwiseConv2DOp>("depthwiseConv2d",
-                                                       inputs[0], kernel,
-                                                       strides, padding, padding);
+      [kernel_tensor, strides](mir::Graph& g,
+                               const std::vector<mir::IODescriptor>& inputs) {
+        std::vector<int32_t> padding{0, 0};
+        auto kernel = g.create<mir::ops::ConstantOp>("", kernel_tensor)->getOutput(0);
+        return g.create<mir::ops::DepthwiseConv2DOp>("depthwiseConv2d", inputs[0], kernel,
+                                                     strides, padding, padding);
       };
 
   vector<Shape> input_shapes{{1, 10, 10, channels}};
index eeded6b..9245763 100644 (file)
@@ -594,21 +594,22 @@ TEST(cpp_operations_test, convTransposed2d) {
           for (iT stride_h = 1; stride_h <= 3; ++stride_h)
             for (iT stride_w = 1; stride_w <= 3; ++stride_w) {
               vector<int> input_shape_data{3, 9, 3, static_cast<int>(input_c)};  // NHWC
-              mir::Shape kernel_shape{kernel_h, kernel_w, output_c, input_c};
+              vector<int> kernel_shape_data{kernel_h, kernel_w, output_c, input_c};
               mir::Shape strides{stride_h, stride_w};
-              vector<unique_ptr<mir::TensorVariant>> input_ntensors(1);
-              Tensor input_atensor;
-              fillTensors(input_ntensors[0], input_atensor, input_shape_data, 1.0f);
+              vector<unique_ptr<mir::TensorVariant>> input_ntensors(2);
+              Tensor input_atensor0;
+              Tensor input_atensor1;
+              fillTensors(input_ntensors[0], input_atensor0, input_shape_data, 1.0f);
+              fillTensors(input_ntensors[1], input_atensor1, kernel_shape_data, 1.0f);
               auto pad_t = mir::ops::PaddingType::Same;
-              mir::TensorVariant kernel = createNTensor(kernel_shape, 1.0f);
-              auto op_generator = [&kernel, &strides, pad_t](
+              auto op_generator = [&strides, pad_t](
                 mir::Graph& g, const std::vector<mir::IODescriptor>& inputs) {
 
-                return g.create<mir::ops::DeConv2DOp>("y", inputs[0], kernel, strides, pad_t);
+                return g.create<mir::ops::DeConv2DOp>("y", inputs[0], inputs[1], strides, pad_t);
               };
 
-              createAndRunTestGraph(op_generator, convTransposed2d, input_ntensors, input_atensor,
-                                    temporary);
+              createAndRunTestGraph(op_generator, convTransposed2d, input_ntensors,
+                                    input_atensor0, input_atensor1, temporary);
             }
 }
 
@@ -626,20 +627,22 @@ TEST(cpp_operations_test, conv2d) {
           for (iT stride_h = 1; stride_h <= 3; ++stride_h)
             for (iT stride_w = 1; stride_w <= 3; ++stride_w) {
               vector<int> input_shape_data{1, 5, 7, static_cast<int>(input_c)};  // NHWC
-              mir::Shape kernel_shape{kernel_h, kernel_w, input_c, output_c}; // HWCN
+              vector<int> kernel_shape_data{kernel_h, kernel_w, input_c, output_c}; // HWCN
               mir::Shape strides{stride_h, stride_w};
-              vector<unique_ptr<mir::TensorVariant>> input_ntensors(1);
-              Tensor input_atensor;
-              fillTensors(input_ntensors[0], input_atensor, input_shape_data, 1.0f);
-              mir::TensorVariant kernel = createNTensor(kernel_shape, 1.0f);
-              auto op_generator = [&kernel, &strides](mir::Graph& g,
-                                                   const std::vector<mir::IODescriptor>& inputs) {
+              vector<unique_ptr<mir::TensorVariant>> input_ntensors(2);
+              Tensor input_atensor0;
+              Tensor input_atensor1;
+              fillTensors(input_ntensors[0], input_atensor0, input_shape_data, 1.0f);
+              fillTensors(input_ntensors[1], input_atensor1, kernel_shape_data, 1.0f);
+              auto op_generator = [&strides](mir::Graph& g,
+                                             const std::vector<mir::IODescriptor>& inputs) {
                 std::vector<int32_t> padding{0, 0};
-                return g.create<mir::ops::Conv2DOp>("y", inputs[0], kernel, strides, padding,
-                                                    padding);
+                return g.create<mir::ops::Conv2DOp>("y", inputs[0], inputs[1],
+                                                    strides, padding, padding);
               };
 
-              createAndRunTestGraph(op_generator, conv2d, input_ntensors, input_atensor, temporary);
+              createAndRunTestGraph(op_generator, conv2d, input_ntensors,
+                                    input_atensor0, input_atensor1, temporary);
             }
 }
 
@@ -657,20 +660,22 @@ TEST(cpp_operations_test, depthwise_conv) {
           for (iT stride_h = 1; stride_h <= 3; ++stride_h)
             for (iT multiplier = 1; multiplier <= 2; ++multiplier) {
               vector<int> input_shape_data{1, 5, 7, static_cast<int>(channels)};  // NHWC
-              mir::Shape kernel_shape{kernel_h, kernel_w, channels, multiplier}; // HWCN
+              vector<int> kernel_shape_data{kernel_h, kernel_w, channels, multiplier}; // HWCN
               mir::Shape strides{stride_h, stride_w};
-              vector<unique_ptr<mir::TensorVariant>> input_ntensors(1);
-              Tensor input_atensor;
-              fillTensors(input_ntensors[0], input_atensor, input_shape_data, 1.0f);
-              mir::TensorVariant kernel = createNTensor(kernel_shape, 1.0f);
-              auto op_generator = [&kernel, &strides](mir::Graph& g,
-                                                    const std::vector<mir::IODescriptor>& inputs) {
+              vector<unique_ptr<mir::TensorVariant>> input_ntensors(2);
+              Tensor input_atensor0;
+              Tensor input_atensor1;
+              fillTensors(input_ntensors[0], input_atensor0, input_shape_data, 1.0f);
+              fillTensors(input_ntensors[1], input_atensor1, kernel_shape_data, 1.0f);
+              auto op_generator = [&strides](mir::Graph& g,
+                                             const std::vector<mir::IODescriptor>& inputs) {
                 std::vector<int32_t> padding{0, 0};
-                return g.create<mir::ops::DepthwiseConv2DOp>("y", inputs[0], kernel, strides,
-                                                             padding, padding);
+                return g.create<mir::ops::DepthwiseConv2DOp>("y", inputs[0], inputs[1],
+                                                             strides, padding, padding);
               };
 
-              createAndRunTestGraph(op_generator, depthwiseConv2d, input_ntensors, input_atensor);
+              createAndRunTestGraph(op_generator, depthwiseConv2d, input_ntensors,
+                                    input_atensor0, input_atensor1);
             }
 }