*/
std::vector<std::pair<Operation*, Operation*>> matchEdge(Predicate p1, Predicate p2);
+ /**
+ * @brief Match a two level tree where the bottommost node has multiple previous nodes
+ * @param p1 Predicate for top node
+ * @param p2 Predicate for bottom node
+ * @return Vector of pairs : all matches; empty if no matches are found
+ */
+ std::vector<std::pair<std::vector<Operation*>, Operation*>> matchUpBush(Predicate p1, Predicate p2);
+
private:
Graph* _g;
};
const std::vector<int32_t>& getPaddingBefore() const { return _paddingBefore; }
const std::vector<int32_t>& getPaddingAfter() const { return _paddingAfter; }
-
private:
void inferOutputShapes();
delete op;
}
+
} // namespace mir
return matches;
}
+ std::vector<std::pair<std::vector<Operation*>, Operation*>>
+ GraphPatternMatcher::matchUpBush(mir::GraphPatternMatcher::Predicate p1,
+ mir::GraphPatternMatcher::Predicate p2) {
+ std::vector<std::pair<std::vector<Operation*>, Operation*>> matches;
+ for (auto* root: _g->getNodes()) {
+ if (p2(root)) {
+ auto& prev_nodes = root->getInputs();
+ if (std::all_of(prev_nodes.begin(), prev_nodes.end(),
+ [p1](const Operation::Input& input) { return p1(input.getProducer()->getNode()); })) {
+ std::vector<Operation*> tops;
+ tops.reserve(prev_nodes.size());
+ for (auto& pr : prev_nodes) {
+ tops.emplace_back(pr.getProducer()->getNode());
+ }
+ matches.emplace_back(std::make_pair(tops, root));
+ }
+ }
+ }
+ return matches;
+ }
} // namespace mir
#include "passes/optimizations/CombineTransposes.h"
#include "passes/optimizations/FuseArithmeticOps.h"
+#include "passes/optimizations/SinkRelu.h"
+#include "passes/optimizations/SinkTranspose.h"
#include "support/CommandLine.h"
#include "Definitions.h"
void Driver::registerOptimizationPass() {
if (cli::doOptimizationPass) {
+ // TODO: maybe we should start managing the optimizations more intelligently?
_passManager.registerPass(std::unique_ptr<Pass>(new CombineTransposes()));
+ _passManager.registerPass(std::unique_ptr<Pass>(new SinkTranspose()));
+ _passManager.registerPass(std::unique_ptr<Pass>(new SinkRelu()));
_passManager.registerPass(std::unique_ptr<Pass>(new FuseArithmeticOps()));
}
} // registerOptimizationPass
} // namespace nnc
-
#endif //NNCC_COMBINE_TRANSPOSES_H
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NNCC_OPTIMIZATION_UTILS_H
+#define NNCC_OPTIMIZATION_UTILS_H
+
+#include "mir/Operation.h"
+#include "mir/Graph.h"
+
+namespace nnc {
+namespace opt_util {
+/**
+* @brief Swap adjacent nodes in Graph. Creates new nodes and replaces the old ones with new.
+* @param g MIR Graph
+* @param top Node
+* @param bottom Node
+*/
+ void swapAdjacent(mir::Graph* g, mir::Operation* top, mir::Operation* bottom);
+
+// TODO: this function and it's usages should be removed, after DCE optimization will be implemented
+ void removeNodeIfUnsed(mir::Graph* g, mir::Operation* op);
+} // namespace opt_util
+} // namespace nnc
+
+#endif //NNCC_OPTIMIZATION_UTILS_H
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NNCC_SINKRELU_H
+#define NNCC_SINKRELU_H
+
+#include "pass/Pass.h"
+#include "pass/PassData.h"
+
+namespace nnc {
+
+/**
+ * @brief This pass sinks relu below MaxPooling and Concat nodes.
+ */
+class SinkRelu : public Pass {
+public:
+ PassData run(PassData data) override;
+
+ std::string getName() override { return "SinkRelu"; };
+};
+
+} // namespace nnc
+
+#endif //NNCC_SINKRELU_H
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NNCC_SINKTRANSPOSE_H
+#define NNCC_SINKTRANSPOSE_H
+
+#include "pass/Pass.h"
+#include "pass/PassData.h"
+
+namespace nnc {
+
+/**
+ * @brief This pass sinks transposes below Relu and Concat nodes (in that order).
+ * `concat(relu(tr(x)), relu(tr(y))) -> tr(concat'(relu(x), relu(y)))`
+ */
+class SinkTranspose : public Pass {
+public:
+ PassData run(PassData data) override;
+
+ std::string getName() override { return "SinkTranspose"; };
+};
+
+} // namespace nnc
+
+#endif //NNCC_SINKTRANSPOSE_H
-set(OPTIMIZATIONS_SRC CombineTransposes.cpp FuseArithmeticOps.cpp)
+set(OPTIMIZATIONS_SRC CombineTransposes.cpp
+ FuseArithmeticOps.cpp
+ SinkRelu.cpp
+ SinkTranspose.cpp
+ OptimizationUtils.cpp)
nnc_add_library(nnc_optimizations SHARED ${OPTIMIZATIONS_SRC})
target_link_libraries(nnc_optimizations PRIVATE mir nnc_support)
* limitations under the License.
*/
-#include <algorithm>
#include "passes/optimizations/CombineTransposes.h"
#include "mir/ops/TransposeOp.h"
#include "mir/Graph.h"
#include "mir/GraphPatternMatcher.h"
+#include <algorithm>
namespace nnc {
*/
#include "passes/optimizations/FuseArithmeticOps.h"
+#include "passes/optimizations/OptimizationUtils.h"
#include "mir/ops/BiasAddOp.h"
#include "mir/ops/ConstantOp.h"
#include "mir/ops/Conv2DOp.h"
using namespace mir;
using namespace std;
+using namespace opt_util;
using OpType = Operation::Type;
using Edge = pair<Operation*, Operation*>;
-// TODO: this function and it's usages should be removed, after DCE optimization will be implemented
-void removeNodeIfUnsed(Graph* g, Operation* op) {
- if (op->getOutput(0)->getConsumers().empty())
- g->removeNode(op);
-}
-
/**
* This function used to get 'ConstantOp' with weights of 'BiasAddOp', 'ScaleOp' or 'Conv2DOp'
* For each of these ops weights stored in second input node
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "passes/optimizations/OptimizationUtils.h"
+
+namespace nnc {
+namespace opt_util {
+
+void swapAdjacent(mir::Graph* g, mir::Operation* top, mir::Operation* bottom) {
+ assert(top->getNumInputs() == bottom->getNumInputs() &&
+ top->getNumInputs() == bottom->getNumOutputs() && "incompatible ops");
+ auto& ins = top->getInputs();
+ std::vector<mir::Operation::Output*> prods;
+ prods.reserve(top->getNumInputs());
+ for (auto& in: ins) {
+ prods.emplace_back(in.getProducer());
+ }
+ mir::Operation* new_bottom = g->copyOpWithInputs(bottom, prods);
+ prods.clear();
+ prods.reserve(new_bottom->getNumOutputs());
+ for (mir::Operation::Output& out: new_bottom->getOutputs()) {
+ prods.emplace_back(&out);
+ }
+ mir::Operation* new_top = g->copyOpWithInputs(top, prods);
+ g->replaceNode(bottom, new_top);
+ g->replaceNode(top, new_bottom);
+}
+
+// TODO: this function and it's usages should be removed, after DCE optimization will be implemented
+void removeNodeIfUnsed(mir::Graph* g, mir::Operation* op) {
+ if (op->getOutput(0)->getConsumers().empty())
+ g->removeNode(op);
+}
+} // namespace opt_util
+} // namespace nnc
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "passes/optimizations/SinkRelu.h"
+#include "passes/optimizations/OptimizationUtils.h"
+#include "mir/ops/TransposeOp.h"
+#include "mir/ops/ConcatOp.h"
+#include "mir/ops/ReluOp.h"
+#include "mir/ops/PoolOp.h"
+#include "mir/Graph.h"
+#include "mir/GraphPatternMatcher.h"
+
+#include <string>
+#include <algorithm>
+
+namespace nnc {
+
+using namespace mir;
+using namespace opt_util;
+
+/*
+ Static casts are safe here because we have already checked the types when matching.
+ */
+PassData SinkRelu::run(PassData data) {
+ auto g = static_cast<Graph*>(data);
+ assert(g);
+ GraphPatternMatcher matcher(g);
+ auto is_relu = [](const Operation* op) { return op->getType() == Operation::Type::ReLU; };
+ auto is_concat = [](const Operation* op) { return op->getType() == Operation::Type::concat; };
+ auto is_max_pool = [](const Operation* op) {
+ auto* p_op = dynamic_cast<const ops::PoolOp*>(op);
+ if (!p_op) return false;
+ return p_op->getPoolingType() == ops::PoolOp::PoolingType::MAX;
+ };
+ std::vector<std::pair<Operation*, Operation*>> matches;
+
+ // sink ReLU through MaxPool
+ matches = matcher.matchEdge(is_relu, is_max_pool);
+ for (auto pair: matches) {
+ swapAdjacent(g, pair.first, pair.second);
+ }
+ // sink ReLU through Concat
+ auto matches_v = matcher.matchUpBush(is_relu, is_concat);
+ for (const auto& pair : matches_v) {
+ auto relus = pair.first;
+ auto* concat = dynamic_cast<ops::ConcatOp*>(pair.second);
+ std::vector<Operation::Output*> pre_relu;
+ pre_relu.reserve(relus.size());
+ for (auto* r : relus) {
+ pre_relu.emplace_back(r->getInput(0)->getProducer());
+ }
+ // create replacement nodes
+ auto new_concat = g->create<ops::ConcatOp>(
+ concat->getName() + "_before_relu", pre_relu, concat->getAxis());
+ auto new_relu = g->create<ops::ReluOp>(
+ relus[0]->getName() + "_after_concat", new_concat->getOutput(0));
+
+ // concat is deleted here
+ g->replaceNode(concat, new_relu);
+ for (auto r: relus) {
+ removeNodeIfUnsed(g,r);
+ }
+ }
+ return g;
+}
+
+} // namespace nnc
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "passes/optimizations/SinkTranspose.h"
+#include "passes/optimizations/OptimizationUtils.h"
+#include "mir/ops/TransposeOp.h"
+#include "mir/ops/ConcatOp.h"
+#include "mir/ops/ReluOp.h"
+#include "mir/Graph.h"
+#include "mir/GraphPatternMatcher.h"
+
+#include <string>
+#include <algorithm>
+
+namespace nnc {
+
+using namespace mir;
+using namespace opt_util;
+/*
+ Static casts are safe here because we have already checked the types when matching.
+ */
+PassData SinkTranspose::run(PassData data) {
+ auto g = static_cast<Graph*>(data);
+ assert(g); // NOLINT
+ GraphPatternMatcher matcher(g);
+ auto is_tr = [](const Operation* op1) { return op1->getType() == Operation::Type::transpose; };
+ auto is_relu = [](const Operation* op2) { return op2->getType() == Operation::Type::ReLU; };
+ auto is_concat = [](const Operation* op2) { return op2->getType() == Operation::Type::concat; };
+ std::vector<std::pair<Operation*, Operation*>> matches;
+
+ // sink transpose below ReLU
+ matches = matcher.matchEdge(is_tr, is_relu);
+ for (auto pair : matches) {
+ swapAdjacent(g, pair.first, pair.second);
+ }
+
+ // sink transpose through Concat
+ auto v_matches = matcher.matchUpBush(is_tr, is_concat);
+ for (const auto& pair : v_matches) {
+ std::vector<Operation*> trs = pair.first;
+ auto* concat = dynamic_cast<ops::ConcatOp*>(pair.second);
+ auto axis_order = dynamic_cast<ops::TransposeOp* >(trs[0])->getAxisOrder();
+ if (std::all_of( trs.begin(), trs.end(), [&axis_order](Operation* tr) {
+ return dynamic_cast<ops::TransposeOp*>(tr)->getAxisOrder() == axis_order;
+ })) {
+ std::vector<Operation::Output*> prev_trans;
+ prev_trans.reserve(trs.size());
+ for (auto transpose : trs) {
+ prev_trans.emplace_back(transpose->getInput(0)->getProducer());
+ }
+ auto new_concat = g->create<ops::ConcatOp>(
+ concat->getName() + "_transposed", prev_trans, axis_order[concat->getAxis()]);
+ auto new_transpose = g->create<ops::TransposeOp>(trs[0]->getName() + "_after_concat",
+ new_concat->getOutput(0),
+ axis_order);
+ // removes old concat
+ g->replaceNode(concat, new_transpose);
+ for (auto tr: trs) {
+ removeNodeIfUnsed(g,tr);
+ }
+ }
+ }
+
+ return g;
+}
+
+} // namespace nnc
-set(TESTS_OPTIMIZATIONS_SRC "CombineTransposes.cpp" "FuseArithmeticOps.cpp")
+set(TESTS_OPTIMIZATIONS_SRC CombineTransposes.cpp
+ SinkTest.cpp
+ FuseArithmeticOps.cpp)
nnc_add_unit_test(tests_for_optimizations ${TESTS} ${TESTS_OPTIMIZATIONS_SRC})
optional_target_link_libraries(tests_for_optimizations nnc_optimizations nnc_support mir)
#include "mir/ops/ReluOp.h"
#include "mir/ops/ElementwiseOp.h"
#include "mir/ops/ConstantOp.h"
-
+#include "mir/ops/TanhOp.h"
+#include "mir/ops/ConcatOp.h"
+#include "mir/ops/OutputOp.h"
+#include "mir/ops/PoolOp.h"
+#include "Util.h"
#include <gtest/gtest.h>
using namespace std;
namespace {
-class DumpVisitor : public Visitor {
-public:
- DumpVisitor(std::ostream& s) : _s(s) {}
-
- void visit(ops::InputOp& op) override { _s << "i_" << op.getName() << "."; };
-
- void visit(ops::TransposeOp& op) override { _s << "t_" << op.getName() << "."; }
-
- void visit(ops::ReluOp& op) override { _s << "r_" << op.getName() << "."; }
-
- void visit(ops::ElementwiseOp& op) override { _s << "e_" << op.getName() << "."; }
-
- std::ostream& _s;
-};
-
TEST(OptPass, eliminateTransposesLinear) {
mir::Graph g;
/* Create graph:
* [relu]
*/
Operation* input = g.create<ops::InputOp>("input", Shape{1, 2, 3});
- Operation* tr1 = g.create<ops::TransposeOp>("tr1", input->getOutput(0), vector<size_t>{1, 0, 2});
+ Operation* tr1 = g.create<ops::TransposeOp>("tr1", input->getOutput(0),
+ vector<size_t>{1, 0, 2});
Operation* tr2 = g.create<ops::TransposeOp>("tr2", tr1->getOutput(0), vector<size_t>{0, 2, 1});
Operation* relu = g.create<ops::ReluOp>("relu", tr2->getOutput(0));
Operation* input = g.create<ops::InputOp>("input", Shape{1, 2, 3, 2});
Operation* tr1 = g.create<ops::TransposeOp>("tr1", input->getOutput(0),
vector<size_t>{1, 0, 2, 3});
- Operation* tr2 = g.create<ops::TransposeOp>("tr2", tr1->getOutput(0), vector<size_t>{1, 0, 2, 3});
- Operation* tr3 = g.create<ops::TransposeOp>("tr3", tr1->getOutput(0), vector<size_t>{1, 0, 2, 3});
+ Operation* tr2 = g.create<ops::TransposeOp>("tr2", tr1->getOutput(0),
+ vector<size_t>{1, 0, 2, 3});
+ Operation* tr3 = g.create<ops::TransposeOp>("tr3", tr1->getOutput(0),
+ vector<size_t>{1, 0, 2, 3});
Operation* elw = g.create<ops::ElementwiseOp>("elewiseAdd",
vector<Operation::Output*>{tr2->getOutput(0),
tr3->getOutput(0)},
ASSERT_EQ(n1, elw->getInput(0)->getNode()->getId());
ASSERT_EQ(n2, elw->getInput(1)->getNode()->getId());
}
-
-} // unnamed namespace
+} // unnamed namespace
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "passes/optimizations/SinkTranspose.h"
+#include "passes/optimizations/SinkRelu.h"
+#include "Util.h"
+#include "mir/ops/TransposeOp.h"
+#include "mir/ops/ReluOp.h"
+#include "mir/ops/TanhOp.h"
+#include "mir/ops/ConcatOp.h"
+#include "mir/ops/OutputOp.h"
+#include "mir/ops/PoolOp.h"
+#include "mir/Graph.h"
+#include "pass/PassData.h"
+
+#include <gtest/gtest.h>
+#include <sstream>
+#include <vector>
+
+using namespace std;
+using namespace nnc;
+using namespace mir;
+
+namespace {
+Operation* getPrev(Operation* op) {
+ assert(op->getNumInputs() == 1);
+ return op->getInput(0)->getProducer()->getNode();
+}
+
+Operation* getNext(Operation* op) {
+ assert(op->getNumOutputs() == 1 && ( op->getOutput(0)->getConsumers().size() == 1 ));
+ return ( *op->getOutput(0)->getConsumers().begin())->getNode();
+}
+
+/* This tests swapping relu and transpose */
+TEST(SinkTest, sinkTrReLU) {
+ mir::Graph g;
+ /*
+ * Create graph:
+ * [input]
+ * ||
+ * [Transpose]
+ * ||
+ * [relu]
+ * ||
+ * [tanh]
+ */
+ Operation* input = g.create<ops::InputOp>("input", Shape{1, 2, 3});
+ Operation* tr1 = g.create<ops::TransposeOp>("tr1", input->getOutput(0),
+ vector<size_t>{1, 0, 2});
+ Operation* relu = g.create<ops::ReluOp>("relu", tr1->getOutput(0));
+ Operation* tanh = g.create<ops::TanhOp>("tanh", relu->getOutput(0));
+ Operation* out = g.create<ops::OutputOp>("out", tanh->getOutput(0));
+ (void) out;
+
+// Check that layout is desired
+ SinkTranspose pass;
+ pass.run(&g);
+
+ // Assert transposes are removed
+ ASSERT_EQ(g.getInputs()[0]->getName(), "input");
+ ASSERT_EQ(getPrev(g.getOutputs()[0])->getName(), "tanh");
+ ASSERT_EQ(getNext(g.getInputs()[0])->getName(), "relu");
+ ASSERT_EQ(getPrev(tanh)->getName(), "tr1");
+}
+
+/* This tests swapping concat and transpose */
+TEST(SinkTest, sinkTrConcat) {
+ mir::Graph g;
+ /*
+ * Create graph:
+ * [input] [input2]
+ * || ||
+ * [Transpose 1] [Transpose 2]
+ * \\ //
+ * [Concat]
+ * ||
+ * [TanH]
+ */
+ Operation* in1 = g.create<ops::InputOp>("inp1", Shape{1, 1, 2, 3});
+ Operation* in2 = g.create<ops::InputOp>("inp2", Shape{1, 1, 2, 3});
+ Operation* tr1 = g.create<ops::TransposeOp>("tr1", in1->getOutput(0),
+ vector<size_t>{0, 3, 1, 2});
+ Operation* tr2 = g.create<ops::TransposeOp>("tr2", in2->getOutput(0),
+ vector<size_t>{0, 3, 1, 2});
+ Operation* conc = g.create<ops::ConcatOp>("concat", vector<Operation::Output*>{
+ tr1->getOutput(0), tr2->getOutput(0)}, 1);
+ Operation* tanh = g.create<ops::TanhOp>("tanh", conc->getOutput(0));
+ Operation* out = g.create<ops::OutputOp>("out", tanh->getOutput(0));
+ (void) out;
+ // Check that layout is as desired
+ SinkTranspose pass;
+ pass.run(&g);
+
+ ASSERT_EQ(getPrev(getPrev(g.getOutputs()[0]))->getType(),
+ Operation::Type::transpose);
+ ASSERT_TRUE(static_cast<ops::TransposeOp*>(getPrev(tanh))->getAxisOrder() ==
+ vector<size_t>({0, 3, 1, 2}));
+ /* Expected Result:
+ * TanH(Transpose(Concat(inp1,inp2)))
+ */
+}
+
+/* This tests swapping concat and transpose */
+TEST(SinkTest, sinkReluConcat) {
+ mir::Graph g;
+ /*
+ * Create graph:
+ * [ inp1 ] [ inp2 ]
+ * || ||
+ * [ Relu 1] [ Relu 2]
+ * \\ //
+ * [ Concat ]
+ * ||
+ * [TanH]
+ */
+ Operation* in1 = g.create<ops::InputOp>("inp1", Shape{1, 1, 2, 3});
+ Operation* in2 = g.create<ops::InputOp>("inp2", Shape{1, 1, 2, 3});
+ Operation* relu1 = g.create<ops::ReluOp>("relu1", in1->getOutput(0));
+ Operation* relu2 = g.create<ops::ReluOp>("relu2", in2->getOutput(0));
+ Operation* conc = g.create<ops::ConcatOp>("concat", vector<Operation::Output*>{
+ relu1->getOutput(0), relu2->getOutput(0)}, 1);
+ Operation* tanh = g.create<ops::TanhOp>("tanh", conc->getOutput(0));
+ Operation* out = g.create<ops::OutputOp>("out", tanh->getOutput(0));
+ (void) out;
+
+ // Check that layout is as desired
+ SinkRelu pass;
+ pass.run(&g);
+
+ ASSERT_EQ(getPrev(getPrev(g.getOutputs()[0]))->getType(), Operation::Type::ReLU);
+ /* Expected Result:
+ * TanH(Relu(Concat(inp1,inp2)))
+ */
+}
+
+/* This tests swapping relu and max_pool */
+TEST(SinkTest, sinkPoolReLU) {
+ mir::Graph g;
+ /*
+ * Create graph:
+ * [input]
+ * ||
+ * [relu]
+ * ||
+ * [MaxPool]
+ * ||
+ * [tanh]
+ */
+ Operation* input = g.create<ops::InputOp>("input", Shape{1, 4, 4, 3});
+ Operation* relu = g.create<ops::ReluOp>("relu", input->getOutput(0));
+ Operation* mp = g.create<ops::PoolOp>("pool", relu->getOutput(0),
+ ops::PoolOp::PoolingType::MAX, Shape{2, 2}, Shape{2, 2},
+ vector<int32_t>{0, 0}, vector<int32_t>{0, 0},
+ ops::PoolOp::BorderType::EMPTY);
+ Operation* tanh = g.create<ops::TanhOp>("tanh", mp->getOutput(0));
+ Operation* out = g.create<ops::OutputOp>("out", tanh->getOutput(0));
+ (void) out;
+
+ SinkRelu pass;
+ pass.run(&g);
+ stringstream ss;
+ DumpVisitor d{ss};
+ g.accept(&d);
+
+ // tanh(relu(pool(input)))
+ ASSERT_EQ(getNext(g.getInputs()[0])->getName(), "pool");
+ ASSERT_EQ(getPrev(g.getOutputs()[0])->getName(), "tanh");
+ ASSERT_EQ("i_input.p_pool.r_relu.th_tanh.", ss.str());
+
+}
+} // unnamed namespace
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NNCC_UTIL_H
+#define NNCC_UTIL_H
+#include "mir/ops/TransposeOp.h"
+#include "mir/ops/ReluOp.h"
+#include "mir/ops/ElementwiseOp.h"
+#include "mir/ops/ConstantOp.h"
+#include "mir/ops/TanhOp.h"
+#include "mir/ops/ConcatOp.h"
+#include "mir/ops/OutputOp.h"
+#include "mir/ops/PoolOp.h"
+#include "mir/Visitor.h"
+
+namespace nnc {
+
+class DumpVisitor : public mir::Visitor {
+public:
+ explicit DumpVisitor(std::ostream& s) : _s(s) {}
+
+ void visit(mir::ops::InputOp& op) override { _s << "i_" << op.getName() << "."; };
+
+ void visit(mir::ops::TanhOp& op) override { _s << "th_" << op.getName() << "."; }
+
+ void visit(mir::ops::ReluOp& op) override { _s << "r_" << op.getName() << "."; }
+
+ void visit(mir::ops::PoolOp& op) override { _s << "p_" << op.getName() << "."; }
+
+ void visit(mir::ops::TransposeOp& op) override { _s << "t_" << op.getName() << "."; }
+
+ void visit(mir::ops::ElementwiseOp& op) override { _s << "e_" << op.getName() << "."; }
+
+ std::ostream& _s;
+};
+
+} // namespace nnc
+#endif //NNCC_UTIL_H