[mir_onnx] Switch to binary elementwise operations (#6412)
authorСергей Баранников/AI Tools Lab /SRR/Engineer/삼성전자 <s.barannikov@samsung.com>
Fri, 9 Aug 2019 09:26:55 +0000 (12:26 +0300)
committerAlexander Efimov/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Fri, 9 Aug 2019 09:26:55 +0000 (12:26 +0300)
Switch to new binary elementwise operations.

Signed-off-by: Sergei Barannikov <s.barannikov@samsung.com>
compiler/mir-onnx-importer/ONNXHelpers.h
compiler/mir-onnx-importer/Op/Add.cpp
compiler/mir-onnx-importer/Op/BatchNormalization.cpp
compiler/mir-onnx-importer/Op/Conv.cpp
compiler/mir-onnx-importer/Op/Gemm.cpp
compiler/mir-onnx-importer/Op/Max.cpp
compiler/mir-onnx-importer/Op/Mul.cpp
compiler/mir-onnx-importer/Op/Scale.cpp
compiler/mir-onnx-importer/Op/Sum.cpp

index 6c05362..9700c11 100644 (file)
@@ -22,7 +22,6 @@
 #include "mir/ShapeRange.h"
 
 #include "mir/ops/TransposeOp.h"
-#include "mir/ops/ElementwiseOp.h"
 
 #include "onnx/onnx.pb.h"
 
@@ -145,22 +144,6 @@ inline mir::Operation::Output *convertMIRToONNX(mir::Graph *graph, mir::Operatio
       ->getOutput(0);
 }
 
-inline mir::Operation::Output *createAdd(mir::Graph *graph, mir::Operation::Output *arg1,
-                                         mir::Operation::Output *arg2)
-{
-  std::vector<mir::Operation::Output *> inputs{arg1, arg2};
-  return graph->create<mir::ops::ElementwiseOp>("", inputs, mir::ops::ElementwiseOp::OpType::add)
-      ->getOutput(0);
-}
-
-inline mir::Operation::Output *createMul(mir::Graph *graph, mir::Operation::Output *arg1,
-                                         mir::Operation::Output *arg2)
-{
-  std::vector<mir::Operation::Output *> inputs{arg1, arg2};
-  return graph->create<mir::ops::ElementwiseOp>("", inputs, mir::ops::ElementwiseOp::OpType::mul)
-      ->getOutput(0);
-}
-
 } // namespace mir_onnx
 
 #endif // __MIR_ONNX_HELPERS_H__
index ed413d8..a500ee6 100644 (file)
@@ -18,7 +18,7 @@
 
 #include "ONNXHelpers.h"
 
-#include "mir/ops/ElementwiseOp.h"
+#include "mir/ops/AddOp.h"
 
 namespace mir_onnx
 {
@@ -28,9 +28,8 @@ AddNodeConverter::convert(const onnx::NodeProto &onnx_node,
                           const std::vector<mir::Operation::Output *> &inputs,
                           mir::Graph *graph) const
 {
-  auto result =
-      createOp<mir::ops::ElementwiseOp>(graph, inputs, mir::ops::ElementwiseOp::OpType::add);
-  return {result->getOutput(0)};
+  auto result = createOp<mir::ops::AddOp>(graph, inputs[0], inputs[1])->getOutput(0);
+  return {result};
 }
 
 } // namespace mir_onnx
index aac2a64..30464f1 100644 (file)
@@ -21,7 +21,9 @@
 #include "mir/ShapeRange.h"
 #include "mir/Tensor.h"
 
+#include "mir/ops/AddOp.h"
 #include "mir/ops/ConstantOp.h"
+#include "mir/ops/MulOp.h"
 
 #include <cmath>
 
