From 9a1c91b65330992e908666534dac4c8d4bd7ab8e Mon Sep 17 00:00:00 2001
From: =?utf8?q?=D0=9F=D0=B0=D0=B2=D0=B5=D0=BB=20=D0=98=D0=BB=D1=8C=D1=8E?=
=?utf8?q?=D1=82=D1=87=D0=B5=D0=BD=D0=BA=D0=BE/AI=20Tools=20Lab=20/SRR/Eng?=
=?utf8?q?ineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?=
Date: Fri, 2 Aug 2019 10:56:05 +0300
Subject: [PATCH] [mir2loco] Support BiasAdd operation (#5901)
* Transform `mir::BiasAddOp` to `loco::BiasAdd` using `loco::BiasEncode>
* Test for this
Signed-off-by: Pavel Iliutchenko
---
compiler/mir2loco/src/mir2loco.cpp | 22 ++++++++++++++-
compiler/mir2loco/src/mir2loco.test.cpp | 47 +++++++++++++++++++++++++++++++++
2 files changed, 68 insertions(+), 1 deletion(-)
diff --git a/compiler/mir2loco/src/mir2loco.cpp b/compiler/mir2loco/src/mir2loco.cpp
index 17e9d22..f236039 100644
--- a/compiler/mir2loco/src/mir2loco.cpp
+++ b/compiler/mir2loco/src/mir2loco.cpp
@@ -16,6 +16,7 @@
#include "mir2loco.h"
+#include "mir/ops/BiasAddOp.h"
#include "mir/ops/ConcatOp.h"
#include "mir/ops/ConstantOp.h"
#include "mir/ops/PoolOp.h"
@@ -122,7 +123,26 @@ loco::DataType DTYPE2DataType(const mir::DTYPE &dtype)
void Transformer::visit(mir::ops::BatchNormOp &op) { throw std::runtime_error("NYI"); }
-void Transformer::visit(mir::ops::BiasAddOp &op) { throw std::runtime_error("NYI"); }
+void Transformer::visit(mir::ops::BiasAddOp &op)
+{
+ // Set Input
+ auto input = op.getInput(0)->getProducer()->getNode();
+ auto bias = op.getInput(1)->getProducer()->getNode();
+ auto loco_input = _mir2loco_map.at(input);
+ auto loco_bias = _mir2loco_map.at(bias);
+ // Create BiasEncode
+ auto bias_encode = _loco_graph->nodes()->create();
+ bias_encode->input(loco_bias);
+ // Set value and bias
+ auto bias_add_node = _loco_graph->nodes()->create();
+ bias_add_node->value(loco_input);
+ bias_add_node->bias(bias_encode);
+ // Set axis
+ bias_add_node->axis(3); // NHWC
+ // Not set Shape
+ // Add to map
+ _mir2loco_map.emplace(&op, bias_add_node);
+}
void Transformer::visit(mir::ops::CappedReluOp &op) { throw std::runtime_error("NYI"); }
diff --git a/compiler/mir2loco/src/mir2loco.test.cpp b/compiler/mir2loco/src/mir2loco.test.cpp
index dcc85e1..b8d6770 100644
--- a/compiler/mir2loco/src/mir2loco.test.cpp
+++ b/compiler/mir2loco/src/mir2loco.test.cpp
@@ -16,6 +16,7 @@
#include "mir2loco.h"
+#include "mir/ops/BiasAddOp.h"
#include "mir/ops/ConcatOp.h"
#include "mir/ops/ConstantOp.h"
#include "mir/ops/PoolOp.h"
@@ -262,6 +263,7 @@ TEST_F(TestTransformer_mir2loco, Const_Float_Test)
mir2loco::Transformer transformer;
auto loco_graph = transformer.transform(&mir_graph);
+
loco::ConstGen *const_node = dynamic_cast(loco_graph->nodes()->at(0));
loco::Push *push_node = dynamic_cast(loco_graph->nodes()->at(1));
@@ -276,3 +278,48 @@ TEST_F(TestTransformer_mir2loco, Const_Float_Test)
for (int i = 0; i < 6; i++)
ASSERT_FLOAT_EQ(const_node->at(i), data[i]);
}
+
+TEST_F(TestTransformer_mir2loco, Bias_Add_Test)
+{
+ mir::Graph mir_graph;
+
+ mir::Shape input_shape{5, 6, 7, 3};
+ mir::Shape bias_shape{3};
+ auto *input = mir_graph.create("input", input_shape);
+ auto *bias = mir_graph.create("bias", bias_shape);
+ auto *bias_add =
+ mir_graph.create("bias_add", input->getOutput(0), bias->getOutput(0));
+ auto *output = mir_graph.create("output", bias_add->getOutput(0));
+
+ mir2loco::Transformer transformer;
+ auto loco_graph = transformer.transform(&mir_graph);
+
+ loco::Pull *pull1_node = dynamic_cast(loco_graph->nodes()->at(0));
+ loco::Pull *pull2_node = dynamic_cast(loco_graph->nodes()->at(1));
+ loco::BiasEncode *bias_enc_node = dynamic_cast(loco_graph->nodes()->at(2));
+ loco::TensorBiasAdd *bias_add_node =
+ dynamic_cast(loco_graph->nodes()->at(3));
+ loco::Push *push_node = dynamic_cast(loco_graph->nodes()->at(4));
+
+ ASSERT_NE(pull1_node, nullptr);
+ ASSERT_NE(pull2_node, nullptr);
+ ASSERT_NE(bias_enc_node, nullptr);
+ ASSERT_NE(bias_add_node, nullptr);
+ ASSERT_NE(push_node, nullptr);
+
+ ASSERT_EQ(bias_enc_node->input(), pull2_node);
+ ASSERT_EQ(bias_add_node->value(), pull1_node);
+ ASSERT_EQ(bias_add_node->bias(), bias_enc_node);
+ ASSERT_EQ(push_node->from(), bias_add_node);
+ // Shape check
+ ASSERT_EQ(pull1_node->rank(), 4);
+ ASSERT_EQ(pull1_node->dim(0), 5);
+ ASSERT_EQ(pull1_node->dim(1), 6);
+ ASSERT_EQ(pull1_node->dim(2), 7);
+ ASSERT_EQ(pull1_node->dim(3), 3);
+
+ ASSERT_EQ(pull2_node->rank(), 1);
+ ASSERT_EQ(pull2_node->dim(0), 3);
+
+ ASSERT_EQ(bias_add_node->axis(), 3);
+}
--
2.7.4