From 8d6459c1ab468a069edc75d6316633bcfb092e30 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=D0=90=D0=BD=D0=B4=D1=80=D0=B5=D0=B9=20=D0=A8=D0=B5=D0=B4?= =?utf8?q?=D1=8C=D0=BA=D0=BE/AI=20Tools=20Lab=20/SRR/Engineer/=EC=82=BC?= =?utf8?q?=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Tue, 12 Feb 2019 19:48:50 +0300 Subject: [PATCH] [nnc] Introduce Compose Transposes optimization (#2972) This PR Introduces `Combine Transposes` optimization which combines adjacent transposes and eliminates the noop ones. Signed-off-by: Andrei Shedko --- .../nnc/passes/optimizations/CombineTransposes.cpp | 54 ++++++- contrib/nnc/unittests/CMakeLists.txt | 1 + contrib/nnc/unittests/optimizations/CMakeLists.txt | 3 + .../unittests/optimizations/CombineTransposes.cpp | 161 +++++++++++++++++++++ 4 files changed, 218 insertions(+), 1 deletion(-) create mode 100644 contrib/nnc/unittests/optimizations/CMakeLists.txt create mode 100644 contrib/nnc/unittests/optimizations/CombineTransposes.cpp diff --git a/contrib/nnc/passes/optimizations/CombineTransposes.cpp b/contrib/nnc/passes/optimizations/CombineTransposes.cpp index 0a6f2b9..e210b4d 100644 --- a/contrib/nnc/passes/optimizations/CombineTransposes.cpp +++ b/contrib/nnc/passes/optimizations/CombineTransposes.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include "passes/optimizations/CombineTransposes.h" #include "core/modelIR/operations/TransposeOp.h" #include "core/modelIR/Graph.h" @@ -23,9 +24,60 @@ namespace nnc { using namespace mir; -PassData CombineTransposes::run(PassData data) { +std::vector combineAxisOrders(const std::vector& order1, + const std::vector& order2) { + assert(order1.size() == order2.size()); + std::vector res(order1.size()); + for (size_t i = 0; i < order1.size(); i++) { + res[order2[order1[i]]] = i; + } + return res; +} + +static bool isIdentityTranspose(const std::vector& axis_order) { + for (size_t i = 0; i < ( axis_order.size()); i++) { + if (axis_order[i] != i) { + return false; + } + } + return true; +} + +nnc::PassData nnc::CombineTransposes::run(nnc::PassData data) { auto g = static_cast(data); assert(g); + GraphPatternMatcher matcher(g); + auto is_tr = [](const Operation* op1) { return op1->getType() == Operation::Type::transpose; }; + std::vector> matches = matcher.matchEdge(is_tr, is_tr); + std::unordered_set deleted_nodes; + while (!matches.empty()) { + for (std::pair match : matches) { + if (deleted_nodes.find(match.first) != deleted_nodes.end()) { break; }; + auto* top_transpose = dynamic_cast(match.first); + if (deleted_nodes.find(match.second) != deleted_nodes.end()) { break; }; + auto* bottom_transpose = dynamic_cast(match.second); + auto combined_axis_order = combineAxisOrders(top_transpose->getAxisOrder(), + bottom_transpose->getAxisOrder()); + + if (!isIdentityTranspose(combined_axis_order)) { + auto new_tr_op = g->create( + top_transpose->getName() + "new", + top_transpose->getInput(0)->getProducer(), combined_axis_order); + + g->replaceNode(bottom_transpose, new_tr_op); + } else { + // Connect top input to all outputs of bottom + Operation* top = top_transpose->getInput(0)->getProducer()->getNode(); + g->replaceNode(bottom_transpose, top); + } + deleted_nodes.emplace(bottom_transpose); + if (top_transpose->getOutput(0)->getConsumers().empty()) { + g->removeNode(top_transpose); + deleted_nodes.emplace(top_transpose); + } + } + matches = matcher.matchEdge(is_tr, is_tr); + }; return g; } diff --git a/contrib/nnc/unittests/CMakeLists.txt b/contrib/nnc/unittests/CMakeLists.txt index 7d61e4e..64779ed 100644 --- a/contrib/nnc/unittests/CMakeLists.txt +++ b/contrib/nnc/unittests/CMakeLists.txt @@ -7,6 +7,7 @@ add_subdirectory(core) add_subdirectory(soft_backend) add_subdirectory(acl_backend) add_subdirectory(support) +add_subdirectory(optimizations) if(NNC_FRONTEND_CAFFE_ENABLED) add_subdirectory(caffe_frontend) endif() diff --git a/contrib/nnc/unittests/optimizations/CMakeLists.txt b/contrib/nnc/unittests/optimizations/CMakeLists.txt new file mode 100644 index 0000000..af6e278 --- /dev/null +++ b/contrib/nnc/unittests/optimizations/CMakeLists.txt @@ -0,0 +1,3 @@ +set(TESTS_OPTIMIZATIONS_SRC "CombineTransposes.cpp") +nnc_add_unit_test(tests_for_optimizations ${TESTS} ${TESTS_OPTIMIZATIONS_SRC}) +nncc_target_link_libraries(tests_for_optimizations nnc_optimizations nnc_support nnc_core) diff --git a/contrib/nnc/unittests/optimizations/CombineTransposes.cpp b/contrib/nnc/unittests/optimizations/CombineTransposes.cpp new file mode 100644 index 0000000..3a09572 --- /dev/null +++ b/contrib/nnc/unittests/optimizations/CombineTransposes.cpp @@ -0,0 +1,161 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "passes/optimizations/CombineTransposes.h" +#include "core/modelIR/operations/TransposeOp.h" +#include "core/modelIR/operations/ReluOp.h" +#include "core/modelIR/operations/ElementwiseOp.h" +#include "core/modelIR/operations/ConstantOp.h" + +#include + +using namespace std; +using namespace nnc; +using namespace mir; + +class DumpVisitor : public Visitor { +public: + DumpVisitor(std::ostream& s) : _s(s) {} + + void visit(ops::InputOp& op) override { _s << "i_" << op.getName() << "."; }; + + void visit(ops::TransposeOp& op) override { _s << "t_" << op.getName() << "."; } + + void visit(ops::ReluOp& op) override { _s << "r_" << op.getName() << "."; } + + void visit(ops::ElementwiseOp& op) override { _s << "e_" << op.getName() << "."; } + + std::ostream& _s; +}; + +TEST(OptPass, eliminateTransposesLinear) { + mir::Graph g; + /* Create graph: + * [input] + * || + * [Transpose 1] + * || + * [Transpose 2] + * || + * [relu] + */ + Operation* input = g.create("input", Shape{1, 2, 3}); + Operation* tr1 = g.create("tr", input->getOutput(0), vector{1, 0, 2}); + Operation* tr15 = g.create("tr", tr1->getOutput(0), vector{1, 0, 2}); + Operation* tr2 = g.create("tr", tr15->getOutput(0), vector{1, 0, 2}); + Operation* relu = g.create("relu", tr2->getOutput(0)); + + // Check that layout is desired + std::stringstream ss; + DumpVisitor d(ss); + CombineTransposes pass; + pass.run(&g); + g.accept(&d); + // Assert only 1 transpose remains + ASSERT_EQ("i_input.t_tr.r_relu.", ss.str()); +} + +TEST(OptPass, combineTransposesLinear) { + mir::Graph g; + /* Create graph: + * [input] + * || + * [Transpose 1] + * || + * [Transpose 2] + * || + * [relu] + */ + Operation* input = g.create("input", Shape{1, 2, 3}); + Operation* tr1 = g.create("tr1", input->getOutput(0), vector{1, 0, 2}); + Operation* tr2 = g.create("tr2", tr1->getOutput(0), vector{0, 2, 1}); + Operation* relu = g.create("relu", tr2->getOutput(0)); + + std::stringstream ss; + DumpVisitor d(ss); + CombineTransposes pass; + pass.run(&g); + g.accept(&d); + + // Assert transposes are combined + ASSERT_EQ("i_input.t_tr1new.r_relu.", ss.str()); + auto ax_ord_actual = dynamic_cast( + ( *( g.getInputs()[0]->getOutput(0)->getConsumers().begin()))->getNode())->getAxisOrder(); + auto ax_ord_true = vector{1, 2, 0}; + ASSERT_TRUE(ax_ord_actual == ax_ord_true); +} + +TEST(OptPass, combineTransposesBush) { + mir::Graph g; + /* Create graph: + * [input] + * || + * [Transpose 1] + * // \\ + *[Transpose 2] [Transpose 3] + * \\ // + * [Elementwise] + */ + Operation* input = g.create("input", Shape{1, 2, 3, 2}); + Operation* tr1 = g.create("tr1", input->getOutput(0), + vector{1, 0, 2, 3}); + Operation* tr2 = g.create("tr2", tr1->getOutput(0), vector{1, 0, 2, 3}); + Operation* tr3 = g.create("tr3", tr1->getOutput(0), vector{1, 0, 2, 3}); + Operation* elw = g.create("elewiseAdd", + vector{tr2->getOutput(0), + tr3->getOutput(0)}, + ops::ElementwiseOp::OpType::add); + std::stringstream ss; + DumpVisitor d(ss); + CombineTransposes pass; + pass.run(&g); + g.accept(&d); + ASSERT_EQ("i_input.e_elewiseAdd.", ss.str()); + ASSERT_EQ(elw->getInput(0)->getProducer()->getNode()->getName(), "input"); + ASSERT_EQ(elw->getInput(1)->getProducer()->getNode()->getName(), "input"); +} + +TEST(OptPass, combineTransposesOpOrder) { + mir::Graph g; + /* Create graph: + * [input] [input2] + * || || + * [Transpose 0] [Transpose1] + * || || + * [Transpose 2] [Transpose 3] + * \\ // + * [Elementwise] + */ + Operation* in1 = g.create("inp1", Shape{1, 2, 3}); + Operation* in2 = g.create("inp2", Shape{1, 2, 3}); + Operation* tr0 = g.create("tr0", in1->getOutput(0), vector{1, 0, 2}); + Operation* tr1 = g.create("tr1", in2->getOutput(0), vector{2, 1, 0}); + Operation* tr2 = g.create("tr2", tr0->getOutput(0), vector{1, 0, 2}); + Operation* tr3 = g.create("tr3", tr1->getOutput(0), vector{2, 1, 0}); + Operation* elw = g.create("elewiseAdd", + vector{tr2->getOutput(0), + tr3->getOutput(0)}, + ops::ElementwiseOp::OpType::add); + g.create("out", elw->getOutput(0)); + int n1 = elw->getInput(0)->getNode()->getInput(0)->getNode()->getInput(0)->getNode()->getId(); + int n2 = elw->getInput(1)->getNode()->getInput(0)->getNode()->getInput(0)->getNode()->getId(); + CombineTransposes pass; + pass.run(&g); + ASSERT_EQ(g.getOutputs()[0]->getInput(0)->getProducer()->getNode()->getName(), "elewiseAdd"); + //Order is preserved + ASSERT_EQ(n1, elw->getInput(0)->getNode()->getId()); + ASSERT_EQ(n2, elw->getInput(1)->getNode()->getId()); +} -- 2.7.4