[mir2loco] Support BiasAdd operation (#5901)
authorПавел Ильютченко/AI Tools Lab /SRR/Engineer/삼성전자 <p.iliutchenk@samsung.com>
Fri, 2 Aug 2019 07:56:05 +0000 (10:56 +0300)
committerEfimov Alexander/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Fri, 2 Aug 2019 07:56:05 +0000 (10:56 +0300)
* Transform `mir::BiasAddOp` to `loco::BiasAdd<Domain::Tensor>` using `loco::BiasEncode>
* Test for this

Signed-off-by: Pavel Iliutchenko <p.iliutchenk@samsung.com>
compiler/mir2loco/src/mir2loco.cpp
compiler/mir2loco/src/mir2loco.test.cpp

index 17e9d22..f236039 100644 (file)
@@ -16,6 +16,7 @@
 
 #include "mir2loco.h"
 
+#include "mir/ops/BiasAddOp.h"
 #include "mir/ops/ConcatOp.h"
 #include "mir/ops/ConstantOp.h"
 #include "mir/ops/PoolOp.h"
@@ -122,7 +123,26 @@ loco::DataType DTYPE2DataType(const mir::DTYPE &dtype)
 
 void Transformer::visit(mir::ops::BatchNormOp &op) { throw std::runtime_error("NYI"); }
 
-void Transformer::visit(mir::ops::BiasAddOp &op) { throw std::runtime_error("NYI"); }
+void Transformer::visit(mir::ops::BiasAddOp &op)
+{
+  // Set Input
+  auto input = op.getInput(0)->getProducer()->getNode();
+  auto bias = op.getInput(1)->getProducer()->getNode();
+  auto loco_input = _mir2loco_map.at(input);
+  auto loco_bias = _mir2loco_map.at(bias);
+  // Create BiasEncode
+  auto bias_encode = _loco_graph->nodes()->create<loco::BiasEncode>();
+  bias_encode->input(loco_bias);
+  // Set value and bias
+  auto bias_add_node = _loco_graph->nodes()->create<loco::TensorBiasAdd>();
+  bias_add_node->value(loco_input);
+  bias_add_node->bias(bias_encode);
+  // Set axis
+  bias_add_node->axis(3); // NHWC
+  // Not set Shape
+  // Add to map
+  _mir2loco_map.emplace(&op, bias_add_node);
+}
 
 void Transformer::visit(mir::ops::CappedReluOp &op) { throw std::runtime_error("NYI"); }
 
index dcc85e1..b8d6770 100644 (file)
@@ -16,6 +16,7 @@
 
 #include "mir2loco.h"
 
+#include "mir/ops/BiasAddOp.h"
 #include "mir/ops/ConcatOp.h"
 #include "mir/ops/ConstantOp.h"
 #include "mir/ops/PoolOp.h"
@@ -262,6 +263,7 @@ TEST_F(TestTransformer_mir2loco, Const_Float_Test)
 
   mir2loco::Transformer transformer;
   auto loco_graph = transformer.transform(&mir_graph);
+
   loco::ConstGen *const_node = dynamic_cast<loco::ConstGen *>(loco_graph->nodes()->at(0));
   loco::Push *push_node = dynamic_cast<loco::Push *>(loco_graph->nodes()->at(1));
 
@@ -276,3 +278,48 @@ TEST_F(TestTransformer_mir2loco, Const_Float_Test)
   for (int i = 0; i < 6; i++)
     ASSERT_FLOAT_EQ(const_node->at<loco::DataType::FLOAT32>(i), data[i]);
 }
+
+TEST_F(TestTransformer_mir2loco, Bias_Add_Test)
+{
+  mir::Graph mir_graph;
+
+  mir::Shape input_shape{5, 6, 7, 3};
+  mir::Shape bias_shape{3};
+  auto *input = mir_graph.create<mir::ops::InputOp>("input", input_shape);
+  auto *bias = mir_graph.create<mir::ops::InputOp>("bias", bias_shape);
+  auto *bias_add =
+      mir_graph.create<mir::ops::BiasAddOp>("bias_add", input->getOutput(0), bias->getOutput(0));
+  auto *output = mir_graph.create<mir::ops::OutputOp>("output", bias_add->getOutput(0));
+
+  mir2loco::Transformer transformer;
+  auto loco_graph = transformer.transform(&mir_graph);
+
+  loco::Pull *pull1_node = dynamic_cast<loco::Pull *>(loco_graph->nodes()->at(0));
+  loco::Pull *pull2_node = dynamic_cast<loco::Pull *>(loco_graph->nodes()->at(1));
+  loco::BiasEncode *bias_enc_node = dynamic_cast<loco::BiasEncode *>(loco_graph->nodes()->at(2));
+  loco::TensorBiasAdd *bias_add_node =
+      dynamic_cast<loco::TensorBiasAdd *>(loco_graph->nodes()->at(3));
+  loco::Push *push_node = dynamic_cast<loco::Push *>(loco_graph->nodes()->at(4));
+
+  ASSERT_NE(pull1_node, nullptr);
+  ASSERT_NE(pull2_node, nullptr);
+  ASSERT_NE(bias_enc_node, nullptr);
+  ASSERT_NE(bias_add_node, nullptr);
+  ASSERT_NE(push_node, nullptr);
+
+  ASSERT_EQ(bias_enc_node->input(), pull2_node);
+  ASSERT_EQ(bias_add_node->value(), pull1_node);
+  ASSERT_EQ(bias_add_node->bias(), bias_enc_node);
+  ASSERT_EQ(push_node->from(), bias_add_node);
+  // Shape check
+  ASSERT_EQ(pull1_node->rank(), 4);
+  ASSERT_EQ(pull1_node->dim(0), 5);
+  ASSERT_EQ(pull1_node->dim(1), 6);
+  ASSERT_EQ(pull1_node->dim(2), 7);
+  ASSERT_EQ(pull1_node->dim(3), 3);
+
+  ASSERT_EQ(pull2_node->rank(), 1);
+  ASSERT_EQ(pull2_node->dim(0), 3);
+
+  ASSERT_EQ(bias_add_node->axis(), 3);
+}