From e36419377e918846b67cf2ebfde60147a2df6f7f Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Tue, 27 Aug 2019 10:30:20 +0900 Subject: [PATCH] [logo] Introduce ConstantFoldingPass (#6923) * [logo] Introduce ConstantFoldingPass This will introduce ConstantFoldingPass compiled from moco Signed-off-by: SaeHie Park * add requires items --- compiler/logo/CMakeLists.txt | 2 + compiler/logo/include/logo/ConstantFoldingPass.h | 41 ++++++ compiler/logo/requires.cmake | 2 + compiler/logo/src/Passes/ConstantFoldingPass.cpp | 170 +++++++++++++++++++++++ 4 files changed, 215 insertions(+) create mode 100644 compiler/logo/include/logo/ConstantFoldingPass.h create mode 100644 compiler/logo/src/Passes/ConstantFoldingPass.cpp diff --git a/compiler/logo/CMakeLists.txt b/compiler/logo/CMakeLists.txt index c632fbf..e69eb67 100644 --- a/compiler/logo/CMakeLists.txt +++ b/compiler/logo/CMakeLists.txt @@ -7,6 +7,8 @@ set_target_properties(logo PROPERTIES POSITION_INDEPENDENT_CODE ON) target_include_directories(logo PRIVATE src) target_include_directories(logo PUBLIC include) target_link_libraries(logo PUBLIC loco) +target_link_libraries(logo PRIVATE locomotiv) +target_link_libraries(logo PRIVATE stdex) if(NOT ENABLE_TEST) return() diff --git a/compiler/logo/include/logo/ConstantFoldingPass.h b/compiler/logo/include/logo/ConstantFoldingPass.h new file mode 100644 index 0000000..99ccdc3 --- /dev/null +++ b/compiler/logo/include/logo/ConstantFoldingPass.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LOGO_CONSTANT_FOLDING_PASS_H__ +#define __LOGO_CONSTANT_FOLDING_PASS_H__ + +#include + +#include + +namespace logo +{ + +/** + * @brief Performs constant folding optimization + */ +class ConstantFoldingPass : public Pass +{ +public: + const char *name(void) const final { return "ConstantFoldingPass"; } + +public: + bool run(loco::Graph *graph) override; +}; + +} // namespace logo + +#endif // __LOGO_CONSTANT_FOLDING_PASS_H__ diff --git a/compiler/logo/requires.cmake b/compiler/logo/requires.cmake index 44f6870..362a4e2 100644 --- a/compiler/logo/requires.cmake +++ b/compiler/logo/requires.cmake @@ -1 +1,3 @@ require("loco") +require("locomotiv") +require("stdex") diff --git a/compiler/logo/src/Passes/ConstantFoldingPass.cpp b/compiler/logo/src/Passes/ConstantFoldingPass.cpp new file mode 100644 index 0000000..c31b3c9 --- /dev/null +++ b/compiler/logo/src/Passes/ConstantFoldingPass.cpp @@ -0,0 +1,170 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include + +#include + +#include + +#include +#include + +namespace +{ + +uint64_t num_elements(const loco::NodeMixin &shape) +{ + if (shape.rank() == 0) + { + return 0; + } + + uint64_t res = 1; + + for (uint32_t axis = 0; axis < shape.rank(); ++axis) + { + assert(shape.dim(axis).known()); + res *= shape.dim(axis).value(); + } + + return res; +} + +/// @brief For some op, constant folding should not be performed. This returns true if node is such +/// op. +bool skip(const loco::Node *node) +{ + static std::set skip_op = { + // TODO Current implementation works for 'Tensor' domain only. Support other domains such as + // `Feature`, `Filter`, `Bias`, etc. + static_cast(loco::CanonicalOpcode::FilterEncode), + static_cast(loco::CanonicalOpcode::FeatureEncode), + static_cast(loco::CanonicalOpcode::BiasEncode), + static_cast(loco::CanonicalOpcode::DepthwiseFilterEncode), + + // We don't perform constant folding for Push + static_cast(loco::CanonicalOpcode::Push), + }; + + if (node->dialect() == loco::CanonicalDialect::get()) + { + if (skip_op.find(node->opnum()) != skip_op.end()) + return true; + } + + return false; +} + +/// @brief Checks if a node is a target of constant folding transform +bool foldable(const loco::Node *node) +{ + if (node->dialect() == loco::CanonicalDialect::get()) + { + if (skip(node)) + return false; + + if (node->arity() == 0) // e.g., when a node is e.g, ConstGen or Pull + return false; + + // When all args are ConstGen, let's do Constant Folding Transforms + for (int i = 0; i < node->arity(); i++) + { + if (node->arg(i)->opnum() != static_cast(loco::CanonicalOpcode::ConstGen)) + return false; + } + + return true; + } + else + { + return false; + } +} + +void fold(loco::Graph *graph, loco::Node *node) +{ + assert(foldable(node)); // sanity check to find a mistake when this function is reused later + + // calcluate foldable node + locomotiv::Session sess(graph, std::vector{node}); + sess.infer(); + auto data = sess.get_output(0); + + assert(data != nullptr); + + auto shape = data->shape(); + auto dtype = data->dtype(); + + // build ConstGen + auto new_const = graph->nodes()->create(); + { + new_const->dtype(dtype); + + new_const->rank(shape->rank()); + for (int d = 0; d < shape->rank(); d++) + new_const->dim(d) = shape->dim(d); + + auto count = num_elements(*new_const); + + if (dtype == loco::DataType::FLOAT32) + { + new_const->size(count); + + auto const_buf = data->as_f32_bufptr()->base(); + for (int x = 0; x < count; x++) + new_const->at(x) = const_buf[x]; + } + else if (dtype == loco::DataType::S32) + { + new_const->size(count); + + auto const_buf = data->as_s32_bufptr()->base(); + for (int x = 0; x < count; x++) + new_const->at(x) = const_buf[x]; + } + } + + // replace node with new_const + loco::replace(node).with(new_const); +} + +} // namespace + +namespace logo +{ + +bool ConstantFoldingPass::run(loco::Graph *graph) +{ + auto outputs = loco::output_nodes(graph); + + bool changed = false; + for (auto node : loco::postorder_traversal(outputs)) + { + if (foldable(node)) + { + fold(graph, node); + changed = true; + } + } + + return changed; +} + +} // namespace logo -- 2.7.4