[moco-tf] Introduce ResolveDuplicateReshape (#6499)
author박천교/On-Device Lab(SR)/Engineer/삼성전자 <ch.bahk@samsung.com>
Mon, 12 Aug 2019 10:02:26 +0000 (19:02 +0900)
committer박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Mon, 12 Aug 2019 10:02:26 +0000 (19:02 +0900)
ResolvedDuplicateReshape transform introduced. This transform finds
duplicate FixedReshape pair, and let the later one to bypass the
former.

Signed-off-by: Cheongyo Bahk <ch.bahk@samsung.com>
compiler/moco-tf/src/Transforms/ResolveDuplicateReshape.cpp [new file with mode: 0644]
compiler/moco-tf/src/Transforms/ResolveDuplicateReshape.h [new file with mode: 0644]

diff --git a/compiler/moco-tf/src/Transforms/ResolveDuplicateReshape.cpp b/compiler/moco-tf/src/Transforms/ResolveDuplicateReshape.cpp
new file mode 100644 (file)
index 0000000..153fa36
--- /dev/null
@@ -0,0 +1,111 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ResolveDuplicateReshape.h"
+
+#include <loco.h>
+
+#include <cassert>
+
+namespace
+{
+
+using FixedReshape = loco::Reshape<loco::ReshapeType::Fixed>;
+
+/// @return  true when 'node' and its input node are both FixedReshapes
+bool is_duplicate_reshape(loco::Node *node)
+{
+  auto node_as_reshape = dynamic_cast<FixedReshape *>(node);
+
+  if (!node_as_reshape)
+    return false;
+
+  auto input_as_reshape = dynamic_cast<FixedReshape *>(node_as_reshape->input());
+
+  if (!input_as_reshape)
+    return false;
+
+  return true;
+}
+
+/**
+ * @brief  Remap reshape's input to its input's input, i.e. bypass input reshape
+ *
+ * Before:
+ *
+ *   In ----- FixedReshape_1 ----- [Out_1]*
+ *                        \
+ *                         ------- FixedReshape_2 --- [Out_2]*
+ *                                 ('reshape' arg)
+ *
+ * After:
+ *
+ *   In ----- FixedReshape_1 ----- [Out_1]*
+ *    \
+ *     --------------------------- FixedReshape_2 --- [Out_2]*
+ *
+ * Note: In case of no Out_1, FixedReshape_1 becomes dead node.
+ *       Out_1 can be another FixedReshape as well, which would be resolved in
+ *       another occurance of this transform pass.
+ */
+void remap_input(FixedReshape *reshape)
+{
+  auto input_reshape = dynamic_cast<FixedReshape *>(reshape->input());
+
+  auto volume = [](FixedReshape *node) {
+    uint32_t vol = 1;
+    for (uint32_t axis = 0; axis < node->rank(); ++axis)
+    {
+      assert(node->dim(axis).known());
+      vol *= node->dim(axis).value();
+    }
+    return vol;
+  };
+
+  // Volume mismatch between duplicate reshapes is pointless
+  assert(volume(reshape) == volume(input_reshape));
+
+  // Set node's input as input's input, i.e. bypass
+  reshape->input(input_reshape->input());
+}
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+bool ResolveDuplicateReshape::run(loco::Graph *graph)
+{
+  auto outputs = loco::output_nodes(graph);
+
+  bool changed = false;
+  for (auto node : loco::postorder_traversal(outputs))
+  {
+    if (is_duplicate_reshape(node))
+    {
+      auto node_as_reshape = dynamic_cast<FixedReshape *>(node);
+
+      remap_input(node_as_reshape);
+    }
+  }
+
+  return changed;
+}
+
+} // namespace tf
+} // namespace moco
diff --git a/compiler/moco-tf/src/Transforms/ResolveDuplicateReshape.h b/compiler/moco-tf/src/Transforms/ResolveDuplicateReshape.h
new file mode 100644 (file)
index 0000000..fa69901
--- /dev/null
@@ -0,0 +1,44 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_RESOLVE_DUPLICATE_RESHAPE_H__
+#define __MOCO_TF_RESOLVE_DUPLICATE_RESHAPE_H__
+
+#include "Transform.h"
+
+#include <loco.h>
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief  Resolve duplicated Reshape nodes in a row
+ */
+class ResolveDuplicateReshape : public Transform
+{
+public:
+  const char *name(void) const final { return "ResolveDuplicateReshape"; }
+
+public:
+  bool run(loco::Graph *graph) override;
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_RESOLVE_DUPLICATE_RESHAPE_H__