[moco/tf] Implement Dead Node Elimination (#4032)
author박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Mon, 1 Jul 2019 10:50:32 +0000 (19:50 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Mon, 1 Jul 2019 10:50:32 +0000 (19:50 +0900)
Now, moco.tf compilation pipiline support Dead Node Elimination. Note that
this optimization is off by default.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
contrib/moco-tf/src/Knob.lst
contrib/moco-tf/src/Optimizer.cpp
contrib/moco-tf/src/Optimizer.test.cpp
contrib/moco-tf/src/Transforms.h
contrib/moco-tf/src/Transforms/RemoveDeadNodeTransform.cpp [new file with mode: 0644]
contrib/moco-tf/src/Transforms/RemoveDeadNodeTransform.h [new file with mode: 0644]

index 91dd355..f22001a 100644 (file)
@@ -3,4 +3,5 @@
 #endif // KNOB_BOOL
 
 // KNOB_BOOL(NAME, DEFAULT_VALUE, DESCRIPTION)
+KNOB_BOOL(RemoveDeadNode, false, Enable RemoveDeadNode optimization)
 KNOB_BOOL(RemoveForwardNode, false, Enable RemoveForwardNode optimization)
index 3e9a9af..b547c85 100644 (file)
@@ -32,6 +32,11 @@ void Optimizer::optimize(loco::Graph *g) const
   moco::tf::Phase phase;
 
   /* TRANSFORM DECLARATION BEGIN */
+  if (moco::tf::get<moco::tf::Knob::RemoveDeadNode>())
+  {
+    phase.emplace_back(stdex::make_unique<RemoveDeadNodeTransform>());
+  }
+
   if (moco::tf::get<moco::tf::Knob::RemoveForwardNode>())
   {
     phase.emplace_back(stdex::make_unique<RemoveForwardNodeTransform>());
index 486ee7e..4f74d33 100644 (file)
@@ -55,3 +55,31 @@ TEST(Optimizer, simple_forward_graph)
 
   SUCCEED();
 }
+
+TEST(Optimizer, simple_forward_graph_with_one_valid_output)
+{
+  moco::tf::Optimizer o;
+
+  /**
+   * Create a simple graph that forwards a constant as graph-level output
+   */
+  loco::Graph g;
+  {
+    auto output = g.outputs()->create();
+
+    auto constgen = g.nodes()->create<loco::ConstGen>();
+    constgen->shape({2, 3});
+
+    auto forward = g.nodes()->create<loco::Forward>();
+    forward->input(constgen);
+
+    auto pull = g.nodes()->create<loco::Push>();
+    pull->from(forward);
+
+    output->node(pull);
+  }
+
+  o.optimize(&g);
+
+  SUCCEED();
+}
index 541d961..ba6ba6e 100644 (file)
@@ -20,6 +20,7 @@
 #include "Transforms/ClearAnnotTransform.h"
 #include "Transforms/FixPaddingTransform.h"
 #include "Transforms/FixShapeTransform.h"
+#include "Transforms/RemoveDeadNodeTransform.h"
 #include "Transforms/RemoveForwardNodeTransform.h"
 
 #endif // __MOCO_TF_TRANSFORMS_H__
diff --git a/contrib/moco-tf/src/Transforms/RemoveDeadNodeTransform.cpp b/contrib/moco-tf/src/Transforms/RemoveDeadNodeTransform.cpp
new file mode 100644 (file)
index 0000000..64abe5a
--- /dev/null
@@ -0,0 +1,80 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "RemoveDeadNodeTransform.h"
+
+#include <loco/IR/Algorithm.h>
+#include <loco/IR/CanonicalDialect.h>
+#include <loco/IR/CanonicalNode.h>
+
+#include <set>
+
+namespace moco
+{
+namespace tf
+{
+
+bool RemoveDeadNodeTransform::run(loco::Graph *g)
+{
+  // Let's enumerate nodes required to compute output nodes
+  auto active_nodes = loco::active_nodes(loco::output_nodes(g));
+
+  // Find dead(= non-active) nodes
+  std::set<loco::Node *> candidates;
+
+  for (auto node : loco::all_nodes(g))
+  {
+    if (active_nodes.find(node) == active_nodes.end())
+    {
+      candidates.insert(node);
+    }
+  }
+
+  // Let's drop the references from each dead node first and then remove these dead nodes
+  //
+  // Why?
+  //
+  // Let us consider the following example:
+  //    %0 = Pull(...)
+  //    %1 = ConstGen(...)
+  //    %2 = Forward(input: %1)
+  //    %3 = Push(from: %0) <- OUTPUT
+  //
+  // Forward (%2) is dead as it does not contribute to the final result (%3). However, it
+  // refers to another dead node (%1).
+  //
+  // This example indicates that naive implementation results in dangling references.
+  //
+  // There are two possible solutions:
+  //  1. Destroy nodes in topological order
+  //  2. Drop the reference first and then destroy them
+  //
+  // The current implementation takes the latter approach for the simplicity of implementation.
+  for (auto node : candidates)
+  {
+    node->drop();
+  }
+
+  for (auto node : candidates)
+  {
+    g->nodes()->destroy(node);
+  }
+
+  return candidates.size() > 0;
+}
+
+} // namespace tf
+} // namespace moco
diff --git a/contrib/moco-tf/src/Transforms/RemoveDeadNodeTransform.h b/contrib/moco-tf/src/Transforms/RemoveDeadNodeTransform.h
new file mode 100644 (file)
index 0000000..dbfd81a
--- /dev/null
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_REMOVE_DEAD_NODE_TRANSFORM_H__
+#define __MOCO_TF_REMOVE_DEAD_NODE_TRANSFORM_H__
+
+#include "Transform.h"
+
+namespace moco
+{
+namespace tf
+{
+
+struct RemoveDeadNodeTransform final : public Transform
+{
+  const char *name(void) const final { return "RemoveDeadNodeTransform"; }
+
+  bool run(loco::Graph *g);
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_REMOVE_DEAD_NODE_TRANSFORM_H__