[moco/tf] Introduce ResolveFusedBatchNorm (#4125)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Mon, 8 Jul 2019 02:45:27 +0000 (11:45 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Mon, 8 Jul 2019 02:45:27 +0000 (11:45 +0900)
This will introduce ResolveFusedBatchNorm transformation that decomposes FusedBatchNorm node into Add and Mul

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
contrib/moco-tf/CMakeLists.txt
contrib/moco-tf/src/TFOptimizer.cpp
contrib/moco-tf/src/Transforms.h
contrib/moco-tf/src/Transforms/ResolveFusedBatchNorm.cpp [new file with mode: 0644]
contrib/moco-tf/src/Transforms/ResolveFusedBatchNorm.h [new file with mode: 0644]
contrib/moco-tf/src/Transforms/ResolveFusedBatchNorm.test.cpp [new file with mode: 0644]

index b654225..9cb18b9 100644 (file)
@@ -36,6 +36,8 @@ nncc_find_package(GTest REQUIRED)
 add_executable(moco_tf_frontend_test ${TESTS})
 target_include_directories(moco_tf_frontend_test PRIVATE src)
 target_link_libraries(moco_tf_frontend_test gtest_main)
+target_link_libraries(moco_tf_frontend_test locop)
+target_link_libraries(moco_tf_frontend_test moco_log)
 target_link_libraries(moco_tf_frontend_test moco_tf_frontend)
 target_link_libraries(moco_tf_frontend_test stdex)
 add_test(moco_tf_frontend_test moco_tf_frontend_test)
index f1380d3..a1468fb 100644 (file)
@@ -32,7 +32,10 @@ void TFOptimizer::optimize(loco::Graph *g) const
   moco::tf::Phase phase;
 
   /* TRANSFORM DECLARATION BEGIN */
-
+  if (moco::tf::get<moco::tf::Knob::ResolveFusedBatchNorm>())
+  {
+    phase.emplace_back(stdex::make_unique<moco::tf::ResolveFusedBatchNorm>());
+  }
   /* TRANSFORM DECLARATION END */
 
   moco::tf::PhaseRunner<moco::tf::PhaseStrategy::Saturate> phase_runner{g};
index ea628d5..5b01e00 100644 (file)
@@ -23,6 +23,7 @@
 #include "Transforms/RemoveDeadNodeTransform.h"
 #include "Transforms/RemoveForwardNodeTransform.h"
 #include "Transforms/ReorderDecodeTransform.h"
+#include "Transforms/ResolveFusedBatchNorm.h"
 #include "Transforms/SimplifyDomainConversionTransform.h"
 
 #endif // __MOCO_TF_TRANSFORMS_H__
diff --git a/contrib/moco-tf/src/Transforms/ResolveFusedBatchNorm.cpp b/contrib/moco-tf/src/Transforms/ResolveFusedBatchNorm.cpp
new file mode 100644 (file)
index 0000000..eee7b42
--- /dev/null
@@ -0,0 +1,260 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ResolveFusedBatchNorm.h"
+
+#include "IR/TFAdd.h"
+#include "IR/TFMul.h"
+
+#include "Convert.h"
+
+#include "IR/TFFusedBatchNorm.h"
+
+#include <loco.h>
+#include <moco/Log.h>
+
+#include <cassert>
+#include <cmath>
+#include <memory>
+
+namespace
+{
+
+bool is_same_shape(loco::ConstGen *lc, loco::ConstGen *rc)
+{
+  if (lc->rank() != rc->rank())
+    return false;
+
+  for (auto r = 0; r < lc->rank(); ++r)
+  {
+    if (lc->dim(r).value() != rc->dim(r).value())
+      return false;
+  }
+  return true;
+}
+
+void copy_shape(const loco::ConstGen *src, loco::ConstGen *dst)
+{
+  assert(src != nullptr);
+  assert(dst != nullptr);
+
+  uint32_t rank = src->rank();
+  dst->rank(rank);
+  for (uint32_t index = 0; index < rank; ++index)
+  {
+    if (src->dim(index).known())
+      dst->dim(index) = loco::make_dimension(src->dim(index).value());
+    else
+      dst->dim(index).unset();
+  }
+}
+
+/**
+ * @note resolve_to_muladd() will transform TFFusedBatchNorm to TFMul, TFAdd and two ConstGen
+ *
+ * <arguments>
+ * %0:input
+ * %1:gamma    : const
+ * %2:beta     : const
+ * %3:mean     : const
+ * %4:variance : const
+ * %5:epsilon  : const
+ *
+ * <constant operations>
+ * fbn_epsilon_array = make_array(%5:epsilon)
+ * fbn_epsilon = %4:variance + fbn_epsilon_array
+ * fbn_rsqrt = 1.0 / math::sqrt(fbn_epsilon)
+ *
+ * fbn_mean = %3:mean
+ * fbn_mul = fbn_rsqrt * %1:gamma
+ * fbn_offset = %2:beta
+ *
+ * fbn_mul_0_param = fbn_mul
+ * fbn_add_param = fbn_offset - fbn_mean * fbn_mul
+ *
+ * <new replace nodes>
+ * %11:fbn_mul_0_param = ConstGen(fbn_mul_0_param)
+ * %12:fbn_mul_0 = TFMul(%0:input, %11:fbn_mul_0_param)
+ * %21:fbn_add_param = ConstGen(fbn_add_param)
+ * %22:fbn = TFAdd(%12:fbn_mul_0,%21:fbn_add_param)
+ */
+bool resolve_to_muladd(loco::Graph *graph, moco::tf::TFFusedBatchNorm *node)
+{
+  LOGGER(lfbn);
+
+  auto tffbn_input = node->input();
+  if (tffbn_input == nullptr)
+  {
+    // This node is already converted
+    return false;
+  }
+
+  auto tffbn_gamma = dynamic_cast<loco::ConstGen *>(node->gamma());
+  auto tffbn_beta = dynamic_cast<loco::ConstGen *>(node->beta());
+  auto tffbn_mean = dynamic_cast<loco::ConstGen *>(node->mean());
+  auto tffbn_variance = dynamic_cast<loco::ConstGen *>(node->variance());
+
+  // all should be const
+  if (tffbn_gamma == nullptr || tffbn_beta == nullptr || tffbn_mean == nullptr ||
+      tffbn_variance == nullptr)
+  {
+    INFO(lfbn) << "TFFBN resolve_to_muladd: One of constant input node is not a constant"
+               << std::endl;
+    return false;
+  }
+  assert(tffbn_gamma->dtype() == loco::DataType::FLOAT32);
+  assert(tffbn_beta->dtype() == loco::DataType::FLOAT32);
+  assert(tffbn_mean->dtype() == loco::DataType::FLOAT32);
+  assert(tffbn_variance->dtype() == loco::DataType::FLOAT32);
+
+  // check all const shape are the same
+  if (!is_same_shape(tffbn_gamma, tffbn_beta) || !is_same_shape(tffbn_gamma, tffbn_mean) ||
+      !is_same_shape(tffbn_gamma, tffbn_variance))
+  {
+    INFO(lfbn) << "TFFBN resolve_to_muladd: Shape of constant are not same" << std::endl;
+    return false;
+  }
+
+  auto tffbn_epsilon = node->epsilon();
+  INFO(lfbn) << "TFFBN tffbn_epsilon = " << tffbn_epsilon << std::endl;
+  auto const_num_elements = tffbn_gamma->size<loco::DataType::FLOAT32>();
+  INFO(lfbn) << "TFFBN const_num_elements = " << const_num_elements << std::endl;
+
+  // fbn_epsilon = %4:variance + fbn_epsilon_array
+  std::unique_ptr<float> fbn_epsilon{new float[const_num_elements]};
+  for (int32_t i = 0; i < const_num_elements; i++)
+  {
+    auto variance = tffbn_variance->at<loco::DataType::FLOAT32>(i);
+    fbn_epsilon.get()[i] = variance + tffbn_epsilon;
+  }
+
+  // fbn_rsqrt = 1.0 / math::sqrt(fbn_epsilon)
+  std::unique_ptr<float> fbn_rsqrt{new float[const_num_elements]};
+  for (int32_t i = 0; i < const_num_elements; i++)
+  {
+    fbn_rsqrt.get()[i] = 1.0 / sqrt(fbn_epsilon.get()[i]);
+  }
+
+  // fbn_mean = %3:mean : TODO remove this block and use %3:mean
+  std::unique_ptr<float> fbn_mean{new float[const_num_elements]};
+  for (int32_t i = 0; i < const_num_elements; i++)
+  {
+    fbn_mean.get()[i] = tffbn_mean->at<loco::DataType::FLOAT32>(i);
+  }
+
+  // fbn_mul = fbn_rsqrt * %1:gamma
+  std::unique_ptr<float> fbn_mul{new float[const_num_elements]};
+  for (int32_t i = 0; i < const_num_elements; i++)
+  {
+    fbn_mul.get()[i] = fbn_rsqrt.get()[i] * tffbn_gamma->at<loco::DataType::FLOAT32>(i);
+  }
+
+  // fbn_offset = %2:beta : TODO remove this block and use %2:beta
+  std::unique_ptr<float> fbn_offset{new float[const_num_elements]};
+  for (int32_t i = 0; i < const_num_elements; i++)
+  {
+    fbn_offset.get()[i] = tffbn_beta->at<loco::DataType::FLOAT32>(i);
+  }
+
+  // fbn_mul_0_param = fbn_mul : remove this and use fbn_mul
+  std::unique_ptr<float> fbn_mul_0_param{new float[const_num_elements]};
+  for (int32_t i = 0; i < const_num_elements; i++)
+  {
+    fbn_mul_0_param.get()[i] = fbn_mul.get()[i];
+  }
+
+  // fbn_add_param = fbn_offset - fbn_mean * fbn_mul
+  std::unique_ptr<float> fbn_add_param{new float[const_num_elements]};
+  for (int32_t i = 0; i < const_num_elements; i++)
+  {
+    fbn_add_param.get()[i] = fbn_offset.get()[i] - fbn_mean.get()[i] * fbn_mul.get()[i];
+  }
+
+  INFO(lfbn) << "TFFBN create ConstGen" << std::endl;
+
+  /*
+   * %11:fbn_mul_0_param = ConstGen(fbn_mul_0_param)
+   * %21:fbn_add_param = ConstGen(fbn_add_param)
+   */
+  auto const_fbn_mul_0_param = graph->nodes()->create<loco::ConstGen>();
+  const_fbn_mul_0_param->dtype(loco::DataType::FLOAT32);
+  copy_shape(tffbn_gamma, const_fbn_mul_0_param);
+  const_fbn_mul_0_param->size<loco::DataType::FLOAT32>(const_num_elements);
+  for (int32_t i = 0; i < const_num_elements; i++)
+  {
+    const_fbn_mul_0_param->at<loco::DataType::FLOAT32>(i) = fbn_mul_0_param.get()[i];
+  }
+  auto const_fbn_add_param = graph->nodes()->create<loco::ConstGen>();
+  const_fbn_add_param->dtype(loco::DataType::FLOAT32);
+  copy_shape(tffbn_gamma, const_fbn_add_param);
+  const_fbn_add_param->size<loco::DataType::FLOAT32>(const_num_elements);
+  for (int32_t i = 0; i < const_num_elements; i++)
+  {
+    const_fbn_add_param->at<loco::DataType::FLOAT32>(i) = fbn_add_param.get()[i];
+  }
+
+  INFO(lfbn) << "TFFBN create TFMul, TFAdd" << std::endl;
+  /*
+   * %12:fbn_mul_0 = TFMul(%0:input, %11:fbn_mul_0_param)
+   * %22:fbn = TFAdd(%12:fbn_mul_0,%21:fbn_add_param)
+   */
+  auto fbn_mul_0 = graph->nodes()->create<moco::tf::TFMul>();
+  fbn_mul_0->x(tffbn_input);
+  fbn_mul_0->y(const_fbn_mul_0_param);
+
+  auto fbn = graph->nodes()->create<moco::tf::TFAdd>();
+  fbn->x(fbn_mul_0);
+  fbn->y(const_fbn_add_param);
+
+  // replace old node with new fbn
+  replace(node).with(fbn);
+  // unlink from graph
+  node->input(nullptr);
+  node->gamma(nullptr);
+  node->beta(nullptr);
+  node->mean(nullptr);
+  node->variance(nullptr);
+
+  return true;
+}
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+bool ResolveFusedBatchNorm::run(loco::Graph *graph)
+{
+  for (auto node : loco::all_nodes(graph))
+  {
+    if (as<moco::tf::TFFusedBatchNorm>(node))
+    {
+      if (resolve_to_muladd(graph, as<moco::tf::TFFusedBatchNorm>(node)))
+      {
+        // tree has been changed. let's return so that we don't need to
+        // considier about following node is correct or not.
+        return true;
+      }
+    }
+  }
+
+  return false;
+}
+
+} // namespace tf
+} // namespace moco
diff --git a/contrib/moco-tf/src/Transforms/ResolveFusedBatchNorm.h b/contrib/moco-tf/src/Transforms/ResolveFusedBatchNorm.h
new file mode 100644 (file)
index 0000000..9243951
--- /dev/null
@@ -0,0 +1,44 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_RESOLVE_FUSEDBATCHNORM_H__
+#define __MOCO_TF_RESOLVE_FUSEDBATCHNORM_H__
+
+#include "Transform.h"
+
+#include <loco.h>
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief  Trasform TFFusedBatchNorm into TFAdd + TFRsqrt + TFMul + TFBatchNorm
+*/
+class ResolveFusedBatchNorm : public Transform
+{
+public:
+  const char *name(void) const final { return "ResolveFusedBatchNorm"; }
+
+public:
+  bool run(loco::Graph *graph) override;
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_RESOLVE_FUSEDBATCHNORM_H__
diff --git a/contrib/moco-tf/src/Transforms/ResolveFusedBatchNorm.test.cpp b/contrib/moco-tf/src/Transforms/ResolveFusedBatchNorm.test.cpp
new file mode 100644 (file)
index 0000000..749cf24
--- /dev/null
@@ -0,0 +1,231 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ResolveFusedBatchNorm.h"
+
+#include "TestHelper.h"
+#include "IR/TFFusedBatchNorm.h"
+#include "Importer.h"
+
+#include <loco.h>
+#include <locop/FormattedGraph.h>
+#include <moco/Log.h>
+#include <stdex/Memory.h>
+
+#include <google/protobuf/io/coded_stream.h>
+#include <google/protobuf/io/zero_copy_stream_impl.h>
+#include <google/protobuf/text_format.h>
+
+#include <gtest/gtest.h>
+
+using namespace moco::tf::test;
+
+namespace
+{
+// clang-format off
+const char *fbn_basic_pbtxt = STRING_CONTENT(
+node {
+  name: "input"
+  op: "Const"
+  attr {
+    key: "dtype"
+    value { type: DT_FLOAT }
+  }
+  attr {
+    key: "value"
+    value {
+      tensor {
+        dtype: DT_FLOAT
+        tensor_shape {
+          dim { size: 1 }
+          dim { size: 4 }
+          dim { size: 4 }
+          dim { size: 1 }
+        }
+        float_val: 1.0
+      }
+    }
+  }
+}
+node {
+  name: "gamma"
+  op: "Const"
+  attr {
+    key: "dtype"
+    value {
+      type: DT_FLOAT
+    }
+  }
+  attr {
+    key: "value"
+    value {
+      tensor {
+        dtype: DT_FLOAT
+        tensor_shape {
+          dim {
+            size: 1
+          }
+        }
+        float_val: 1.0
+      }
+    }
+  }
+}
+node {
+  name: "beta"
+  op: "Const"
+  attr {
+    key: "dtype"
+    value {
+      type: DT_FLOAT
+    }
+  }
+  attr {
+    key: "value"
+    value {
+      tensor {
+        dtype: DT_FLOAT
+        tensor_shape {
+          dim {
+            size: 1
+          }
+        }
+        float_val: 1.0
+      }
+    }
+  }
+}
+node {
+  name: "FBN_01/mean"
+  op: "Const"
+  attr {
+    key: "dtype"
+    value {
+      type: DT_FLOAT
+    }
+  }
+  attr {
+    key: "value"
+    value {
+      tensor {
+        dtype: DT_FLOAT
+        tensor_shape {
+          dim {
+            size: 1
+          }
+        }
+        float_val: 1.0
+      }
+    }
+  }
+}
+node {
+  name: "FBN_01/variance"
+  op: "Const"
+  attr {
+    key: "dtype"
+    value {
+      type: DT_FLOAT
+    }
+  }
+  attr {
+    key: "value"
+    value {
+      tensor {
+        dtype: DT_FLOAT
+        tensor_shape {
+          dim {
+            size: 1
+          }
+        }
+        float_val: 1.0
+      }
+    }
+  }
+}
+node {
+  name: "FBN_01"
+  op: "FusedBatchNorm"
+  input: "input"
+  input: "gamma"
+  input: "beta"
+  input: "FBN_01/mean"
+  input: "FBN_01/variance"
+  attr {
+    key: "T"
+    value {
+      type: DT_FLOAT
+    }
+  }
+  attr {
+    key: "data_format"
+    value {
+      s: "NHWC"
+    }
+  }
+  attr {
+    key: "epsilon"
+    value {
+      f: 0.001
+    }
+  }
+  attr {
+    key: "is_training"
+    value {
+      b: false
+    }
+  }
+}
+);
+// clang-format on
+
+} // namespace
+
+namespace
+{
+
+char to_char(bool b) { return b ? 'Y' : 'N'; }
+
+} // namespace
+
+TEST(ResolveFusedBatchNorm, fbn_resolve_basic)
+{
+  LOGGER(l);
+
+  // load graph
+  moco::tf::Importer importer;
+  moco::tf::ModelSignature signature;
+  signature.add_output(moco::tf::TensorName("FBN_01", 0));
+
+  tensorflow::GraphDef graph_def;
+  EXPECT_TRUE(parse_graphdef(fbn_basic_pbtxt, graph_def));
+  auto graph = importer.import(signature, graph_def);
+
+  INFO(l) << "Before ResolveFusedBatchNorm";
+  INFO(l) << locop::fmt<locop::LinearV1>(graph);
+
+  moco::tf::ResolveFusedBatchNorm transform;
+  bool changed = transform.run(graph.get());
+
+  INFO(l) << "After ResolveFusedBatchNorm " << to_char(changed);
+  INFO(l) << locop::fmt<locop::LinearV1>(graph);
+
+  // Output value test will be done with mocotest-tf
+  // Network structure of transformation is not important and may be changed
+  // in the future so it will not be checked here.
+
+  SUCCEED();
+}