[nnc] The first ONNX model resnet50 works on NNC interpreter (#2718)
authorАндрей Тищенко/AI Tools Lab /SRR/Staff Engineer/삼성전자 <a.tischenko@partner.samsung.com>
Fri, 21 Dec 2018 10:45:26 +0000 (13:45 +0300)
committerРоман Михайлович Русяев/AI Tools Lab /SRR/Staff Engineer/삼성전자 <r.rusyaev@samsung.com>
Fri, 21 Dec 2018 10:45:26 +0000 (13:45 +0300)
Several operators were fixed: BatchNormalization, Reshape, Gemm and Pooling. Now NNC is available to convert the ONNX resnt50 network, play it back in interpreter and to produce the out which is totally comparable with reference data.

Signed-off-by: Andrew V. Tischenko a.tischenko@partner.samsung.com
14 files changed:
contrib/nnc/core/modelIR/operations/GemmOp.cpp
contrib/nnc/core/modelIR/operations/PoolOp.cpp
contrib/nnc/include/core/modelIR/Graph.h
contrib/nnc/include/core/modelIR/operations/GemmOp.h
contrib/nnc/include/passes/interpreter/Interpreter.h
contrib/nnc/passes/interpreter/Interpreter.cpp
contrib/nnc/passes/interpreter/interpreter_pass.cpp
contrib/nnc/passes/interpreter/ops/Gemm.h
contrib/nnc/passes/onnx_frontend/ONNXImporterImpl.cpp
contrib/nnc/passes/onnx_frontend/ONNXImporterImpl.h
contrib/nnc/passes/onnx_frontend/ONNXOpCreator.cpp
contrib/nnc/passes/onnx_frontend/ONNXOpCreator.h
contrib/nnc/passes/soft_backend/ModelAnalyzer.cpp
contrib/nnc/passes/soft_backend/code_snippets/cpp_gemm.def [new file with mode: 0644]

index d0dd275..d629f45 100644 (file)
@@ -21,51 +21,18 @@ namespace mir {
 namespace ops {
 
 void GemmOp::inferOutputShapes() {
-//Input tensor A: The shape of A should be (M, K) if transA is 0, or (K, M) if transA is non-zero.
-//Input tensor B: The shape of B should be (K, N) if transB is 0, or (N, K) if transB is non-zero.
-//Input tensor C: The shape of C should be unidirectional broadcastable to (M, N).
-//Output tensor Y: shape (M, N) or vector(N) if M == 1.
-  std::vector<int32_t > vector(2); // this vector will be used to create shapes of tensor A, B
+  auto shape_a = getInputShape(0);
+  auto shape_b = getInputShape(1);
+  assert((shape_a.rank() == shape_b.rank()) && (shape_a.rank() == 2));
+  assert(shape_a.dim(1) == shape_b.dim(0) && "Multiplicable");
 
-  // Flatten the shape by dim(0)
-  mir::Shape shape0 = {getInputShape(0).dim(0),
-                       getInputShape(0).numElements() / getInputShape(0).dim(0)};
-  assert(shape0.rank() == 2);
-  Shape shape_a(2);
-  if (_transA) {
-    shape_a.dim(0) = shape0.dim(shape0.rank() - 1);
-    shape_a.dim(1) = shape0.dim(shape0.rank() - 2);
-  } else {
-    shape_a.dim(0) = shape0.dim(shape0.rank() - 2);
-    shape_a.dim(1) = shape0.dim(shape0.rank() - 1);
-  }
-
-  auto& shape1 = getInputShape(1);
-  // It must be a matrice
-  assert(shape1.rank() == 2);
-  Shape shape_b(2);
-
-  if (_transB) {
-    shape_b.dim(0) = shape1.dim(shape1.rank() - 1);
-    shape_b.dim(1) = shape1.dim(shape1.rank() - 2);
-  } else {
-    shape_b.dim(0) = shape1.dim(shape1.rank() - 2);
-    shape_b.dim(1) = shape1.dim(shape1.rank() - 1);
-  }
-
-  // Number of cols in tensor A must be equal to number of rows in tensor B
-  assert(shape_a.dim(1) == shape_b.dim(0));
   Shape mult_a_b({shape_a.dim(0), shape_b.dim(1)});
 
-  Shape shape_c = getInputShape(2);
+  auto shape_c = getInputShape(2);
+  assert((mult_a_b == shape_c) ||
+         (((shape_c.rank() == 1)) && (mult_a_b.dim(0) == 1) &&
+          (mult_a_b.dim(1) == shape_c.dim(0))));
 
-  if (shape_c.rank() == 1){
-    assert(mult_a_b.dim(0) == 1);
-    assert(mult_a_b.dim(1) == shape_c.dim(0));
-  } else {
-    assert(shape_c.rank() == 2);
-    assert((mult_a_b.dim(0) == shape_c.dim(0)) && (mult_a_b.dim(1) == shape_c.dim(1)));
-  }
   setOutputShape(0, mult_a_b);
 }
 
index 742019b..318dd47 100644 (file)
@@ -61,6 +61,10 @@ void PoolOp::inferOutputShapes() {
       assert(false);
   }
 
+  for (int i = 0; i < output_shape.rank(); i++) {
+    assert(output_shape.dim(i) >= 0);
+  }
+
   setOutputShape(0, output_shape);
 }
 
index 3456c19..58b8ea9 100644 (file)
@@ -89,10 +89,6 @@ class Graph {
    */
   ops::VariableOp* replaceWithInputNode(const Operation* op);
 
-  void setConstants(std::set<Operation*> consts) {
-    _constants = consts;
-  }
-
   /**
    * @brief Change graph inputs to nodes with names in newInputs
    * @param new_inputs names of nodes to be made into input nodes
@@ -123,7 +119,6 @@ class Graph {
     _ops.push_back(op);
   }
 
-
   void registerOp(ops::ConstantOp* op) {
     _constants.insert(op);
     _ops.push_back(op);
index 4c10af9..bb90c3c 100644 (file)
@@ -26,28 +26,13 @@ namespace ops {
 
 class GemmOp : public Operation {
 public:
-    GemmOp(const IODescriptor a, const IODescriptor b, const IODescriptor c,
-           bool transA, bool transB, float alpha, float beta) :
-          Operation(Type::gemmOp, {a, b, c}),
-          _a(a), _b(b), _c(c), _transA(transA),_transB(transB), _alpha(alpha), _beta(beta) {
+  GemmOp(IODescriptor arg, const IODescriptor b, const IODescriptor c) :
+            Operation(Type::gemmOp, {arg, b, c}) {
     inferOutputShapes();
   }
 
-  bool  getTransA() {return _transA;}
-  bool  getTransB() {return _transB;}
-  float getAlpha()  {return _alpha;}
-  float getBeta()   {return _beta;}
-
 private:
   void inferOutputShapes();
-
-  const IODescriptor _a;
-  const IODescriptor _b;
-  const IODescriptor _c;
-  bool  _transA;
-  bool  _transB;
-  float _alpha;
-  float _beta;
 };
 } // namespace ops
 } // namespace mir
index 267ee28..e6f10d6 100644 (file)
@@ -21,6 +21,7 @@
 #include <map>
 #include <string>
 #include <unordered_map>
+#include <set>
 
 #include "core/modelIR/Visitor.h"
 #include "core/modelIR/Operation.h"
@@ -44,12 +45,12 @@ public:
   void visit(ops::Conv2DOp& op) override;
   void visit(ops::DeConv2DOp& op) override;
   void visit(ops::DepthwiseConv2DOp& op) override;
+  void visit(ops::GemmOp& op) override;
   void visit(ops::DropoutOp& op) override;
   void visit(ops::ElementwiseOp& op) override;
   void visit(ops::EluOp& op) override;
   void visit(ops::FullyConnectedOp& op) override;
   void visit(ops::GatherOp& op) override;
-  void visit(ops::GemmOp& op) override;
   void visit(ops::PadOp& op) override;
   void visit(ops::PoolOp& op) override;
   void visit(ops::ReduceFOp& op) override;
@@ -68,6 +69,7 @@ public:
 
   void setInput(const std::string &name, const TensorVariant& data);
   std::vector<TensorVariant> &getResult(Operation* op);
+  void dump(Operation& op, bool all = false);
 
   ~NNInterpreter() override = default;
 
index 6f5da6e..8e195d4 100644 (file)
 #include "ops/Pad.h"
 #include "ops/common.h"
 
-#include <vector>
 #include <cmath>
 #include <cassert>
+#include <fenv.h>
+#include <iostream>
+#include <vector>
 
 namespace nnc {
 
@@ -76,32 +78,79 @@ using namespace nnc::mir;
 
 std::vector<TensorVariant> &NNInterpreter::var(size_t id) { return vars[id]; }
 
-void NNInterpreter::setInput(const std::string &name, const TensorVariant& t) { data.emplace(name, t); }
+static void dumpIndex (Index ndx) {
+  for (int i = 0; i < ndx.rank(); i++) {
+    std::cout << (i ? "," : "(") << ndx.at(i);
+  }
+  std::cout << ")\t";
+}
+
+#if(0)
+  #define DUMP(x, y) dump(x, (y))
+#else
+  #define DUMP(x, y)
+#endif
+
+void NNInterpreter::dump(Operation& op, bool all) {
+  // TODO: in theory there could be several outputs from the given 'op'.
+  TensorVariant tensor = var(op.getId())[0];
+  std::cout << "Tensor '" << op.getName() << "' DType = " << (int)tensor.getDataType()  << ", ElementSize = " << tensor.getElementSize()
+           << ", Shape = {";
+  auto shape = tensor.getShape();
+  for (int i = 0; i < shape.rank(); i++) {
+    std::cout << shape.dim(i) << (i == shape.rank() - 1 ? "} " : ", ");
+  }
+  std::cout << "ElementsNumber " << shape.numElements() << "\n";
+  static bool do_it = false;
+  if (do_it || all) {
+    auto last_idx = shape.rank() - 1;
+    for (auto idx : ShapeRange(shape)) {
+      if (!idx.at(last_idx))
+        std::cout << "\n";
+      dumpIndex(idx);
+      if (tensor.getDataType() == DTYPE::FLOAT32)
+        std::cout << *(float_t*)tensor.at(idx) << "\t";
+      else
+        std::cout << *(int32_t*)tensor.at(idx) << "\t";
+    }
+    std::cout << "\n";
+  }
+}
+
+void NNInterpreter::setInput(const std::string &name, const TensorVariant& t) {
+// TODO: our tests are failed with fe enable exception
+//  feenableexcept(FE_INVALID | FE_OVERFLOW);
+//  |
+//                 FE_DIVBYZERO |
+//                 FE_OVERFLOW  |
+//                 FE_UNDERFLOW);
+//  feenableexcept(FE_ALL_EXCEPT);
+
+  data.emplace(name, t);
+}
 
 void NNInterpreter::visit(ops::VariableOp& op) {
   (void)op;
   auto it = data.find(op.getName());
   if( it == data.end() )
   {
-    throw std::runtime_error("Can't find data for node \"" + op.getName() + ". Input data was not set correctly?");
+    throw std::runtime_error("Can't find data for node \"" + op.getName() +
+                             ". Input data was not set correctly?");
   }
   var(op.getId()) = {it->second};
 }
 
 void NNInterpreter::visit(ops::ConstantOp& op) {
+  assert(data.find(op.getName()) == data.end());
   var(op.getId()) = {op.getValue()};
 }
 
 std::vector<TensorVariant> &NNInterpreter::getResult(Operation* op) {
   auto res = vars.find(op->getId());
   if (res != vars.end())
-  {
     return res->second;
-  }
   else
-  {
-    throw std::runtime_error("No such value");
-  }
+    throw std::runtime_error("No such value: " + std::to_string(op->getId()));
 }
 
 void NNInterpreter::visit(ops::ConcatOp& op) {
@@ -112,6 +161,7 @@ void NNInterpreter::visit(ops::ConcatOp& op) {
     ins.push_back(var(in.op->getId())[in.index]);
   }
   var(op.getId()) = Concat<float>(ins, op.getOutputShape(0), op.getAxis())();
+  DUMP(op, false);
 }
 
 void NNInterpreter::visit(ops::Conv2DOp& op) {
@@ -123,6 +173,7 @@ void NNInterpreter::visit(ops::ReshapeOp& op) {
   auto operand = op.getPrevNodes()[0];
   auto input = var(operand.op->getId())[operand.index];
   var(op.getId()) = Reshape<float>(input, op.getOutputShape(0))();
+  DUMP(op, false);
 }
 
 void NNInterpreter::visit(ops::ReluOp& op) {
@@ -130,6 +181,7 @@ void NNInterpreter::visit(ops::ReluOp& op) {
   Tensor<float> input(var(operand.op->getId())[operand.index]);
   var(op.getId()) = Fill<float>(
       op.getOutputShape(0), [&input](const Index &id) { return std::max(input.at(id), 0.0f); })();
+  DUMP(op, false);
 }
 
 void NNInterpreter::visit(ops::SigmoidOp& op) {
@@ -144,12 +196,14 @@ void NNInterpreter::visit(ops::SoftmaxOp& op) {
   auto operand = op.getPrevNodes()[0];
   auto input = var(operand.op->getId())[operand.index];
   var(op.getId()) = Softmax(op.getInputShape(0), input, op.getAxis())();
+  DUMP(op, false);
 }
 
 void NNInterpreter::visit(ops::PoolOp& op) {
   auto operand = op.getPrevNodes()[0];
   auto input = var(operand.op->getId())[operand.index];
   var(op.getId()) = Pool(input, op)();
+  DUMP(op, false);
 }
 
 void NNInterpreter::visit(ops::FullyConnectedOp& op) {
@@ -159,9 +213,13 @@ void NNInterpreter::visit(ops::FullyConnectedOp& op) {
 }
 
 void NNInterpreter::visit(ops::GemmOp& op) {
-  auto operand = op.getPrevNodes()[0];
-  TensorVariant input = var(operand.op->getId())[operand.index];
-  var(op.getId()) = Gemm<float>(input, op)();
+  auto operand_a = op.getPrevNodes()[0];
+  auto operand_b = op.getPrevNodes()[1];
+  auto operand_c = op.getPrevNodes()[2];
+  const TensorVariant input_a = var(operand_a.op->getId())[operand_a.index];
+  const TensorVariant input_b = var(operand_b.op->getId())[operand_b.index];
+  const TensorVariant input_c = var(operand_c.op->getId())[operand_c.index];
+  var(op.getId()) = Gemm<float>(input_a, input_b, input_c, op)();
 }
 
 void NNInterpreter::visit(ops::CappedReluOp& op) {
@@ -182,6 +240,7 @@ void NNInterpreter::visit(ops::BiasAddOp& op) {
   auto operand = op.getPrevNodes()[0];
   auto input = var(operand.op->getId())[operand.index];
   var(op.getId()) = BiasAdd(input, op.getWeights(), op.getOutputShape(0))();
+  DUMP(op, false);
 }
 
 void NNInterpreter::visit(ops::BatchNormOp& op) {
@@ -189,13 +248,15 @@ void NNInterpreter::visit(ops::BatchNormOp& op) {
   TensorVariant input(var(operand.op->getId())[operand.index]);
   // TODO implement this
     var(op.getId()) = BatchNorm<float>(input, op)();
+  DUMP(op, false);
 }
 
 void NNInterpreter::visit(ops::ScaleOp& op) {
   auto operand = op.getPrevNodes()[0];
   TensorVariant input(var(operand.op->getId())[operand.index]);
   // TODO implement this
-   var(op.getId()) = Scale(input, op)();
+  var(op.getId()) = Scale(input, op)();
+  DUMP(op, false);
 }
 
 
@@ -213,6 +274,7 @@ void NNInterpreter::visit(ops::DropoutOp& op) {
   TensorVariant input(var(operand.op->getId())[operand.index]);
   // TODO implement this
    var(op.getId()) = Dropout<float>(input, op)();
+  DUMP(op, false);
 }
 
 void NNInterpreter::visit(ops::TanhOp& op) {
@@ -267,11 +329,13 @@ void NNInterpreter::visit(ops::ElementwiseOp& op) {
         acc = func(acc, ins[i].at(id));
       return acc;
     })();
+  DUMP(op, false);
 }
 
 void NNInterpreter::visit(ops::DeConv2DOp& op) {
   auto operand = op.getPrevNodes()[0];
   var(op.getId()) = DeConv2D(var(operand.op->getId())[operand.index], op)();
+  DUMP(op, false);
 }
 
 void NNInterpreter::visit(ops::EluOp& op) {
@@ -283,6 +347,7 @@ void NNInterpreter::visit(ops::EluOp& op) {
     else
       return op.getAlpha()*(expf(input.at(id))-1);
   })();
+  DUMP(op, false);
 }
 
 void NNInterpreter::visit(ops::SqueezeOp& op) {
@@ -290,12 +355,14 @@ void NNInterpreter::visit(ops::SqueezeOp& op) {
   auto& input = var(operand.op->getId())[operand.index];
   //Squeeze is just a special case of reshape
   var(op.getId()) = Reshape<float>(input, op.getOutputShape(0))();
+  DUMP(op, false);
 }
 
 void NNInterpreter::visit(ops::PadOp& op) {
   auto operand = op.getPrevNodes()[0];
   auto& input = var(operand.op->getId())[operand.index];
   var(op.getId()) = Pad(input, op)();
+  DUMP(op, false);
 }
 
 void NNInterpreter::visit(ops::SqrtOp& op) {
@@ -325,6 +392,7 @@ void NNInterpreter::visit(ops::ResizeOp& op) {
     default:
       assert(false && "Not supported Optype");
   }
+  DUMP(op, false);
 
 }
 
@@ -352,12 +420,14 @@ void NNInterpreter::visit(ops::ReduceFOp& op) {
     default:
       assert(false && "Not Implemented");
   }
+  DUMP(op, false);
 }
 
 void NNInterpreter::visit(ops::TransposeOp& op) {
   auto operand = op.getPrevNodes()[0];
   auto& input = var(operand.op->getId())[operand.index];
   var(op.getId()) = Transpose(input, op)();
+  DUMP(op, false);
 }
 
 void NNInterpreter::visit(ops::GatherOp& op) {
index 58365cd..792e54c 100644 (file)
@@ -105,7 +105,7 @@ PassData InterpreterPass::run(PassData data) {
 
   // Check nodes
   const auto& outputs = g->collectOutputs();
-
+  
   for (auto& out : outputs) {
     auto outputNode = interpreter.getResult(out);
     if (outputNode.empty()) {
@@ -126,7 +126,7 @@ PassData InterpreterPass::run(PassData data) {
 #else
     std::cout << "Result <" << out_node->getName()
               << "> wasn't saved, due to lack of HDF5" << std::endl;
-    
+
 #endif  // NNC_HDF5_SUPPORTED
     if (is_several_outs)
       delete out_data;
index e130626..0d94813 100644 (file)
 #ifndef _NNC_CORE_BACKEND_INTERPRETER_GEMM_
 #define _NNC_CORE_BACKEND_INTERPRETER_GEMM_
 
-#include "core/modelIR/ShapeRange.h"
 #include "core/modelIR/operations/GemmOp.h"
+#include "core/modelIR/ShapeRange.h"
+#include "core/modelIR/TensorVariant.h"
 #include "OperationImpl.h"
 
-namespace nnc
-{
+namespace nnc {
 template<typename T>
-class Gemm : public OperationImpl<T>
-{
+class Gemm : public OperationImpl<T> {
 public:
-    Gemm(const mir::TensorVariant &_input, const mir::ops::GemmOp &_op) :
-                                                            _op(_op), _input(_input) {}
+    Gemm(const mir::TensorVariant& a, const mir::TensorVariant& b, const mir::TensorVariant& c,
+         mir::ops::GemmOp& op) : _op(op), _tensor_a(a), _tensor_b(b), _tensor_c(c) {}
 
   std::vector<mir::TensorVariant> operator()() override {
     mir::TensorVariant res = OperationImpl<T>::allocate_tensor(_op.getOutputShape(0));
+    mir::Tensor<T> accessor(res);
+    mir::ShapeRange out_range(res.getShape());
+
+//    mir::Tensor<T> tensor_b(_b);
+    auto b_shape = _tensor_b.getShape();
+    int32_t b_rank = b_shape.rank();
+
+    auto& in_shape = _tensor_a.getShape();
+    int32_t in_rank = in_shape.rank();
+    assert(in_shape.dim(in_rank - 1) == b_shape.dim(b_rank - 2));
+    (void)in_rank;
+
+    // First, we have to multply _input(which is alpha*tensorA) and tensor_b
+    auto len = b_shape.dim(b_rank - 2);
+    int32_t row;
+    int32_t col;
+    for (auto &out_idx : out_range) {
+      mir::Index t_idx = out_idx;
+      T& output_element = accessor.at(out_idx);
+      col = t_idx.at(-1);
+      row = t_idx.at(-2);
+      for (int32_t i = 0; i < len; ++i) {
+        t_idx.at(-1) = i;
+        T& in = _tensor_a.at(t_idx);
+        t_idx.at(-1) = col;
+        t_idx.at(-2) = i;
+        T& w = _tensor_b.at(t_idx);
+        t_idx.at(-2) = row;
+        output_element += w * in;
+      }
+    }
+
+    // Now we have to add result of multiplication and (beta*tensor_c)
+    // We'd like to broadcast Tensor C to the output shape
+    assert(_op.getOutputShape(0).rank() == 2);
+    assert((_op.getOutputShape(0).rank() == _op.getInputShape(2).rank()) ||
+           ((_op.getInputShape(2).rank() == 1) && (_op.getOutputShape(0).dim(0) == 1)));
+
+    auto t = mir::TensorVariant (_tensor_c, _op.getOutputShape(0));
+    mir::Tensor<T> tensor_c(t);
+    for (auto idx : mir::ShapeRange(_op.getOutputShape(0))) {
+      accessor.at(idx) += tensor_c.at(idx);
+    }
     return {res};
   }
 
 private:
-  const mir::ops::GemmOp& _op;
-  const mir::Tensor<T> _input;
+  mir::ops::GemmOp& _op;
+  mir::Tensor<T> _tensor_a;
+  mir::Tensor<T> _tensor_b;
+  const mir::TensorVariant _tensor_c;
 };
 } // namespace nnc
 
index d3e9e64..d9643c4 100644 (file)
 
 #include "core/modelIR/IrDotDumper.h"
 #include "core/modelIR/operations/ConstantOp.h"
-#include "core/modelIR/Operation.h"
-#include "core/modelIR/Shape.h"
-#include "core/modelIR/TensorVariant.h"
 #include "core/modelIR/operations/Conv2DOp.h"
 #include "core/modelIR/operations/ElementwiseOp.h"
+#include "core/modelIR/operations/TransposeOp.h"
 #include "core/modelIR/operations/VariableOp.h"
+#include "core/modelIR/Operation.h"
+#include "core/modelIR/Shape.h"
+#include "core/modelIR/TensorUtil.h"
+#include "core/modelIR/TensorVariant.h"
 #include "onnx/onnx_pb.h"
 #include "onnx/proto_utils.h"
 #include "passes/common_frontend/model_allocation.h"
@@ -35,6 +37,8 @@
 
 #include "ONNXImporterImpl.h"
 #include "ONNXPerfectHash.h"
+#include "ONNXOpCreator.h"
+
 
 namespace nnc {
 
@@ -85,6 +89,7 @@ static mir::TensorVariant createTensor(const onnx::TensorProto* tensor) {
   size_t element_size;
   size_t buffer_size;
   const char* src_data;
+  auto shape = ShapeHelper::createShape(tensor->dims(), static_cast<size_t>(tensor->dims_size()));
 
   if (tensor->float_data_size() != 0) {
     element_size = sizeof(float);
@@ -100,12 +105,20 @@ static mir::TensorVariant createTensor(const onnx::TensorProto* tensor) {
     element_size = sizeof(int32_t);
     buffer_size = tensor->int32_data_size() * element_size;
     src_data = reinterpret_cast<const char*>(tensor->int32_data().data());
-    throw PassException("WARNING: We don't support int32 tensors yet, investigate\n");
+    mir::DTYPE type = mir::DTYPE::INT32;
   } else if (tensor->int64_data_size() != 0) {
-    element_size = sizeof(int64_t);
+    // FIXME: we could lose the data here
+    type = mir::DTYPE::INT32;
+    element_size = sizeof(int32_t);
     buffer_size = tensor->int64_data_size() * element_size;
-    src_data = reinterpret_cast<const char*>(tensor->int64_data().data());
-    throw PassException("WARNING: We don't support int64 tensors yet, investigate\n");
+
+    auto src_data64 = reinterpret_cast<const int64_t *>(tensor->int64_data().data());
+    std::shared_ptr<char> shared_buffer (new char[buffer_size], std::default_delete<char[]>());
+    auto dst_data = reinterpret_cast<int32_t *>(shared_buffer.get());
+    for (int i = 0; i < tensor->int64_data_size(); i++) {
+      dst_data[i] = (int32_t)src_data64 [i];
+    }
+    return mir::TensorVariant(shape, shared_buffer, type, element_size);
   } else if (tensor->raw_data().size() != 0) {
     switch ((tensor->data_type())) {
       case onnx::TensorProto_DataType_FLOAT:
@@ -123,7 +136,6 @@ static mir::TensorVariant createTensor(const onnx::TensorProto* tensor) {
   std::shared_ptr<char> data(new char[buffer_size], std::default_delete<char[]>());
   memcpy(data.get(), src_data, buffer_size);
 
-  auto shape = ShapeHelper::createShape(tensor->dims(), static_cast<size_t>(tensor->dims_size()));
   return mir::TensorVariant(shape, data, type, element_size);
 }
 
@@ -131,11 +143,7 @@ void ONNXImporterImpl::createGraphInputs() {
   auto& graph = _model->graph();
   auto& initializer = graph.initializer();
   auto& value_info = graph.value_info();
-  auto init_size = graph.initializer_size();
-  auto val_size = graph.value_info_size();
-  auto inp_size = graph.input_size();
   std::map<std::string, const onnx::TensorProto*> onnx_tensors;
-  std::set<mir::Operation*> constants;
 
   // Collect all initializers of the given graph
   for (int i = 0; i < graph.initializer_size(); i++) {
@@ -153,10 +161,10 @@ void ONNXImporterImpl::createGraphInputs() {
       _inputTensors.insert(std::make_pair(name, createTensor(onnx_tensor)));
       auto constant = _graph->create<mir::ops::ConstantOp>(name, _inputTensors.at(name));
       _tensorNameToPrevMirOp[name] = constant;
-      constants.insert(constant);
     } else {
-      // We're dealing with graph input
+      // We're dealing with graph input (assuming the picture only)
       auto onnx_input_shape = input.type().tensor_type().shape();
+      assert(onnx_input_shape.dim_size() == 4);
       mir::Shape shape(4);
       for (int i = 0; i < onnx_input_shape.dim_size(); i++) {
         assert(onnx_input_shape.dim(i).has_dim_value());
@@ -167,31 +175,28 @@ void ONNXImporterImpl::createGraphInputs() {
       _tensorNameToPrevMirOp[name] = node;
     }
   }
-  if (!constants.empty())
-    _graph->setConstants(constants);
 }
 
-static void dumpShape(const mir::Shape& shape) {
-  std::cout << "{";
-  for (int i = 0; i < shape.rank(); i++) {
-    std::cout << shape.dim(i) << (i == shape.rank() - 1 ? "} " : ", ");
-  }
-}
-
-void ONNXImporterImpl::dump(const std::vector<mir::Operation*>& ops, const onnx::NodeProto& onnx_node) {
+void ONNXImporterImpl::dump(const std::vector<mir::Operation*>& inputs,
+                            const std::vector<mir::Operation*>& ops,
+                            const onnx::NodeProto& onnx_node) {
   for (auto op : ops) {
-    std::cout << onnx_node.op_type() << " '" << op->getName() << "' Input Shapes: ";
-    for (int i = 0; i < op->getNumInputs() ; i++) {
-      dumpShape(op->getInputShape(i));
-    }
-    std::cout << " Output Shapes: ";
-    for (int i = 0; i < op->getNumOutputs() ; i++) {
-      dumpShape(op->getOutputShape(i));
+    std::cout << onnx_node.op_type() << " '" << op->getName() << "'";
+    if (inputs[0]->getNumInputs() > 0) {
+      std::cout << "Input Shape: ";
+      dumpShape(inputs[0]->getOutputShape(0));
     }
+    std::cout << " Output Shape: ";
+    dumpShape(op->getOutputShape(0));
     auto* onnx_op_type = ONNXPerfectHash::getONNXOpType(onnx_node.op_type().c_str(), onnx_node.op_type().size());
     switch (onnx_op_type->opCode) {
       case ONNXOpCode::opConv: {
         auto *conv = dynamic_cast<mir::ops::Conv2DOp *>(op);
+        if (conv == nullptr) {
+          assert(dynamic_cast<mir::ops::TransposeOp *>(op) != nullptr);
+          conv = dynamic_cast<mir::ops::Conv2DOp *>(op->getPrevNodes()[0].op);
+        }
+        assert(conv);
         std::cout << " Weights tensor shape ";
         dumpShape(conv->getKernel().getShape());
         std::cout << " Strides  ";
@@ -204,6 +209,11 @@ void ONNXImporterImpl::dump(const std::vector<mir::Operation*>& ops, const onnx:
       case ONNXOpCode::opAveragePool:
       case ONNXOpCode::opMaxPool: {
         auto *pool = dynamic_cast<mir::ops::PoolOp *>(op);
+        if (pool == nullptr) {
+          assert(dynamic_cast<mir::ops::TransposeOp *>(op) != nullptr);
+          pool = dynamic_cast<mir::ops::PoolOp *>(op->getPrevNodes()[0].op);
+        }
+        assert(pool);
         std::cout << " Kernel ";
         dumpShape(pool->getWindowShape());
         std::cout << " Strides  ";
@@ -250,9 +260,6 @@ mir::Graph *ONNXImporterImpl::createIR() {
     auto* onnx_op_type = ONNXPerfectHash::getONNXOpType(op_type, onnx_node.op_type().size());
 
     switch (onnx_op_type->opCode) {
-      //case ONNXOpCode::opIdentity:
-        // TOD: We simply remove the operation because it does nothing. Is it OK?
-      //  break;
       case ONNXOpCode::opConv:
         outputs = _opCreator.convertConv2D(input_nodes, onnx_node);
         break;
@@ -283,7 +290,7 @@ mir::Graph *ONNXImporterImpl::createIR() {
         outputs = _opCreator.convertConcat(input_nodes, onnx_node);
         break;
       case ONNXOpCode::opReshape:
-        outputs = _opCreator.convertReshape(input_nodes[0], input_nodes[1]->getOutputShape(0));
+        outputs = _opCreator.convertReshape(input_nodes);
         break;
       case ONNXOpCode::opRelu:
         outputs = _opCreator.convertRelu(input_nodes);
@@ -310,7 +317,7 @@ mir::Graph *ONNXImporterImpl::createIR() {
         throw PassException("Invalid ONNXOpCode" + std::to_string((int)onnx_op_type->opCode));
     }
     // Set outputs' names
-    for (int i = 0; i < outputs.size(); i++){
+    for (int i = 0; i < outputs.size(); i++) {
       outputs[i]->setName(onnx_node.output(i));
       auto result = _tensorNameToPrevMirOp.emplace(outputs[i]->getName(), outputs[i]);
       if(!result.second)
@@ -319,7 +326,7 @@ mir::Graph *ONNXImporterImpl::createIR() {
     assert (outputs.size());
     // FIXME: it should be done properly via the given graph outputs
     _graphOutputs.assign(outputs.begin(), outputs.end());
-    dump(outputs, onnx_node);
+    dump(input_nodes, outputs, onnx_node);
   }
   // set graph outputs
   // TODO: it should be done with onnx graph outputs
index 3f0cd3f..7fd5683 100644 (file)
@@ -39,9 +39,16 @@ public:
 
   void import() {};
   mir::Graph *createIR() override;
-  void dump(const std::vector<mir::Operation*>& op, const onnx::NodeProto& onnx_node);
+  void dump(const std::vector<mir::Operation*>& inputs, const std::vector<mir::Operation*>& op,
+            const onnx::NodeProto& onnx_node);
 
-private:
+  static void dumpShape(mir::Shape shape) {
+    std::cout << "{";
+    for (int i = 0; i < shape.rank(); i++) {
+      std::cout << shape.dim(i) << (i == shape.rank() - 1 ? "} " : ", ");
+    }
+  }
+  private:
   void createGraphInputs();
   // This map maps onnx tensor names to MIR operations/nodes
   std::map<std::string, mir::Operation*> _tensorNameToPrevMirOp;
index 987ff29..30750f0 100644 (file)
 #include "core/modelIR/operations/ScaleOp.h"
 #include "core/modelIR/operations/SigmoidOp.h"
 #include "core/modelIR/operations/SoftmaxOp.h"
+#include "core/modelIR/operations/TransposeOp.h"
 #include "core/modelIR/operations/VariableOp.h"
 #include "core/modelIR/operations/ElementwiseOp.h"
 #include "passes/common_frontend/shape_helper.h"
 #include "pass/PassException.h"
 #include "ONNXOpCreator.h"
+#include "ONNXImporterImpl.h"
 
 namespace nnc {
 
@@ -80,16 +82,17 @@ static std::pair<bool, float> getFloatAttribute(const onnx::NodeProto& onnx_node
   return {false, 0.0};
 }
 
-static TensorVariant createTensor(float value) {
+// Create vector tensor filled with the given value
+static TensorVariant createTensor(float value, const mir::Shape& shape) {
   mir::DTYPE element_type = mir::DTYPE::FLOAT32;
   size_t element_size = sizeof(value);
-  size_t buffer_size = 1 * element_size;
-  const char* src_data = reinterpret_cast<const char*>(&value);
 
-  std::shared_ptr<char> data(new char[buffer_size], std::default_delete<char[]>());
-  std::memcpy(data.get(), src_data, buffer_size);
-  Shape shape{1};
-  return mir::TensorVariant(shape, data, element_type, element_size);
+  float* dst_ptr = new float[shape.numElements()];
+  for (int i = 0; i < shape.numElements(); i++) {
+    dst_ptr[i] = value;
+  }
+  std::shared_ptr<char> data((char*)dst_ptr, std::default_delete<char[]>());
+  return mir::TensorVariant({shape.numElements()}, data, element_type, element_size);
 }
 
 struct KernelStridesPadding {
@@ -143,12 +146,14 @@ std::vector<Operation*> ONNXOpCreator::convertConv2D(InputOps& inputs,
 
   inputs.resize(1);
   std::vector<Operation*> outputs;
-  outputs = createOp<ops::Conv2DOp>(inputs[0]->getOutput(0), transposed, cdata.strides_shape,
+  // Transpose ONNX NCHW to MIR NHWC
+  auto t_input = convertONNXToMIR(inputs[0]->getOutput(0));
+  outputs = createOp<ops::Conv2DOp>(t_input[0]->getOutput(0), transposed, cdata.strides_shape,
                                     cdata.padding_before, cdata.padding_after);
   if (input_bias)
     outputs = createOp<ops::BiasAddOp>(outputs[0]->getOutput(0), input_bias->getValue());
 
-  return outputs;
+  return convertMIRToONNX(outputs[0]->getOutput(0));
 }
 
 std::vector<Operation*> ONNXOpCreator::convertConcat(InputOps& inputs,
@@ -180,18 +185,22 @@ std::vector<Operation*> ONNXOpCreator::convertPool(InputOps& inputs, ONNXOpCode
 
   std::vector<Operation*> result;
   KernelStridesPadding cdata;
+  // Transpose ONNX NCHW to MIR NHWC
+  auto t_input = convertONNXToMIR(inputs[0]->getOutput(0));
 
   switch (op_code) {
-    case ONNXOpCode::opGlobalAveragePool:
+    case ONNXOpCode::opGlobalAveragePool: {
       // GlobalAveragePool is equivalent to AveragePool with kernel size equal
       // to the spatial dimension of input tensor
-      return createOp<ops::PoolOp>(inputs[0]->getOutput(0),
-                                   ops::PoolOp::PoolingType::AVG,
-                                   inputs[0]->getOutputShape(0), // kernel_shape
-                                   Shape({1, 1}),                // strides_shape
-                                   cdata.padding_before, cdata.padding_after,
-                                   ops::PoolOp::BorderType::ZEROFILLED,
-                                   ops::PoolOp::RoundMode::floor);
+      result = createOp<ops::PoolOp>(t_input[0]->getOutput(0),
+                                     ops::PoolOp::PoolingType::AVG,
+                                     t_input[0]->getOutputShape(0), // kernel_shape
+                                     Shape({1, 1}),                // strides_shape
+                                     cdata.padding_before, cdata.padding_after,
+                                     ops::PoolOp::BorderType::ZEROFILLED,
+                                     ops::PoolOp::RoundMode::floor);
+      return convertMIRToONNX(result[0]->getOutput(0));
+    }
     case ONNXOpCode::opAveragePool:
       border_type = ops::PoolOp::BorderType::ZEROFILLED;
       pool_type = ops::PoolOp::PoolingType::AVG;
@@ -206,16 +215,16 @@ std::vector<Operation*> ONNXOpCreator::convertPool(InputOps& inputs, ONNXOpCode
   // Proceed with Average or Max Pool
   getKernelStridesPadding(onnx_node, cdata);
 
-  result = createOp<ops::PoolOp>(inputs[0]->getOutput(0), pool_type,
+  result = createOp<ops::PoolOp>(t_input[0]->getOutput(0), pool_type,
                                  cdata.kernel_shape, cdata.strides_shape,
                                  cdata.padding_before, cdata.padding_after,
                                  border_type,
                                  ops::PoolOp::RoundMode::floor);
-  return result;
+  return convertMIRToONNX(result[0]->getOutput(0));
 }
 
 std::vector<Operation*> ONNXOpCreator::convertSoftmax(InputOps& inputs,
-                                                     const onnx::NodeProto& onnx_node) {
+                                                      const onnx::NodeProto& onnx_node) {
   int axis;
   bool found;
   std::tie (found, axis) = getIntAttribute(onnx_node);
@@ -223,8 +232,36 @@ std::vector<Operation*> ONNXOpCreator::convertSoftmax(InputOps& inputs,
   return createOp<ops::SoftmaxOp>(inputs[0]->getOutput(0), axis);
 }
 
-std::vector<Operation*> ONNXOpCreator::convertReshape(Operation* inputData, Shape outputShape) {
-  auto outputs = createOp<ops::ReshapeOp>(inputData->getOutput(0), outputShape);
+std::vector<Operation*> ONNXOpCreator::convertReshape(InputOps& inputs) {
+  // The original shape
+  auto in_shape = inputs[0]->getInputShape(0);
+
+  // Input tensor describing the new shape
+  // TODO: could it be not a constant?
+  auto* op = dynamic_cast<mir::ops::ConstantOp*>(inputs[1]);
+  assert(op && "We support constants only");
+  auto shape_tensor = op->getValue();
+  Shape shape_tensor_shape = (shape_tensor).getShape();
+  assert(shape_tensor_shape.rank() == 1);
+  // The rank of the new shape
+  auto cnt  = shape_tensor_shape.numElements();
+  // The vector to build the new shape from
+  std::vector<int32_t > shape_vector(cnt);
+  ShapeRange out_range(shape_tensor_shape);
+  Tensor<int32_t> tensor_accessor(shape_tensor);
+
+  int i = 0;
+  for (auto idx : out_range) {
+    if (tensor_accessor.at(idx) == 0)
+      shape_vector[i] = in_shape.dim(i);
+    else if (tensor_accessor.at(idx) == -1)
+      shape_vector[i] = Shape::autoDim;
+    else
+      shape_vector[i] = tensor_accessor.at(idx);
+    i++;
+  }
+  auto out_shape = Shape(shape_vector);
+  auto outputs = createOp<ops::ReshapeOp>(inputs[0]->getOutput(0), out_shape);
   return outputs;
 }
 
@@ -267,40 +304,39 @@ std::vector<Operation*> ONNXOpCreator::convertElementwise(InputOps& inputs,
     descriptors.push_back(input->getOutput(0));
   return createOp<ops::ElementwiseOp>(descriptors, op_type);
 }
-
 std::vector<Operation*> ONNXOpCreator::convertBatchNorm(InputOps& inputs,
-                                                       const onnx::NodeProto& onnx_node,
-                                                       InputTensors& input_tensors) {
+                                                        const onnx::NodeProto& onnx_node,
+                                                        InputTensors& input_tensors) {
+  // overall_res = (X - mean) / sqrt(var + epsilon) * scale + bias
   bool found;
   float value;
-
   std::tie(found, value) = getFloatAttribute(onnx_node, "epsilon");
-  float epsilon = found ? value : 1e-05;
-  std::tie(found, value) = getFloatAttribute(onnx_node, "momentum");
-  float momentum = found ? value : 0.9;
-  // FIXME: spatial vs. scale_factor
-  //std::tie(found, value) = getFloatAttribute(onnx_node, "spatial");
-  float scale_factor = 0.0f;
-  // Scale tensor
-  assert(input_tensors.find(inputs[1]->getName()) != input_tensors.end());
-  auto ptensor = input_tensors.at(inputs[1]->getName());
-  Tensor<float> nnc_scale(ptensor);
-  // Bias tensor
-  assert(input_tensors.find(inputs[2]->getName()) != input_tensors.end());
-  auto nnc_bias = input_tensors.at(inputs[2]->getName());
-  // TODO: there are 2 training tensors in the inputs
+  float epsilon = found ? value : 1e-05f;
 
-  inputs.resize(1);
-  auto mean_outputs = createOp<ops::BiasAddOp>(inputs[0]->getOutput(0), nnc_bias);
+  const auto& scale = input_tensors.at(inputs[1]->getName());
+  const auto& bias = input_tensors.at(inputs[2]->getName());
+  const auto& mean = input_tensors.at(inputs[3]->getName());
+  const auto& var = input_tensors.at(inputs[4]->getName());
+
+  // res1 = X - mean
+  Tensor<float> bias_data(mean);
+  for (auto& idx: ShapeRange(bias_data.getShape()))
+    bias_data.at(idx) *= -1;
 
-  // create scale argument from variance:
-  // multiply elements of variance by scaleFactor and
-  // normalize biased input using scale operation
-  for (Index idx : ShapeRange(nnc_scale.getShape()))
-    nnc_scale.at(idx) = 1.0f / std::sqrt(nnc_scale.at(idx) * scale_factor + epsilon);
+  auto data = convertONNXToMIR(inputs[0]->getOutput(0));
+  auto bias_add_1 = createOp<ops::BiasAddOp>(data[0]->getOutput(0), mean);
 
-  auto variance_outputs = createOp<ops::ScaleOp>(mean_outputs[0]->getOutput(0), ptensor);
-  return variance_outputs;
+  // res2 = res1 * scale / (var + epsilon)
+  Tensor<float> multiplier(scale);
+  Tensor<float> var_accessor(var);
+  for (auto& idx: ShapeRange(scale.getShape()))
+    multiplier.at(idx) /= std::sqrt(var_accessor.at(idx) + epsilon);
+  auto scale_op = createOp<ops::ScaleOp>(bias_add_1[0]->getOutput(0), scale);
+
+  // overall_res = res2 + bias
+  auto bias_add_2 = createOp<ops::BiasAddOp>(scale_op[0]->getOutput(0), bias);
+
+  return {convertMIRToONNX(bias_add_2[0]->getOutput(0))};
 }
 
 std::vector<Operation*> ONNXOpCreator::convertDropout(InputOps& inputs,
@@ -318,7 +354,8 @@ std::vector<Operation*> ONNXOpCreator::convertScale(InputOps& inputs,
   float value;
   std::tie(found, value) = getFloatAttribute(onnx_node, "scale");
   float scale = found ? value : 1.0;
-  auto outputs = createOp<ops::ScaleOp>(inputs[0]->getOutput(0), createTensor(scale));
+  auto outputs = createOp<ops::ScaleOp>(inputs[0]->getOutput(0),
+                                        createTensor(scale, inputs[0]->getOutputShape(0)));
   return outputs;
 }
 
@@ -328,28 +365,75 @@ std::vector<Operation*> ONNXOpCreator::convertGemm(InputOps& inputs,
   int   ivalue;
   float fvalue;
 
+  // Compute Y = alpha * A' * B' + beta * C, where input tensor A has shape (M, K) or (K, M),
+  // input tensor B has shape (K, N) or (N, K),
+  // input tensor C is broadcastable to shape (M, N), and output tensor Y has shape (M, N).
+  // A will be transposed before doing the computation if attribute transA is non-zero,
+  // same for B and transB. This operator supports unidirectional broadcasting
+  // (tensor C should be unidirectional broadcastable to tensor A * B).
+
   std::tie (found, ivalue) = getIntAttribute(onnx_node, "transA");
-  bool transA = found ? ivalue : 0;
+  bool trans_a = found ? ivalue : 0;
   std::tie (found, ivalue) = getIntAttribute(onnx_node, "transB");
-  bool transB = found ? ivalue : 0;
-  std::tie (found, fvalue) = getIntAttribute(onnx_node, "alpha");
+  bool trans_b = found ? ivalue : 0;
+  std::tie (found, fvalue) = getFloatAttribute(onnx_node, "alpha");
   float alpha = found ? fvalue : 1.0;
-  std::tie (found, fvalue) = getIntAttribute(onnx_node, "beta");
+  std::tie (found, fvalue) = getFloatAttribute(onnx_node, "beta");
   float beta = found ? fvalue : 1.0;
 
+  // 1. Prepare input matrix A
   // Flatten the shape by dim(0)
   mir::Shape shape0 ({inputs[0]->getOutputShape(0).dim(0),
                       inputs[0]->getOutputShape(0).numElements() /
                                                         inputs[0]->getOutputShape(0).dim(0)});
-  auto reshape = createOp<ops::ReshapeOp>(inputs[0]->getOutput(0), shape0);
+  auto input_a = createOp<ops::ReshapeOp>(inputs[0]->getOutput(0), shape0);
+  if (trans_a)
+    input_a = createOp<ops::TransposeOp>(input_a[0]->getOutput(0), std::vector<std::size_t>{1, 0});
+  if (alpha != 1.0)
+    input_a = createOp<ops::ScaleOp>(input_a[0]->getOutput(0),
+                                          createTensor(alpha, input_a[0]->getOutputShape(0)));
+
+  // 2. Prepare input matrix B
+  //
+  auto input_b = inputs[1]->getOutput(0);
+  if (trans_b)
+    input_b = createOp<ops::TransposeOp>(input_b, std::vector<std::size_t>{1, 0})[0]->getOutput(0);
+  // Number of cols in tensor A must be equal to number of rows in tensor B
+  assert(input_a[0]->getOutput(0).op->getOutputShape(0).dim(1) ==
+         input_b.op->getOutputShape(0).dim(0));
+  Shape mult_a_b({input_a[0]->getOutput(0).op->getOutputShape(0).dim(0),
+                  input_b.op->getOutputShape(0).dim(1)});
+
+  // 3. Prepare input matrix C
+  //
+  auto input_c = inputs[2]->getOutput(0);
+  auto beta_tensor = createTensor(beta, input_c.op->getOutputShape(0));
+  if ((mult_a_b.rank() == 2) && (input_c.op->getOutputShape(0).rank() == 1)) {
+    beta_tensor = TensorVariant(beta_tensor, mult_a_b);
+  }
+  auto constant = createOp<ops::ConstantOp>(beta_tensor)[0]->getOutput(0);
+  std::vector<IODescriptor> descriptors = {constant, input_c};
+  auto c_mult = createOp<ops::ElementwiseOp>(descriptors, ops::ElementwiseOp::OpType::mul);
+  assert(c_mult[0]->getOutput(0).op->getOutputShape(0) == mult_a_b);
+  return createOp<ops::GemmOp>(input_a[0]->getOutput(0), input_b, c_mult[0]->getOutput(0));
+}
 
-  std::vector<IODescriptor> descriptors;
-  descriptors.push_back(reshape[0]->getOutput(0));
-  descriptors.push_back(inputs[1]->getOutput(0));
-  descriptors.push_back(inputs[2]->getOutput(0));
+std::vector<Operation*>
+ONNXOpCreator::createInput(const std::string& input_name, const mir::Shape& input_shape) {
+  // TODO For now we only support convolutional networks with one element per batch.
+  assert(input_shape.rank() == 4 && input_shape.dim(0) == 1);
+  auto variable = _graph->create<ops::VariableOp>(input_name, input_shape);
+  return {variable};
+}
+
+std::vector<Operation*> ONNXOpCreator::convertONNXToMIR(const mir::IODescriptor& arg) {
+  // NCHW -> NHWC
+  return createOp<ops::TransposeOp>(arg, std::vector<std::size_t>{0, 2, 3, 1});
+}
 
-  return createOp<ops::GemmOp>(reshape[0]->getOutput(0), inputs[1]->getOutput(0),
-                               inputs[2]->getOutput(0), transA, transB, alpha, beta);
+std::vector<Operation*> ONNXOpCreator::convertMIRToONNX(const mir::IODescriptor& arg) {
+  // NHWC -> NCHW
+  return createOp<ops::TransposeOp>(arg, std::vector<std::size_t>{0, 3, 1, 2});
 }
 
 } // namespace nnc
index 479492a..30999b5 100644 (file)
@@ -43,7 +43,7 @@ public:
   std::vector<mir::Operation*> convertPool(InputOps& inputs, ONNXOpCode op_code,
                                           const onnx::NodeProto& onnx_node);
   std::vector<mir::Operation*> convertSoftmax(InputOps& inputs, const onnx::NodeProto& onnx_node);
-  std::vector<mir::Operation*> convertReshape(mir::Operation* input_data, mir::Shape output_shape);
+  std::vector<mir::Operation*> convertReshape(InputOps& inputs);
   std::vector<mir::Operation*> convertRelu(InputOps& inputs);
   std::vector<mir::Operation*> convertSigmoid(InputOps& inputs);
 
@@ -58,6 +58,10 @@ public:
   std::vector<mir::Operation*> convertGather(InputOps& inputs, const onnx::NodeProto& onnx_node);
   std::vector<mir::Operation*> convertGemm(InputOps& inputs, const onnx::NodeProto& onnx_node);
 
+  std::vector<mir::Operation*> createInput(const std::string&, const mir::Shape&);
+  std::vector<mir::Operation*> convertONNXToMIR(const mir::IODescriptor& arg);
+  std::vector<mir::Operation*> convertMIRToONNX(const mir::IODescriptor& arg);
+
 private:
   template <typename OpType, typename ...Types>
   std::vector<nnc::mir::Operation*> createOp(Types&&... args);
index db124f3..bd0c827 100644 (file)
@@ -28,7 +28,6 @@
 #include "core/modelIR/operations/BiasAddOp.h"
 #include "core/modelIR/operations/CappedReluOp.h"
 #include "core/modelIR/operations/ConcatOp.h"
-#include "core/modelIR/operations/ConstantOp.h"
 #include "core/modelIR/operations/Conv2DOp.h"
 #include "core/modelIR/operations/Deconv2DOp.h"
 #include "core/modelIR/operations/DepthwiseConv2DOp.h"
@@ -36,8 +35,8 @@
 #include "core/modelIR/operations/ElementwiseOp.h"
 #include "core/modelIR/operations/EluOp.h"
 #include "core/modelIR/operations/FullyConnectedOp.h"
-#include "core/modelIR/operations/GemmOp.h"
 #include "core/modelIR/operations/GatherOp.h"
+#include "core/modelIR/operations/GemmOp.h"
 #include "core/modelIR/operations/PadOp.h"
 #include "core/modelIR/operations/PoolOp.h"
 #include "core/modelIR/operations/ReduceFOp.h"
@@ -88,7 +87,7 @@ void ModelAnalyzer::addOpDescr(Operation* op, const string& opName) {
     nodeTid = allocateTensor(name, TensorDescription::Type::OUT);
     _named_tensors.push_back(nodeTid);
     type = OpDescr::Type::OUT;
-  } else  {
+  } else {
     // process ordinary op
     nodeTid = allocateTensor();
   }
@@ -141,8 +140,6 @@ void ModelAnalyzer::analyze(const mir::Graph* g) {
 
   // Collect all inputs and constants
   vector<Operation*> init_ops(g->collectInputs());
-  vector<Operation*> constant_ops(g->collectConstants());
-  init_ops.insert(init_ops.end(), constant_ops.begin(), constant_ops.end());
 
   // Walk all network inputs
   for (Operation* in : init_ops) {
diff --git a/contrib/nnc/passes/soft_backend/code_snippets/cpp_gemm.def b/contrib/nnc/passes/soft_backend/code_snippets/cpp_gemm.def
new file mode 100644 (file)
index 0000000..1b9e094
--- /dev/null
@@ -0,0 +1,34 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+inline void gemm(const float* tensor_a_data, const Dims<4>& tensor_a_dims,
+                   const float* tensor_b_data, const Dims<4>& tensor_b_dims,
+                   const float* tensor_c_data, const Dims<4>& tensor_c_dims,
+                         float* output_data,   const Dims<4>& output_dims) {
+  const auto tensor_a_map =
+      MapAsMatrixWithFirstDimAsRows(tensor_a_data, tensor_a_dims);
+  const auto tensor_b_map =
+      MapAsMatrixWithFirstDimAsRows(tensor_b_data, tensor_b_dims);
+  const auto tensor_c_map =
+      MapAsMatrixWithFirstDimAsRows(tensor_c_data, tensor_c_dims);
+  auto output_matrix_map =
+      MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
+  Gemm(tensor_a_map, tensor_b_map, &output_matrix_map);
+  auto size = tensor_a_dims.sizes[0] * tensor_a_dims.sizes[1] *
+              tensor_a_dims.sizes[2] * tensor_a_dims.sizes[3];
+  for (int i = 0; i < size; i++) {
+    output_data[i] = output_data[i] + tensor_c_data[i];
+  }
+}