[nnc] Add transpose folding optimization (#6724)
authorСергей Баранников/AI Tools Lab /SRR/Engineer/삼성전자 <s.barannikov@samsung.com>
Tue, 20 Aug 2019 10:48:24 +0000 (19:48 +0900)
committerAlexander Efimov/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Tue, 20 Aug 2019 10:48:24 +0000 (13:48 +0300)
Constant fold sequence of operations Constant -> Transpose.

Signed-off-by: Sergei Barannikov <s.barannikov@samsung.com>
compiler/nnc/driver/Driver.cpp
compiler/nnc/include/passes/optimizations/ConstantFoldTranspose.h [new file with mode: 0644]
compiler/nnc/passes/optimizations/CMakeLists.txt
compiler/nnc/passes/optimizations/ConstantFoldTranspose.cpp [new file with mode: 0644]

index 117bb06..cd49be5 100644 (file)
@@ -24,6 +24,7 @@
 #include "passes/acl_soft_backend/AclCppGenerator.h"
 
 #include "passes/optimizations/CombineTransposes.h"
+#include "passes/optimizations/ConstantFoldTranspose.h"
 #include "passes/optimizations/RemoveDeadEnds.h"
 #include "passes/optimizations/FuseArithmeticOps.h"
 #include "passes/optimizations/SinkRelu.h"
@@ -140,6 +141,10 @@ void Driver::registerBackendPass()
 
 void Driver::registerOptimizationPass()
 {
+  // TODO For now this optimization is mandatory. Do it optional when ACL backend is able to handle
+  //  all transposes that come from importers.
+  _passManager.registerPass(std::unique_ptr<Pass>(new ConstantFoldTranspose()));
+
   if (cli::doOptimizationPass)
   {
     // TODO: maybe we should start managing the optimizations more intelligently?
diff --git a/compiler/nnc/include/passes/optimizations/ConstantFoldTranspose.h b/compiler/nnc/include/passes/optimizations/ConstantFoldTranspose.h
new file mode 100644 (file)
index 0000000..96e2070
--- /dev/null
@@ -0,0 +1,39 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NNCC_CONSTANT_FOLD_TRANSPOSE_H
+#define NNCC_CONSTANT_FOLD_TRANSPOSE_H
+
+#include "pass/Pass.h"
+
+namespace nnc
+{
+
+class ConstantFoldTranspose : public Pass
+{
+public:
+  PassData run(PassData data) override;
+
+  std::string getName() override
+  {
+    static const std::string name("opt_constant_fold_transpose");
+    return name;
+  };
+};
+
+} // namespace nnc
+
+#endif // NNCC_CONSTANT_FOLD_TRANSPOSE_H
index 94ba075..00799ac 100644 (file)
@@ -1,9 +1,10 @@
 set(OPTIMIZATIONS_SRC CombineTransposes.cpp
-                      FuseArithmeticOps.cpp
-                      RemoveDeadEnds.cpp
-                      SinkRelu.cpp
-                      SinkTranspose.cpp
-                      OptimizationUtils.cpp)
+        ConstantFoldTranspose.cpp
+        FuseArithmeticOps.cpp
+        RemoveDeadEnds.cpp
+        SinkRelu.cpp
+        SinkTranspose.cpp
+        OptimizationUtils.cpp)
 nnc_add_library(nnc_optimizations SHARED ${OPTIMIZATIONS_SRC})
 target_link_libraries(nnc_optimizations PRIVATE mir nnc_support)
 
diff --git a/compiler/nnc/passes/optimizations/ConstantFoldTranspose.cpp b/compiler/nnc/passes/optimizations/ConstantFoldTranspose.cpp
new file mode 100644 (file)
index 0000000..69f2b21
--- /dev/null
@@ -0,0 +1,83 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "passes/optimizations/ConstantFoldTranspose.h"
+#include "passes/optimizations/OptimizationUtils.h"
+#include "mir/GraphPatternMatcher.h"
+#include "mir/ShapeRange.h"
+#include "mir/Tensor.h"
+#include "mir/ops/ConstantOp.h"
+#include "mir/ops/TransposeOp.h"
+
+using namespace nnc;
+using namespace mir;
+
+// Copy & paste from interpreter backend.
+// TODO Extract this to a common place and use in both interpreter and optimizations.
+static void transpose(const TensorVariant &arg, TensorVariant &res,
+                      const std::vector<std::size_t> &axis_order)
+{
+  Tensor<float> arg_accessor(arg);
+  Tensor<float> res_accessor(res);
+
+  const auto &input_shape = arg.getShape();
+  const int num_axes = static_cast<int>(axis_order.size());
+  assert(num_axes == input_shape.rank());
+
+  ShapeRange in_range(input_shape);
+  Index out_index(input_shape.rank());
+
+  for (const auto &in_index : in_range)
+  {
+    for (int i = 0; i < num_axes; ++i)
+    {
+      out_index.at(i) = in_index.at(axis_order[i]);
+    }
+    res_accessor.at(out_index) = arg_accessor.at(in_index);
+  }
+}
+
+PassData ConstantFoldTranspose::run(PassData data)
+{
+  auto graph = static_cast<Graph *>(data);
+
+  GraphPatternMatcher matcher(graph);
+  auto is_constant = [](const Operation *op) { return op->getType() == Operation::Type::constant; };
+  auto is_transpose = [](const Operation *op) {
+    return op->getType() == Operation::Type::transpose;
+  };
+
+  auto matches = matcher.matchEdge(is_constant, is_transpose);
+  while (!matches.empty())
+  {
+    for (const auto match : matches)
+    {
+      auto constant_op = dynamic_cast<ops::ConstantOp *>(match.first);
+      auto transpose_op = dynamic_cast<ops::TransposeOp *>(match.second);
+
+      // FIXME Revise this when we've got type information in operations.
+      TensorVariant res(DataType::FLOAT32, transpose_op->getOutputShape(0));
+      transpose(constant_op->getValue(), res, transpose_op->getAxisOrder());
+
+      auto new_op = graph->create<ops::ConstantOp>("", res);
+
+      graph->replaceNode(transpose_op, new_op);
+      opt_util::removeNodeIfUnused(graph, constant_op);
+    }
+    matches = matcher.matchEdge(is_constant, is_transpose);
+  }
+  return graph;
+}