[nnc] Fix soft backend implementation for Gemm (#2750)
authorАндрей Тищенко/AI Tools Lab /SRR/Staff Engineer/삼성전자 <a.tischenko@partner.samsung.com>
Mon, 14 Jan 2019 14:03:01 +0000 (17:03 +0300)
committerEfimov Alexander/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Mon, 14 Jan 2019 14:03:01 +0000 (17:03 +0300)
- Fixed generation of cpu soft backend gemm operation
- Small refactoring and typo fixes of interpreter and onnx frontend

Signed-off-by: Andrew V. Tischenko <a.tischenko@partner.samsung.com>
13 files changed:
contrib/nnc/include/core/modelIR/operations/GemmOp.h
contrib/nnc/include/passes/interpreter/Interpreter.h
contrib/nnc/passes/interpreter/Interpreter.cpp
contrib/nnc/passes/interpreter/interpreter_pass.cpp
contrib/nnc/passes/interpreter/ops/Gemm.h
contrib/nnc/passes/onnx_frontend/ONNXImporterImpl.cpp
contrib/nnc/passes/onnx_frontend/ONNXOpCreator.cpp
contrib/nnc/passes/soft_backend/CPPGenerator.cpp
contrib/nnc/passes/soft_backend/ModelAnalyzer.cpp
contrib/nnc/passes/soft_backend/SBSerializer.cpp
contrib/nnc/passes/soft_backend/code_snippets/cpp_gemm.def
contrib/nnc/passes/soft_backend/code_snippets/cpp_operations.def
contrib/nnc/unittests/soft_backend/CPPOperations.cpp

index bb90c3c..8ab3ea5 100644 (file)
@@ -26,8 +26,8 @@ namespace ops {
 
 class GemmOp : public Operation {
 public:
-  GemmOp(IODescriptor arg, const IODescriptor b, const IODescriptor c) :
-            Operation(Type::gemmOp, {arg, b, c}) {
+  GemmOp(IODescriptor a, IODescriptor b, IODescriptor c) :
+            Operation(Type::gemmOp, {a, b, c}) {
     inferOutputShapes();
   }
 
@@ -37,5 +37,4 @@ private:
 } // namespace ops
 } // namespace mir
 } // namespace nnc
-
 #endif //_NNC_CORE_IR_MODEL_GEMM_OP_H_
index 23dfaf4..a42d5f3 100644 (file)
@@ -45,12 +45,12 @@ public:
   void visit(ops::Conv2DOp& op) override;
   void visit(ops::DeConv2DOp& op) override;
   void visit(ops::DepthwiseConv2DOp& op) override;
-  void visit(ops::GemmOp& op) override;
   void visit(ops::DropoutOp& op) override;
   void visit(ops::ElementwiseOp& op) override;
   void visit(ops::EluOp& op) override;
   void visit(ops::FullyConnectedOp& op) override;
   void visit(ops::GatherOp& op) override;
+  void visit(ops::GemmOp& op) override;
   void visit(ops::LeakyReluOp& op) override;
   void visit(ops::PadOp& op) override;
   void visit(ops::PoolOp& op) override;
index f0fd1c8..c6cd178 100644 (file)
@@ -106,7 +106,7 @@ void NNInterpreter::dump(Operation& op, bool all) {
   if (do_it || all) {
     auto last_idx = shape.rank() - 1;
     for (auto idx : ShapeRange(shape)) {
-      if (!idx.at(last_idx))
+      if (!(idx.at(last_idx) % 15))
         std::cout << "\n";
       dumpIndex(idx);
       if (tensor.getDataType() == DTYPE::FLOAT32)
index 792e54c..8e6bafc 100644 (file)
@@ -105,14 +105,14 @@ PassData InterpreterPass::run(PassData data) {
 
   // Check nodes
   const auto& outputs = g->collectOutputs();
-  
+#if 0
+  interpreter.dump(*outputs[0], true);
+#endif
+
   for (auto& out : outputs) {
     auto outputNode = interpreter.getResult(out);
-    if (outputNode.empty()) {
+    if (outputNode.empty())
       throw PassException("No value for output node <" + out->getName() + ">");
-    } else {
-      std::cout << "Output node <" + out->getName() + "> found" << std::endl;
-    }
   }
 
   bool is_several_outs = (outputs.size() > 1);
index 0d94813..e6bf2fd 100644 (file)
@@ -26,15 +26,15 @@ namespace nnc {
 template<typename T>
 class Gemm : public OperationImpl<T> {
 public:
-    Gemm(const mir::TensorVariant& a, const mir::TensorVariant& b, const mir::TensorVariant& c,
-         mir::ops::GemmOp& op) : _op(op), _tensor_a(a), _tensor_b(b), _tensor_c(c) {}
+    Gemm(const mir::TensorVariant& a, const mir::TensorVariant& b,
+         const mir::TensorVariant& c, mir::ops::GemmOp& op) :
+                           _op(op), _tensor_a(a), _tensor_b(b), _tensor_c(c) {}
 
   std::vector<mir::TensorVariant> operator()() override {
     mir::TensorVariant res = OperationImpl<T>::allocate_tensor(_op.getOutputShape(0));
     mir::Tensor<T> accessor(res);
     mir::ShapeRange out_range(res.getShape());
 
-//    mir::Tensor<T> tensor_b(_b);
     auto b_shape = _tensor_b.getShape();
     int32_t b_rank = b_shape.rank();
 
@@ -67,7 +67,8 @@ public:
     // We'd like to broadcast Tensor C to the output shape
     assert(_op.getOutputShape(0).rank() == 2);
     assert((_op.getOutputShape(0).rank() == _op.getInputShape(2).rank()) ||
-           ((_op.getInputShape(2).rank() == 1) && (_op.getOutputShape(0).dim(0) == 1)));
+           ((_op.getInputShape(2).rank() == 1) &&
+            (_op.getOutputShape(0).dim(0) == 1)));
 
     auto t = mir::TensorVariant (_tensor_c, _op.getOutputShape(0));
     mir::Tensor<T> tensor_c(t);
@@ -81,7 +82,7 @@ private:
   mir::ops::GemmOp& _op;
   mir::Tensor<T> _tensor_a;
   mir::Tensor<T> _tensor_b;
-  const mir::TensorVariant _tensor_c;
+  mir::TensorVariant _tensor_c;
 };
 } // namespace nnc
 
index d9643c4..e874295 100644 (file)
@@ -326,7 +326,9 @@ mir::Graph *ONNXImporterImpl::createIR() {
     assert (outputs.size());
     // FIXME: it should be done properly via the given graph outputs
     _graphOutputs.assign(outputs.begin(), outputs.end());
+#if 0
     dump(input_nodes, outputs, onnx_node);
+#endif
   }
   // set graph outputs
   // TODO: it should be done with onnx graph outputs
index 30750f0..32b35f8 100644 (file)
@@ -376,6 +376,8 @@ std::vector<Operation*> ONNXOpCreator::convertGemm(InputOps& inputs,
   bool trans_a = found ? ivalue : 0;
   std::tie (found, ivalue) = getIntAttribute(onnx_node, "transB");
   bool trans_b = found ? ivalue : 0;
+  std::tie (found, ivalue) = getIntAttribute(onnx_node, "broadcast");
+  bool broadcast = found ? ivalue : 0;
   std::tie (found, fvalue) = getFloatAttribute(onnx_node, "alpha");
   float alpha = found ? fvalue : 1.0;
   std::tie (found, fvalue) = getFloatAttribute(onnx_node, "beta");
@@ -408,6 +410,7 @@ std::vector<Operation*> ONNXOpCreator::convertGemm(InputOps& inputs,
   //
   auto input_c = inputs[2]->getOutput(0);
   auto beta_tensor = createTensor(beta, input_c.op->getOutputShape(0));
+  // TODO: check 'broadcast' attribute here
   if ((mult_a_b.rank() == 2) && (input_c.op->getOutputShape(0).rank() == 1)) {
     beta_tensor = TensorVariant(beta_tensor, mult_a_b);
   }
@@ -435,5 +438,4 @@ std::vector<Operation*> ONNXOpCreator::convertMIRToONNX(const mir::IODescriptor&
   // NHWC -> NCHW
   return createOp<ops::TransposeOp>(arg, std::vector<std::size_t>{0, 3, 1, 2});
 }
-
 } // namespace nnc
index 18b7955..1e26996 100644 (file)
@@ -55,6 +55,7 @@ using namespace std;
 #include "cpp_pad.generated.h"
 #include "cpp_transpose.generated.h"
 #include "cpp_gather.generated.h"
+#include "cpp_gemm.generated.h"
 
 namespace nnc
 {
@@ -297,6 +298,7 @@ void CPPCodeGenerator::materializeCode(ostream &out, const ModelAnalyzer &ma, co
   out.write(cpp_slice, sizeof(cpp_slice));
   out.write(cpp_elementwise, sizeof(cpp_elementwise));
   out.write(cpp_elu, sizeof(cpp_elu));
+  out.write(cpp_gemm, sizeof(cpp_gemm));
   out.write(cpp_tanh, sizeof(cpp_tanh));
   out.write(cpp_pad, sizeof(cpp_pad));
   out.write(cpp_sqrt, sizeof(cpp_sqrt));
index b4d0d17..4d80c4f 100644 (file)
@@ -24,6 +24,7 @@
 #include "core/modelIR/operations/BiasAddOp.h"
 #include "core/modelIR/operations/CappedReluOp.h"
 #include "core/modelIR/operations/ConcatOp.h"
+#include <core/modelIR/operations/ConstantOp.h>
 #include "core/modelIR/operations/Conv2DOp.h"
 #include "core/modelIR/operations/Deconv2DOp.h"
 #include "core/modelIR/operations/DepthwiseConv2DOp.h"
@@ -89,9 +90,9 @@ void ModelAnalyzer::addOpDescr(Operation* op, const string& function_name) {
   vector<size_t> node_input_tensors;
   for (const IODescriptor &d: op->getPrevNodes()) {
     size_t idx = d.index;
-    Operation *op = d.op;
-    assert(_opToDescr.find(op) != _opToDescr.end());
-    const OpDescr &descr = *_opToDescr[op];
+    Operation *prev_op = d.op;
+    assert(_opToDescr.find(prev_op) != _opToDescr.end());
+    const OpDescr &descr = *_opToDescr[prev_op];
     const size_t &inTid = descr._outputs[idx];
     node_input_tensors.push_back(inTid);
   }
@@ -225,7 +226,7 @@ void ModelAnalyzer::visit(ops::FullyConnectedOp& op) {
 }
 
 void ModelAnalyzer::visit(ops::GemmOp& op) {
-  addOpDescr(&op, "gemm");
+  addOpDescr(&op, "gemmOp");
 }
 
 void ModelAnalyzer::visit(ops::CappedReluOp& op) {
index ffd6c20..85601f6 100644 (file)
@@ -222,6 +222,7 @@ void Serializer::visit(ops::FullyConnectedOp& op) {
 
 void Serializer::visit(ops::GemmOp& op) {
   _curOp->_paramStartOffset = _buffer.size();
+  serializeShape(op.getOutputShape(0));
 }
 
 void Serializer::visit(ops::CappedReluOp& op) {
index 1b9e094..4ecef2b 100644 (file)
@@ -13,22 +13,23 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-inline void gemm(const float* tensor_a_data, const Dims<4>& tensor_a_dims,
-                   const float* tensor_b_data, const Dims<4>& tensor_b_dims,
-                   const float* tensor_c_data, const Dims<4>& tensor_c_dims,
-                         float* output_data,   const Dims<4>& output_dims) {
-  const auto tensor_a_map =
-      MapAsMatrixWithFirstDimAsRows(tensor_a_data, tensor_a_dims);
-  const auto tensor_b_map =
-      MapAsMatrixWithFirstDimAsRows(tensor_b_data, tensor_b_dims);
-  const auto tensor_c_map =
-      MapAsMatrixWithFirstDimAsRows(tensor_c_data, tensor_c_dims);
+inline void gemm(const float* input_a, const Dims<4>& input_a_dims,
+                const float* input_b, const Dims<4>& input_b_dims,
+                const float* input_c, const Dims<4>& input_c_dims,
+                float* output_data, const Dims<4>& out_dims) {
+  const auto input_matrix_a_map =
+      MapAsMatrixWithFirstDimAsRows(input_a, input_a_dims);
+  const auto input_matrix_b_map =
+      MapAsMatrixWithFirstDimAsRows(input_b, input_b_dims);
+  const auto input_matrix_c_map =
+      MapAsMatrixWithFirstDimAsRows(input_c, input_c_dims);
   auto output_matrix_map =
-      MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
-  Gemm(tensor_a_map, tensor_b_map, &output_matrix_map);
-  auto size = tensor_a_dims.sizes[0] * tensor_a_dims.sizes[1] *
-              tensor_a_dims.sizes[2] * tensor_a_dims.sizes[3];
-  for (int i = 0; i < size; i++) {
-    output_data[i] = output_data[i] + tensor_c_data[i];
-  }
+      MapAsMatrixWithFirstDimAsRows(output_data, out_dims);
+
+  Gemm(input_matrix_b_map, input_matrix_a_map, &output_matrix_map);
+
+  int len = out_dims.sizes[0] * out_dims.sizes[1] *
+            out_dims.sizes[2] * out_dims.sizes[3];
+  for (int i = 0; i < len; i++)
+    output_data[i] += input_c[i];
 }
index 02b5fae..73d73c6 100644 (file)
@@ -436,6 +436,15 @@ void fullConnect(Tensor &out, const char *params, const Tensor &in)
                  out.getData(), shapeToDims(out_s));
 }
 
+void gemmOp(Tensor &out, const char *params, const Tensor &tensor_a, const Tensor &tensor_b, const Tensor &tensor_c) {
+  Shape out_s = deserializeShape(params);
+  out.reShape(out_s);
+
+  gemm(tensor_a.getData(), shapeToDims(tensor_a.getShape()),
+       tensor_b.getData(), shapeToDims(tensor_b.getShape()),
+       tensor_c.getData(), shapeToDims(tensor_c.getShape()),
+       out.getData(), shapeToDims(out_s));
+}
 /**
  * @brief Resize assuming tflite axis order (NHWC)
  */
index 5d184f9..2d78eb1 100644 (file)
@@ -37,6 +37,7 @@
 #include "code_snippets/cpp_elu.def"
 #include "code_snippets/cpp_fully_connected.def"
 #include "code_snippets/cpp_gather.def"
+#include "code_snippets/cpp_gemm.def"
 #include "code_snippets/cpp_sigmoid.def"
 #include "code_snippets/cpp_pad.def"
 #include "code_snippets/cpp_pool.def"