[exo] Fold Reshape of Const Pass (#8432)
author박천교/On-Device Lab(SR)/Engineer/삼성전자 <ch.bahk@samsung.com>
Thu, 24 Oct 2019 04:32:26 +0000 (13:32 +0900)
committer박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Thu, 24 Oct 2019 04:32:26 +0000 (13:32 +0900)
This commit introduces FoldReshapeOfConst Pass, which folds TFLReshape
followed by TFLConst

Signed-off-by: Cheongyo Bahk <ch.bahk@samsung.com>
compiler/exo/src/Pass/FoldReshapeOfConst.cpp [new file with mode: 0644]
compiler/exo/src/Pass/FoldReshapeOfConst.h [new file with mode: 0644]

diff --git a/compiler/exo/src/Pass/FoldReshapeOfConst.cpp b/compiler/exo/src/Pass/FoldReshapeOfConst.cpp
new file mode 100644 (file)
index 0000000..b58b364
--- /dev/null
@@ -0,0 +1,115 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "FoldReshapeOfConst.h"
+
+#include "Check.h"
+
+#include "Dialect/IR/TFLNodes.h"
+#include "Dialect/IR/TFLNodeVisitor.h"
+
+#include <loco/IR/Nodes.h>
+#include <loco/Service/ShapeInference.h>
+
+namespace
+{
+
+/**
+ * @brief   Check if node is TFLReshape and its input is TFLConst
+ * @return  Casted TFLReshape for foldable candidate, nullptr otherwise
+ */
+locoex::TFLReshape *as_candidate(loco::Node *node)
+{
+  auto reshape = dynamic_cast<locoex::TFLReshape *>(node);
+  if (not reshape)
+    return nullptr;
+
+  // Only accept Constant input of Reshape
+  if (not dynamic_cast<locoex::TFLConst *>(reshape->tensor()))
+    return nullptr;
+
+  return reshape;
+}
+
+uint32_t volume(loco::Node *tensor_node)
+{
+  auto shape = loco::shape_get(tensor_node).as<loco::TensorShape>();
+
+  uint32_t vol = 1;
+  for (uint32_t axis = 0; axis < shape.rank(); ++axis)
+    vol *= shape.dim(axis).value();
+
+  return vol;
+}
+
+void fold_reshape_of_const(locoex::TFLReshape *reshape)
+{
+  const loco::DataType FLOAT32 = loco::DataType::FLOAT32;
+
+  auto const_orig = dynamic_cast<locoex::TFLConst *>(reshape->tensor());
+
+  // Exceptions
+  {
+    EXO_ASSERT(const_orig, "Only support for Reshape-Const pair");
+    // TODO support other data types
+    if (const_orig->dtype() != FLOAT32)
+      EXO_THROW("NYI for this data type");
+
+    if (volume(const_orig) != volume(reshape))
+      EXO_THROW("New shape of Reshape is not matched");
+  }
+
+  auto new_shape = loco::shape_get(reshape).as<loco::TensorShape>();
+
+  // TFLConst to replace
+  auto const_new = reshape->graph()->nodes()->create<locoex::TFLConst>();
+
+  const_new->dtype(FLOAT32);
+  const_new->rank(new_shape.rank());
+  const_new->size<FLOAT32>(const_orig->size<FLOAT32>());
+  for (uint32_t axis = 0; axis < new_shape.rank(); ++axis)
+    const_new->dim(axis) = new_shape.dim(axis);
+
+  for (uint32_t i = 0; i < const_new->size<FLOAT32>(); ++i)
+  {
+    const_new->at<FLOAT32>(i) = const_orig->at<FLOAT32>(i);
+  }
+
+  // replace
+  loco::replace(reshape).with(const_new);
+}
+
+} // namespace
+
+namespace exo
+{
+
+bool FoldReshapeOfConst::run(loco::Graph *g)
+{
+  bool changed = false;
+  for (auto node : loco::active_nodes(loco::output_nodes(g)))
+  {
+    if (auto reshape = as_candidate(node))
+    {
+      fold_reshape_of_const(reshape);
+      changed = true;
+    }
+  }
+
+  return changed;
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/Pass/FoldReshapeOfConst.h b/compiler/exo/src/Pass/FoldReshapeOfConst.h
new file mode 100644 (file)
index 0000000..66b4b58
--- /dev/null
@@ -0,0 +1,46 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __PASS_FOLD_RESHAPE_OF_CONST_H__
+#define __PASS_FOLD_RESHAPE_OF_CONST_H__
+
+#include <logo/Pass.h>
+
+namespace exo
+{
+
+/**
+ * @brief Class to fuse TFLReshape + TFLConst into one equivalent TFLConst
+ *
+ * <before>
+ *      TFLConst --- TFLReshape --- Out
+ *
+ * <after>
+ *      TFLConst --- TFLReshape ---
+ *      TFLConst (new) ------------ Out
+ *
+ * TODO This pass is for temporary. Deprecate this pass.
+ */
+struct FoldReshapeOfConst final : public logo::Pass
+{
+  const char *name(void) const final { return "exo::FoldReshapeOfConst"; }
+
+  bool run(loco::Graph *g) final;
+};
+
+} // namespace exo
+
+#endif // __PASS_FOLD_RESHAPE_OF_CONST_H__