[logo] Introduce ResolveDuplicateReshapePass (#6938)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Tue, 27 Aug 2019 04:20:53 +0000 (13:20 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Tue, 27 Aug 2019 04:20:53 +0000 (13:20 +0900)
This will introduce ResolveDuplicateReshapePass copied from moco

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
compiler/logo/include/logo/ResolveDuplicateReshapePass.h [new file with mode: 0644]
compiler/logo/src/Passes/ResolveDuplicateReshapePass.cpp [new file with mode: 0644]

diff --git a/compiler/logo/include/logo/ResolveDuplicateReshapePass.h b/compiler/logo/include/logo/ResolveDuplicateReshapePass.h
new file mode 100644 (file)
index 0000000..7e6c67f
--- /dev/null
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LOGO_RESOLVE_DUPLICATE_RESHAPE_PASS_H__
+#define __LOGO_RESOLVE_DUPLICATE_RESHAPE_PASS_H__
+
+#include <logo/Pass.h>
+
+#include <loco.h>
+
+namespace logo
+{
+
+/**
+ * @brief  Resolve duplicated Reshape nodes in a row
+ */
+class ResolveDuplicateReshapePass final : public Pass
+{
+public:
+  const char *name(void) const final { return "ResolveDuplicateReshapePass"; }
+
+public:
+  bool run(loco::Graph *graph) override;
+};
+
+} // namespace logo
+
+#endif // __LOGO_RESOLVE_DUPLICATE_RESHAPE_PASS_H__
diff --git a/compiler/logo/src/Passes/ResolveDuplicateReshapePass.cpp b/compiler/logo/src/Passes/ResolveDuplicateReshapePass.cpp
new file mode 100644 (file)
index 0000000..d3c74cb
--- /dev/null
@@ -0,0 +1,108 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <logo/ResolveDuplicateReshapePass.h>
+
+#include <loco.h>
+
+#include <cassert>
+
+namespace
+{
+
+/// @return  true when 'node' and its input node are both FixedReshapes
+bool is_duplicate_reshape(loco::Node *node)
+{
+  auto node_as_reshape = dynamic_cast<loco::FixedReshape *>(node);
+
+  if (!node_as_reshape)
+    return false;
+
+  auto input_as_reshape = dynamic_cast<loco::FixedReshape *>(node_as_reshape->input());
+
+  if (!input_as_reshape)
+    return false;
+
+  return true;
+}
+
+/**
+ * @brief  Remap reshape's input to its input's input, i.e. bypass input reshape
+ *
+ * Before:
+ *
+ *   In ----- FixedReshape_1 ----- [Out_1]*
+ *                        \
+ *                         ------- FixedReshape_2 --- [Out_2]*
+ *                                 ('reshape' arg)
+ *
+ * After:
+ *
+ *   In ----- FixedReshape_1 ----- [Out_1]*
+ *    \
+ *     --------------------------- FixedReshape_2 --- [Out_2]*
+ *
+ * Note: In case of no Out_1, FixedReshape_1 becomes dead node.
+ *       Out_1 can be another FixedReshape as well, which would be resolved in
+ *       another occurance of this transform pass.
+ */
+void remap_input(loco::FixedReshape *reshape)
+{
+  auto input_reshape = dynamic_cast<loco::FixedReshape *>(reshape->input());
+
+  auto volume = [](loco::FixedReshape *node) {
+    uint32_t vol = 1;
+    for (uint32_t axis = 0; axis < node->rank(); ++axis)
+    {
+      assert(node->dim(axis).known());
+      vol *= node->dim(axis).value();
+    }
+    return vol;
+  };
+
+  // Volume mismatch between duplicate reshapes is pointless
+  assert(volume(reshape) == volume(input_reshape));
+
+  // Set node's input as input's input, i.e. bypass
+  reshape->input(input_reshape->input());
+}
+
+} // namespace
+
+namespace logo
+{
+
+bool ResolveDuplicateReshapePass::run(loco::Graph *graph)
+{
+  auto outputs = loco::output_nodes(graph);
+
+  bool changed = false;
+  for (auto node : loco::postorder_traversal(outputs))
+  {
+    if (is_duplicate_reshape(node))
+    {
+      auto node_as_reshape = dynamic_cast<loco::FixedReshape *>(node);
+
+      remap_input(node_as_reshape);
+
+      changed = true;
+    }
+  }
+
+  return changed;
+}
+
+} // namespace logo