From bf34852d52caf863bf9742893ddc7d72d6f77b52 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EC=9C=A4=ED=98=84=EC=8B=9D/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Fri, 4 Oct 2019 17:20:13 +0900 Subject: [PATCH] [logo] simplifying FilterEncode-FilterDecode (#7875) * [logo] simplifying FilterEncode-FilterDecode to Transpose (or to Forward) This commit simplifying FilterEncode-FilterDecode to Transpose (or to Forward) Signed-off-by: Hyun Sik Yoon * fix wrong const size --- compiler/logo/CMakeLists.txt | 1 + .../src/Passes/SimplifyDomainConversionPass.cpp | 113 +++++++++++++- .../Passes/SimplifyDomainConversionPass.test.cpp | 164 +++++++++++++++++++++ 3 files changed, 272 insertions(+), 6 deletions(-) create mode 100644 compiler/logo/src/Passes/SimplifyDomainConversionPass.test.cpp diff --git a/compiler/logo/CMakeLists.txt b/compiler/logo/CMakeLists.txt index c76d009..88ce51c 100644 --- a/compiler/logo/CMakeLists.txt +++ b/compiler/logo/CMakeLists.txt @@ -19,3 +19,4 @@ nnas_find_package(GTest REQUIRED) GTest_AddTest(logo_test ${TESTS}) target_include_directories(logo_test PRIVATE src) target_link_libraries(logo_test logo) +target_link_libraries(logo_test stdex) diff --git a/compiler/logo/src/Passes/SimplifyDomainConversionPass.cpp b/compiler/logo/src/Passes/SimplifyDomainConversionPass.cpp index 1f30717..89acdcf 100644 --- a/compiler/logo/src/Passes/SimplifyDomainConversionPass.cpp +++ b/compiler/logo/src/Passes/SimplifyDomainConversionPass.cpp @@ -20,7 +20,10 @@ #include #include +#include + #include +#include #include namespace @@ -42,6 +45,19 @@ bool equal(const Permutation *lhs, const Permutation *lhs, const Permutation *rhs) +{ + for (const auto &axis : + {FilterAxis::Count, FilterAxis::Depth, FilterAxis::Height, FilterAxis::Width}) + { + if (lhs->axis(axis) != rhs->axis(axis)) + { + return false; + } + } + return true; +} + void set_input_null(loco::Node *node) { if (auto casted = dynamic_cast(node)) @@ -50,6 +66,8 @@ void set_input_null(loco::Node *node) casted->input(nullptr); else if (auto casted = dynamic_cast(node)) casted->input(nullptr); + else if (auto casted = dynamic_cast(node)) + casted->input(nullptr); else assert(false && "not supported node type"); } @@ -93,7 +111,7 @@ bool SimplifyDomainConversionPass::run(loco::Graph *g) if (equal(perm_encoder->perm(), perm_decoder->perm())) { - candidates.insert({encode_node, decode_node->input()}); + forwardCandidates.insert({encode_node, decode_node->input()}); } } @@ -126,7 +144,59 @@ bool SimplifyDomainConversionPass::run(loco::Graph *g) if (equal(perm_encoder->perm(), perm_decoder->perm())) { - candidates.insert({decode_node, encode_node->input()}); + forwardCandidates.insert({decode_node, encode_node->input()}); + } + } + + // Let's find `FilterEncode -- FilterDecode` pattern + void visit(loco::FilterDecode *decode_node) final + { + using namespace loco; + + auto encode_node = dynamic_cast(decode_node->input()); + if (encode_node == nullptr) + { + return; + } + assert(encode_node->input() != nullptr); + + auto encoder = encode_node->encoder(); + assert(encoder != nullptr); + + auto decoder = decode_node->decoder(); + assert(decoder != nullptr); + + // NOTE Work only for permuting codec + auto perm_decoder = dynamic_cast *>(decoder); + auto perm_encoder = dynamic_cast *>(encoder); + + if (perm_encoder == nullptr || perm_decoder == nullptr) + { + return; + } + + if (equal(perm_encoder->perm(), perm_decoder->perm())) + { + forwardCandidates.insert({decode_node, encode_node->input()}); + } + else + { + std::vector perm_vec; + perm_vec.resize(4); + + auto enc_perm = perm_encoder->perm(); + auto dec_perm = perm_decoder->perm(); + + for (const auto &axis : + {FilterAxis::Count, FilterAxis::Height, FilterAxis::Width, FilterAxis::Depth}) + { + auto from = enc_perm->axis(axis); + auto to = dec_perm->axis(axis); + perm_vec[to] = from; + } + + transposeCandidates.insert(stdex::make_unique( + encode_node, decode_node, encode_node->input(), perm_vec)); } } @@ -136,7 +206,7 @@ bool SimplifyDomainConversionPass::run(loco::Graph *g) if (auto encode_node = dynamic_cast(decode_node->input())) { assert(encode_node->input() != nullptr); - candidates.insert({decode_node, encode_node->input()}); + forwardCandidates.insert({decode_node, encode_node->input()}); } } @@ -144,7 +214,23 @@ bool SimplifyDomainConversionPass::run(loco::Graph *g) using SimplifyingInfo = std::pair; - std::set candidates; + std::set forwardCandidates; + + struct TransposeCtx + { + loco::Node *first_node; // starting node of subgraph that will be replaced + loco::Node *last_node; // end node of subgraph that will be replaced + loco::Node *input_node; // input of subgraph + std::vector perm_vec; // perm vector for transpose + + TransposeCtx(loco::Node *first, loco::Node *last, loco::Node *input, + std::vector perm) + : first_node(first), last_node(last), input_node(input), perm_vec(perm) + { /* empty */ + } + }; + + std::set> transposeCandidates; }; Collector collector; @@ -158,7 +244,7 @@ bool SimplifyDomainConversionPass::run(loco::Graph *g) } } - for (auto p : collector.candidates) + for (auto p : collector.forwardCandidates) { auto forward_node = g->nodes()->create(); forward_node->input(p.second); @@ -166,7 +252,22 @@ bool SimplifyDomainConversionPass::run(loco::Graph *g) set_input_null(p.first); } - return collector.candidates.size() > 0; + for (auto &ctx : collector.transposeCandidates) + { + auto transpose_node = g->nodes()->create(); + { + transpose_node->perm()->size(ctx->perm_vec.size()); + + for (loco::TensorAxis axis = 0; axis < ctx->perm_vec.size(); axis++) + transpose_node->perm()->axis(axis) = ctx->perm_vec[axis]; + } + + transpose_node->input(ctx->input_node); + replace(ctx->last_node).with(transpose_node); + set_input_null(ctx->first_node); + } + + return (collector.forwardCandidates.size() > 0 or collector.transposeCandidates.size() > 0); } } // namespace logo diff --git a/compiler/logo/src/Passes/SimplifyDomainConversionPass.test.cpp b/compiler/logo/src/Passes/SimplifyDomainConversionPass.test.cpp new file mode 100644 index 0000000..bbb07b9 --- /dev/null +++ b/compiler/logo/src/Passes/SimplifyDomainConversionPass.test.cpp @@ -0,0 +1,164 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "TestHelper.h" + +#include +#include + +#include + +namespace +{ + +// code borrowed from GraphBlock.h/cpp in exo-tflite +enum class FilterLayout +{ + OHWI, // a.k.a., NHWC, Tensorflow Lite uses this layout + HWIO, // Tensorflow format +}; + +template loco::Permutation perm(); + +template <> loco::Permutation perm() +{ + // Make NHWC permutation for encoder and decoder + loco::Permutation OHWI; // a.k.a., NHWC + + OHWI.axis(loco::FilterAxis::Count) = 0; + OHWI.axis(loco::FilterAxis::Height) = 1; + OHWI.axis(loco::FilterAxis::Width) = 2; + OHWI.axis(loco::FilterAxis::Depth) = 3; + + return OHWI; +} + +template <> loco::Permutation perm() +{ + // Make NHWC permutation for encoder and decoder + loco::Permutation HWIO; + + HWIO.axis(loco::FilterAxis::Height) = 0; + HWIO.axis(loco::FilterAxis::Width) = 1; + HWIO.axis(loco::FilterAxis::Depth) = 2; + HWIO.axis(loco::FilterAxis::Count) = 3; + + return HWIO; +} + +template loco::FilterDecode *make_filter_decode(loco::Node *input_for_decode) +{ + loco::Graph *g = input_for_decode->graph(); + + auto decoder = stdex::make_unique>(); + + decoder->perm(perm()); + + auto dec = g->nodes()->create(); + dec->input(input_for_decode); + dec->decoder(std::move(decoder)); + + return dec; +} + +template loco::FilterEncode *make_filter_encode(loco::Node *input_for_encode) +{ + loco::Graph *g = input_for_encode->graph(); + + auto encoder = stdex::make_unique>(); + + encoder->perm(perm()); + + auto enc = g->nodes()->create(); + enc->input(input_for_encode); + enc->encoder(std::move(encoder)); + + return enc; +} + +/* + test case: + ConstGen (2x3x4x5) ---- FeatureEncode ---- FeatureDecode --- Push + 0 H O 0 + 1 W H 1 + 2 I(depth) W 2 + 3 O(coutn) I 3 + + axis 0 ---------------------> H --------------> H -----------> 1 + axis 1 ---------------------> W --------------> W -----------> 2 + axis 2 ---------------------> I --------------> I -----------> 3 + axis 3 ---------------------> O --------------> O -----------> 0 + + so perm vercotr of Tranpose = [3, 0, 1, 2] +*/ +void create_net_FilterEncode_FilterDecode_different_perms(loco::Graph *graph) +{ + assert(graph); + + auto const_node = graph->nodes()->create(); + { + const_node->dtype(loco::DataType::FLOAT32); + const_node->rank(4); + const_node->dim(0) = 1; // any number + const_node->dim(1) = 1; + const_node->dim(2) = 1; + const_node->dim(3) = 1; + const_node->size(1 * 1 * 1 * 1); + for (uint32_t i = 0; i < 1 * 1 * 1 * 1; i++) + const_node->at(i) = 3.14f; // any number + } + + auto encoder = make_filter_encode(const_node); + auto decoder = make_filter_decode(encoder); + + auto push_node = graph->nodes()->create(); + { + push_node->from(decoder); + } + + auto graph_output = graph->outputs()->create(); + { + graph_output->name("output"); + graph_output->dtype(loco::DataType::FLOAT32); + loco::link(graph_output, push_node); + } +} + +} // namespace + +TEST(SimplifyDomainConversionPass, FilterEncode_FilterDecode_different_perms) +{ + auto graph = loco::make_graph(); + create_net_FilterEncode_FilterDecode_different_perms(graph.get()); + + logo::SimplifyDomainConversionPass pass; + while (pass.run(graph.get()) == true) + ; + + auto tr = logo::test::find_first_node_by_type(graph.get()); + { + ASSERT_EQ(tr->perm()->size(), 4); + ASSERT_EQ(tr->perm()->axis(0), 3); + ASSERT_EQ(tr->perm()->axis(1), 0); + ASSERT_EQ(tr->perm()->axis(2), 1); + ASSERT_EQ(tr->perm()->axis(3), 2); + } + + auto const_gen = dynamic_cast(tr->input()); + ASSERT_NE(const_gen, nullptr); +} -- 2.7.4