From c96835113f8b3b3006b8963d3a0ef7fb438fd17d Mon Sep 17 00:00:00 2001 From: =?utf8?q?=D0=9F=D0=B0=D0=B2=D0=B5=D0=BB=20=D0=98=D0=BB=D1=8C=D1=8E?= =?utf8?q?=D1=82=D1=87=D0=B5=D0=BD=D0=BA=D0=BE/AI=20Tools=20Lab=20/SRR/Eng?= =?utf8?q?ineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Wed, 25 Sep 2019 17:50:29 +0300 Subject: [PATCH] [mir2loco] Implemented transformation for FullyConnectedOp (#7414) * Support transformation mir::FullyConnected to loco::MatrixMul Signed-off-by: Pavel Iliutchenko --- compiler/mir2loco/include/mir2loco.h | 1 + compiler/mir2loco/src/mir2loco.cpp | 65 +++++++++++++++++++++++++++++++++ compiler/mir2loco/src/mir2loco.test.cpp | 54 +++++++++++++++++++++++++++ 3 files changed, 120 insertions(+) diff --git a/compiler/mir2loco/include/mir2loco.h b/compiler/mir2loco/include/mir2loco.h index fdedd59..ff1faf0 100644 --- a/compiler/mir2loco/include/mir2loco.h +++ b/compiler/mir2loco/include/mir2loco.h @@ -34,6 +34,7 @@ public: void visit(mir::ops::Conv2DOp &op) override; void visit(mir::ops::DeConv2DOp &op) override; void visit(mir::ops::DepthwiseConv2DOp &op) override; + void visit(mir::ops::FullyConnectedOp &op) override; void visit(mir::ops::InputOp &op) override; void visit(mir::ops::MaxPool2DOp &op) override; void visit(mir::ops::MulOp &op) override; diff --git a/compiler/mir2loco/src/mir2loco.cpp b/compiler/mir2loco/src/mir2loco.cpp index 9ac5e73..87b86d7 100644 --- a/compiler/mir2loco/src/mir2loco.cpp +++ b/compiler/mir2loco/src/mir2loco.cpp @@ -23,6 +23,7 @@ #include "mir/ops/Conv2DOp.h" #include "mir/ops/Deconv2DOp.h" #include "mir/ops/DepthwiseConv2DOp.h" +#include "mir/ops/FullyConnectedOp.h" #include "mir/ops/MaxPool2DOp.h" #include "mir/ops/MulOp.h" #include "mir/ops/ReluOp.h" @@ -142,6 +143,40 @@ loco::Permutation createHWOIFilterPermutation() return perm; } +loco::Permutation createMatrixPermutation(bool height_first = true) +{ + loco::Permutation perm; + if (height_first) + { + perm.axis(loco::MatrixAxis::Height) = 0; + perm.axis(loco::MatrixAxis::Width) = 1; + } + else + { + perm.axis(loco::MatrixAxis::Width) = 0; + perm.axis(loco::MatrixAxis::Height) = 1; + } + return perm; +} + +loco::MatrixEncode *createMatrixEncode(loco::Graph *graph, bool height_first = true) +{ + auto encode_node = graph->nodes()->create(); + auto perm = createMatrixPermutation(height_first); + auto enc = stdex::make_unique>(perm); + encode_node->encoder(std::move(enc)); + return encode_node; +} + +loco::MatrixDecode *createMatrixDecode(loco::Graph *graph, bool height_first = true) +{ + auto decode_node = graph->nodes()->create(); + auto perm = createMatrixPermutation(height_first); + auto dec = stdex::make_unique>(perm); + decode_node->decoder(std::move(dec)); + return decode_node; +} + loco::FilterEncode *createFilterEncode(loco::Graph *graph, const loco::Permutation &perm) { @@ -438,6 +473,36 @@ void Transformer::visit(mir::ops::DepthwiseConv2DOp &op) _mir2loco_map.emplace(&op, feature_dec); } +void Transformer::visit(mir::ops::FullyConnectedOp &op) +{ + auto input = op.getInput(0)->getProducer()->getNode(); + auto weights = op.getInput(1)->getProducer()->getNode(); + // Check 2D shape + assert(op.getInput(0)->getProducer()->getShape().rank() == 2); + assert(op.getInput(1)->getProducer()->getShape().rank() == 2); + // Get Nodes + auto input_node = _mir2loco_map.at(input); + auto weights_node = _mir2loco_map.at(weights); + // MatrixEncode + auto input_enc = createMatrixEncode(_loco_graph.get()); + auto weights_enc = createMatrixEncode(_loco_graph.get()); + // Set inputs for encodes + input_enc->input(input_node); + weights_enc->input(weights_node); + // Create op + auto mat_mul = _loco_graph->nodes()->create(); + // Set lhs and rhs + mat_mul->lhs(input_enc); + mat_mul->rhs(weights_enc); + // MatrixDecode + auto matrix_dec = createMatrixDecode(_loco_graph.get()); + // Set input + matrix_dec->input(mat_mul); + // Not set Shape + // Add to map + _mir2loco_map.emplace(&op, matrix_dec); +} + void Transformer::visit(mir::ops::InputOp &op) { auto pull_node = _loco_graph->nodes()->create(); diff --git a/compiler/mir2loco/src/mir2loco.test.cpp b/compiler/mir2loco/src/mir2loco.test.cpp index f7348af..7ed0820 100644 --- a/compiler/mir2loco/src/mir2loco.test.cpp +++ b/compiler/mir2loco/src/mir2loco.test.cpp @@ -23,6 +23,7 @@ #include "mir/ops/Conv2DOp.h" #include "mir/ops/Deconv2DOp.h" #include "mir/ops/DepthwiseConv2DOp.h" +#include "mir/ops/FullyConnectedOp.h" #include "mir/ops/MaxPool2DOp.h" #include "mir/ops/MulOp.h" #include "mir/ops/ReluOp.h" @@ -616,3 +617,56 @@ TEST_F(TestTransformer_mir2loco, DeConv2D_Test) ASSERT_NE(push_node, nullptr); ASSERT_EQ(push_node->from(), decode_node); } + +TEST_F(TestTransformer_mir2loco, FullyConnected_Test) +{ + mir::Graph mir_graph; + + mir::Shape input_shape{10, 2}; + auto *input = mir_graph.create(input_shape)->getOutput(0); + const float data[] = {5.9, 5.32, 54.11231, 3.409}; + auto mir_tensor = + mir::TensorVariant(mir::DataType::FLOAT32, mir::Shape{2, 2}, (const void *)data); + auto *constant = mir_graph.create(mir_tensor)->getOutput(0); + auto *fc = mir_graph.create(input, constant)->getOutput(0); + mir_graph.create(fc); + input->setName("x"); + fc->setName("y"); + + mir2loco::Transformer transformer; + auto loco_graph = transformer.transform(&mir_graph); + + // Pull + auto inputs = loco_graph->inputs(); + loco::Pull *pull_node = loco::pull_node(loco_graph.get(), 0); + ASSERT_NE(pull_node, nullptr); + // MatrixEncode + auto pull_uses = loco::succs(pull_node); + ASSERT_EQ(pull_uses.size(), 1); + loco::MatrixEncode *encode_node = dynamic_cast(*pull_uses.begin()); + ASSERT_NE(encode_node, nullptr); + ASSERT_EQ(encode_node->input(), pull_node); + // MatMul + auto encode_uses = loco::succs(encode_node); + ASSERT_EQ(encode_uses.size(), 1); + loco::MatMul *fc_node = dynamic_cast(*encode_uses.begin()); + ASSERT_NE(fc_node, nullptr); + loco::MatrixEncode *kernel_encode_node = dynamic_cast(fc_node->rhs()); + ASSERT_NE(kernel_encode_node, nullptr); + ASSERT_EQ(fc_node->lhs(), encode_node); + // ConstGen + loco::ConstGen *const_node = dynamic_cast(kernel_encode_node->input()); + ASSERT_NE(const_node, nullptr); + // MatrixDecode + auto fc_uses = loco::succs(fc_node); + ASSERT_EQ(fc_uses.size(), 1); + loco::MatrixDecode *decode_node = dynamic_cast(*fc_uses.begin()); + ASSERT_NE(decode_node, nullptr); + ASSERT_EQ(decode_node->input(), fc_node); + // Push + auto decode_uses = loco::succs(decode_node); + ASSERT_EQ(decode_uses.size(), 1); + loco::Push *push_node = dynamic_cast(*decode_uses.begin()); + ASSERT_NE(push_node, nullptr); + ASSERT_EQ(push_node->from(), decode_node); +} -- 2.7.4