From 3c6dc7df2fb315bade12eef56fe733a1172182c1 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Mon, 8 Jul 2019 11:45:27 +0900 Subject: [PATCH] [moco/tf] Introduce ResolveFusedBatchNorm (#4125) This will introduce ResolveFusedBatchNorm transformation that decomposes FusedBatchNorm node into Add and Mul Signed-off-by: SaeHie Park --- contrib/moco-tf/CMakeLists.txt | 2 + contrib/moco-tf/src/TFOptimizer.cpp | 5 +- contrib/moco-tf/src/Transforms.h | 1 + .../src/Transforms/ResolveFusedBatchNorm.cpp | 260 +++++++++++++++++++++ .../moco-tf/src/Transforms/ResolveFusedBatchNorm.h | 44 ++++ .../src/Transforms/ResolveFusedBatchNorm.test.cpp | 231 ++++++++++++++++++ 6 files changed, 542 insertions(+), 1 deletion(-) create mode 100644 contrib/moco-tf/src/Transforms/ResolveFusedBatchNorm.cpp create mode 100644 contrib/moco-tf/src/Transforms/ResolveFusedBatchNorm.h create mode 100644 contrib/moco-tf/src/Transforms/ResolveFusedBatchNorm.test.cpp diff --git a/contrib/moco-tf/CMakeLists.txt b/contrib/moco-tf/CMakeLists.txt index b654225..9cb18b9 100644 --- a/contrib/moco-tf/CMakeLists.txt +++ b/contrib/moco-tf/CMakeLists.txt @@ -36,6 +36,8 @@ nncc_find_package(GTest REQUIRED) add_executable(moco_tf_frontend_test ${TESTS}) target_include_directories(moco_tf_frontend_test PRIVATE src) target_link_libraries(moco_tf_frontend_test gtest_main) +target_link_libraries(moco_tf_frontend_test locop) +target_link_libraries(moco_tf_frontend_test moco_log) target_link_libraries(moco_tf_frontend_test moco_tf_frontend) target_link_libraries(moco_tf_frontend_test stdex) add_test(moco_tf_frontend_test moco_tf_frontend_test) diff --git a/contrib/moco-tf/src/TFOptimizer.cpp b/contrib/moco-tf/src/TFOptimizer.cpp index f1380d3..a1468fb 100644 --- a/contrib/moco-tf/src/TFOptimizer.cpp +++ b/contrib/moco-tf/src/TFOptimizer.cpp @@ -32,7 +32,10 @@ void TFOptimizer::optimize(loco::Graph *g) const moco::tf::Phase phase; /* TRANSFORM DECLARATION BEGIN */ - + if (moco::tf::get()) + { + phase.emplace_back(stdex::make_unique()); + } /* TRANSFORM DECLARATION END */ moco::tf::PhaseRunner phase_runner{g}; diff --git a/contrib/moco-tf/src/Transforms.h b/contrib/moco-tf/src/Transforms.h index ea628d5..5b01e00 100644 --- a/contrib/moco-tf/src/Transforms.h +++ b/contrib/moco-tf/src/Transforms.h @@ -23,6 +23,7 @@ #include "Transforms/RemoveDeadNodeTransform.h" #include "Transforms/RemoveForwardNodeTransform.h" #include "Transforms/ReorderDecodeTransform.h" +#include "Transforms/ResolveFusedBatchNorm.h" #include "Transforms/SimplifyDomainConversionTransform.h" #endif // __MOCO_TF_TRANSFORMS_H__ diff --git a/contrib/moco-tf/src/Transforms/ResolveFusedBatchNorm.cpp b/contrib/moco-tf/src/Transforms/ResolveFusedBatchNorm.cpp new file mode 100644 index 0000000..eee7b42 --- /dev/null +++ b/contrib/moco-tf/src/Transforms/ResolveFusedBatchNorm.cpp @@ -0,0 +1,260 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ResolveFusedBatchNorm.h" + +#include "IR/TFAdd.h" +#include "IR/TFMul.h" + +#include "Convert.h" + +#include "IR/TFFusedBatchNorm.h" + +#include +#include + +#include +#include +#include + +namespace +{ + +bool is_same_shape(loco::ConstGen *lc, loco::ConstGen *rc) +{ + if (lc->rank() != rc->rank()) + return false; + + for (auto r = 0; r < lc->rank(); ++r) + { + if (lc->dim(r).value() != rc->dim(r).value()) + return false; + } + return true; +} + +void copy_shape(const loco::ConstGen *src, loco::ConstGen *dst) +{ + assert(src != nullptr); + assert(dst != nullptr); + + uint32_t rank = src->rank(); + dst->rank(rank); + for (uint32_t index = 0; index < rank; ++index) + { + if (src->dim(index).known()) + dst->dim(index) = loco::make_dimension(src->dim(index).value()); + else + dst->dim(index).unset(); + } +} + +/** + * @note resolve_to_muladd() will transform TFFusedBatchNorm to TFMul, TFAdd and two ConstGen + * + * + * %0:input + * %1:gamma : const + * %2:beta : const + * %3:mean : const + * %4:variance : const + * %5:epsilon : const + * + * + * fbn_epsilon_array = make_array(%5:epsilon) + * fbn_epsilon = %4:variance + fbn_epsilon_array + * fbn_rsqrt = 1.0 / math::sqrt(fbn_epsilon) + * + * fbn_mean = %3:mean + * fbn_mul = fbn_rsqrt * %1:gamma + * fbn_offset = %2:beta + * + * fbn_mul_0_param = fbn_mul + * fbn_add_param = fbn_offset - fbn_mean * fbn_mul + * + * + * %11:fbn_mul_0_param = ConstGen(fbn_mul_0_param) + * %12:fbn_mul_0 = TFMul(%0:input, %11:fbn_mul_0_param) + * %21:fbn_add_param = ConstGen(fbn_add_param) + * %22:fbn = TFAdd(%12:fbn_mul_0,%21:fbn_add_param) + */ +bool resolve_to_muladd(loco::Graph *graph, moco::tf::TFFusedBatchNorm *node) +{ + LOGGER(lfbn); + + auto tffbn_input = node->input(); + if (tffbn_input == nullptr) + { + // This node is already converted + return false; + } + + auto tffbn_gamma = dynamic_cast(node->gamma()); + auto tffbn_beta = dynamic_cast(node->beta()); + auto tffbn_mean = dynamic_cast(node->mean()); + auto tffbn_variance = dynamic_cast(node->variance()); + + // all should be const + if (tffbn_gamma == nullptr || tffbn_beta == nullptr || tffbn_mean == nullptr || + tffbn_variance == nullptr) + { + INFO(lfbn) << "TFFBN resolve_to_muladd: One of constant input node is not a constant" + << std::endl; + return false; + } + assert(tffbn_gamma->dtype() == loco::DataType::FLOAT32); + assert(tffbn_beta->dtype() == loco::DataType::FLOAT32); + assert(tffbn_mean->dtype() == loco::DataType::FLOAT32); + assert(tffbn_variance->dtype() == loco::DataType::FLOAT32); + + // check all const shape are the same + if (!is_same_shape(tffbn_gamma, tffbn_beta) || !is_same_shape(tffbn_gamma, tffbn_mean) || + !is_same_shape(tffbn_gamma, tffbn_variance)) + { + INFO(lfbn) << "TFFBN resolve_to_muladd: Shape of constant are not same" << std::endl; + return false; + } + + auto tffbn_epsilon = node->epsilon(); + INFO(lfbn) << "TFFBN tffbn_epsilon = " << tffbn_epsilon << std::endl; + auto const_num_elements = tffbn_gamma->size(); + INFO(lfbn) << "TFFBN const_num_elements = " << const_num_elements << std::endl; + + // fbn_epsilon = %4:variance + fbn_epsilon_array + std::unique_ptr fbn_epsilon{new float[const_num_elements]}; + for (int32_t i = 0; i < const_num_elements; i++) + { + auto variance = tffbn_variance->at(i); + fbn_epsilon.get()[i] = variance + tffbn_epsilon; + } + + // fbn_rsqrt = 1.0 / math::sqrt(fbn_epsilon) + std::unique_ptr fbn_rsqrt{new float[const_num_elements]}; + for (int32_t i = 0; i < const_num_elements; i++) + { + fbn_rsqrt.get()[i] = 1.0 / sqrt(fbn_epsilon.get()[i]); + } + + // fbn_mean = %3:mean : TODO remove this block and use %3:mean + std::unique_ptr fbn_mean{new float[const_num_elements]}; + for (int32_t i = 0; i < const_num_elements; i++) + { + fbn_mean.get()[i] = tffbn_mean->at(i); + } + + // fbn_mul = fbn_rsqrt * %1:gamma + std::unique_ptr fbn_mul{new float[const_num_elements]}; + for (int32_t i = 0; i < const_num_elements; i++) + { + fbn_mul.get()[i] = fbn_rsqrt.get()[i] * tffbn_gamma->at(i); + } + + // fbn_offset = %2:beta : TODO remove this block and use %2:beta + std::unique_ptr fbn_offset{new float[const_num_elements]}; + for (int32_t i = 0; i < const_num_elements; i++) + { + fbn_offset.get()[i] = tffbn_beta->at(i); + } + + // fbn_mul_0_param = fbn_mul : remove this and use fbn_mul + std::unique_ptr fbn_mul_0_param{new float[const_num_elements]}; + for (int32_t i = 0; i < const_num_elements; i++) + { + fbn_mul_0_param.get()[i] = fbn_mul.get()[i]; + } + + // fbn_add_param = fbn_offset - fbn_mean * fbn_mul + std::unique_ptr fbn_add_param{new float[const_num_elements]}; + for (int32_t i = 0; i < const_num_elements; i++) + { + fbn_add_param.get()[i] = fbn_offset.get()[i] - fbn_mean.get()[i] * fbn_mul.get()[i]; + } + + INFO(lfbn) << "TFFBN create ConstGen" << std::endl; + + /* + * %11:fbn_mul_0_param = ConstGen(fbn_mul_0_param) + * %21:fbn_add_param = ConstGen(fbn_add_param) + */ + auto const_fbn_mul_0_param = graph->nodes()->create(); + const_fbn_mul_0_param->dtype(loco::DataType::FLOAT32); + copy_shape(tffbn_gamma, const_fbn_mul_0_param); + const_fbn_mul_0_param->size(const_num_elements); + for (int32_t i = 0; i < const_num_elements; i++) + { + const_fbn_mul_0_param->at(i) = fbn_mul_0_param.get()[i]; + } + auto const_fbn_add_param = graph->nodes()->create(); + const_fbn_add_param->dtype(loco::DataType::FLOAT32); + copy_shape(tffbn_gamma, const_fbn_add_param); + const_fbn_add_param->size(const_num_elements); + for (int32_t i = 0; i < const_num_elements; i++) + { + const_fbn_add_param->at(i) = fbn_add_param.get()[i]; + } + + INFO(lfbn) << "TFFBN create TFMul, TFAdd" << std::endl; + /* + * %12:fbn_mul_0 = TFMul(%0:input, %11:fbn_mul_0_param) + * %22:fbn = TFAdd(%12:fbn_mul_0,%21:fbn_add_param) + */ + auto fbn_mul_0 = graph->nodes()->create(); + fbn_mul_0->x(tffbn_input); + fbn_mul_0->y(const_fbn_mul_0_param); + + auto fbn = graph->nodes()->create(); + fbn->x(fbn_mul_0); + fbn->y(const_fbn_add_param); + + // replace old node with new fbn + replace(node).with(fbn); + // unlink from graph + node->input(nullptr); + node->gamma(nullptr); + node->beta(nullptr); + node->mean(nullptr); + node->variance(nullptr); + + return true; +} + +} // namespace + +namespace moco +{ +namespace tf +{ + +bool ResolveFusedBatchNorm::run(loco::Graph *graph) +{ + for (auto node : loco::all_nodes(graph)) + { + if (as(node)) + { + if (resolve_to_muladd(graph, as(node))) + { + // tree has been changed. let's return so that we don't need to + // considier about following node is correct or not. + return true; + } + } + } + + return false; +} + +} // namespace tf +} // namespace moco diff --git a/contrib/moco-tf/src/Transforms/ResolveFusedBatchNorm.h b/contrib/moco-tf/src/Transforms/ResolveFusedBatchNorm.h new file mode 100644 index 0000000..9243951 --- /dev/null +++ b/contrib/moco-tf/src/Transforms/ResolveFusedBatchNorm.h @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __MOCO_TF_RESOLVE_FUSEDBATCHNORM_H__ +#define __MOCO_TF_RESOLVE_FUSEDBATCHNORM_H__ + +#include "Transform.h" + +#include + +namespace moco +{ +namespace tf +{ + +/** + * @brief Trasform TFFusedBatchNorm into TFAdd + TFRsqrt + TFMul + TFBatchNorm +*/ +class ResolveFusedBatchNorm : public Transform +{ +public: + const char *name(void) const final { return "ResolveFusedBatchNorm"; } + +public: + bool run(loco::Graph *graph) override; +}; + +} // namespace tf +} // namespace moco + +#endif // __MOCO_TF_RESOLVE_FUSEDBATCHNORM_H__ diff --git a/contrib/moco-tf/src/Transforms/ResolveFusedBatchNorm.test.cpp b/contrib/moco-tf/src/Transforms/ResolveFusedBatchNorm.test.cpp new file mode 100644 index 0000000..749cf24 --- /dev/null +++ b/contrib/moco-tf/src/Transforms/ResolveFusedBatchNorm.test.cpp @@ -0,0 +1,231 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ResolveFusedBatchNorm.h" + +#include "TestHelper.h" +#include "IR/TFFusedBatchNorm.h" +#include "Importer.h" + +#include +#include +#include +#include + +#include +#include +#include + +#include + +using namespace moco::tf::test; + +namespace +{ +// clang-format off +const char *fbn_basic_pbtxt = STRING_CONTENT( +node { + name: "input" + op: "Const" + attr { + key: "dtype" + value { type: DT_FLOAT } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { size: 1 } + dim { size: 4 } + dim { size: 4 } + dim { size: 1 } + } + float_val: 1.0 + } + } + } +} +node { + name: "gamma" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 1 + } + } + float_val: 1.0 + } + } + } +} +node { + name: "beta" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 1 + } + } + float_val: 1.0 + } + } + } +} +node { + name: "FBN_01/mean" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 1 + } + } + float_val: 1.0 + } + } + } +} +node { + name: "FBN_01/variance" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 1 + } + } + float_val: 1.0 + } + } + } +} +node { + name: "FBN_01" + op: "FusedBatchNorm" + input: "input" + input: "gamma" + input: "beta" + input: "FBN_01/mean" + input: "FBN_01/variance" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "epsilon" + value { + f: 0.001 + } + } + attr { + key: "is_training" + value { + b: false + } + } +} +); +// clang-format on + +} // namespace + +namespace +{ + +char to_char(bool b) { return b ? 'Y' : 'N'; } + +} // namespace + +TEST(ResolveFusedBatchNorm, fbn_resolve_basic) +{ + LOGGER(l); + + // load graph + moco::tf::Importer importer; + moco::tf::ModelSignature signature; + signature.add_output(moco::tf::TensorName("FBN_01", 0)); + + tensorflow::GraphDef graph_def; + EXPECT_TRUE(parse_graphdef(fbn_basic_pbtxt, graph_def)); + auto graph = importer.import(signature, graph_def); + + INFO(l) << "Before ResolveFusedBatchNorm"; + INFO(l) << locop::fmt(graph); + + moco::tf::ResolveFusedBatchNorm transform; + bool changed = transform.run(graph.get()); + + INFO(l) << "After ResolveFusedBatchNorm " << to_char(changed); + INFO(l) << locop::fmt(graph); + + // Output value test will be done with mocotest-tf + // Network structure of transformation is not important and may be changed + // in the future so it will not be checked here. + + SUCCEED(); +} -- 2.7.4