From 7dc64d4862cc387bab79c90602dba07866634138 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Tue, 27 Aug 2019 13:20:53 +0900 Subject: [PATCH] [logo] Introduce ResolveDuplicateReshapePass (#6938) This will introduce ResolveDuplicateReshapePass copied from moco Signed-off-by: SaeHie Park --- .../include/logo/ResolveDuplicateReshapePass.h | 41 ++++++++ .../src/Passes/ResolveDuplicateReshapePass.cpp | 108 +++++++++++++++++++++ 2 files changed, 149 insertions(+) create mode 100644 compiler/logo/include/logo/ResolveDuplicateReshapePass.h create mode 100644 compiler/logo/src/Passes/ResolveDuplicateReshapePass.cpp diff --git a/compiler/logo/include/logo/ResolveDuplicateReshapePass.h b/compiler/logo/include/logo/ResolveDuplicateReshapePass.h new file mode 100644 index 0000000..7e6c67f --- /dev/null +++ b/compiler/logo/include/logo/ResolveDuplicateReshapePass.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LOGO_RESOLVE_DUPLICATE_RESHAPE_PASS_H__ +#define __LOGO_RESOLVE_DUPLICATE_RESHAPE_PASS_H__ + +#include + +#include + +namespace logo +{ + +/** + * @brief Resolve duplicated Reshape nodes in a row + */ +class ResolveDuplicateReshapePass final : public Pass +{ +public: + const char *name(void) const final { return "ResolveDuplicateReshapePass"; } + +public: + bool run(loco::Graph *graph) override; +}; + +} // namespace logo + +#endif // __LOGO_RESOLVE_DUPLICATE_RESHAPE_PASS_H__ diff --git a/compiler/logo/src/Passes/ResolveDuplicateReshapePass.cpp b/compiler/logo/src/Passes/ResolveDuplicateReshapePass.cpp new file mode 100644 index 0000000..d3c74cb --- /dev/null +++ b/compiler/logo/src/Passes/ResolveDuplicateReshapePass.cpp @@ -0,0 +1,108 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include + +#include + +namespace +{ + +/// @return true when 'node' and its input node are both FixedReshapes +bool is_duplicate_reshape(loco::Node *node) +{ + auto node_as_reshape = dynamic_cast(node); + + if (!node_as_reshape) + return false; + + auto input_as_reshape = dynamic_cast(node_as_reshape->input()); + + if (!input_as_reshape) + return false; + + return true; +} + +/** + * @brief Remap reshape's input to its input's input, i.e. bypass input reshape + * + * Before: + * + * In ----- FixedReshape_1 ----- [Out_1]* + * \ + * ------- FixedReshape_2 --- [Out_2]* + * ('reshape' arg) + * + * After: + * + * In ----- FixedReshape_1 ----- [Out_1]* + * \ + * --------------------------- FixedReshape_2 --- [Out_2]* + * + * Note: In case of no Out_1, FixedReshape_1 becomes dead node. + * Out_1 can be another FixedReshape as well, which would be resolved in + * another occurance of this transform pass. + */ +void remap_input(loco::FixedReshape *reshape) +{ + auto input_reshape = dynamic_cast(reshape->input()); + + auto volume = [](loco::FixedReshape *node) { + uint32_t vol = 1; + for (uint32_t axis = 0; axis < node->rank(); ++axis) + { + assert(node->dim(axis).known()); + vol *= node->dim(axis).value(); + } + return vol; + }; + + // Volume mismatch between duplicate reshapes is pointless + assert(volume(reshape) == volume(input_reshape)); + + // Set node's input as input's input, i.e. bypass + reshape->input(input_reshape->input()); +} + +} // namespace + +namespace logo +{ + +bool ResolveDuplicateReshapePass::run(loco::Graph *graph) +{ + auto outputs = loco::output_nodes(graph); + + bool changed = false; + for (auto node : loco::postorder_traversal(outputs)) + { + if (is_duplicate_reshape(node)) + { + auto node_as_reshape = dynamic_cast(node); + + remap_input(node_as_reshape); + + changed = true; + } + } + + return changed; +} + +} // namespace logo -- 2.7.4