* limitations under the License.
*/
+#include <algorithm>
#include "passes/optimizations/CombineTransposes.h"
#include "core/modelIR/operations/TransposeOp.h"
#include "core/modelIR/Graph.h"
using namespace mir;
-PassData CombineTransposes::run(PassData data) {
+std::vector<size_t> combineAxisOrders(const std::vector<std::size_t>& order1,
+ const std::vector<std::size_t>& order2) {
+ assert(order1.size() == order2.size());
+ std::vector<size_t> res(order1.size());
+ for (size_t i = 0; i < order1.size(); i++) {
+ res[order2[order1[i]]] = i;
+ }
+ return res;
+}
+
+static bool isIdentityTranspose(const std::vector<size_t>& axis_order) {
+ for (size_t i = 0; i < ( axis_order.size()); i++) {
+ if (axis_order[i] != i) {
+ return false;
+ }
+ }
+ return true;
+}
+
+nnc::PassData nnc::CombineTransposes::run(nnc::PassData data) {
auto g = static_cast<Graph*>(data);
assert(g);
+ GraphPatternMatcher matcher(g);
+ auto is_tr = [](const Operation* op1) { return op1->getType() == Operation::Type::transpose; };
+ std::vector<std::pair<Operation*, Operation*>> matches = matcher.matchEdge(is_tr, is_tr);
+ std::unordered_set<Operation*> deleted_nodes;
+ while (!matches.empty()) {
+ for (std::pair<Operation*, Operation*> match : matches) {
+ if (deleted_nodes.find(match.first) != deleted_nodes.end()) { break; };
+ auto* top_transpose = dynamic_cast<mir::ops::TransposeOp*>(match.first);
+ if (deleted_nodes.find(match.second) != deleted_nodes.end()) { break; };
+ auto* bottom_transpose = dynamic_cast<mir::ops::TransposeOp*>(match.second);
+ auto combined_axis_order = combineAxisOrders(top_transpose->getAxisOrder(),
+ bottom_transpose->getAxisOrder());
+
+ if (!isIdentityTranspose(combined_axis_order)) {
+ auto new_tr_op = g->create<mir::ops::TransposeOp>(
+ top_transpose->getName() + "new",
+ top_transpose->getInput(0)->getProducer(), combined_axis_order);
+
+ g->replaceNode(bottom_transpose, new_tr_op);
+ } else {
+ // Connect top input to all outputs of bottom
+ Operation* top = top_transpose->getInput(0)->getProducer()->getNode();
+ g->replaceNode(bottom_transpose, top);
+ }
+ deleted_nodes.emplace(bottom_transpose);
+ if (top_transpose->getOutput(0)->getConsumers().empty()) {
+ g->removeNode(top_transpose);
+ deleted_nodes.emplace(top_transpose);
+ }
+ }
+ matches = matcher.matchEdge(is_tr, is_tr);
+ };
return g;
}
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "passes/optimizations/CombineTransposes.h"
+#include "core/modelIR/operations/TransposeOp.h"
+#include "core/modelIR/operations/ReluOp.h"
+#include "core/modelIR/operations/ElementwiseOp.h"
+#include "core/modelIR/operations/ConstantOp.h"
+
+#include <gtest/gtest.h>
+
+using namespace std;
+using namespace nnc;
+using namespace mir;
+
+class DumpVisitor : public Visitor {
+public:
+ DumpVisitor(std::ostream& s) : _s(s) {}
+
+ void visit(ops::InputOp& op) override { _s << "i_" << op.getName() << "."; };
+
+ void visit(ops::TransposeOp& op) override { _s << "t_" << op.getName() << "."; }
+
+ void visit(ops::ReluOp& op) override { _s << "r_" << op.getName() << "."; }
+
+ void visit(ops::ElementwiseOp& op) override { _s << "e_" << op.getName() << "."; }
+
+ std::ostream& _s;
+};
+
+TEST(OptPass, eliminateTransposesLinear) {
+ mir::Graph g;
+ /* Create graph:
+ * [input]
+ * ||
+ * [Transpose 1]
+ * ||
+ * [Transpose 2]
+ * ||
+ * [relu]
+ */
+ Operation* input = g.create<ops::InputOp>("input", Shape{1, 2, 3});
+ Operation* tr1 = g.create<ops::TransposeOp>("tr", input->getOutput(0), vector<size_t>{1, 0, 2});
+ Operation* tr15 = g.create<ops::TransposeOp>("tr", tr1->getOutput(0), vector<size_t>{1, 0, 2});
+ Operation* tr2 = g.create<ops::TransposeOp>("tr", tr15->getOutput(0), vector<size_t>{1, 0, 2});
+ Operation* relu = g.create<ops::ReluOp>("relu", tr2->getOutput(0));
+
+ // Check that layout is desired
+ std::stringstream ss;
+ DumpVisitor d(ss);
+ CombineTransposes pass;
+ pass.run(&g);
+ g.accept(&d);
+ // Assert only 1 transpose remains
+ ASSERT_EQ("i_input.t_tr.r_relu.", ss.str());
+}
+
+TEST(OptPass, combineTransposesLinear) {
+ mir::Graph g;
+ /* Create graph:
+ * [input]
+ * ||
+ * [Transpose 1]
+ * ||
+ * [Transpose 2]
+ * ||
+ * [relu]
+ */
+ Operation* input = g.create<ops::InputOp>("input", Shape{1, 2, 3});
+ Operation* tr1 = g.create<ops::TransposeOp>("tr1", input->getOutput(0), vector<size_t>{1, 0, 2});
+ Operation* tr2 = g.create<ops::TransposeOp>("tr2", tr1->getOutput(0), vector<size_t>{0, 2, 1});
+ Operation* relu = g.create<ops::ReluOp>("relu", tr2->getOutput(0));
+
+ std::stringstream ss;
+ DumpVisitor d(ss);
+ CombineTransposes pass;
+ pass.run(&g);
+ g.accept(&d);
+
+ // Assert transposes are combined
+ ASSERT_EQ("i_input.t_tr1new.r_relu.", ss.str());
+ auto ax_ord_actual = dynamic_cast<ops::TransposeOp*>(
+ ( *( g.getInputs()[0]->getOutput(0)->getConsumers().begin()))->getNode())->getAxisOrder();
+ auto ax_ord_true = vector<size_t>{1, 2, 0};
+ ASSERT_TRUE(ax_ord_actual == ax_ord_true);
+}
+
+TEST(OptPass, combineTransposesBush) {
+ mir::Graph g;
+ /* Create graph:
+ * [input]
+ * ||
+ * [Transpose 1]
+ * // \\
+ *[Transpose 2] [Transpose 3]
+ * \\ //
+ * [Elementwise<add>]
+ */
+ Operation* input = g.create<ops::InputOp>("input", Shape{1, 2, 3, 2});
+ Operation* tr1 = g.create<ops::TransposeOp>("tr1", input->getOutput(0),
+ vector<size_t>{1, 0, 2, 3});
+ Operation* tr2 = g.create<ops::TransposeOp>("tr2", tr1->getOutput(0), vector<size_t>{1, 0, 2, 3});
+ Operation* tr3 = g.create<ops::TransposeOp>("tr3", tr1->getOutput(0), vector<size_t>{1, 0, 2, 3});
+ Operation* elw = g.create<ops::ElementwiseOp>("elewiseAdd",
+ vector<Operation::Output*>{tr2->getOutput(0),
+ tr3->getOutput(0)},
+ ops::ElementwiseOp::OpType::add);
+ std::stringstream ss;
+ DumpVisitor d(ss);
+ CombineTransposes pass;
+ pass.run(&g);
+ g.accept(&d);
+ ASSERT_EQ("i_input.e_elewiseAdd.", ss.str());
+ ASSERT_EQ(elw->getInput(0)->getProducer()->getNode()->getName(), "input");
+ ASSERT_EQ(elw->getInput(1)->getProducer()->getNode()->getName(), "input");
+}
+
+TEST(OptPass, combineTransposesOpOrder) {
+ mir::Graph g;
+ /* Create graph:
+ * [input] [input2]
+ * || ||
+ * [Transpose 0] [Transpose1]
+ * || ||
+ * [Transpose 2] [Transpose 3]
+ * \\ //
+ * [Elementwise<add>]
+ */
+ Operation* in1 = g.create<ops::InputOp>("inp1", Shape{1, 2, 3});
+ Operation* in2 = g.create<ops::InputOp>("inp2", Shape{1, 2, 3});
+ Operation* tr0 = g.create<ops::TransposeOp>("tr0", in1->getOutput(0), vector<size_t>{1, 0, 2});
+ Operation* tr1 = g.create<ops::TransposeOp>("tr1", in2->getOutput(0), vector<size_t>{2, 1, 0});
+ Operation* tr2 = g.create<ops::TransposeOp>("tr2", tr0->getOutput(0), vector<size_t>{1, 0, 2});
+ Operation* tr3 = g.create<ops::TransposeOp>("tr3", tr1->getOutput(0), vector<size_t>{2, 1, 0});
+ Operation* elw = g.create<ops::ElementwiseOp>("elewiseAdd",
+ vector<Operation::Output*>{tr2->getOutput(0),
+ tr3->getOutput(0)},
+ ops::ElementwiseOp::OpType::add);
+ g.create<ops::OutputOp>("out", elw->getOutput(0));
+ int n1 = elw->getInput(0)->getNode()->getInput(0)->getNode()->getInput(0)->getNode()->getId();
+ int n2 = elw->getInput(1)->getNode()->getInput(0)->getNode()->getInput(0)->getNode()->getId();
+ CombineTransposes pass;
+ pass.run(&g);
+ ASSERT_EQ(g.getOutputs()[0]->getInput(0)->getProducer()->getNode()->getName(), "elewiseAdd");
+ //Order is preserved
+ ASSERT_EQ(n1, elw->getInput(0)->getNode()->getId());
+ ASSERT_EQ(n2, elw->getInput(1)->getNode()->getId());
+}