[nnc] New Gemm implementation, GlobalAveragePool and Conv2D updates. (#2622)
authorАндрей Тищенко/AI Tools Lab /SRR/Staff Engineer/삼성전자 <a.tischenko@partner.samsung.com>
Thu, 13 Dec 2018 09:50:49 +0000 (12:50 +0300)
committerРоман Михайлович Русяев/AI Tools Lab /SRR/Staff Engineer/삼성전자 <r.rusyaev@samsung.com>
Thu, 13 Dec 2018 09:50:49 +0000 (12:50 +0300)
The first steps in new Gemm implementation were done: Now it's a standard NN operator with proper shape inference. Interpreter (an d maybe soft backend) support will be added in the next patch.
Pool and Conv operators are using similar input data that's why they were partially merged. As result they now produce proper output shape.

Signed-off-by: Andrew V. Tischenko a.tischenko@partner.samsung.com
22 files changed:
contrib/nnc/core/CMakeLists.txt
contrib/nnc/core/modelIR/IrDotDumper.cpp
contrib/nnc/core/modelIR/Operation.cpp
contrib/nnc/core/modelIR/operations/GemmOp.cpp [new file with mode: 0644]
contrib/nnc/include/core/modelIR/IrDotDumper.h
contrib/nnc/include/core/modelIR/operations/ConstantOp.h
contrib/nnc/include/core/modelIR/operations/GemmOp.h [new file with mode: 0644]
contrib/nnc/include/core/modelIR/operations/operations.lst.h
contrib/nnc/include/passes/interpreter/Interpreter.h
contrib/nnc/passes/acl_soft_backend/AclCppOpGenerator.cpp
contrib/nnc/passes/acl_soft_backend/AclCppOpGenerator.h
contrib/nnc/passes/interpreter/Interpreter.cpp
contrib/nnc/passes/interpreter/ops/Gemm.cpp [new file with mode: 0644]
contrib/nnc/passes/interpreter/ops/Gemm.h [new file with mode: 0644]
contrib/nnc/passes/onnx_frontend/ONNXImporterImpl.cpp
contrib/nnc/passes/onnx_frontend/ONNXImporterImpl.h
contrib/nnc/passes/onnx_frontend/ONNXOpCreator.cpp
contrib/nnc/passes/soft_backend/ModelAnalyzer.cpp
contrib/nnc/passes/soft_backend/ModelAnalyzer.h
contrib/nnc/passes/soft_backend/SBSerializer.cpp
contrib/nnc/passes/soft_backend/SBSerializer.h
contrib/nnc/tests/interpreter/graph_creator.cpp

index 52aba7f..f2713f8 100644 (file)
@@ -4,6 +4,7 @@ set(SOURCES "modelIR/operations/ConcatOp.cpp"
             "modelIR/operations/DepthwiseConv2DOp.cpp"
             "modelIR/operations/FullyConnectedOp.cpp"
             "modelIR/operations/GatherOp.cpp"
+            "modelIR/operations/GemmOp.cpp"
             "modelIR/operations/PadOp.cpp"
             "modelIR/operations/PoolOp.cpp"
             "modelIR/operations/SqueezeOp.cpp"
index 84d261e..23731df 100644 (file)
@@ -30,6 +30,7 @@
 #include "core/modelIR/operations/EluOp.h"
 #include "core/modelIR/operations/FullyConnectedOp.h"
 #include "core/modelIR/operations/GatherOp.h"
+#include "core/modelIR/operations/GemmOp.h"
 #include "core/modelIR/operations/PadOp.h"
 #include "core/modelIR/operations/PoolOp.h"
 #include "core/modelIR/operations/ReduceFOp.h"
@@ -125,6 +126,13 @@ void IrDotDumper::visit(ops::FullyConnectedOp& op) {
   dotBuilder.updateWithOp(&op, nodeInfo);
 }
 
+void IrDotDumper::visit(ops::GemmOp& op) {
+  auto nodeInfo = DotIrNodeInfo().withType("Gemm", op.getName())
+                                 .withInShapes(getInputShapes(op))
+                                 .withOutShapes(getOutputShapes(op));
+  dotBuilder.updateWithOp(&op, nodeInfo);
+}
+
 void IrDotDumper::visit(ops::SoftmaxOp& op) {
   auto nodeInfo = DotIrNodeInfo().withType("Softmax", op.getName())
                                  .withInShapes(getInputShapes(op))
index 2005ea6..97765ae 100644 (file)
@@ -28,6 +28,7 @@
 #include "core/modelIR/operations/EluOp.h"
 #include "core/modelIR/operations/FullyConnectedOp.h"
 #include "core/modelIR/operations/GatherOp.h"
+#include "core/modelIR/operations/GemmOp.h"
 #include "core/modelIR/operations/PadOp.h"
 #include "core/modelIR/operations/PoolOp.h"
 #include "core/modelIR/operations/ReduceFOp.h"
diff --git a/contrib/nnc/core/modelIR/operations/GemmOp.cpp b/contrib/nnc/core/modelIR/operations/GemmOp.cpp
new file mode 100644 (file)
index 0000000..d0dd275
--- /dev/null
@@ -0,0 +1,74 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "core/modelIR/operations/GemmOp.h"
+
+namespace nnc {
+namespace mir {
+namespace ops {
+
+void GemmOp::inferOutputShapes() {
+//Input tensor A: The shape of A should be (M, K) if transA is 0, or (K, M) if transA is non-zero.
+//Input tensor B: The shape of B should be (K, N) if transB is 0, or (N, K) if transB is non-zero.
+//Input tensor C: The shape of C should be unidirectional broadcastable to (M, N).
+//Output tensor Y: shape (M, N) or vector(N) if M == 1.
+  std::vector<int32_t > vector(2); // this vector will be used to create shapes of tensor A, B
+
+  // Flatten the shape by dim(0)
+  mir::Shape shape0 = {getInputShape(0).dim(0),
+                       getInputShape(0).numElements() / getInputShape(0).dim(0)};
+  assert(shape0.rank() == 2);
+  Shape shape_a(2);
+  if (_transA) {
+    shape_a.dim(0) = shape0.dim(shape0.rank() - 1);
+    shape_a.dim(1) = shape0.dim(shape0.rank() - 2);
+  } else {
+    shape_a.dim(0) = shape0.dim(shape0.rank() - 2);
+    shape_a.dim(1) = shape0.dim(shape0.rank() - 1);
+  }
+
+  auto& shape1 = getInputShape(1);
+  // It must be a matrice
+  assert(shape1.rank() == 2);
+  Shape shape_b(2);
+
+  if (_transB) {
+    shape_b.dim(0) = shape1.dim(shape1.rank() - 1);
+    shape_b.dim(1) = shape1.dim(shape1.rank() - 2);
+  } else {
+    shape_b.dim(0) = shape1.dim(shape1.rank() - 2);
+    shape_b.dim(1) = shape1.dim(shape1.rank() - 1);
+  }
+
+  // Number of cols in tensor A must be equal to number of rows in tensor B
+  assert(shape_a.dim(1) == shape_b.dim(0));
+  Shape mult_a_b({shape_a.dim(0), shape_b.dim(1)});
+
+  Shape shape_c = getInputShape(2);
+
+  if (shape_c.rank() == 1){
+    assert(mult_a_b.dim(0) == 1);
+    assert(mult_a_b.dim(1) == shape_c.dim(0));
+  } else {
+    assert(shape_c.rank() == 2);
+    assert((mult_a_b.dim(0) == shape_c.dim(0)) && (mult_a_b.dim(1) == shape_c.dim(1)));
+  }
+  setOutputShape(0, mult_a_b);
+}
+
+} // namespace ops
+} // namespace mir
+} // namespace nnc
index f5a907a..dad796c 100644 (file)
@@ -18,7 +18,6 @@
 #define _NNC_BACKEND_INTERPRETER_CORE_DOTDUMPER_
 
 #include "core/modelIR/Visitor.h"
-
 #include "core/modelIR/ir_dot_builder.h"
 
 namespace nnc
@@ -45,6 +44,7 @@ public:
   void visit(ops::EluOp& op) override;
   void visit(ops::FullyConnectedOp& op) override;
   void visit(ops::GatherOp& op) override;
+  void visit(ops::GemmOp& op) override;
   void visit(ops::PadOp& op) override;
   void visit(ops::PoolOp& op) override;
   void visit(ops::ReduceFOp& op) override;
index 075ccad..92d22f4 100644 (file)
@@ -25,14 +25,15 @@ namespace ops {
 
 class ConstantOp : public Operation {
 public:
-  ConstantOp(const TensorVariant& value) : Operation(Type::constant, {}), _value(value) {
-    setOutputShape(0, _value.getShape());
+  ConstantOp(const std::shared_ptr<mir::TensorVariant>& value) :
+                                            Operation(Type::constant, {}), _value(value) {
+    setOutputShape(0, _value->getShape());
   }
 
-  const TensorVariant& getValue() const { return _value; }
+    const std::shared_ptr<mir::TensorVariant>& getValue() const { return _value; }
 
 private:
-  TensorVariant _value;
+    const std::shared_ptr<mir::TensorVariant>& _value;
 };
 
 } // namespace ops
diff --git a/contrib/nnc/include/core/modelIR/operations/GemmOp.h b/contrib/nnc/include/core/modelIR/operations/GemmOp.h
new file mode 100644 (file)
index 0000000..4c10af9
--- /dev/null
@@ -0,0 +1,56 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _NNC_CORE_IR_MODEL_GEMM_OP_H_
+#define _NNC_CORE_IR_MODEL_GEMM_OP_H_
+
+#include "core/modelIR/Operation.h"
+#include "core/modelIR/TensorVariant.h"
+
+namespace nnc {
+namespace mir {
+namespace ops {
+
+class GemmOp : public Operation {
+public:
+    GemmOp(const IODescriptor a, const IODescriptor b, const IODescriptor c,
+           bool transA, bool transB, float alpha, float beta) :
+          Operation(Type::gemmOp, {a, b, c}),
+          _a(a), _b(b), _c(c), _transA(transA),_transB(transB), _alpha(alpha), _beta(beta) {
+    inferOutputShapes();
+  }
+
+  bool  getTransA() {return _transA;}
+  bool  getTransB() {return _transB;}
+  float getAlpha()  {return _alpha;}
+  float getBeta()   {return _beta;}
+
+private:
+  void inferOutputShapes();
+
+  const IODescriptor _a;
+  const IODescriptor _b;
+  const IODescriptor _c;
+  bool  _transA;
+  bool  _transB;
+  float _alpha;
+  float _beta;
+};
+} // namespace ops
+} // namespace mir
+} // namespace nnc
+
+#endif //_NNC_CORE_IR_MODEL_GEMM_OP_H_
index c4ae1d2..9292e71 100644 (file)
@@ -25,6 +25,7 @@ HANDLE_OP(gather, GatherOp)
 HANDLE_OP(softmax, SoftmaxOp)
 HANDLE_OP(pool, PoolOp)
 HANDLE_OP(fullyConnected, FullyConnectedOp)
+HANDLE_OP(gemmOp, GemmOp)
 HANDLE_OP(cappedReLU, CappedReluOp)
 HANDLE_OP(biasAdd, BiasAddOp)
 HANDLE_OP(variable, VariableOp)
index f1ddba3..25239c2 100644 (file)
@@ -49,6 +49,7 @@ public:
   void visit(ops::EluOp& op) override;
   void visit(ops::FullyConnectedOp& op) override;
   void visit(ops::GatherOp& op) override;
+  void visit(ops::GemmOp& op) override;
   void visit(ops::PadOp& op) override;
   void visit(ops::PoolOp& op) override;
   void visit(ops::ReduceFOp& op) override;
index 404fa3b..46d5c6f 100644 (file)
@@ -17,6 +17,7 @@
 #include "core/modelIR/operations/DropoutOp.h"
 #include "core/modelIR/operations/ElementwiseOp.h"
 #include "core/modelIR/operations/FullyConnectedOp.h"
+#include "core/modelIR/operations/GemmOp.h"
 #include "core/modelIR/operations/PoolOp.h"
 #include "core/modelIR/operations/ReduceFOp.h"
 #include "core/modelIR/operations/ReluOp.h"
@@ -290,6 +291,10 @@ void AclCppOpGenerator::visit(ops::FullyConnectedOp& op) {
   runLayer(layer);
 }
 
+void AclCppOpGenerator::visit(ops::GemmOp& op) {
+  assert(false);
+}
+
 void AclCppOpGenerator::visit(ops::CappedReluOp& op) {
   genActivation(op, "LU_BOUNDED_RELU", op.getCap());
 }
index b6108fc..de10424 100644 (file)
@@ -60,6 +60,7 @@ public:
   void visit(mir::ops::EluOp& op) override;
   void visit(mir::ops::FullyConnectedOp& op) override;
   void visit(mir::ops::GatherOp& op) override;
+  void visit(mir::ops::GemmOp& op) override;
   void visit(mir::ops::PadOp& op) override;
   void visit(mir::ops::PoolOp& op) override;
   void visit(mir::ops::ReduceFOp& op) override;
index 76cd84a..a8cf131 100644 (file)
@@ -22,6 +22,7 @@
 #include "passes/interpreter/Interpreter.h"
 
 #include "core/modelIR/operations/FullyConnectedOp.h"
+#include "core/modelIR/operations/GemmOp.h"
 #include "core/modelIR/operations/SoftmaxOp.h"
 #include "core/modelIR/operations/CappedReluOp.h"
 #include "core/modelIR/operations/DepthwiseConv2DOp.h"
@@ -51,6 +52,7 @@
 #include "ops/DeConv2D.h"
 #include "ops/Depthwise_conv_2D.h"
 #include "ops/FullyConnected.h"
+#include "ops/Gemm.h"
 #include "ops/Pool.h"
 #include "ops/Reshape.h"
 #include "ops/Softmax.h"
@@ -80,7 +82,7 @@ void NNInterpreter::visit(ops::VariableOp& op) {
 }
 
 void NNInterpreter::visit(ops::ConstantOp& op) {
-  var(op.getId()) = {op.getValue()};
+  var(op.getId()) = {*op.getValue()};
 }
 
 std::vector<TensorVariant> &NNInterpreter::getResult(Operation* op) {
@@ -156,6 +158,13 @@ void NNInterpreter::visit(ops::FullyConnectedOp& op) {
   var(op.getId()) = FullyConnected<float>(input, op)();
 }
 
+void NNInterpreter::visit(ops::GemmOp& op) {
+  mapByName(&op);
+  auto operand = op.getPrevNodes()[0];
+  TensorVariant input = var(operand.op->getId())[operand.index];
+  var(op.getId()) = Gemm<float>(input, op)();
+}
+
 void NNInterpreter::visit(ops::CappedReluOp& op) {
   mapByName(&op);
   auto operand = op.getPrevNodes()[0];
diff --git a/contrib/nnc/passes/interpreter/ops/Gemm.cpp b/contrib/nnc/passes/interpreter/ops/Gemm.cpp
new file mode 100644 (file)
index 0000000..4259656
--- /dev/null
@@ -0,0 +1,19 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Gemm.h"
+//Do not remove
+//Used to force compile Gemm.h
diff --git a/contrib/nnc/passes/interpreter/ops/Gemm.h b/contrib/nnc/passes/interpreter/ops/Gemm.h
new file mode 100644 (file)
index 0000000..e130626
--- /dev/null
@@ -0,0 +1,44 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _NNC_CORE_BACKEND_INTERPRETER_GEMM_
+#define _NNC_CORE_BACKEND_INTERPRETER_GEMM_
+
+#include "core/modelIR/ShapeRange.h"
+#include "core/modelIR/operations/GemmOp.h"
+#include "OperationImpl.h"
+
+namespace nnc
+{
+template<typename T>
+class Gemm : public OperationImpl<T>
+{
+public:
+    Gemm(const mir::TensorVariant &_input, const mir::ops::GemmOp &_op) :
+                                                            _op(_op), _input(_input) {}
+
+  std::vector<mir::TensorVariant> operator()() override {
+    mir::TensorVariant res = OperationImpl<T>::allocate_tensor(_op.getOutputShape(0));
+    return {res};
+  }
+
+private:
+  const mir::ops::GemmOp& _op;
+  const mir::Tensor<T> _input;
+};
+} // namespace nnc
+
+#endif //_NNC_CORE_BACKEND_INTERPRETER_GEMM_
index a453471..98f68b8 100644 (file)
@@ -150,23 +150,17 @@ void ONNXImporterImpl::createGraphInputs() {
       mir::Shape input_shape = ShapeHelper::createShape(onnx_tensor->dims(),
                                                    static_cast<size_t>(onnx_tensor->dims_size()));
       _inputTensors[name] = createTensor(onnx_tensor, input_shape);
-      auto constant = _graph->create<mir::ops::ConstantOp>(name, *_inputTensors[name].get());
+      auto constant = _graph->create<mir::ops::ConstantOp>(name, _inputTensors[name]);
       _tensorNameToPrevMirOp[name] = constant;
       constants.insert(constant);
     } else {
       // We're dealing with graph input
       auto onnx_input_shape = input.type().tensor_type().shape();
-      std::vector<int> shape_vector(onnx_input_shape.dim_size());
+      mir::Shape shape(4);
       for (int i = 0; i < onnx_input_shape.dim_size(); i++) {
         assert(onnx_input_shape.dim(i).has_dim_value());
-        shape_vector[i] = onnx_input_shape.dim(i).dim_value();
+        shape.dim(i) = onnx_input_shape.dim(i).dim_value();
       }
-      mir::Shape shape(shape_vector);
-      // TODO For now we only support convolutional networks. The input data have already been
-      // transformed from ONNX NCHW format to ModelIR NHWC; reflect the changes in the IR.
-      if (shape.rank() == 4)
-        shape = mir::Shape{shape.dim(0), shape.dim(1), shape.dim(2), shape.dim(0)};
-
       // TODO: Temporary solution!
       auto node = _graph->create<mir::ops::VariableOp>(name, shape);
       _tensorNameToPrevMirOp[name] = node;
@@ -176,6 +170,54 @@ void ONNXImporterImpl::createGraphInputs() {
     _graph->setConstants(constants);
 }
 
+static void dumpShape(const mir::Shape& shape) {
+  std::cout << "{";
+  for (int i = 0; i < shape.rank(); i++) {
+    std::cout << shape.dim(i) << (i == shape.rank() - 1 ? "} " : ", ");
+  }
+}
+
+void ONNXImporterImpl::dump(const std::vector<mir::Operation*>& ops, const onnx::NodeProto& onnx_node) {
+  for (auto op : ops) {
+    std::cout << onnx_node.op_type() << " '" << op->getName() << "' Input Shapes: ";
+    for (int i = 0; i < op->getNumInputs() ; i++) {
+      dumpShape(op->getInputShape(i));
+    }
+    std::cout << " Output Shapes: ";
+    for (int i = 0; i < op->getNumOutputs() ; i++) {
+      dumpShape(op->getOutputShape(i));
+    }
+    auto* onnx_op_type = ONNXPerfectHash::getONNXOpType(onnx_node.op_type().c_str(), onnx_node.op_type().size());
+    switch (onnx_op_type->opCode) {
+      case ONNXOpCode::opConv: {
+        auto *conv = dynamic_cast<mir::ops::Conv2DOp *>(op);
+        std::cout << " Weights tensor shape ";
+        dumpShape(conv->getKernel().getShape());
+        std::cout << " Strides  ";
+        dumpShape(conv->getStrides());
+        std::cout << " Padding before:  (" << conv->getPaddingBefore()[0] << " " << conv->getPaddingBefore()[1] << ")";
+        std::cout << " After:  (" << conv->getPaddingAfter()[0] << " " << conv->getPaddingAfter()[1] << ")";
+        break;
+      }
+      case ONNXOpCode::opGlobalAveragePool:
+      case ONNXOpCode::opAveragePool:
+      case ONNXOpCode::opMaxPool: {
+        auto *pool = dynamic_cast<mir::ops::PoolOp *>(op);
+        std::cout << " Kernel ";
+        dumpShape(pool->getWindowShape());
+        std::cout << " Strides  ";
+        dumpShape(pool->getStrides());
+        std::cout << " Padding before:  " << pool->getPaddingBefore()[0] << " " << pool->getPaddingBefore()[1];
+        std::cout << " After:  " << pool->getPaddingAfter()[0] << " " << pool->getPaddingAfter()[1];
+        break;
+      }
+      default:
+        break;
+    }
+    std::cout << "\n";
+  }
+}
+
 mir::Graph *ONNXImporterImpl::createIR() {
   GOOGLE_PROTOBUF_VERIFY_VERSION;
 
@@ -267,6 +309,7 @@ mir::Graph *ONNXImporterImpl::createIR() {
     assert (outputs.size());
     // FIXME: it should be done properly via the given graph outputs
     _graphOutputs.assign(outputs.begin(), outputs.end());
+    dump(outputs, onnx_node);
   }
   // set graph outputs
   // TODO: it should be done with onnx graph outputs
index 67667fb..24ec1c3 100644 (file)
@@ -39,6 +39,7 @@ public:
 
   void import() {};
   mir::Graph *createIR() override;
+  void dump(const std::vector<mir::Operation*>& op, const onnx::NodeProto& onnx_node);
 
 private:
   void createGraphInputs();
index 6d6faf1..0e9b111 100644 (file)
@@ -31,6 +31,7 @@
 #include "core/modelIR/operations/DepthwiseConv2DOp.h"
 #include "core/modelIR/operations/DropoutOp.h"
 #include "core/modelIR/operations/FullyConnectedOp.h"
+#include "core/modelIR/operations/GemmOp.h"
 #include "core/modelIR/operations/PoolOp.h"
 #include "core/modelIR/operations/ReluOp.h"
 #include "core/modelIR/operations/ReshapeOp.h"
@@ -86,34 +87,67 @@ static const mir::TensorVariant* createTensor(float data) {
   std::shared_ptr<char> shared_buffer (new char[buffer_size], std::default_delete<char[]>());
   memcpy(shared_buffer.get(), src_data, buffer_size);
   Shape tensor_shape = Shape({1});
+  // FIXME: it has to be shared_ptr
   auto mir_tensor = new mir::TensorVariant(tensor_shape, shared_buffer, type, element_size);
   return mir_tensor;
 }
 
+struct KernelStridesPadding {
+  Shape kernel_shape;
+  Shape strides_shape;
+  std::vector<int32_t> padding_before{0, 0};
+  std::vector<int32_t> padding_after{0, 0};
+};
+
+static void getKernelStridesPadding(const onnx::NodeProto &onnx_node, KernelStridesPadding &cdata) {
+  auto* kshape = findAttribute(onnx_node, "kernel_shape");
+  assert(kshape && kshape->ints_size());
+  auto* strides = findAttribute(onnx_node, "strides");
+  assert(strides && strides->ints_size());
+  auto* pads = findAttribute(onnx_node, "pads");
+
+  cdata.kernel_shape = ShapeHelper::createShape(kshape->ints(), kshape->ints_size());
+  cdata.strides_shape = ShapeHelper::createShape(strides->ints(), strides->ints_size());
+
+  if (pads) {
+    assert(pads->ints_size() >= 2);
+    cdata.padding_before[0] = pads->ints(0);
+    cdata.padding_before[1] = pads->ints(1);
+    // TODO: ONNX padding could be for the beginning and ending along each axis that's why we
+    // should select the interesting ones.
+    if (pads->ints_size() == 4) {
+      cdata.padding_after[0] = pads->ints(2);
+      cdata.padding_after[1] = pads->ints(3);
+    }
+  }
+};
+
 std::vector<Operation*> ONNXOpCreator::convertConv2D(InputOps& inputs,
                                                     const onnx::NodeProto& onnx_node) {
   assert(inputs.size() >= 2);
 
-  auto* strides = findAttribute(onnx_node, "strides");
-  assert(strides && strides->ints_size());
-  Shape onnx_strides_shape = ShapeHelper::createShape(strides->ints(), strides->ints_size());
-  // FIXME: it's a hack
-  Shape strides_shape = {onnx_strides_shape.dim(0), onnx_strides_shape.dim(1), 1};
+  KernelStridesPadding cdata;
+  getKernelStridesPadding(onnx_node, cdata);
+  // FIXME: It can be non-constant value.
   auto* in_weights = dynamic_cast<mir::ops::ConstantOp*>(inputs[1]);
-  assert(in_weights);
+  assert(in_weights && "Weights could be a constant tensor only");
   auto in_weights_tensor = in_weights->getValue();
-  // TODO: we don't support padding at the moment
-  std::vector<int32_t> padding_before{0, 0};
-  std::vector<int32_t> padding_after{0, 0};
-  Operation* input_bias;
-  if (inputs.size() > 2)
-    input_bias = inputs[2];
+  // We should transpose ONNX MCHW to HWOI
+  auto transposed = transposeTensor<2, 3, 1, 0>(in_weights_tensor);
+
+  mir::ops::ConstantOp* input_bias = nullptr;
+  if (inputs.size() > 2) {
+    input_bias = dynamic_cast<mir::ops::ConstantOp*>(inputs[2]);
+    assert(input_bias && "1D optional bias could be a constant tensor only");
+  }
 
   inputs.resize(1);
   std::vector<Operation*> outputs;
-  outputs = createOp<ops::Conv2DOp>(inputs[0]->getOutput(0), in_weights_tensor, strides_shape,
-                                    padding_before, padding_after);
-  // TODO: there could be bias tensor as inputs[2].
+  outputs = createOp<ops::Conv2DOp>(inputs[0]->getOutput(0), *transposed, cdata.strides_shape,
+                                    cdata.padding_before, cdata.padding_after);
+  if (input_bias)
+    outputs = createOp<ops::BiasAddOp>(outputs[0]->getOutput(0), *input_bias->getValue());
+
   return outputs;
 }
 
@@ -136,8 +170,7 @@ std::vector<Operation*> ONNXOpCreator::convertPool(InputOps& inputs, ONNXOpCode
   ops::PoolOp::PoolingType pool_type;
 
   std::vector<Operation*> result;
-  std::vector<int32_t> padding_before{0, 0};
-  std::vector<int32_t> padding_after{0, 0};
+  KernelStridesPadding cdata;
 
   switch (op_code) {
     case ONNXOpCode::opGlobalAveragePool:
@@ -147,7 +180,7 @@ std::vector<Operation*> ONNXOpCreator::convertPool(InputOps& inputs, ONNXOpCode
                                    ops::PoolOp::PoolingType::AVG,
                                    inputs[0]->getOutputShape(0), // kernel_shape
                                    Shape({1, 1}),                // strides_shape
-                                   padding_before, padding_after,// no padding
+                                   cdata.padding_before, cdata.padding_after,
                                    ops::PoolOp::BorderType::ZEROFILLED,
                                    ops::PoolOp::RoundMode::floor);
     case ONNXOpCode::opAveragePool:
@@ -162,29 +195,12 @@ std::vector<Operation*> ONNXOpCreator::convertPool(InputOps& inputs, ONNXOpCode
       assert(false);
   }
   // Proceed with Average or Max Pool
-  auto* kshape = findAttribute(onnx_node, "kernel_shape");
-  assert(kshape && kshape->ints_size());
-  auto* strides = findAttribute(onnx_node, "strides");
-  assert(strides && strides->ints_size());
-  auto* pads = findAttribute(onnx_node, "pads");
-
-  Shape kernel_shape = ShapeHelper::createShape(kshape->ints(), kshape->ints_size());
-  Shape strides_shape = ShapeHelper::createShape(strides->ints(), strides->ints_size());
-
-  if (pads) {
-    assert(pads->ints_size() >= 2);
-    padding_before[0] = pads->ints(0);
-    padding_before[1] = pads->ints(1);
-    // TODO: ONNX padding could be for the beginning and ending along each axis that's why we
-    // should select the interesting ones.
-    if (pads->ints_size() == 4) {
-      padding_after[0] = pads->ints(2);
-      padding_after[1] = pads->ints(3);
-    }
+  getKernelStridesPadding(onnx_node, cdata);
 
-  }
-  result = createOp<ops::PoolOp>(inputs[0]->getOutput(0), pool_type, kernel_shape, strides_shape,
-                                 padding_before, padding_after, border_type,
+  result = createOp<ops::PoolOp>(inputs[0]->getOutput(0), pool_type,
+                                 cdata.kernel_shape, cdata.strides_shape,
+                                 cdata.padding_before, cdata.padding_after,
+                                 border_type,
                                  ops::PoolOp::RoundMode::floor);
   return result;
 }
@@ -252,7 +268,7 @@ std::vector<Operation*> ONNXOpCreator::convertBatchNorm(InputOps& inputs,
 }
 
 std::vector<Operation*> ONNXOpCreator::convertDropout(InputOps& inputs,
-                                                     const onnx::NodeProto& onnx_node) {
+                                                      const onnx::NodeProto& onnx_node) {
   bool found;
   float value;
   std::tie(found, value) = getFloatAttribute(onnx_node, "ratio");
@@ -271,8 +287,32 @@ std::vector<Operation*> ONNXOpCreator::convertScale(InputOps& inputs,
 }
 
 std::vector<Operation*> ONNXOpCreator::convertGemm(InputOps& inputs,
-                                                  const onnx::NodeProto& onnx_node) {
-  // TODO: NIY
-  return inputs;
+                                                   const onnx::NodeProto& onnx_node) {
+  bool  found;
+  int   ivalue;
+  float fvalue;
+
+  std::tie (found, ivalue) = getIntAttribute(onnx_node, "transA");
+  bool transA = found ? ivalue : 0;
+  std::tie (found, ivalue) = getIntAttribute(onnx_node, "transB");
+  bool transB = found ? ivalue : 0;
+  std::tie (found, fvalue) = getIntAttribute(onnx_node, "alpha");
+  float alpha = found ? fvalue : 1.0;
+  std::tie (found, fvalue) = getIntAttribute(onnx_node, "beta");
+  float beta = found ? fvalue : 1.0;
+
+  // Flatten the shape by dim(0)
+  mir::Shape shape0 ({inputs[0]->getOutputShape(0).dim(0),
+                      inputs[0]->getOutputShape(0).numElements() /
+                                                        inputs[0]->getOutputShape(0).dim(0)});
+  auto reshape = createOp<ops::ReshapeOp>(inputs[0]->getOutput(0), shape0);
+
+  std::vector<IODescriptor> descriptors;
+  descriptors.push_back(reshape[0]->getOutput(0));
+  descriptors.push_back(inputs[1]->getOutput(0));
+  descriptors.push_back(inputs[2]->getOutput(0));
+
+  return createOp<ops::GemmOp>(reshape[0]->getOutput(0), inputs[1]->getOutput(0),
+                               inputs[2]->getOutput(0), transA, transB, alpha, beta);
 }
 } // namespace nnc
index 3fe75d3..ebb0b2c 100644 (file)
@@ -36,6 +36,7 @@
 #include "core/modelIR/operations/ElementwiseOp.h"
 #include "core/modelIR/operations/EluOp.h"
 #include "core/modelIR/operations/FullyConnectedOp.h"
+#include "core/modelIR/operations/GemmOp.h"
 #include "core/modelIR/operations/GatherOp.h"
 #include "core/modelIR/operations/PadOp.h"
 #include "core/modelIR/operations/PoolOp.h"
@@ -207,6 +208,10 @@ void ModelAnalyzer::visit(ops::FullyConnectedOp& op) {
   addOpDescr(&op, "fullConnect");
 }
 
+void ModelAnalyzer::visit(ops::GemmOp& op) {
+  addOpDescr(&op, "gemm");
+}
+
 void ModelAnalyzer::visit(ops::CappedReluOp& op) {
   addOpDescr(&op, "cappedRelu");
 }
index a9ce551..36cb487 100644 (file)
@@ -102,6 +102,7 @@ public:
   void visit(mir::ops::EluOp& op) override;
   void visit(mir::ops::FullyConnectedOp& op) override;
   void visit(mir::ops::GatherOp& op) override;
+  void visit(mir::ops::GemmOp& op) override;
   void visit(mir::ops::PadOp& op) override;
   void visit(mir::ops::PoolOp& op) override;
   void visit(mir::ops::ReduceFOp& op) override;
index c2241ff..29bbe5a 100644 (file)
@@ -28,6 +28,7 @@
 #include "core/modelIR/operations/SoftmaxOp.h"
 #include "core/modelIR/operations/PoolOp.h"
 #include "core/modelIR/operations/FullyConnectedOp.h"
+#include "core/modelIR/operations/GemmOp.h"
 #include "core/modelIR/operations/CappedReluOp.h"
 #include "core/modelIR/operations/BiasAddOp.h"
 #include "core/modelIR/operations/ReluOp.h"
@@ -229,6 +230,10 @@ void Serializer::visit(ops::FullyConnectedOp& op) {
   serializeShape(op.getOutputShape(0));
 }
 
+void Serializer::visit(ops::GemmOp& op) {
+  _curOp->_paramStartOffset = _buffer.size();
+}
+
 void Serializer::visit(ops::CappedReluOp& op) {
   _curOp->_paramStartOffset = _buffer.size();
   serializeT<float>(op.getCap());
@@ -245,7 +250,7 @@ void Serializer::visit(ops::VariableOp& op) {
 
 void Serializer::visit(ops::ConstantOp& op) {
   _curOp->_paramStartOffset = _buffer.size();
-  serializeTensor(op.getValue());
+  serializeTensor(*op.getValue());
 }
 
 void Serializer::visit(ops::ReluOp& op) {
index 748db4e..2e68ebf 100644 (file)
@@ -54,6 +54,7 @@ public:
   void visit(mir::ops::EluOp& op) override;
   void visit(mir::ops::FullyConnectedOp& op) override;
   void visit(mir::ops::GatherOp& op) override;
+  void visit(mir::ops::GemmOp& op) override;
   void visit(mir::ops::PadOp& op) override;
   void visit(mir::ops::PoolOp& op) override;
   void visit(mir::ops::ReduceFOp& op) override;
index a23ac4e..1bf3a58 100644 (file)
@@ -19,6 +19,7 @@
 
 #include "core/modelIR/operations/VariableOp.h"
 #include "core/modelIR/operations/FullyConnectedOp.h"
+#include "core/modelIR/operations/GemmOp.h"
 #include "core/modelIR/operations/Conv2DOp.h"
 #include "core/modelIR/operations/DepthwiseConv2DOp.h"
 #include "core/modelIR/operations/PoolOp.h"