From 4b03b342536739d08cb26c4dfccc679afca93402 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=D0=90=D0=BD=D0=B4=D1=80=D0=B5=D0=B9=20=D0=A8=D0=B5=D0=B4?= =?utf8?q?=D1=8C=D0=BA=D0=BE/AI=20Tools=20Lab=20/SRR/Engineer/=EC=82=BC?= =?utf8?q?=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Wed, 8 May 2019 16:37:32 +0300 Subject: [PATCH] [MIR] Added a Remove Dead Ends Pass (#3272) Added a `RemoveDeadEnds` pass that decreases compiler runtime memory usage. Signed-off-by: Andrei Shedko --- contrib/nnc/driver/Driver.cpp | 2 + .../passes/optimizations/OptimizationUtils.h | 2 +- .../include/passes/optimizations/RemoveDeadEnds.h | 39 +++++++++++++++++ contrib/nnc/passes/optimizations/CMakeLists.txt | 1 + .../nnc/passes/optimizations/FuseArithmeticOps.cpp | 10 ++--- .../nnc/passes/optimizations/OptimizationUtils.cpp | 9 ++-- .../nnc/passes/optimizations/RemoveDeadEnds.cpp | 31 ++++++++++++++ contrib/nnc/passes/optimizations/SinkRelu.cpp | 5 +-- contrib/nnc/passes/optimizations/SinkTranspose.cpp | 6 +-- contrib/nnc/unittests/optimizations/CMakeLists.txt | 1 + .../unittests/optimizations/FuseArithmeticOps.cpp | 18 +------- .../nnc/unittests/optimizations/RemoveDeadEnds.cpp | 49 ++++++++++++++++++++++ contrib/nnc/unittests/optimizations/SinkTest.cpp | 8 ++-- contrib/nnc/unittests/optimizations/Util.h | 11 +++++ 14 files changed, 154 insertions(+), 38 deletions(-) create mode 100644 contrib/nnc/include/passes/optimizations/RemoveDeadEnds.h create mode 100644 contrib/nnc/passes/optimizations/RemoveDeadEnds.cpp create mode 100644 contrib/nnc/unittests/optimizations/RemoveDeadEnds.cpp diff --git a/contrib/nnc/driver/Driver.cpp b/contrib/nnc/driver/Driver.cpp index cab9ea3..e060952 100644 --- a/contrib/nnc/driver/Driver.cpp +++ b/contrib/nnc/driver/Driver.cpp @@ -24,6 +24,7 @@ #include "passes/acl_soft_backend/AclCppGenerator.h" #include "passes/optimizations/CombineTransposes.h" +#include "passes/optimizations/RemoveDeadEnds.h" #include "passes/optimizations/FuseArithmeticOps.h" #include "passes/optimizations/SinkRelu.h" #include "passes/optimizations/SinkTranspose.h" @@ -126,6 +127,7 @@ void Driver::registerOptimizationPass() { _passManager.registerPass(std::unique_ptr(new SinkTranspose())); _passManager.registerPass(std::unique_ptr(new SinkRelu())); _passManager.registerPass(std::unique_ptr(new FuseArithmeticOps())); + _passManager.registerPass(std::unique_ptr(new RemoveDeadEnds())); } } // registerOptimizationPass diff --git a/contrib/nnc/include/passes/optimizations/OptimizationUtils.h b/contrib/nnc/include/passes/optimizations/OptimizationUtils.h index de2be4c..bec4c80 100644 --- a/contrib/nnc/include/passes/optimizations/OptimizationUtils.h +++ b/contrib/nnc/include/passes/optimizations/OptimizationUtils.h @@ -31,7 +31,7 @@ namespace opt_util { void swapAdjacent(mir::Graph* g, mir::Operation* top, mir::Operation* bottom); // TODO: this function and it's usages should be removed, after DCE optimization will be implemented - void removeNodeIfUnsed(mir::Graph* g, mir::Operation* op); + void removeNodeIfUnused(mir::Graph* g, mir::Operation* op); } // namespace opt_util } // namespace nnc diff --git a/contrib/nnc/include/passes/optimizations/RemoveDeadEnds.h b/contrib/nnc/include/passes/optimizations/RemoveDeadEnds.h new file mode 100644 index 0000000..16dc41a --- /dev/null +++ b/contrib/nnc/include/passes/optimizations/RemoveDeadEnds.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNCC_REMOVEDEADENDS_H +#define NNCC_REMOVEDEADENDS_H + +#include "pass/Pass.h" +#include "pass/PassData.h" + +namespace nnc { + +/** + * @brief This pass removes operations without uses. + * Importers currently only generate `sConstantOp`s without uses. + */ +class RemoveDeadEnds : public Pass { +public: + PassData run(PassData data) override; + + std::string getName() override {return "RemoveDeadEnds";}; +}; + +} // namespace nnc + + +#endif //NNCC_REMOVEDEADENDS_H diff --git a/contrib/nnc/passes/optimizations/CMakeLists.txt b/contrib/nnc/passes/optimizations/CMakeLists.txt index f29109a..94ba075 100644 --- a/contrib/nnc/passes/optimizations/CMakeLists.txt +++ b/contrib/nnc/passes/optimizations/CMakeLists.txt @@ -1,5 +1,6 @@ set(OPTIMIZATIONS_SRC CombineTransposes.cpp FuseArithmeticOps.cpp + RemoveDeadEnds.cpp SinkRelu.cpp SinkTranspose.cpp OptimizationUtils.cpp) diff --git a/contrib/nnc/passes/optimizations/FuseArithmeticOps.cpp b/contrib/nnc/passes/optimizations/FuseArithmeticOps.cpp index d17ade7..a99a033 100644 --- a/contrib/nnc/passes/optimizations/FuseArithmeticOps.cpp +++ b/contrib/nnc/passes/optimizations/FuseArithmeticOps.cpp @@ -167,9 +167,9 @@ bool fuseSuccessiveOps(Graph* g) { // Replace second successive operation with new one and remove old nodes g->replaceNode(edge.second, new_op); - removeNodeIfUnsed(g, edge.first); - removeNodeIfUnsed(g, const1_op); - removeNodeIfUnsed(g, const2_op); + removeNodeIfUnused(g, edge.first); + removeNodeIfUnused(g, const1_op); + removeNodeIfUnused(g, const2_op); } // If there is no successive operations to fuse - graph wasn't changed @@ -209,8 +209,8 @@ bool sinkBiasThroughScale(Graph* g) { // Replace old scale with new bias and remove old nodes g->replaceNode(old_scale_op, new_bias_op); - removeNodeIfUnsed(g, old_bias_op); - removeNodeIfUnsed(g, old_bias_const_op); + removeNodeIfUnused(g, old_bias_op); + removeNodeIfUnused(g, old_bias_const_op); } // If there is no bias-scale edges - graph wasn't changed diff --git a/contrib/nnc/passes/optimizations/OptimizationUtils.cpp b/contrib/nnc/passes/optimizations/OptimizationUtils.cpp index edb2565..5971b73 100644 --- a/contrib/nnc/passes/optimizations/OptimizationUtils.cpp +++ b/contrib/nnc/passes/optimizations/OptimizationUtils.cpp @@ -19,8 +19,10 @@ namespace nnc { namespace opt_util { void swapAdjacent(mir::Graph* g, mir::Operation* top, mir::Operation* bottom) { - assert(top->getNumInputs() == bottom->getNumInputs() && - top->getNumInputs() == bottom->getNumOutputs() && "incompatible ops"); + assert( + top->getNumInputs() == bottom->getNumInputs() && top->getNumInputs() == 1 && + top->getNumInputs() == top->getNumOutputs() && + top->getNumInputs() == bottom->getNumOutputs() && "incompatible ops"); auto& ins = top->getInputs(); std::vector prods; prods.reserve(top->getNumInputs()); @@ -39,9 +41,10 @@ void swapAdjacent(mir::Graph* g, mir::Operation* top, mir::Operation* bottom) { } // TODO: this function and it's usages should be removed, after DCE optimization will be implemented -void removeNodeIfUnsed(mir::Graph* g, mir::Operation* op) { +void removeNodeIfUnused(mir::Graph* g, mir::Operation* op) { if (op->getOutput(0)->getConsumers().empty()) g->removeNode(op); } + } // namespace opt_util } // namespace nnc diff --git a/contrib/nnc/passes/optimizations/RemoveDeadEnds.cpp b/contrib/nnc/passes/optimizations/RemoveDeadEnds.cpp new file mode 100644 index 0000000..3ec1297 --- /dev/null +++ b/contrib/nnc/passes/optimizations/RemoveDeadEnds.cpp @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "passes/optimizations/RemoveDeadEnds.h" +#include "mir/Graph.h" + +using namespace mir; + +nnc::PassData nnc::RemoveDeadEnds::run(PassData data) { + auto g = static_cast(data); + assert(g); + for (auto op: g->getNodes()) { + if (op->getOutput(0)->getConsumers().empty() && op->getType() == mir::Operation::Type::constant) { + g->removeNode(op); + } + } + return g; +} diff --git a/contrib/nnc/passes/optimizations/SinkRelu.cpp b/contrib/nnc/passes/optimizations/SinkRelu.cpp index 8c58ff2..8cfafee 100644 --- a/contrib/nnc/passes/optimizations/SinkRelu.cpp +++ b/contrib/nnc/passes/optimizations/SinkRelu.cpp @@ -31,9 +31,6 @@ namespace nnc { using namespace mir; using namespace opt_util; -/* - Static casts are safe here because we have already checked the types when matching. - */ PassData SinkRelu::run(PassData data) { auto g = static_cast(data); assert(g); @@ -71,7 +68,7 @@ PassData SinkRelu::run(PassData data) { // concat is deleted here g->replaceNode(concat, new_relu); for (auto r: relus) { - removeNodeIfUnsed(g,r); + removeNodeIfUnused(g, r); } } return g; diff --git a/contrib/nnc/passes/optimizations/SinkTranspose.cpp b/contrib/nnc/passes/optimizations/SinkTranspose.cpp index 4cd39e7..351e29f 100644 --- a/contrib/nnc/passes/optimizations/SinkTranspose.cpp +++ b/contrib/nnc/passes/optimizations/SinkTranspose.cpp @@ -29,9 +29,7 @@ namespace nnc { using namespace mir; using namespace opt_util; -/* - Static casts are safe here because we have already checked the types when matching. - */ + PassData SinkTranspose::run(PassData data) { auto g = static_cast(data); assert(g); // NOLINT @@ -69,7 +67,7 @@ PassData SinkTranspose::run(PassData data) { // removes old concat g->replaceNode(concat, new_transpose); for (auto tr: trs) { - removeNodeIfUnsed(g,tr); + removeNodeIfUnused(g, tr); } } } diff --git a/contrib/nnc/unittests/optimizations/CMakeLists.txt b/contrib/nnc/unittests/optimizations/CMakeLists.txt index 1527254..40d3719 100644 --- a/contrib/nnc/unittests/optimizations/CMakeLists.txt +++ b/contrib/nnc/unittests/optimizations/CMakeLists.txt @@ -1,5 +1,6 @@ set(TESTS_OPTIMIZATIONS_SRC CombineTransposes.cpp SinkTest.cpp + RemoveDeadEnds.cpp FuseArithmeticOps.cpp) nnc_add_unit_test(tests_for_optimizations ${TESTS} ${TESTS_OPTIMIZATIONS_SRC}) optional_target_link_libraries(tests_for_optimizations nnc_optimizations nnc_support mir) diff --git a/contrib/nnc/unittests/optimizations/FuseArithmeticOps.cpp b/contrib/nnc/unittests/optimizations/FuseArithmeticOps.cpp index 5f4854e..e738c0b 100644 --- a/contrib/nnc/unittests/optimizations/FuseArithmeticOps.cpp +++ b/contrib/nnc/unittests/optimizations/FuseArithmeticOps.cpp @@ -15,6 +15,7 @@ */ #include "passes/optimizations/FuseArithmeticOps.h" +#include "Util.h" #include "mir/Graph.h" #include "mir/ops/BiasAddOp.h" #include "mir/ops/ConstantOp.h" @@ -29,23 +30,6 @@ using namespace mir; namespace { -class DumpVisitor : public Visitor { -public: - explicit DumpVisitor(std::stringstream& s) : _s(s) {} - - void visit(ops::InputOp& op) override { _s << "i_" << op.getName() << "."; }; - - void visit(ops::ConstantOp& op) override { _s << "const_" << op.getName() << "."; } - - void visit(ops::ScaleOp& op) override { _s << "s_" << op.getName() << "."; } - - void visit(ops::BiasAddOp& op) override { _s << "b_" << op.getName() << "."; } - - void visit(ops::Conv2DOp& op) override { _s << "conv_" << op.getName() << "."; } - - std::stringstream& _s; -}; - TEST(OptPass, fuseConvBiasScaleScaleBias) { mir::Graph g; diff --git a/contrib/nnc/unittests/optimizations/RemoveDeadEnds.cpp b/contrib/nnc/unittests/optimizations/RemoveDeadEnds.cpp new file mode 100644 index 0000000..789b803 --- /dev/null +++ b/contrib/nnc/unittests/optimizations/RemoveDeadEnds.cpp @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +#include "passes/optimizations/RemoveDeadEnds.h" +#include "mir/ops/TransposeOp.h" +#include "mir/ops/ReluOp.h" +#include "mir/ops/ElementwiseOp.h" +#include "mir/ops/ConstantOp.h" + +#include + +namespace { +using namespace std; +using namespace nnc; +using namespace mir; + +TEST(OptPass, removeDeadEndConstants) { + mir::Graph g; + /* Create graph: (with useless constants) + * [input] + * || + * [relu] + */ + Operation* C0 = g.create("C0", TensorVariant(DTYPE::FLOAT32, {2, 2})); + Operation* input = g.create("input", Shape{1, 2, 3}); + Operation* C1 = g.create("C1", TensorVariant(DTYPE::FLOAT32, {2, 2})); + Operation* C2 = g.create("C2", TensorVariant(DTYPE::FLOAT32, {2, 2})); + Operation* relu = g.create("relu", input->getOutput(0)); + + std::stringstream ss; + RemoveDeadEnds pass; + pass.run(&g); + ASSERT_EQ(2, g.getNodes().size()); +} +} // unnamed namespace diff --git a/contrib/nnc/unittests/optimizations/SinkTest.cpp b/contrib/nnc/unittests/optimizations/SinkTest.cpp index 81c535f..e75ed33 100644 --- a/contrib/nnc/unittests/optimizations/SinkTest.cpp +++ b/contrib/nnc/unittests/optimizations/SinkTest.cpp @@ -46,7 +46,7 @@ Operation* getNext(Operation* op) { } /* This tests swapping relu and transpose */ -TEST(SinkTest, sinkTrReLU) { +TEST(OptPass, sinkTrReLU) { mir::Graph g; /* * Create graph: @@ -78,7 +78,7 @@ TEST(SinkTest, sinkTrReLU) { } /* This tests swapping concat and transpose */ -TEST(SinkTest, sinkTrConcat) { +TEST(OptPass, sinkTrConcat) { mir::Graph g; /* * Create graph: @@ -115,7 +115,7 @@ TEST(SinkTest, sinkTrConcat) { } /* This tests swapping concat and transpose */ -TEST(SinkTest, sinkReluConcat) { +TEST(OptPass, sinkReluConcat) { mir::Graph g; /* * Create graph: @@ -148,7 +148,7 @@ TEST(SinkTest, sinkReluConcat) { } /* This tests swapping relu and max_pool */ -TEST(SinkTest, sinkPoolReLU) { +TEST(OptPass, sinkPoolReLU) { mir::Graph g; /* * Create graph: diff --git a/contrib/nnc/unittests/optimizations/Util.h b/contrib/nnc/unittests/optimizations/Util.h index b847174..6abe8af 100644 --- a/contrib/nnc/unittests/optimizations/Util.h +++ b/contrib/nnc/unittests/optimizations/Util.h @@ -24,6 +24,9 @@ #include "mir/ops/ConcatOp.h" #include "mir/ops/OutputOp.h" #include "mir/ops/PoolOp.h" +#include "mir/ops/BiasAddOp.h" +#include "mir/ops/ScaleOp.h" +#include "mir/ops/Conv2DOp.h" #include "mir/Visitor.h" namespace nnc { @@ -36,12 +39,20 @@ public: void visit(mir::ops::TanhOp& op) override { _s << "th_" << op.getName() << "."; } + void visit(mir::ops::ScaleOp& op) override { _s << "s_" << op.getName() << "."; } + + void visit(mir::ops::BiasAddOp& op) override { _s << "b_" << op.getName() << "."; } + void visit(mir::ops::ReluOp& op) override { _s << "r_" << op.getName() << "."; } void visit(mir::ops::PoolOp& op) override { _s << "p_" << op.getName() << "."; } void visit(mir::ops::TransposeOp& op) override { _s << "t_" << op.getName() << "."; } + void visit(mir::ops::Conv2DOp& op) override { _s << "conv_" << op.getName() << "."; } + + void visit(mir::ops::ConstantOp& op) override { _s << "const_" << op.getName() << "."; } + void visit(mir::ops::ElementwiseOp& op) override { _s << "e_" << op.getName() << "."; } std::ostream& _s; -- 2.7.4