From c96835113f8b3b3006b8963d3a0ef7fb438fd17d Mon Sep 17 00:00:00 2001
From: =?utf8?q?=D0=9F=D0=B0=D0=B2=D0=B5=D0=BB=20=D0=98=D0=BB=D1=8C=D1=8E?=
=?utf8?q?=D1=82=D1=87=D0=B5=D0=BD=D0=BA=D0=BE/AI=20Tools=20Lab=20/SRR/Eng?=
=?utf8?q?ineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?=
Date: Wed, 25 Sep 2019 17:50:29 +0300
Subject: [PATCH] [mir2loco] Implemented transformation for FullyConnectedOp
(#7414)
* Support transformation mir::FullyConnected to loco::MatrixMul
Signed-off-by: Pavel Iliutchenko
---
compiler/mir2loco/include/mir2loco.h | 1 +
compiler/mir2loco/src/mir2loco.cpp | 65 +++++++++++++++++++++++++++++++++
compiler/mir2loco/src/mir2loco.test.cpp | 54 +++++++++++++++++++++++++++
3 files changed, 120 insertions(+)
diff --git a/compiler/mir2loco/include/mir2loco.h b/compiler/mir2loco/include/mir2loco.h
index fdedd59..ff1faf0 100644
--- a/compiler/mir2loco/include/mir2loco.h
+++ b/compiler/mir2loco/include/mir2loco.h
@@ -34,6 +34,7 @@ public:
void visit(mir::ops::Conv2DOp &op) override;
void visit(mir::ops::DeConv2DOp &op) override;
void visit(mir::ops::DepthwiseConv2DOp &op) override;
+ void visit(mir::ops::FullyConnectedOp &op) override;
void visit(mir::ops::InputOp &op) override;
void visit(mir::ops::MaxPool2DOp &op) override;
void visit(mir::ops::MulOp &op) override;
diff --git a/compiler/mir2loco/src/mir2loco.cpp b/compiler/mir2loco/src/mir2loco.cpp
index 9ac5e73..87b86d7 100644
--- a/compiler/mir2loco/src/mir2loco.cpp
+++ b/compiler/mir2loco/src/mir2loco.cpp
@@ -23,6 +23,7 @@
#include "mir/ops/Conv2DOp.h"
#include "mir/ops/Deconv2DOp.h"
#include "mir/ops/DepthwiseConv2DOp.h"
+#include "mir/ops/FullyConnectedOp.h"
#include "mir/ops/MaxPool2DOp.h"
#include "mir/ops/MulOp.h"
#include "mir/ops/ReluOp.h"
@@ -142,6 +143,40 @@ loco::Permutation createHWOIFilterPermutation()
return perm;
}
+loco::Permutation createMatrixPermutation(bool height_first = true)
+{
+ loco::Permutation perm;
+ if (height_first)
+ {
+ perm.axis(loco::MatrixAxis::Height) = 0;
+ perm.axis(loco::MatrixAxis::Width) = 1;
+ }
+ else
+ {
+ perm.axis(loco::MatrixAxis::Width) = 0;
+ perm.axis(loco::MatrixAxis::Height) = 1;
+ }
+ return perm;
+}
+
+loco::MatrixEncode *createMatrixEncode(loco::Graph *graph, bool height_first = true)
+{
+ auto encode_node = graph->nodes()->create();
+ auto perm = createMatrixPermutation(height_first);
+ auto enc = stdex::make_unique>(perm);
+ encode_node->encoder(std::move(enc));
+ return encode_node;
+}
+
+loco::MatrixDecode *createMatrixDecode(loco::Graph *graph, bool height_first = true)
+{
+ auto decode_node = graph->nodes()->create();
+ auto perm = createMatrixPermutation(height_first);
+ auto dec = stdex::make_unique>(perm);
+ decode_node->decoder(std::move(dec));
+ return decode_node;
+}
+
loco::FilterEncode *createFilterEncode(loco::Graph *graph,
const loco::Permutation &perm)
{
@@ -438,6 +473,36 @@ void Transformer::visit(mir::ops::DepthwiseConv2DOp &op)
_mir2loco_map.emplace(&op, feature_dec);
}
+void Transformer::visit(mir::ops::FullyConnectedOp &op)
+{
+ auto input = op.getInput(0)->getProducer()->getNode();
+ auto weights = op.getInput(1)->getProducer()->getNode();
+ // Check 2D shape
+ assert(op.getInput(0)->getProducer()->getShape().rank() == 2);
+ assert(op.getInput(1)->getProducer()->getShape().rank() == 2);
+ // Get Nodes
+ auto input_node = _mir2loco_map.at(input);
+ auto weights_node = _mir2loco_map.at(weights);
+ // MatrixEncode
+ auto input_enc = createMatrixEncode(_loco_graph.get());
+ auto weights_enc = createMatrixEncode(_loco_graph.get());
+ // Set inputs for encodes
+ input_enc->input(input_node);
+ weights_enc->input(weights_node);
+ // Create op
+ auto mat_mul = _loco_graph->nodes()->create();
+ // Set lhs and rhs
+ mat_mul->lhs(input_enc);
+ mat_mul->rhs(weights_enc);
+ // MatrixDecode
+ auto matrix_dec = createMatrixDecode(_loco_graph.get());
+ // Set input
+ matrix_dec->input(mat_mul);
+ // Not set Shape
+ // Add to map
+ _mir2loco_map.emplace(&op, matrix_dec);
+}
+
void Transformer::visit(mir::ops::InputOp &op)
{
auto pull_node = _loco_graph->nodes()->create();
diff --git a/compiler/mir2loco/src/mir2loco.test.cpp b/compiler/mir2loco/src/mir2loco.test.cpp
index f7348af..7ed0820 100644
--- a/compiler/mir2loco/src/mir2loco.test.cpp
+++ b/compiler/mir2loco/src/mir2loco.test.cpp
@@ -23,6 +23,7 @@
#include "mir/ops/Conv2DOp.h"
#include "mir/ops/Deconv2DOp.h"
#include "mir/ops/DepthwiseConv2DOp.h"
+#include "mir/ops/FullyConnectedOp.h"
#include "mir/ops/MaxPool2DOp.h"
#include "mir/ops/MulOp.h"
#include "mir/ops/ReluOp.h"
@@ -616,3 +617,56 @@ TEST_F(TestTransformer_mir2loco, DeConv2D_Test)
ASSERT_NE(push_node, nullptr);
ASSERT_EQ(push_node->from(), decode_node);
}
+
+TEST_F(TestTransformer_mir2loco, FullyConnected_Test)
+{
+ mir::Graph mir_graph;
+
+ mir::Shape input_shape{10, 2};
+ auto *input = mir_graph.create(input_shape)->getOutput(0);
+ const float data[] = {5.9, 5.32, 54.11231, 3.409};
+ auto mir_tensor =
+ mir::TensorVariant(mir::DataType::FLOAT32, mir::Shape{2, 2}, (const void *)data);
+ auto *constant = mir_graph.create(mir_tensor)->getOutput(0);
+ auto *fc = mir_graph.create(input, constant)->getOutput(0);
+ mir_graph.create(fc);
+ input->setName("x");
+ fc->setName("y");
+
+ mir2loco::Transformer transformer;
+ auto loco_graph = transformer.transform(&mir_graph);
+
+ // Pull
+ auto inputs = loco_graph->inputs();
+ loco::Pull *pull_node = loco::pull_node(loco_graph.get(), 0);
+ ASSERT_NE(pull_node, nullptr);
+ // MatrixEncode
+ auto pull_uses = loco::succs(pull_node);
+ ASSERT_EQ(pull_uses.size(), 1);
+ loco::MatrixEncode *encode_node = dynamic_cast(*pull_uses.begin());
+ ASSERT_NE(encode_node, nullptr);
+ ASSERT_EQ(encode_node->input(), pull_node);
+ // MatMul
+ auto encode_uses = loco::succs(encode_node);
+ ASSERT_EQ(encode_uses.size(), 1);
+ loco::MatMul *fc_node = dynamic_cast(*encode_uses.begin());
+ ASSERT_NE(fc_node, nullptr);
+ loco::MatrixEncode *kernel_encode_node = dynamic_cast(fc_node->rhs());
+ ASSERT_NE(kernel_encode_node, nullptr);
+ ASSERT_EQ(fc_node->lhs(), encode_node);
+ // ConstGen
+ loco::ConstGen *const_node = dynamic_cast(kernel_encode_node->input());
+ ASSERT_NE(const_node, nullptr);
+ // MatrixDecode
+ auto fc_uses = loco::succs(fc_node);
+ ASSERT_EQ(fc_uses.size(), 1);
+ loco::MatrixDecode *decode_node = dynamic_cast(*fc_uses.begin());
+ ASSERT_NE(decode_node, nullptr);
+ ASSERT_EQ(decode_node->input(), fc_node);
+ // Push
+ auto decode_uses = loco::succs(decode_node);
+ ASSERT_EQ(decode_uses.size(), 1);
+ loco::Push *push_node = dynamic_cast(*decode_uses.begin());
+ ASSERT_NE(push_node, nullptr);
+ ASSERT_EQ(push_node->from(), decode_node);
+}
--
2.7.4