[nnc] Introduce Compose Transposes optimization (#2972)
authorАндрей Шедько/AI Tools Lab /SRR/Engineer/삼성전자 <a.shedko@samsung.com>
Tue, 12 Feb 2019 16:48:50 +0000 (19:48 +0300)
committerEfimov Alexander/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Tue, 12 Feb 2019 16:48:50 +0000 (19:48 +0300)
This PR Introduces `Combine Transposes` optimization which combines adjacent transposes and eliminates the noop ones.

Signed-off-by: Andrei Shedko <a.shedko@samsung.com>
contrib/nnc/passes/optimizations/CombineTransposes.cpp
contrib/nnc/unittests/CMakeLists.txt
contrib/nnc/unittests/optimizations/CMakeLists.txt [new file with mode: 0644]
contrib/nnc/unittests/optimizations/CombineTransposes.cpp [new file with mode: 0644]

index 0a6f2b9..e210b4d 100644 (file)
@@ -14,6 +14,7 @@
  * limitations under the License.
  */
 
+#include <algorithm>
 #include "passes/optimizations/CombineTransposes.h"
 #include "core/modelIR/operations/TransposeOp.h"
 #include "core/modelIR/Graph.h"
@@ -23,9 +24,60 @@ namespace nnc {
 
 using namespace mir;
 
-PassData CombineTransposes::run(PassData data) {
+std::vector<size_t> combineAxisOrders(const std::vector<std::size_t>& order1,
+                                      const std::vector<std::size_t>& order2) {
+  assert(order1.size() == order2.size());
+  std::vector<size_t> res(order1.size());
+  for (size_t i = 0; i < order1.size(); i++) {
+    res[order2[order1[i]]] = i;
+  }
+  return res;
+}
+
+static bool isIdentityTranspose(const std::vector<size_t>& axis_order) {
+  for (size_t i = 0; i < ( axis_order.size()); i++) {
+    if (axis_order[i] != i) {
+      return false;
+    }
+  }
+  return true;
+}
+
+nnc::PassData nnc::CombineTransposes::run(nnc::PassData data) {
   auto g = static_cast<Graph*>(data);
   assert(g);
+  GraphPatternMatcher matcher(g);
+  auto is_tr = [](const Operation* op1) { return op1->getType() == Operation::Type::transpose; };
+  std::vector<std::pair<Operation*, Operation*>> matches = matcher.matchEdge(is_tr, is_tr);
+  std::unordered_set<Operation*> deleted_nodes;
+  while (!matches.empty()) {
+    for (std::pair<Operation*, Operation*> match : matches) {
+      if (deleted_nodes.find(match.first) != deleted_nodes.end()) { break; };
+      auto* top_transpose = dynamic_cast<mir::ops::TransposeOp*>(match.first);
+      if (deleted_nodes.find(match.second) != deleted_nodes.end()) { break; };
+      auto* bottom_transpose = dynamic_cast<mir::ops::TransposeOp*>(match.second);
+      auto combined_axis_order = combineAxisOrders(top_transpose->getAxisOrder(),
+                                                   bottom_transpose->getAxisOrder());
+
+      if (!isIdentityTranspose(combined_axis_order)) {
+        auto new_tr_op = g->create<mir::ops::TransposeOp>(
+            top_transpose->getName() + "new",
+            top_transpose->getInput(0)->getProducer(), combined_axis_order);
+
+        g->replaceNode(bottom_transpose, new_tr_op);
+      } else {
+        // Connect top input to all outputs of bottom
+        Operation* top = top_transpose->getInput(0)->getProducer()->getNode();
+        g->replaceNode(bottom_transpose, top);
+      }
+      deleted_nodes.emplace(bottom_transpose);
+      if (top_transpose->getOutput(0)->getConsumers().empty()) {
+        g->removeNode(top_transpose);
+        deleted_nodes.emplace(top_transpose);
+      }
+    }
+    matches = matcher.matchEdge(is_tr, is_tr);
+  };
   return g;
 }
 
index 7d61e4e..64779ed 100644 (file)
@@ -7,6 +7,7 @@ add_subdirectory(core)
 add_subdirectory(soft_backend)
 add_subdirectory(acl_backend)
 add_subdirectory(support)
+add_subdirectory(optimizations)
 if(NNC_FRONTEND_CAFFE_ENABLED)
   add_subdirectory(caffe_frontend)
 endif()
diff --git a/contrib/nnc/unittests/optimizations/CMakeLists.txt b/contrib/nnc/unittests/optimizations/CMakeLists.txt
new file mode 100644 (file)
index 0000000..af6e278
--- /dev/null
@@ -0,0 +1,3 @@
+set(TESTS_OPTIMIZATIONS_SRC "CombineTransposes.cpp")
+nnc_add_unit_test(tests_for_optimizations ${TESTS} ${TESTS_OPTIMIZATIONS_SRC})
+nncc_target_link_libraries(tests_for_optimizations nnc_optimizations nnc_support nnc_core)
diff --git a/contrib/nnc/unittests/optimizations/CombineTransposes.cpp b/contrib/nnc/unittests/optimizations/CombineTransposes.cpp
new file mode 100644 (file)
index 0000000..3a09572
--- /dev/null
@@ -0,0 +1,161 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "passes/optimizations/CombineTransposes.h"
+#include "core/modelIR/operations/TransposeOp.h"
+#include "core/modelIR/operations/ReluOp.h"
+#include "core/modelIR/operations/ElementwiseOp.h"
+#include "core/modelIR/operations/ConstantOp.h"
+
+#include <gtest/gtest.h>
+
+using namespace std;
+using namespace nnc;
+using namespace mir;
+
+class DumpVisitor : public Visitor {
+public:
+  DumpVisitor(std::ostream& s) : _s(s) {}
+
+  void visit(ops::InputOp& op) override { _s << "i_" << op.getName() << "."; };
+
+  void visit(ops::TransposeOp& op) override { _s << "t_" << op.getName() << "."; }
+
+  void visit(ops::ReluOp& op) override { _s << "r_" << op.getName() << "."; }
+
+  void visit(ops::ElementwiseOp& op) override { _s << "e_" << op.getName() << "."; }
+
+  std::ostream& _s;
+};
+
+TEST(OptPass, eliminateTransposesLinear) {
+  mir::Graph g;
+  /*   Create graph:
+   *      [input]
+   *        ||
+   *   [Transpose 1]
+   *        ||
+   *   [Transpose 2]
+   *        ||
+   *      [relu]
+   */
+  Operation* input = g.create<ops::InputOp>("input", Shape{1, 2, 3});
+  Operation* tr1 = g.create<ops::TransposeOp>("tr", input->getOutput(0), vector<size_t>{1, 0, 2});
+  Operation* tr15 = g.create<ops::TransposeOp>("tr", tr1->getOutput(0), vector<size_t>{1, 0, 2});
+  Operation* tr2 = g.create<ops::TransposeOp>("tr", tr15->getOutput(0), vector<size_t>{1, 0, 2});
+  Operation* relu = g.create<ops::ReluOp>("relu", tr2->getOutput(0));
+
+  // Check that layout is desired
+  std::stringstream ss;
+  DumpVisitor d(ss);
+  CombineTransposes pass;
+  pass.run(&g);
+  g.accept(&d);
+  // Assert only 1 transpose remains
+  ASSERT_EQ("i_input.t_tr.r_relu.", ss.str());
+}
+
+TEST(OptPass, combineTransposesLinear) {
+  mir::Graph g;
+  /* Create graph:
+   *      [input]
+   *        ||
+   *   [Transpose 1]
+   *        ||
+   *   [Transpose 2]
+   *        ||
+   *      [relu]
+   */
+  Operation* input = g.create<ops::InputOp>("input", Shape{1, 2, 3});
+  Operation* tr1 = g.create<ops::TransposeOp>("tr1", input->getOutput(0), vector<size_t>{1, 0, 2});
+  Operation* tr2 = g.create<ops::TransposeOp>("tr2", tr1->getOutput(0), vector<size_t>{0, 2, 1});
+  Operation* relu = g.create<ops::ReluOp>("relu", tr2->getOutput(0));
+
+  std::stringstream ss;
+  DumpVisitor d(ss);
+  CombineTransposes pass;
+  pass.run(&g);
+  g.accept(&d);
+
+  // Assert transposes are combined
+  ASSERT_EQ("i_input.t_tr1new.r_relu.", ss.str());
+  auto ax_ord_actual = dynamic_cast<ops::TransposeOp*>(
+      ( *( g.getInputs()[0]->getOutput(0)->getConsumers().begin()))->getNode())->getAxisOrder();
+  auto ax_ord_true = vector<size_t>{1, 2, 0};
+  ASSERT_TRUE(ax_ord_actual == ax_ord_true);
+}
+
+TEST(OptPass, combineTransposesBush) {
+  mir::Graph g;
+  /*      Create graph:
+   *         [input]
+   *            ||
+   *       [Transpose 1]
+   *        //       \\
+   *[Transpose 2] [Transpose 3]
+   *       \\       //
+   *    [Elementwise<add>]
+   */
+  Operation* input = g.create<ops::InputOp>("input", Shape{1, 2, 3, 2});
+  Operation* tr1 = g.create<ops::TransposeOp>("tr1", input->getOutput(0),
+                                              vector<size_t>{1, 0, 2, 3});
+  Operation* tr2 = g.create<ops::TransposeOp>("tr2", tr1->getOutput(0), vector<size_t>{1, 0, 2, 3});
+  Operation* tr3 = g.create<ops::TransposeOp>("tr3", tr1->getOutput(0), vector<size_t>{1, 0, 2, 3});
+  Operation* elw = g.create<ops::ElementwiseOp>("elewiseAdd",
+                                                vector<Operation::Output*>{tr2->getOutput(0),
+                                                                           tr3->getOutput(0)},
+                                                ops::ElementwiseOp::OpType::add);
+  std::stringstream ss;
+  DumpVisitor d(ss);
+  CombineTransposes pass;
+  pass.run(&g);
+  g.accept(&d);
+  ASSERT_EQ("i_input.e_elewiseAdd.", ss.str());
+  ASSERT_EQ(elw->getInput(0)->getProducer()->getNode()->getName(), "input");
+  ASSERT_EQ(elw->getInput(1)->getProducer()->getNode()->getName(), "input");
+}
+
+TEST(OptPass, combineTransposesOpOrder) {
+  mir::Graph g;
+  /*      Create graph:
+   *   [input]     [input2]
+   *      ||          ||
+   * [Transpose 0] [Transpose1]
+   *      ||          ||
+   * [Transpose 2] [Transpose 3]
+   *       \\       //
+   *    [Elementwise<add>]
+   */
+  Operation* in1 = g.create<ops::InputOp>("inp1", Shape{1, 2, 3});
+  Operation* in2 = g.create<ops::InputOp>("inp2", Shape{1, 2, 3});
+  Operation* tr0 = g.create<ops::TransposeOp>("tr0", in1->getOutput(0), vector<size_t>{1, 0, 2});
+  Operation* tr1 = g.create<ops::TransposeOp>("tr1", in2->getOutput(0), vector<size_t>{2, 1, 0});
+  Operation* tr2 = g.create<ops::TransposeOp>("tr2", tr0->getOutput(0), vector<size_t>{1, 0, 2});
+  Operation* tr3 = g.create<ops::TransposeOp>("tr3", tr1->getOutput(0), vector<size_t>{2, 1, 0});
+  Operation* elw = g.create<ops::ElementwiseOp>("elewiseAdd",
+                                                vector<Operation::Output*>{tr2->getOutput(0),
+                                                                           tr3->getOutput(0)},
+                                                ops::ElementwiseOp::OpType::add);
+  g.create<ops::OutputOp>("out", elw->getOutput(0));
+  int n1 = elw->getInput(0)->getNode()->getInput(0)->getNode()->getInput(0)->getNode()->getId();
+  int n2 = elw->getInput(1)->getNode()->getInput(0)->getNode()->getInput(0)->getNode()->getId();
+  CombineTransposes pass;
+  pass.run(&g);
+  ASSERT_EQ(g.getOutputs()[0]->getInput(0)->getProducer()->getNode()->getName(), "elewiseAdd");
+  //Order is preserved
+  ASSERT_EQ(n1, elw->getInput(0)->getNode()->getId());
+  ASSERT_EQ(n2, elw->getInput(1)->getNode()->getId());
+}