[nncc-optimization] Sink transpose optimization added (#3292)
authorАндрей Шедько/AI Tools Lab /SRR/Engineer/삼성전자 <a.shedko@samsung.com>
Tue, 7 May 2019 15:59:12 +0000 (18:59 +0300)
committerEfimov Alexander/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Tue, 7 May 2019 15:59:12 +0000 (18:59 +0300)
- Added the transpose sinking optimization
- Moved swapAdjacent to OptimizationUtils

Signed-off-by: Andrei Shedko <a.shedko@samsung.com>
19 files changed:
contrib/mir/include/mir/GraphPatternMatcher.h
contrib/mir/include/mir/ops/PoolOp.h
contrib/mir/src/Graph.cpp
contrib/mir/src/GraphPatternMatcher.cpp
contrib/nnc/driver/Driver.cpp
contrib/nnc/include/passes/optimizations/CombineTransposes.h
contrib/nnc/include/passes/optimizations/OptimizationUtils.h [new file with mode: 0644]
contrib/nnc/include/passes/optimizations/SinkRelu.h [new file with mode: 0644]
contrib/nnc/include/passes/optimizations/SinkTranspose.h [new file with mode: 0644]
contrib/nnc/passes/optimizations/CMakeLists.txt
contrib/nnc/passes/optimizations/CombineTransposes.cpp
contrib/nnc/passes/optimizations/FuseArithmeticOps.cpp
contrib/nnc/passes/optimizations/OptimizationUtils.cpp [new file with mode: 0644]
contrib/nnc/passes/optimizations/SinkRelu.cpp [new file with mode: 0644]
contrib/nnc/passes/optimizations/SinkTranspose.cpp [new file with mode: 0644]
contrib/nnc/unittests/optimizations/CMakeLists.txt
contrib/nnc/unittests/optimizations/CombineTransposes.cpp
contrib/nnc/unittests/optimizations/SinkTest.cpp [new file with mode: 0644]
contrib/nnc/unittests/optimizations/Util.h [new file with mode: 0644]

index 453743d..6e0fae4 100644 (file)
@@ -35,6 +35,14 @@ public:
    */
   std::vector<std::pair<Operation*, Operation*>> matchEdge(Predicate p1, Predicate p2);
 
+  /**
+   * @brief Match a two level tree where the bottommost node has multiple previous nodes
+   * @param p1 Predicate for top node
+   * @param p2 Predicate for bottom node
+   * @return Vector of pairs : all matches; empty if no matches are found
+   */
+  std::vector<std::pair<std::vector<Operation*>, Operation*>> matchUpBush(Predicate p1, Predicate p2);
+
 private:
   Graph* _g;
 };
index 46ac221..d5f9d80 100644 (file)
@@ -71,7 +71,6 @@ public:
   const std::vector<int32_t>& getPaddingBefore() const { return _paddingBefore; }
 
   const std::vector<int32_t>& getPaddingAfter() const { return _paddingAfter; }
-
 private:
   void inferOutputShapes();
 
index c3b9fd0..f14acea 100644 (file)
@@ -144,4 +144,5 @@ void Graph::removeNode(Operation* op) {
   delete op;
 }
 
+
 } // namespace mir
index 1fa42a8..92b0eed 100644 (file)
@@ -42,4 +42,24 @@ std::vector<std::pair<Operation*, Operation*>> GraphPatternMatcher::matchEdge(
   return matches;
 }
 
+  std::vector<std::pair<std::vector<Operation*>, Operation*>>
+  GraphPatternMatcher::matchUpBush(mir::GraphPatternMatcher::Predicate p1,
+                                   mir::GraphPatternMatcher::Predicate p2) {
+    std::vector<std::pair<std::vector<Operation*>, Operation*>> matches;
+    for (auto* root: _g->getNodes()) {
+      if (p2(root)) {
+        auto& prev_nodes = root->getInputs();
+        if (std::all_of(prev_nodes.begin(), prev_nodes.end(),
+                        [p1](const Operation::Input& input) { return p1(input.getProducer()->getNode()); })) {
+          std::vector<Operation*> tops;
+          tops.reserve(prev_nodes.size());
+          for (auto& pr : prev_nodes) {
+            tops.emplace_back(pr.getProducer()->getNode());
+          }
+          matches.emplace_back(std::make_pair(tops, root));
+        }
+      }
+    }
+    return matches;
+  }
 } // namespace mir
