[mir2loco] Switch to binary elementwise operations (#6413)
authorСергей Баранников/AI Tools Lab /SRR/Engineer/삼성전자 <s.barannikov@samsung.com>
Fri, 9 Aug 2019 09:29:45 +0000 (12:29 +0300)
committerAlexander Efimov/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Fri, 9 Aug 2019 09:29:45 +0000 (12:29 +0300)
Switch to new binary elementwise operations.

Signed-off-by: Sergei Barannikov <s.barannikov@samsung.com>
compiler/mir2loco/include/mir2loco.h
compiler/mir2loco/src/mir2loco.cpp
compiler/mir2loco/src/mir2loco.test.cpp

index 469de9f..79ad08f 100644 (file)
@@ -26,6 +26,7 @@ public:
   Transformer() = default;
   ~Transformer() = default;
 
+  void visit(mir::ops::AddOp &op) override;
   void visit(mir::ops::BatchNormOp &op) override;
   void visit(mir::ops::CappedReluOp &op) override;
   void visit(mir::ops::ConcatOp &op) override;
@@ -34,7 +35,6 @@ public:
   void visit(mir::ops::DeConv2DOp &op) override;
   void visit(mir::ops::DepthwiseConv2DOp &op) override;
   void visit(mir::ops::DropoutOp &op) override;
-  void visit(mir::ops::ElementwiseOp &op) override;
   void visit(mir::ops::EluOp &op) override;
   void visit(mir::ops::FullyConnectedOp &op) override;
   void visit(mir::ops::GatherOp &op) override;
index 57b49f5..0dede32 100644 (file)
 
 #include "mir2loco.h"
 
+#include "mir/ops/AddOp.h"
 #include "mir/ops/ConcatOp.h"
 #include "mir/ops/ConstantOp.h"
 #include "mir/ops/Conv2DOp.h"
-#include "mir/ops/ElementwiseOp.h"
 #include "mir/ops/PoolOp.h"
 #include "mir/ops/ReluOp.h"
 #include "mir/ops/ReshapeOp.h"
@@ -122,6 +122,21 @@ loco::DataType ConvertDataType(mir::DataType data_type)
 }
 } // namespace
 
+void Transformer::visit(mir::ops::AddOp &op)
+{
+  // Get Input
+  auto lhs = _mir2loco_map.at(op.getInput(0)->getProducer()->getNode());
+  auto rhs = _mir2loco_map.at(op.getInput(1)->getProducer()->getNode());
+
+  auto result = _loco_graph->nodes()->create<loco::EltwiseAdd>();
+  result->lhs(lhs);
+  result->rhs(rhs);
+
+  // Not set Shape
+  // Add to map
+  _mir2loco_map.emplace(&op, result);
+}
+
 void Transformer::visit(mir::ops::BatchNormOp &op) { throw std::runtime_error("NYI"); }
 
 void Transformer::visit(mir::ops::CappedReluOp &op) { throw std::runtime_error("NYI"); }
@@ -233,36 +248,6 @@ void Transformer::visit(mir::ops::DepthwiseConv2DOp &op) { throw std::runtime_er
 
 void Transformer::visit(mir::ops::DropoutOp &op) { throw std::runtime_error("NYI"); }
 
-void Transformer::visit(mir::ops::ElementwiseOp &op)
-{
-  // TODO Currently, MIR supports arbitrary number of inputs (>= 2).
-  if (op.getNumInputs() != 2)
-    throw std::runtime_error("NYI");
-
-  // Get Input
-  auto lhs = _mir2loco_map.at(op.getInput(0)->getProducer()->getNode());
-  auto rhs = _mir2loco_map.at(op.getInput(1)->getProducer()->getNode());
-  loco::Node *result = nullptr;
-  switch (op.getOpType())
-  {
-    case mir::ops::ElementwiseOp::OpType::add:
-    {
-      auto add_node = _loco_graph->nodes()->create<loco::EltwiseAdd>();
-      add_node->lhs(lhs);
-      add_node->rhs(rhs);
-      result = add_node;
-      break;
-    }
-    default:
-    {
-      throw std::runtime_error("NYI");
-    }
-  }
-  // Not set Shape
-  // Add to map
-  _mir2loco_map.emplace(&op, result);
-}
-
 void Transformer::visit(mir::ops::EluOp &op) { throw std::runtime_error("NYI"); }
 
 void Transformer::visit(mir::ops::FullyConnectedOp &op) { throw std::runtime_error("NYI"); }
index 824299f..5ed3e23 100644 (file)
 
 #include "mir2loco.h"
 
+#include "mir/ops/AddOp.h"
 #include "mir/ops/ConcatOp.h"
 #include "mir/ops/ConstantOp.h"
 #include "mir/ops/Conv2DOp.h"
-#include "mir/ops/ElementwiseOp.h"
 #include "mir/ops/PoolOp.h"
 #include "mir/ops/ReluOp.h"
 #include "mir/ops/ReshapeOp.h"
@@ -288,9 +288,7 @@ TEST_F(TestTransformer_mir2loco, Add_Test)
 
   auto *input1 = mir_graph.create<mir::ops::InputOp>("input1", input_shape);
   auto *input2 = mir_graph.create<mir::ops::InputOp>("input2", input_shape);
-  auto *add = mir_graph.create<mir::ops::ElementwiseOp>(
-      "bias_add", std::vector<mir::Operation::Output *>{input1->getOutput(0), input2->getOutput(0)},
-      mir::ops::ElementwiseOp::OpType::add);
+  auto *add = mir_graph.create<mir::ops::AddOp>("add", input1->getOutput(0), input2->getOutput(0));
   auto *output = mir_graph.create<mir::ops::OutputOp>("output", add->getOutput(0));
 
   mir2loco::Transformer transformer;