@@ -51,7 +53,7 @@ BatchNormalizationNodeConverter::convert(const onnx::NodeProto &onnx_node,
 
   auto data = convertONNXToMIR(graph, inputs[0]);
   auto mean = createOp<mir::ops::ConstantOp>(graph, mean_tensor)->getOutput(0);
-  auto result = createAdd(graph, data, mean);
+  auto result = createOp<mir::ops::AddOp>(graph, data, mean)->getOutput(0);
 
   // res2 = res1 * scale / (var + epsilon)
   mir::Tensor<float> multiplier(scale_tensor);
@@ -59,11 +61,11 @@ BatchNormalizationNodeConverter::convert(const onnx::NodeProto &onnx_node,
   for (auto &idx : mir::ShapeRange(scale_tensor.getShape()))
     multiplier.at(idx) /= std::sqrt(var_accessor.at(idx) + epsilon);
   auto scale = createOp<mir::ops::ConstantOp>(graph, scale_tensor)->getOutput(0);
-  result = createMul(graph, result, scale);
+  result = createOp<mir::ops::MulOp>(graph, result, scale)->getOutput(0);
 
   // overall_res = res2 + bias
   auto bias = createOp<mir::ops::ConstantOp>(graph, bias_tensor)->getOutput(0);
-  result = createAdd(graph, result, bias);
+  result = createOp<mir::ops::AddOp>(graph, result, bias)->getOutput(0);
 
   return {convertMIRToONNX(graph, result)};
 }
index 6daa68e..0a3d21d 100644 (file)
 
 #include "mir/TensorUtil.h"
 
+#include "mir/ops/AddOp.h"
 #include "mir/ops/ConstantOp.h"
 #include "mir/ops/Conv2DOp.h"
 #include "mir/ops/DepthwiseConv2DOp.h"
-#include "mir/ops/ElementwiseOp.h"
 
 namespace mir_onnx
 {
@@ -78,7 +78,7 @@ ConvNodeConverter::convert(const onnx::NodeProto &onnx_node,
 
   if (inputs.size() > 2)
   {
-    result = createAdd(graph, result, inputs[2]);
+    result = createOp<mir::ops::AddOp>(graph, result, inputs[2])->getOutput(0);
   }
 
   return {convertMIRToONNX(graph, result)};
index bff6bd2..e079137 100644 (file)
@@ -22,6 +22,7 @@
 
 #include "mir/ops/ConstantOp.h"
 #include "mir/ops/GemmOp.h"
+#include "mir/ops/MulOp.h"
 #include "mir/ops/ReshapeOp.h"
 #include "mir/ops/TransposeOp.h"
 
@@ -65,7 +66,7 @@ GemmNodeConverter::convert(const onnx::NodeProto &onnx_node,
   {
     auto alpha_tensor = createScalarTensor(alpha_val, input_a->getShape());
     auto alpha = createOp<mir::ops::ConstantOp>(graph, alpha_tensor)->getOutput(0);
-    input_a = createMul(graph, input_a, alpha);
+    input_a = createOp<mir::ops::MulOp>(graph, input_a, alpha)->getOutput(0);
   }
 
   // 2. Prepare input matrix B
@@ -88,7 +89,7 @@ GemmNodeConverter::convert(const onnx::NodeProto &onnx_node,
   }
   auto beta = createOp<mir::ops::ConstantOp>(graph, beta_tensor)->getOutput(0);
   std::vector<mir::Operation::Output *> mul_inputs = {beta, input_c};
-  auto c_mult = createMul(graph, beta, input_c);
+  auto c_mult = createOp<mir::ops::MulOp>(graph, beta, input_c)->getOutput(0);
   assert(c_mult->getShape() == mult_a_b);
   auto result = createOp<mir::ops::GemmOp>(graph, input_a, input_b, c_mult);
   return {result->getOutput(0)};
index b32ca43..3a6ba69 100644 (file)
@@ -18,7 +18,7 @@
 
 #include "ONNXHelpers.h"
 
-#include "mir/ops/ElementwiseOp.h"
+#include "mir/ops/MaxOp.h"
 
 namespace mir_onnx
 {
@@ -28,9 +28,8 @@ MaxNodeConverter::convert(const onnx::NodeProto &onnx_node,
                           const std::vector<mir::Operation::Output *> &inputs,
                           mir::Graph *graph) const
 {
-  auto result =
-      createOp<mir::ops::ElementwiseOp>(graph, inputs, mir::ops::ElementwiseOp::OpType::max);
-  return {result->getOutput(0)};
+  auto result = createOp<mir::ops::MaxOp>(graph, inputs[0], inputs[1])->getOutput(0);
+  return {result};
 }
 
 } // namespace mir_onnx
index e095233..f27af87 100644 (file)
@@ -18,7 +18,7 @@
 
 #include "ONNXHelpers.h"
 
-#include "mir/ops/ElementwiseOp.h"
+#include "mir/ops/MulOp.h"
 
 namespace mir_onnx
 {
@@ -28,9 +28,8 @@ MulNodeConverter::convert(const onnx::NodeProto &onnx_node,
                           const std::vector<mir::Operation::Output *> &inputs,
                           mir::Graph *graph) const
 {
-  auto result =
-      createOp<mir::ops::ElementwiseOp>(graph, inputs, mir::ops::ElementwiseOp::OpType::mul);
-  return {result->getOutput(0)};
+  auto result = createOp<mir::ops::MulOp>(graph, inputs[0], inputs[1])->getOutput(0);
+  return {result};
 }
 
 } // namespace mir_onnx
index 188c3df..6334082 100644 (file)
@@ -18,6 +18,7 @@
 
 #include "ONNXHelpers.h"
 
+#include "mir/ops/MulOp.h"
 #include "mir/ops/ConstantOp.h"
 
 namespace mir_onnx
@@ -35,7 +36,7 @@ ScaleNodeConverter::convert(const onnx::NodeProto &onnx_node,
   const auto &shape = inputs[0]->getShape();
   auto scale_tensor = createScalarTensor(scale_val, shape);
   auto scale = createOp<mir::ops::ConstantOp>(graph, scale_tensor)->getOutput(0);
-  auto result = createMul(graph, inputs[0], scale);
+  auto result = createOp<mir::ops::MulOp>(graph, inputs[0], scale)->getOutput(0);
   return {result};
 }
 
index d2fa94c..c25f786 100644 (file)
@@ -18,7 +18,7 @@
 
 #include "ONNXHelpers.h"
 
-#include "mir/ops/ElementwiseOp.h"
+#include "mir/ops/AddOp.h"
 
 namespace mir_onnx
 {
@@ -28,9 +28,15 @@ SumNodeConverter::convert(const onnx::NodeProto &onnx_node,
                           const std::vector<mir::Operation::Output *> &inputs,
                           mir::Graph *graph) const
 {
-  auto result =
-      createOp<mir::ops::ElementwiseOp>(graph, inputs, mir::ops::ElementwiseOp::OpType::add);
-  return {result->getOutput(0)};
+  assert(inputs.size() >= 1);
+
+  auto result = inputs[0];
+  for (int i = 1; i < static_cast<int>(inputs.size()); ++i)
+  {
+    result = createOp<mir::ops::AddOp>(graph, result, inputs[i])->getOutput(0);
+  }
+
+  return {result};
 }
 
 } // namespace mir_onnx