index a360bfe..cab9ea3 100644 (file)
@@ -25,6 +25,8 @@
 
 #include "passes/optimizations/CombineTransposes.h"
 #include "passes/optimizations/FuseArithmeticOps.h"
+#include "passes/optimizations/SinkRelu.h"
+#include "passes/optimizations/SinkTranspose.h"
 
 #include "support/CommandLine.h"
 #include "Definitions.h"
@@ -119,7 +121,10 @@ void Driver::registerBackendPass() {
 
 void Driver::registerOptimizationPass() {
   if (cli::doOptimizationPass) {
+    // TODO: maybe we should start managing the optimizations more intelligently?
     _passManager.registerPass(std::unique_ptr<Pass>(new CombineTransposes()));
+    _passManager.registerPass(std::unique_ptr<Pass>(new SinkTranspose()));
+    _passManager.registerPass(std::unique_ptr<Pass>(new SinkRelu()));
     _passManager.registerPass(std::unique_ptr<Pass>(new FuseArithmeticOps()));
   }
 } // registerOptimizationPass
diff --git a/contrib/nnc/include/passes/optimizations/OptimizationUtils.h b/contrib/nnc/include/passes/optimizations/OptimizationUtils.h
new file mode 100644 (file)
index 0000000..de2be4c
--- /dev/null
@@ -0,0 +1,38 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NNCC_OPTIMIZATION_UTILS_H
+#define NNCC_OPTIMIZATION_UTILS_H
+
+#include "mir/Operation.h"
+#include "mir/Graph.h"
+
+namespace nnc {
+namespace opt_util {
+/**
+* @brief Swap adjacent nodes in Graph. Creates new nodes and replaces the old ones with new.
+* @param g MIR Graph
+* @param top Node
+* @param bottom Node
+*/
+  void swapAdjacent(mir::Graph* g, mir::Operation* top, mir::Operation* bottom);
+
+// TODO: this function and it's usages should be removed, after DCE optimization will be implemented
+  void removeNodeIfUnsed(mir::Graph* g, mir::Operation* op);
+} // namespace opt_util
+} // namespace nnc
+
+#endif //NNCC_OPTIMIZATION_UTILS_H
diff --git a/contrib/nnc/include/passes/optimizations/SinkRelu.h b/contrib/nnc/include/passes/optimizations/SinkRelu.h
new file mode 100644 (file)
index 0000000..8a27853
--- /dev/null
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NNCC_SINKRELU_H
+#define NNCC_SINKRELU_H
+
+#include "pass/Pass.h"
+#include "pass/PassData.h"
+
+namespace nnc {
+
+/**
+ * @brief This pass sinks relu below MaxPooling and Concat nodes.
+ */
+class SinkRelu : public Pass {
+public:
+  PassData run(PassData data) override;
+
+  std::string getName() override { return "SinkRelu"; };
+};
+
+} // namespace nnc
+
+#endif //NNCC_SINKRELU_H
diff --git a/contrib/nnc/include/passes/optimizations/SinkTranspose.h b/contrib/nnc/include/passes/optimizations/SinkTranspose.h
new file mode 100644 (file)
index 0000000..2515b10
--- /dev/null
@@ -0,0 +1,38 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NNCC_SINKTRANSPOSE_H
+#define NNCC_SINKTRANSPOSE_H
+
+#include "pass/Pass.h"
+#include "pass/PassData.h"
+
+namespace nnc {
+
+/**
+ * @brief This pass sinks transposes below Relu and Concat nodes (in that order).
+ * `concat(relu(tr(x)), relu(tr(y))) -> tr(concat'(relu(x), relu(y)))`
+ */
+class SinkTranspose : public Pass {
+public:
+  PassData run(PassData data) override;
+
+  std::string getName() override { return "SinkTranspose"; };
+};
+
+} // namespace nnc
+
+#endif //NNCC_SINKTRANSPOSE_H
index 2808ea6..f29109a 100644 (file)
@@ -1,4 +1,8 @@
-set(OPTIMIZATIONS_SRC CombineTransposes.cpp FuseArithmeticOps.cpp)
+set(OPTIMIZATIONS_SRC CombineTransposes.cpp
+                      FuseArithmeticOps.cpp
+                      SinkRelu.cpp
+                      SinkTranspose.cpp
+                      OptimizationUtils.cpp)
 nnc_add_library(nnc_optimizations SHARED ${OPTIMIZATIONS_SRC})
 target_link_libraries(nnc_optimizations PRIVATE mir nnc_support)
 
index ff958c2..9349c96 100644 (file)
  * limitations under the License.
  */
 
-#include <algorithm>
 #include "passes/optimizations/CombineTransposes.h"
 #include "mir/ops/TransposeOp.h"
 #include "mir/Graph.h"
 #include "mir/GraphPatternMatcher.h"
+#include <algorithm>
 
 namespace nnc {
 
index b33a6a7..d17ade7 100644 (file)
@@ -15,6 +15,7 @@
  */
 
 #include "passes/optimizations/FuseArithmeticOps.h"
+#include "passes/optimizations/OptimizationUtils.h"
 #include "mir/ops/BiasAddOp.h"
 #include "mir/ops/ConstantOp.h"
 #include "mir/ops/Conv2DOp.h"
@@ -34,16 +35,11 @@ namespace {
 
 using namespace mir;
 using namespace std;
+using namespace opt_util;
 
 using OpType = Operation::Type;
 using Edge = pair<Operation*, Operation*>;
 
-// TODO: this function and it's usages should be removed, after DCE optimization will be implemented
-void removeNodeIfUnsed(Graph* g, Operation* op) {
-  if (op->getOutput(0)->getConsumers().empty())
-    g->removeNode(op);
-}
-
 /**
  * This function used to get 'ConstantOp' with weights of 'BiasAddOp', 'ScaleOp' or 'Conv2DOp'
  * For each of these ops weights stored in second input node
diff --git a/contrib/nnc/passes/optimizations/OptimizationUtils.cpp b/contrib/nnc/passes/optimizations/OptimizationUtils.cpp
new file mode 100644 (file)
index 0000000..edb2565
--- /dev/null
@@ -0,0 +1,47 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "passes/optimizations/OptimizationUtils.h"
+
+namespace nnc {
+namespace opt_util {
+
+void swapAdjacent(mir::Graph* g, mir::Operation* top, mir::Operation* bottom) {
+  assert(top->getNumInputs() == bottom->getNumInputs() &&
+         top->getNumInputs() == bottom->getNumOutputs() && "incompatible ops");
+  auto& ins = top->getInputs();
+  std::vector<mir::Operation::Output*> prods;
+  prods.reserve(top->getNumInputs());
+  for (auto& in: ins) {
+    prods.emplace_back(in.getProducer());
+  }
+  mir::Operation* new_bottom = g->copyOpWithInputs(bottom, prods);
+  prods.clear();
+  prods.reserve(new_bottom->getNumOutputs());
+  for (mir::Operation::Output& out: new_bottom->getOutputs()) {
+    prods.emplace_back(&out);
+  }
+  mir::Operation* new_top = g->copyOpWithInputs(top, prods);
+  g->replaceNode(bottom, new_top);
+  g->replaceNode(top, new_bottom);
+}
+
+// TODO: this function and it's usages should be removed, after DCE optimization will be implemented
+void removeNodeIfUnsed(mir::Graph* g, mir::Operation* op) {
+  if (op->getOutput(0)->getConsumers().empty())
+    g->removeNode(op);
+}
+} // namespace opt_util
+} // namespace nnc
diff --git a/contrib/nnc/passes/optimizations/SinkRelu.cpp b/contrib/nnc/passes/optimizations/SinkRelu.cpp
new file mode 100644 (file)
index 0000000..8c58ff2
--- /dev/null
@@ -0,0 +1,80 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "passes/optimizations/SinkRelu.h"
+#include "passes/optimizations/OptimizationUtils.h"
+#include "mir/ops/TransposeOp.h"
+#include "mir/ops/ConcatOp.h"
+#include "mir/ops/ReluOp.h"
+#include "mir/ops/PoolOp.h"
+#include "mir/Graph.h"
+#include "mir/GraphPatternMatcher.h"
+
+#include <string>
+#include <algorithm>
+
+namespace nnc {
+
+using namespace mir;
+using namespace opt_util;
+
+/*
+ Static casts are safe here because we have already checked the types when matching.
+ */
+PassData SinkRelu::run(PassData data) {
+  auto g = static_cast<Graph*>(data);
+  assert(g);
+  GraphPatternMatcher matcher(g);
+  auto is_relu = [](const Operation* op) { return op->getType() == Operation::Type::ReLU; };
+  auto is_concat = [](const Operation* op) { return op->getType() == Operation::Type::concat; };
+  auto is_max_pool = [](const Operation* op) {
+    auto* p_op = dynamic_cast<const ops::PoolOp*>(op);
+    if (!p_op) return false;
+    return p_op->getPoolingType() == ops::PoolOp::PoolingType::MAX;
+  };
+  std::vector<std::pair<Operation*, Operation*>> matches;
+
+  // sink ReLU through MaxPool
+  matches = matcher.matchEdge(is_relu, is_max_pool);
+  for (auto pair: matches) {
+    swapAdjacent(g, pair.first, pair.second);
+  }
+  // sink ReLU through Concat
+  auto matches_v = matcher.matchUpBush(is_relu, is_concat);
+  for (const auto& pair : matches_v) {
+    auto relus = pair.first;
+    auto* concat = dynamic_cast<ops::ConcatOp*>(pair.second);
+    std::vector<Operation::Output*> pre_relu;
+    pre_relu.reserve(relus.size());
+    for (auto* r : relus) {
+      pre_relu.emplace_back(r->getInput(0)->getProducer());
+    }
+    // create replacement nodes
+    auto new_concat = g->create<ops::ConcatOp>(
+        concat->getName() + "_before_relu", pre_relu, concat->getAxis());
+    auto new_relu = g->create<ops::ReluOp>(
+        relus[0]->getName() + "_after_concat", new_concat->getOutput(0));
+
+    // concat is deleted here
+    g->replaceNode(concat, new_relu);
+    for (auto r: relus) {
+      removeNodeIfUnsed(g,r);
+    }
+  }
+  return g;
+}
+
+} // namespace nnc
diff --git a/contrib/nnc/passes/optimizations/SinkTranspose.cpp b/contrib/nnc/passes/optimizations/SinkTranspose.cpp
new file mode 100644 (file)
index 0000000..4cd39e7
--- /dev/null
@@ -0,0 +1,80 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "passes/optimizations/SinkTranspose.h"
+#include "passes/optimizations/OptimizationUtils.h"
+#include "mir/ops/TransposeOp.h"
+#include "mir/ops/ConcatOp.h"
+#include "mir/ops/ReluOp.h"
+#include "mir/Graph.h"
+#include "mir/GraphPatternMatcher.h"
+
+#include <string>
+#include <algorithm>
+
+namespace nnc {
+
+using namespace mir;
+using namespace opt_util;
+/*
+ Static casts are safe here because we have already checked the types when matching.
+ */
+PassData SinkTranspose::run(PassData data) {
+  auto g = static_cast<Graph*>(data);
+  assert(g); // NOLINT
+  GraphPatternMatcher matcher(g);
+  auto is_tr = [](const Operation* op1) { return op1->getType() == Operation::Type::transpose; };
+  auto is_relu = [](const Operation* op2) { return op2->getType() == Operation::Type::ReLU; };
+  auto is_concat = [](const Operation* op2) { return op2->getType() == Operation::Type::concat; };
+  std::vector<std::pair<Operation*, Operation*>> matches;
+
+  // sink transpose below ReLU
+  matches = matcher.matchEdge(is_tr, is_relu);
+  for (auto pair : matches) {
+    swapAdjacent(g, pair.first, pair.second);
+  }
+
+  // sink transpose through Concat
+  auto v_matches = matcher.matchUpBush(is_tr, is_concat);
+  for (const auto& pair : v_matches) {
+    std::vector<Operation*> trs = pair.first;
+    auto* concat = dynamic_cast<ops::ConcatOp*>(pair.second);
+    auto axis_order = dynamic_cast<ops::TransposeOp* >(trs[0])->getAxisOrder();
+    if (std::all_of( trs.begin(), trs.end(), [&axis_order](Operation* tr) {
+        return dynamic_cast<ops::TransposeOp*>(tr)->getAxisOrder() == axis_order;
+    })) {
+      std::vector<Operation::Output*> prev_trans;
+      prev_trans.reserve(trs.size());
+      for (auto transpose : trs) {
+        prev_trans.emplace_back(transpose->getInput(0)->getProducer());
+      }
+      auto new_concat = g->create<ops::ConcatOp>(
+        concat->getName() + "_transposed", prev_trans, axis_order[concat->getAxis()]);
+      auto new_transpose = g->create<ops::TransposeOp>(trs[0]->getName() + "_after_concat",
+                                                       new_concat->getOutput(0),
+                                                       axis_order);
+      // removes old concat
+      g->replaceNode(concat, new_transpose);
+      for (auto tr: trs) {
+        removeNodeIfUnsed(g,tr);
+      }
+    }
+  }
+
+  return g;
+}
+
+} // namespace nnc
index ea8d366..1527254 100644 (file)
@@ -1,3 +1,5 @@
-set(TESTS_OPTIMIZATIONS_SRC "CombineTransposes.cpp" "FuseArithmeticOps.cpp")
+set(TESTS_OPTIMIZATIONS_SRC CombineTransposes.cpp
+                            SinkTest.cpp
+                            FuseArithmeticOps.cpp)
 nnc_add_unit_test(tests_for_optimizations ${TESTS} ${TESTS_OPTIMIZATIONS_SRC})
 optional_target_link_libraries(tests_for_optimizations nnc_optimizations nnc_support mir)
index 9a559fb..727465b 100644 (file)
 #include "mir/ops/ReluOp.h"
 #include "mir/ops/ElementwiseOp.h"
 #include "mir/ops/ConstantOp.h"
-
+#include "mir/ops/TanhOp.h"
+#include "mir/ops/ConcatOp.h"
+#include "mir/ops/OutputOp.h"
+#include "mir/ops/PoolOp.h"
+#include "Util.h"
 #include <gtest/gtest.h>
 
 using namespace std;
@@ -28,21 +32,6 @@ using namespace mir;
 
 namespace {
 
-class DumpVisitor : public Visitor {
-public:
-  DumpVisitor(std::ostream& s) : _s(s) {}
-
-  void visit(ops::InputOp& op) override { _s << "i_" << op.getName() << "."; };
-
-  void visit(ops::TransposeOp& op) override { _s << "t_" << op.getName() << "."; }
-
-  void visit(ops::ReluOp& op) override { _s << "r_" << op.getName() << "."; }
-
-  void visit(ops::ElementwiseOp& op) override { _s << "e_" << op.getName() << "."; }
-
-  std::ostream& _s;
-};
-
 TEST(OptPass, eliminateTransposesLinear) {
   mir::Graph g;
   /*   Create graph:
@@ -82,7 +71,8 @@ TEST(OptPass, combineTransposesLinear) {
    *      [relu]
    */
   Operation* input = g.create<ops::InputOp>("input", Shape{1, 2, 3});
-  Operation* tr1 = g.create<ops::TransposeOp>("tr1", input->getOutput(0), vector<size_t>{1, 0, 2});
+  Operation* tr1 = g.create<ops::TransposeOp>("tr1", input->getOutput(0),
+                                              vector<size_t>{1, 0, 2});
   Operation* tr2 = g.create<ops::TransposeOp>("tr2", tr1->getOutput(0), vector<size_t>{0, 2, 1});
   Operation* relu = g.create<ops::ReluOp>("relu", tr2->getOutput(0));
 
@@ -114,8 +104,10 @@ TEST(OptPass, combineTransposesBush) {
   Operation* input = g.create<ops::InputOp>("input", Shape{1, 2, 3, 2});
   Operation* tr1 = g.create<ops::TransposeOp>("tr1", input->getOutput(0),
                                               vector<size_t>{1, 0, 2, 3});
-  Operation* tr2 = g.create<ops::TransposeOp>("tr2", tr1->getOutput(0), vector<size_t>{1, 0, 2, 3});
-  Operation* tr3 = g.create<ops::TransposeOp>("tr3", tr1->getOutput(0), vector<size_t>{1, 0, 2, 3});
+  Operation* tr2 = g.create<ops::TransposeOp>("tr2", tr1->getOutput(0),
+                                              vector<size_t>{1, 0, 2, 3});
+  Operation* tr3 = g.create<ops::TransposeOp>("tr3", tr1->getOutput(0),
+                                              vector<size_t>{1, 0, 2, 3});
   Operation* elw = g.create<ops::ElementwiseOp>("elewiseAdd",
                                                 vector<Operation::Output*>{tr2->getOutput(0),
                                                                            tr3->getOutput(0)},
@@ -161,5 +153,4 @@ TEST(OptPass, combineTransposesOpOrder) {
   ASSERT_EQ(n1, elw->getInput(0)->getNode()->getId());
   ASSERT_EQ(n2, elw->getInput(1)->getNode()->getId());
 }
-
-}  // unnamed namespace
+} // unnamed namespace
diff --git a/contrib/nnc/unittests/optimizations/SinkTest.cpp b/contrib/nnc/unittests/optimizations/SinkTest.cpp
new file mode 100644 (file)
index 0000000..81c535f
--- /dev/null
@@ -0,0 +1,185 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "passes/optimizations/SinkTranspose.h"
+#include "passes/optimizations/SinkRelu.h"
+#include "Util.h"
+#include "mir/ops/TransposeOp.h"
+#include "mir/ops/ReluOp.h"
+#include "mir/ops/TanhOp.h"
+#include "mir/ops/ConcatOp.h"
+#include "mir/ops/OutputOp.h"
+#include "mir/ops/PoolOp.h"
+#include "mir/Graph.h"
+#include "pass/PassData.h"
+
+#include <gtest/gtest.h>
+#include <sstream>
+#include <vector>
+
+using namespace std;
+using namespace nnc;
+using namespace mir;
+
+namespace {
+Operation* getPrev(Operation* op) {
+  assert(op->getNumInputs() == 1);
+  return op->getInput(0)->getProducer()->getNode();
+}
+
+Operation* getNext(Operation* op) {
+  assert(op->getNumOutputs() == 1 && ( op->getOutput(0)->getConsumers().size() == 1 ));
+  return ( *op->getOutput(0)->getConsumers().begin())->getNode();
+}
+
+/* This tests swapping relu and transpose */
+TEST(SinkTest, sinkTrReLU) {
+  mir::Graph g;
+  /*
+   * Create graph:
+   *      [input]
+   *        ||
+   *    [Transpose]
+   *        ||
+   *      [relu]
+   *        ||
+   *      [tanh]
+   */
+  Operation* input = g.create<ops::InputOp>("input", Shape{1, 2, 3});
+  Operation* tr1 = g.create<ops::TransposeOp>("tr1", input->getOutput(0),
+                                              vector<size_t>{1, 0, 2});
+  Operation* relu = g.create<ops::ReluOp>("relu", tr1->getOutput(0));
+  Operation* tanh = g.create<ops::TanhOp>("tanh", relu->getOutput(0));
+  Operation* out = g.create<ops::OutputOp>("out", tanh->getOutput(0));
+  (void) out;
+
+// Check that layout is desired
+  SinkTranspose pass;
+  pass.run(&g);
+
+  // Assert transposes are removed
+  ASSERT_EQ(g.getInputs()[0]->getName(), "input");
+  ASSERT_EQ(getPrev(g.getOutputs()[0])->getName(), "tanh");
+  ASSERT_EQ(getNext(g.getInputs()[0])->getName(), "relu");
+  ASSERT_EQ(getPrev(tanh)->getName(), "tr1");
+}
+
+/* This tests swapping concat and transpose */
+TEST(SinkTest, sinkTrConcat) {
+  mir::Graph g;
+  /*
+   *   Create graph:
+   *     [input]     [input2]
+   *       ||          ||
+   *  [Transpose 1] [Transpose 2]
+   *         \\    //
+   *         [Concat]
+   *            ||
+   *          [TanH]
+   */
+  Operation* in1 = g.create<ops::InputOp>("inp1", Shape{1, 1, 2, 3});
+  Operation* in2 = g.create<ops::InputOp>("inp2", Shape{1, 1, 2, 3});
+  Operation* tr1 = g.create<ops::TransposeOp>("tr1", in1->getOutput(0),
+                                              vector<size_t>{0, 3, 1, 2});
+  Operation* tr2 = g.create<ops::TransposeOp>("tr2", in2->getOutput(0),
+                                              vector<size_t>{0, 3, 1, 2});
+  Operation* conc = g.create<ops::ConcatOp>("concat", vector<Operation::Output*>{
+      tr1->getOutput(0), tr2->getOutput(0)}, 1);
+  Operation* tanh = g.create<ops::TanhOp>("tanh", conc->getOutput(0));
+  Operation* out = g.create<ops::OutputOp>("out", tanh->getOutput(0));
+  (void) out;
+  // Check that layout is as desired
+  SinkTranspose pass;
+  pass.run(&g);
+
+  ASSERT_EQ(getPrev(getPrev(g.getOutputs()[0]))->getType(),
+            Operation::Type::transpose);
+  ASSERT_TRUE(static_cast<ops::TransposeOp*>(getPrev(tanh))->getAxisOrder() ==
+              vector<size_t>({0, 3, 1, 2}));
+  /* Expected Result:
+   * TanH(Transpose(Concat(inp1,inp2)))
+   */
+}
+
+/* This tests swapping concat and transpose */
+TEST(SinkTest, sinkReluConcat) {
+  mir::Graph g;
+  /*
+   *   Create graph:
+   *     [ inp1 ]  [ inp2 ]
+   *        ||        ||
+   *     [ Relu 1] [ Relu 2]
+   *         \\     //
+   *        [ Concat ]
+   *            ||
+   *          [TanH]
+   */
+  Operation* in1 = g.create<ops::InputOp>("inp1", Shape{1, 1, 2, 3});
+  Operation* in2 = g.create<ops::InputOp>("inp2", Shape{1, 1, 2, 3});
+  Operation* relu1 = g.create<ops::ReluOp>("relu1", in1->getOutput(0));
+  Operation* relu2 = g.create<ops::ReluOp>("relu2", in2->getOutput(0));
+  Operation* conc = g.create<ops::ConcatOp>("concat", vector<Operation::Output*>{
+      relu1->getOutput(0), relu2->getOutput(0)}, 1);
+  Operation* tanh = g.create<ops::TanhOp>("tanh", conc->getOutput(0));
+  Operation* out = g.create<ops::OutputOp>("out", tanh->getOutput(0));
+  (void) out;
+
+  // Check that layout is as desired
+  SinkRelu pass;
+  pass.run(&g);
+
+  ASSERT_EQ(getPrev(getPrev(g.getOutputs()[0]))->getType(), Operation::Type::ReLU);
+  /* Expected Result:
+   * TanH(Relu(Concat(inp1,inp2)))
+   */
+}
+
+/* This tests swapping relu and max_pool */
+TEST(SinkTest, sinkPoolReLU) {
+  mir::Graph g;
+  /*
+   * Create graph:
+   *      [input]
+   *        ||
+   *      [relu]
+   *        ||
+   *     [MaxPool]
+   *        ||
+   *      [tanh]
+   */
+  Operation* input = g.create<ops::InputOp>("input", Shape{1, 4, 4, 3});
+  Operation* relu = g.create<ops::ReluOp>("relu", input->getOutput(0));
+  Operation* mp = g.create<ops::PoolOp>("pool", relu->getOutput(0),
+                                        ops::PoolOp::PoolingType::MAX, Shape{2, 2}, Shape{2, 2},
+                                        vector<int32_t>{0, 0}, vector<int32_t>{0, 0},
+                                        ops::PoolOp::BorderType::EMPTY);
+  Operation* tanh = g.create<ops::TanhOp>("tanh", mp->getOutput(0));
+  Operation* out = g.create<ops::OutputOp>("out", tanh->getOutput(0));
+  (void) out;
+
+  SinkRelu pass;
+  pass.run(&g);
+  stringstream ss;
+  DumpVisitor d{ss};
+  g.accept(&d);
+
+  // tanh(relu(pool(input)))
+  ASSERT_EQ(getNext(g.getInputs()[0])->getName(), "pool");
+  ASSERT_EQ(getPrev(g.getOutputs()[0])->getName(), "tanh");
+  ASSERT_EQ("i_input.p_pool.r_relu.th_tanh.", ss.str());
+
+}
+} // unnamed namespace
diff --git a/contrib/nnc/unittests/optimizations/Util.h b/contrib/nnc/unittests/optimizations/Util.h
new file mode 100644 (file)
index 0000000..b847174
--- /dev/null
@@ -0,0 +1,51 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NNCC_UTIL_H
+#define NNCC_UTIL_H
+#include "mir/ops/TransposeOp.h"
+#include "mir/ops/ReluOp.h"
+#include "mir/ops/ElementwiseOp.h"
+#include "mir/ops/ConstantOp.h"
+#include "mir/ops/TanhOp.h"
+#include "mir/ops/ConcatOp.h"
+#include "mir/ops/OutputOp.h"
+#include "mir/ops/PoolOp.h"
+#include "mir/Visitor.h"
+
+namespace nnc {
+
+class DumpVisitor : public mir::Visitor {
+public:
+  explicit DumpVisitor(std::ostream& s) : _s(s) {}
+
+  void visit(mir::ops::InputOp& op) override { _s << "i_" << op.getName() << "."; };
+
+  void visit(mir::ops::TanhOp& op) override { _s << "th_" << op.getName() << "."; }
+
+  void visit(mir::ops::ReluOp& op) override { _s << "r_" << op.getName() << "."; }
+
+  void visit(mir::ops::PoolOp& op) override { _s << "p_" << op.getName() << "."; }
+
+  void visit(mir::ops::TransposeOp& op) override { _s << "t_" << op.getName() << "."; }
+
+  void visit(mir::ops::ElementwiseOp& op) override { _s << "e_" << op.getName() << "."; }
+
+  std::ostream& _s;
+};
+
+} // namespace nnc
+#endif //NNCC_UTIL_H