--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ResolveFusedBatchNorm.h"
+
+#include "IR/TFAdd.h"
+#include "IR/TFMul.h"
+
+#include "Convert.h"
+
+#include "IR/TFFusedBatchNorm.h"
+
+#include <loco.h>
+#include <moco/Log.h>
+
+#include <cassert>
+#include <cmath>
+#include <memory>
+
+namespace
+{
+
+bool is_same_shape(loco::ConstGen *lc, loco::ConstGen *rc)
+{
+ if (lc->rank() != rc->rank())
+ return false;
+
+ for (auto r = 0; r < lc->rank(); ++r)
+ {
+ if (lc->dim(r).value() != rc->dim(r).value())
+ return false;
+ }
+ return true;
+}
+
+void copy_shape(const loco::ConstGen *src, loco::ConstGen *dst)
+{
+ assert(src != nullptr);
+ assert(dst != nullptr);
+
+ uint32_t rank = src->rank();
+ dst->rank(rank);
+ for (uint32_t index = 0; index < rank; ++index)
+ {
+ if (src->dim(index).known())
+ dst->dim(index) = loco::make_dimension(src->dim(index).value());
+ else
+ dst->dim(index).unset();
+ }
+}
+
+/**
+ * @note resolve_to_muladd() will transform TFFusedBatchNorm to TFMul, TFAdd and two ConstGen
+ *
+ * <arguments>
+ * %0:input
+ * %1:gamma : const
+ * %2:beta : const
+ * %3:mean : const
+ * %4:variance : const
+ * %5:epsilon : const
+ *
+ * <constant operations>
+ * fbn_epsilon_array = make_array(%5:epsilon)
+ * fbn_epsilon = %4:variance + fbn_epsilon_array
+ * fbn_rsqrt = 1.0 / math::sqrt(fbn_epsilon)
+ *
+ * fbn_mean = %3:mean
+ * fbn_mul = fbn_rsqrt * %1:gamma
+ * fbn_offset = %2:beta
+ *
+ * fbn_mul_0_param = fbn_mul
+ * fbn_add_param = fbn_offset - fbn_mean * fbn_mul
+ *
+ * <new replace nodes>
+ * %11:fbn_mul_0_param = ConstGen(fbn_mul_0_param)
+ * %12:fbn_mul_0 = TFMul(%0:input, %11:fbn_mul_0_param)
+ * %21:fbn_add_param = ConstGen(fbn_add_param)
+ * %22:fbn = TFAdd(%12:fbn_mul_0,%21:fbn_add_param)
+ */
+bool resolve_to_muladd(loco::Graph *graph, moco::tf::TFFusedBatchNorm *node)
+{
+ LOGGER(lfbn);
+
+ auto tffbn_input = node->input();
+ if (tffbn_input == nullptr)
+ {
+ // This node is already converted
+ return false;
+ }
+
+ auto tffbn_gamma = dynamic_cast<loco::ConstGen *>(node->gamma());
+ auto tffbn_beta = dynamic_cast<loco::ConstGen *>(node->beta());
+ auto tffbn_mean = dynamic_cast<loco::ConstGen *>(node->mean());
+ auto tffbn_variance = dynamic_cast<loco::ConstGen *>(node->variance());
+
+ // all should be const
+ if (tffbn_gamma == nullptr || tffbn_beta == nullptr || tffbn_mean == nullptr ||
+ tffbn_variance == nullptr)
+ {
+ INFO(lfbn) << "TFFBN resolve_to_muladd: One of constant input node is not a constant"
+ << std::endl;
+ return false;
+ }
+ assert(tffbn_gamma->dtype() == loco::DataType::FLOAT32);
+ assert(tffbn_beta->dtype() == loco::DataType::FLOAT32);
+ assert(tffbn_mean->dtype() == loco::DataType::FLOAT32);
+ assert(tffbn_variance->dtype() == loco::DataType::FLOAT32);
+
+ // check all const shape are the same
+ if (!is_same_shape(tffbn_gamma, tffbn_beta) || !is_same_shape(tffbn_gamma, tffbn_mean) ||
+ !is_same_shape(tffbn_gamma, tffbn_variance))
+ {
+ INFO(lfbn) << "TFFBN resolve_to_muladd: Shape of constant are not same" << std::endl;
+ return false;
+ }
+
+ auto tffbn_epsilon = node->epsilon();
+ INFO(lfbn) << "TFFBN tffbn_epsilon = " << tffbn_epsilon << std::endl;
+ auto const_num_elements = tffbn_gamma->size<loco::DataType::FLOAT32>();
+ INFO(lfbn) << "TFFBN const_num_elements = " << const_num_elements << std::endl;
+
+ // fbn_epsilon = %4:variance + fbn_epsilon_array
+ std::unique_ptr<float> fbn_epsilon{new float[const_num_elements]};
+ for (int32_t i = 0; i < const_num_elements; i++)
+ {
+ auto variance = tffbn_variance->at<loco::DataType::FLOAT32>(i);
+ fbn_epsilon.get()[i] = variance + tffbn_epsilon;
+ }
+
+ // fbn_rsqrt = 1.0 / math::sqrt(fbn_epsilon)
+ std::unique_ptr<float> fbn_rsqrt{new float[const_num_elements]};
+ for (int32_t i = 0; i < const_num_elements; i++)
+ {
+ fbn_rsqrt.get()[i] = 1.0 / sqrt(fbn_epsilon.get()[i]);
+ }
+
+ // fbn_mean = %3:mean : TODO remove this block and use %3:mean
+ std::unique_ptr<float> fbn_mean{new float[const_num_elements]};
+ for (int32_t i = 0; i < const_num_elements; i++)
+ {
+ fbn_mean.get()[i] = tffbn_mean->at<loco::DataType::FLOAT32>(i);
+ }
+
+ // fbn_mul = fbn_rsqrt * %1:gamma
+ std::unique_ptr<float> fbn_mul{new float[const_num_elements]};
+ for (int32_t i = 0; i < const_num_elements; i++)
+ {
+ fbn_mul.get()[i] = fbn_rsqrt.get()[i] * tffbn_gamma->at<loco::DataType::FLOAT32>(i);
+ }
+
+ // fbn_offset = %2:beta : TODO remove this block and use %2:beta
+ std::unique_ptr<float> fbn_offset{new float[const_num_elements]};
+ for (int32_t i = 0; i < const_num_elements; i++)
+ {
+ fbn_offset.get()[i] = tffbn_beta->at<loco::DataType::FLOAT32>(i);
+ }
+
+ // fbn_mul_0_param = fbn_mul : remove this and use fbn_mul
+ std::unique_ptr<float> fbn_mul_0_param{new float[const_num_elements]};
+ for (int32_t i = 0; i < const_num_elements; i++)
+ {
+ fbn_mul_0_param.get()[i] = fbn_mul.get()[i];
+ }
+
+ // fbn_add_param = fbn_offset - fbn_mean * fbn_mul
+ std::unique_ptr<float> fbn_add_param{new float[const_num_elements]};
+ for (int32_t i = 0; i < const_num_elements; i++)
+ {
+ fbn_add_param.get()[i] = fbn_offset.get()[i] - fbn_mean.get()[i] * fbn_mul.get()[i];
+ }
+
+ INFO(lfbn) << "TFFBN create ConstGen" << std::endl;
+
+ /*
+ * %11:fbn_mul_0_param = ConstGen(fbn_mul_0_param)
+ * %21:fbn_add_param = ConstGen(fbn_add_param)
+ */
+ auto const_fbn_mul_0_param = graph->nodes()->create<loco::ConstGen>();
+ const_fbn_mul_0_param->dtype(loco::DataType::FLOAT32);
+ copy_shape(tffbn_gamma, const_fbn_mul_0_param);
+ const_fbn_mul_0_param->size<loco::DataType::FLOAT32>(const_num_elements);
+ for (int32_t i = 0; i < const_num_elements; i++)
+ {
+ const_fbn_mul_0_param->at<loco::DataType::FLOAT32>(i) = fbn_mul_0_param.get()[i];
+ }
+ auto const_fbn_add_param = graph->nodes()->create<loco::ConstGen>();
+ const_fbn_add_param->dtype(loco::DataType::FLOAT32);
+ copy_shape(tffbn_gamma, const_fbn_add_param);
+ const_fbn_add_param->size<loco::DataType::FLOAT32>(const_num_elements);
+ for (int32_t i = 0; i < const_num_elements; i++)
+ {
+ const_fbn_add_param->at<loco::DataType::FLOAT32>(i) = fbn_add_param.get()[i];
+ }
+
+ INFO(lfbn) << "TFFBN create TFMul, TFAdd" << std::endl;
+ /*
+ * %12:fbn_mul_0 = TFMul(%0:input, %11:fbn_mul_0_param)
+ * %22:fbn = TFAdd(%12:fbn_mul_0,%21:fbn_add_param)
+ */
+ auto fbn_mul_0 = graph->nodes()->create<moco::tf::TFMul>();
+ fbn_mul_0->x(tffbn_input);
+ fbn_mul_0->y(const_fbn_mul_0_param);
+
+ auto fbn = graph->nodes()->create<moco::tf::TFAdd>();
+ fbn->x(fbn_mul_0);
+ fbn->y(const_fbn_add_param);
+
+ // replace old node with new fbn
+ replace(node).with(fbn);
+ // unlink from graph
+ node->input(nullptr);
+ node->gamma(nullptr);
+ node->beta(nullptr);
+ node->mean(nullptr);
+ node->variance(nullptr);
+
+ return true;
+}
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+bool ResolveFusedBatchNorm::run(loco::Graph *graph)
+{
+ for (auto node : loco::all_nodes(graph))
+ {
+ if (as<moco::tf::TFFusedBatchNorm>(node))
+ {
+ if (resolve_to_muladd(graph, as<moco::tf::TFFusedBatchNorm>(node)))
+ {
+ // tree has been changed. let's return so that we don't need to
+ // considier about following node is correct or not.
+ return true;
+ }
+ }
+ }
+
+ return false;
+}
+
+} // namespace tf
+} // namespace moco
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ResolveFusedBatchNorm.h"
+
+#include "TestHelper.h"
+#include "IR/TFFusedBatchNorm.h"
+#include "Importer.h"
+
+#include <loco.h>
+#include <locop/FormattedGraph.h>
+#include <moco/Log.h>
+#include <stdex/Memory.h>
+
+#include <google/protobuf/io/coded_stream.h>
+#include <google/protobuf/io/zero_copy_stream_impl.h>
+#include <google/protobuf/text_format.h>
+
+#include <gtest/gtest.h>
+
+using namespace moco::tf::test;
+
+namespace
+{
+// clang-format off
+const char *fbn_basic_pbtxt = STRING_CONTENT(
+node {
+ name: "input"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value { type: DT_FLOAT }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim { size: 1 }
+ dim { size: 4 }
+ dim { size: 4 }
+ dim { size: 1 }
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "gamma"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "beta"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "FBN_01/mean"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "FBN_01/variance"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "FBN_01"
+ op: "FusedBatchNorm"
+ input: "input"
+ input: "gamma"
+ input: "beta"
+ input: "FBN_01/mean"
+ input: "FBN_01/variance"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "data_format"
+ value {
+ s: "NHWC"
+ }
+ }
+ attr {
+ key: "epsilon"
+ value {
+ f: 0.001
+ }
+ }
+ attr {
+ key: "is_training"
+ value {
+ b: false
+ }
+ }
+}
+);
+// clang-format on
+
+} // namespace
+
+namespace
+{
+
+char to_char(bool b) { return b ? 'Y' : 'N'; }
+
+} // namespace
+
+TEST(ResolveFusedBatchNorm, fbn_resolve_basic)
+{
+ LOGGER(l);
+
+ // load graph
+ moco::tf::Importer importer;
+ moco::tf::ModelSignature signature;
+ signature.add_output(moco::tf::TensorName("FBN_01", 0));
+
+ tensorflow::GraphDef graph_def;
+ EXPECT_TRUE(parse_graphdef(fbn_basic_pbtxt, graph_def));
+ auto graph = importer.import(signature, graph_def);
+
+ INFO(l) << "Before ResolveFusedBatchNorm";
+ INFO(l) << locop::fmt<locop::LinearV1>(graph);
+
+ moco::tf::ResolveFusedBatchNorm transform;
+ bool changed = transform.run(graph.get());
+
+ INFO(l) << "After ResolveFusedBatchNorm " << to_char(changed);
+ INFO(l) << locop::fmt<locop::LinearV1>(graph);
+
+ // Output value test will be done with mocotest-tf
+ // Network structure of transformation is not important and may be changed
+ // in the future so it will not be checked here.
+
+ SUCCEED();
+}