[moco-tf] Constant Folding Transform for Canonical Dialect (#5759)
author윤현식/On-Device Lab(SR)/Principal Engineer/삼성전자 <hyunsik.yoon@samsung.com>
Thu, 25 Jul 2019 00:59:31 +0000 (09:59 +0900)
committer박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Thu, 25 Jul 2019 00:59:31 +0000 (09:59 +0900)
* [moco-tf] Constant Folding Transform for Canonical Dialect

This commit adds constant folding transform for canonical dialect, plus a test and build files.

Signed-off-by: Hyun Sik Yoon <hyunsik.yoon@samsung.com>
* revised

compiler/moco-tf/CMakeLists.txt
compiler/moco-tf/requires.cmake
compiler/moco-tf/src/Transforms.h
compiler/moco-tf/src/Transforms/ConstantFoldingTransform.cpp [new file with mode: 0644]
compiler/moco-tf/src/Transforms/ConstantFoldingTransform.h [new file with mode: 0644]
compiler/moco-tf/src/Transforms/ConstantFoldingTransform.test.cpp [new file with mode: 0644]

index 9cb18b9..9cca82a 100644 (file)
@@ -26,6 +26,7 @@ target_link_libraries(moco_tf_frontend PRIVATE stdex)
 target_link_libraries(moco_tf_frontend PRIVATE cwrap)
 target_link_libraries(moco_tf_frontend PRIVATE moco_log)
 target_link_libraries(moco_tf_frontend PRIVATE pepper_strcast)
+target_link_libraries(moco_tf_frontend PRIVATE locomotiv)
 
 if(NOT ENABLE_TEST)
   return()
index 97fec45..1f31c62 100644 (file)
@@ -4,3 +4,4 @@ require("cwrap")
 require("stdex")
 require("moco-log")
 require("pepper-strcast")
+require("locomotiv")
index ecee6fc..3956245 100644 (file)
@@ -18,6 +18,7 @@
 #define __MOCO_TF_TRANSFORMS_H__
 
 #include "Transforms/ClearAnnotTransform.h"
+#include "Transforms/ConstantFoldingTransform.h"
 #include "Transforms/FixPaddingTransform.h"
 #include "Transforms/FixShapeTransform.h"
 #include "Transforms/FuseBinaryIntoPreceding.h"
diff --git a/compiler/moco-tf/src/Transforms/ConstantFoldingTransform.cpp b/compiler/moco-tf/src/Transforms/ConstantFoldingTransform.cpp
new file mode 100644 (file)
index 0000000..f6c0772
--- /dev/null
@@ -0,0 +1,175 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConstantFoldingTransform.h"
+
+#include <loco.h>
+#include <loco/IR/CanonicalDialect.h>
+#include <moco/Log.h>
+#include <stdex/Memory.h>
+
+#include <locomotiv/Session.h>
+
+#include <cassert>
+#include <stdexcept>
+
+namespace
+{
+
+using namespace moco::tf;
+
+uint64_t num_elements(const loco::NodeMixin<loco::NodeTrait::TensorShape> &shape)
+{
+  if (shape.rank() == 0)
+  {
+    return 0;
+  }
+
+  uint64_t res = 1;
+
+  for (uint32_t axis = 0; axis < shape.rank(); ++axis)
+  {
+    assert(shape.dim(axis).known());
+    res *= shape.dim(axis).value();
+  }
+
+  return res;
+}
+
+/// @brief For some op, constant folding should not be performed. This returns true if node is such
+/// op.
+bool skip(const loco::Node *node)
+{
+  static std::set<uint32_t> skip_op = {
+      // TODO Current implementation works for 'Tensor' domain only. Support other domains such as
+      //      `Feature`, `Filter`, `Bias`, etc.
+      static_cast<uint32_t>(loco::CanonicalOpcode::FilterEncode),
+      static_cast<uint32_t>(loco::CanonicalOpcode::FeatureEncode),
+      static_cast<uint32_t>(loco::CanonicalOpcode::BiasEncode),
+      static_cast<uint32_t>(loco::CanonicalOpcode::DepthwiseFilterEncode),
+
+      // We don't perform constant folding for Push
+      static_cast<uint32_t>(loco::CanonicalOpcode::Push),
+  };
+
+  if (node->dialect() == loco::CanonicalDialect::get())
+  {
+    if (skip_op.find(node->opnum()) != skip_op.end())
+      return true;
+  }
+
+  return false;
+}
+
+/// @brief Checks if a node is a target of constant folding transform
+bool foldable(const loco::Node *node)
+{
+  if (node->dialect() == loco::CanonicalDialect::get())
+  {
+    if (skip(node))
+      return false;
+
+    if (node->arity() == 0) // e.g., when a node is e.g, ConstGen or Pull
+      return false;
+
+    // When all args are ConstGen, let's do Constant Folding Transforms
+    for (int i = 0; i < node->arity(); i++)
+    {
+      if (node->arg(i)->opnum() != static_cast<uint32_t>(loco::CanonicalOpcode::ConstGen))
+        return false;
+    }
+
+    return true;
+  }
+  else
+  {
+    return false;
+  }
+}
+
+void fold(loco::Graph *graph, loco::Node *node)
+{
+  assert(foldable(node)); // sanity check to find a mistake when this function is reused later
+
+  // calcluate foldable node
+  locomotiv::Session sess(graph, std::vector<loco::Node *>{node});
+  sess.infer();
+  auto data = sess.get_output(0);
+
+  assert(data != nullptr);
+
+  auto shape = data->shape();
+  auto dtype = data->dtype();
+
+  // build ConstGen
+  auto new_const = graph->nodes()->create<loco::ConstGen>();
+  {
+    new_const->dtype(dtype);
+
+    new_const->rank(shape->rank());
+    for (int d = 0; d < shape->rank(); d++)
+      new_const->dim(d) = shape->dim(d);
+
+    auto count = num_elements(*new_const);
+
+    if (dtype == loco::DataType::FLOAT32)
+    {
+      new_const->size<loco::DataType::FLOAT32>(count);
+
+      auto const_buf = data->as_f32_bufptr()->base();
+      for (int x = 0; x < count; x++)
+        new_const->at<loco::DataType::FLOAT32>(x) = const_buf[x];
+    }
+    else if (dtype == loco::DataType::S32)
+    {
+      new_const->size<loco::DataType::S32>(count);
+
+      auto const_buf = data->as_s32_bufptr()->base();
+      for (int x = 0; x < count; x++)
+        new_const->at<loco::DataType::S32>(x) = const_buf[x];
+    }
+  }
+
+  // replace node with new_const
+  loco::replace(node).with(new_const);
+}
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+bool ConstantFoldingTransform::run(loco::Graph *graph)
+{
+  auto outputs = loco::output_nodes(graph);
+
+  bool changed = false;
+  for (auto node : loco::postorder_traversal(outputs))
+  {
+    if (foldable(node))
+    {
+      fold(graph, node);
+      changed = true;
+    }
+  }
+
+  return changed;
+}
+
+} // namespace tf
+} // namespace moco
diff --git a/compiler/moco-tf/src/Transforms/ConstantFoldingTransform.h b/compiler/moco-tf/src/Transforms/ConstantFoldingTransform.h
new file mode 100644 (file)
index 0000000..890083b
--- /dev/null
@@ -0,0 +1,44 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_CONSTANT_FOLDING_TRANSFORM_H__
+#define __MOCO_TF_CONSTANT_FOLDING_TRANSFORM_H__
+
+#include "Transform.h"
+
+#include <loco.h>
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief  Performs constant folding optimization
+ */
+class ConstantFoldingTransform : public Transform
+{
+public:
+  const char *name(void) const final { return "ConstantFoldingTransform"; }
+
+public:
+  bool run(loco::Graph *graph) override;
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_CONSTANT_FOLDING_TRANSFORM_H__
diff --git a/compiler/moco-tf/src/Transforms/ConstantFoldingTransform.test.cpp b/compiler/moco-tf/src/Transforms/ConstantFoldingTransform.test.cpp
new file mode 100644 (file)
index 0000000..c337d65
--- /dev/null
@@ -0,0 +1,118 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConstantFoldingTransform.h"
+
+#include "TestHelper.h"
+#include "IR/TFFusedBatchNorm.h"
+#include "Importer.h"
+#include "Canonicalizer.h"
+
+#include <loco.h>
+#include <locop/FormattedGraph.h>
+#include <moco/Log.h>
+
+#include <gtest/gtest.h>
+
+using namespace moco::tf::test;
+
+namespace
+{
+// clang-format off
+
+/*
+  test case:
+      ConstGen -- Relu -- Push
+
+  after constant folding
+              ConstGen -- Push
+*/
+const char *case01 = STRING_CONTENT(
+node {
+  name: "input"
+  op: "Const"
+  attr {
+    key: "dtype" value { type: DT_FLOAT }
+  }
+  attr {
+    key: "value"
+    value {
+      tensor {
+        dtype: DT_FLOAT
+        tensor_shape { dim { size: 2 } }
+        float_val: -3.14
+        float_val: 3.14
+      }
+    }
+  }
+}
+node {
+  name: "relu"
+  op: "Relu"
+  input: "input"
+  attr { key: "T" value { type: DT_FLOAT } }
+}
+);
+// clang-format on
+
+} // namespace
+
+namespace
+{
+
+char to_char(bool b) { return b ? 'Y' : 'N'; }
+
+} // namespace
+
+TEST(ConstantFolding, case01)
+{
+  LOGGER(l);
+
+  // load graph
+  moco::tf::Importer importer;
+  moco::tf::ModelSignature signature;
+  signature.add_output(moco::tf::TensorName("relu", 0));
+
+  tensorflow::GraphDef graph_def;
+  EXPECT_TRUE(parse_graphdef(case01, graph_def));
+  auto graph = importer.import(signature, graph_def);
+
+  // Convert graph to hold only Canonical dialect
+  moco::tf::Canonicalizer canonicalizer;
+  canonicalizer.canonicalize(graph.get());
+
+  INFO(l) << "Before ConstantFolding";
+  INFO(l) << locop::fmt<locop::LinearV1>(graph);
+
+  moco::tf::ConstantFoldingTransform transform;
+  while (transform.run(graph.get()) == true)
+  {
+    INFO(l) << "running ConstantFolding...";
+  }
+
+  INFO(l) << "After ConstantFolding ";
+  INFO(l) << locop::fmt<locop::LinearV1>(graph);
+
+  auto push = moco::tf::test::find_first_node_bytype<loco::Push>(graph.get());
+  auto const_gen = dynamic_cast<loco::ConstGen *>(push->from());
+
+  ASSERT_NE(const_gen, nullptr);
+  ASSERT_EQ(const_gen->size<loco::DataType::FLOAT32>(), 2);
+  ASSERT_EQ(const_gen->at<loco::DataType::FLOAT32>(0), 0); // result of relu (-3.14)
+  ASSERT_EQ(const_gen->at<loco::DataType::FLOAT32>(1), 3.14f);
+}
+
+// TODO Add more complex cases