From 0e5a09253a3cd17526c7b3b19f1405353e490475 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=D0=90=D0=BD=D0=B4=D1=80=D0=B5=D0=B9=20=D0=A8=D0=B5=D0=B4?= =?utf8?q?=D1=8C=D0=BA=D0=BE/AI=20Tools=20Lab=20/SRR/Engineer/=EC=82=BC?= =?utf8?q?=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Thu, 31 Jan 2019 18:59:19 +0300 Subject: [PATCH] [nnc] Initial optimization Pass (#2940) Initial support for transpose composition Added a `Matcher` class for matching graph patterns Signed-off-by: Andrei Shedko --- contrib/nnc/core/CMakeLists.txt | 1 + contrib/nnc/core/modelIR/Graph.cpp | 22 +++++++---- contrib/nnc/core/modelIR/GraphPatternMatcher.cpp | 45 +++++++++++++++++++++ contrib/nnc/core/modelIR/Operation.cpp | 24 +++++++++++ contrib/nnc/driver/Driver.cpp | 14 ++++++- contrib/nnc/driver/Driver.h | 10 ++--- contrib/nnc/driver/Options.cpp | 17 ++++++++ contrib/nnc/include/core/modelIR/Graph.h | 19 ++++++--- .../nnc/include/core/modelIR/GraphPatternMatcher.h | 46 ++++++++++++++++++++++ contrib/nnc/include/core/modelIR/Operation.h | 13 ++++++ contrib/nnc/include/option/Options.h | 3 ++ .../passes/optimizations/CombineTransposes.h | 39 ++++++++++++++++++ contrib/nnc/passes/CMakeLists.txt | 5 +++ contrib/nnc/passes/dot_dumper/DumperPass.cpp | 1 - .../nnc/passes/interpreter/interpreter_pass.cpp | 1 - contrib/nnc/passes/optimizations/CMakeLists.txt | 6 +++ .../nnc/passes/optimizations/CombineTransposes.cpp | 32 +++++++++++++++ contrib/nnc/passes/soft_backend/SBSerializer.cpp | 1 - 18 files changed, 275 insertions(+), 24 deletions(-) create mode 100644 contrib/nnc/core/modelIR/GraphPatternMatcher.cpp create mode 100644 contrib/nnc/include/core/modelIR/GraphPatternMatcher.h create mode 100644 contrib/nnc/include/passes/optimizations/CombineTransposes.h create mode 100644 contrib/nnc/passes/optimizations/CMakeLists.txt create mode 100644 contrib/nnc/passes/optimizations/CombineTransposes.cpp diff --git a/contrib/nnc/core/CMakeLists.txt b/contrib/nnc/core/CMakeLists.txt index 3251204..efef88a 100644 --- a/contrib/nnc/core/CMakeLists.txt +++ b/contrib/nnc/core/CMakeLists.txt @@ -15,6 +15,7 @@ set(SOURCES "modelIR/operations/ConcatOp.cpp" "modelIR/Index.cpp" "modelIR/ir_dot_builder.cpp" "modelIR/IrDotDumper.cpp" + "modelIR/GraphPatternMatcher.cpp" "modelIR/ir_dot_node_info.cpp" "modelIR/Operation.cpp" "modelIR/Shape.cpp" diff --git a/contrib/nnc/core/modelIR/Graph.cpp b/contrib/nnc/core/modelIR/Graph.cpp index c82ff64..6a0cc6c 100644 --- a/contrib/nnc/core/modelIR/Graph.cpp +++ b/contrib/nnc/core/modelIR/Graph.cpp @@ -14,12 +14,12 @@ * limitations under the License. */ +#include "core/modelIR/Graph.h" + #include #include #include -#include "core/modelIR/Graph.h" - namespace nnc { namespace mir { @@ -93,7 +93,7 @@ Graph::~Graph() { } void Graph::registerOp(Operation* op) { - _ops.push_back(op); + _ops.emplace(op); if (auto* input_op = dynamic_cast(op)) _inputs.emplace_back(input_op); @@ -102,7 +102,7 @@ void Graph::registerOp(Operation* op) { _outputs.emplace_back(output_op); } -void Graph::replaceNode(const Operation* op, Operation* with) { +void Graph::replaceNode(Operation* op, Operation* with) { replaceUsages(op, with); _inputs.erase(std::remove_if(_inputs.begin(), _inputs.end(), [op](ops::InputOp* n) { @@ -113,12 +113,11 @@ void Graph::replaceNode(const Operation* op, Operation* with) { return n == op; }), _outputs.end()); - _ops.erase(std::remove_if(_ops.begin(), _ops.end(), [op](Operation* n) { - return n == op; - }), _ops.end()); + _ops.erase(op); + } -ops::InputOp* Graph::replaceWithInputNode(const Operation* op) { +ops::InputOp* Graph::replaceWithInputNode(Operation* op) { assert(op->getNumOutputs() <= 1 && "Only operations with single output value can be replaced with input node"); assert(op->getNextNodes().size() <= 1 @@ -154,5 +153,12 @@ void Graph::replaceInputNodes(const std::vector& new_inputs) { } } +void Graph::removeNode(Operation* op) { + op->removeFromPrev(); + op->removeFromNext(); + _ops.erase(op); + delete op; +} + } // namespace mir } // namespace nnc diff --git a/contrib/nnc/core/modelIR/GraphPatternMatcher.cpp b/contrib/nnc/core/modelIR/GraphPatternMatcher.cpp new file mode 100644 index 0000000..94ef068 --- /dev/null +++ b/contrib/nnc/core/modelIR/GraphPatternMatcher.cpp @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "core/modelIR/GraphPatternMatcher.h" + +#include +#include + +namespace nnc { +namespace mir { + +std::vector> GraphPatternMatcher::matchEdge( + GraphPatternMatcher::Predicate p1, + GraphPatternMatcher::Predicate p2) { + + std::vector> matches; + for (auto* start: _g->getNodes()) { + if (p1(start)) { + const auto& next_nodes = start->getNextNodes(); + for (auto* end: next_nodes) { + if (p2(end)) { + matches.emplace_back(std::make_pair(start, end)); + break; + } + } + } + } + return matches; +} + +} // namespace nnc +} // namespace mir diff --git a/contrib/nnc/core/modelIR/Operation.cpp b/contrib/nnc/core/modelIR/Operation.cpp index 20a1ee6..47956a1 100644 --- a/contrib/nnc/core/modelIR/Operation.cpp +++ b/contrib/nnc/core/modelIR/Operation.cpp @@ -79,6 +79,11 @@ void Operation::setOutputShape(std::size_t index, const Shape& shape) { _outputShapes[index] = shape; } +void Operation::setInput(const IODescriptor& descr, size_t i) { + descr.op->_outputs.emplace_back(this); + _inputs[i] = descr; +} + void Operation::accept(IVisitor* v) { switch (getType()) { #define HANDLE_OP(OpType, OpClass) \ @@ -92,5 +97,24 @@ void Operation::accept(IVisitor* v) { } } +void Operation::removeFromPrev() { + for (const auto& prev : _inputs) { + auto& mutable_next = prev.op->_outputs; + mutable_next.erase(std::find(std::begin(mutable_next), std::end(mutable_next), this)); + } + _inputs.clear(); +} + +void Operation::removeFromNext() { + for (auto* next : _outputs) { + auto& mutable_prev = next->_inputs; + mutable_prev.erase( + std::remove_if(mutable_prev.begin(), mutable_prev.end(), [this](IODescriptor n) { + return n.op == this; + }), mutable_prev.end()); + } + _outputs.clear(); +} + } // namespace mir } // namespace nnc diff --git a/contrib/nnc/driver/Driver.cpp b/contrib/nnc/driver/Driver.cpp index 2d641cf..3bf6ea4 100644 --- a/contrib/nnc/driver/Driver.cpp +++ b/contrib/nnc/driver/Driver.cpp @@ -22,12 +22,12 @@ #include "passes/soft_backend/CPPGenerator.h" #include "passes/dot_dumper/DumperPass.h" #include "passes/acl_soft_backend/AclCppGenerator.h" +#include "passes/optimizations/CombineTransposes.h" #include "support/CommandLine.h" #include "Definitions.h" #include "option/Options.h" #include "Driver.h" - namespace nnc { /** @@ -110,10 +110,22 @@ void Driver::registerBackendPass() { } // registerBackendPass +static void registerDumper(PassManager& pass_manager) { + if (cli::dumpGraph) + pass_manager.registerPass(std::unique_ptr(new DumperPass())); +} + +void Driver::registerOptimizationPass() { + if (cli::doOptimizationPass) { + _passManager.registerPass(std::unique_ptr(new CombineTransposes())); + } +} // registerOptimizationPass + void Driver::runDriver() { // register passes registerFrontendPass(); + registerOptimizationPass(); registerBackendPass(); // run registered passes diff --git a/contrib/nnc/driver/Driver.h b/contrib/nnc/driver/Driver.h index 06059a6..7a4459d 100644 --- a/contrib/nnc/driver/Driver.h +++ b/contrib/nnc/driver/Driver.h @@ -22,14 +22,12 @@ #include "pass/PassManager.h" -namespace nnc -{ +namespace nnc { /** * @brief exceptions description class for compiler driver */ -class DriverException : public std::exception -{ +class DriverException : public std::exception { public: DriverException() = default; explicit DriverException(const std::string& reason) : _msg(reason) {} @@ -44,8 +42,7 @@ private: /** * @brief Compiler Driver manages the whole pipeline compilation process */ -class Driver -{ +class Driver { public: /** * @brief main method to run compiler driver @@ -57,6 +54,7 @@ public: private: void registerFrontendPass(); void registerBackendPass(); + void registerOptimizationPass(); void runPasses(); PassManager _passManager; diff --git a/contrib/nnc/driver/Options.cpp b/contrib/nnc/driver/Options.cpp index c073d31..0a1bb7c 100644 --- a/contrib/nnc/driver/Options.cpp +++ b/contrib/nnc/driver/Options.cpp @@ -142,6 +142,23 @@ Option inputFile(optname("--nnmodel, -m"), checkInFile); /** + * Options for *optimizer* + */ +Option doOptimizationPass(optname("-O"), + overview("whether to optimize model or not"), + false, + optional(true), optvalues(""), nullptr, + separators(""), + showopt(true)); + +Option dumpGraph(optname("--dump, -D"), + overview("dump graph to dot files after optimization passes"), + false, + optional(true), optvalues(""), nullptr, + separators(""), + showopt(true)); + +/** * Options for *backend* */ // options for soft backend diff --git a/contrib/nnc/include/core/modelIR/Graph.h b/contrib/nnc/include/core/modelIR/Graph.h index 2f82fb2..87e5f77 100644 --- a/contrib/nnc/include/core/modelIR/Graph.h +++ b/contrib/nnc/include/core/modelIR/Graph.h @@ -20,6 +20,7 @@ #include #include #include +#include #include #include @@ -55,26 +56,32 @@ class Graph { * @brief Returns all graph nodes * @return vector containing all graph nodes */ - const std::vector& getNodes() const { return _ops; } + std::unordered_set getNodes() const { return _ops; } /** * @brief Returns all graph input nodes * @returns vector containing all graph input nodes */ - const std::vector& getInputs() const { return _inputs; } + std::vector getInputs() const { return _inputs; } /** * @brief Returns all graph output nodes * @returns vector containing all graph output nodes */ - const std::vector& getOutputs() const { return _outputs; } + std::vector getOutputs() const { return _outputs; } + + /** + * @brief remove node from graph, along with its links in other nodes + * @param op node to be removed + */ + void removeNode(Operation* op); /** * @brief Subsitude node in graph with another keeping all edges * @param op Node to subsitude * @param with Node to place instead */ - void replaceNode(const Operation* op, Operation* with); + void replaceNode(Operation* op, Operation* with); /** * @brief Replaces referenced node with input(VariableOp) node @@ -82,7 +89,7 @@ class Graph { * @return Input node which is placed in graph instead of passed node * @warning deletes passed node */ - ops::InputOp* replaceWithInputNode(const Operation* op); + ops::InputOp* replaceWithInputNode(Operation* op); /** * @brief Change graph inputs to nodes with names in newInputs @@ -94,7 +101,7 @@ class Graph { private: void registerOp(Operation* op); - std::vector _ops; + std::unordered_set _ops; size_t _lastNodeId = 0; std::vector _inputs; std::vector _outputs; diff --git a/contrib/nnc/include/core/modelIR/GraphPatternMatcher.h b/contrib/nnc/include/core/modelIR/GraphPatternMatcher.h new file mode 100644 index 0000000..bdf939b --- /dev/null +++ b/contrib/nnc/include/core/modelIR/GraphPatternMatcher.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNCC_GRAPH_PATTERN_MATCHER_H +#define NNCC_GRAPH_PATTERN_MATCHER_H + +#include "Graph.h" + +namespace nnc { +namespace mir { + +class Operation; + +class GraphPatternMatcher { +public: + using Predicate = bool(const Operation*); + explicit GraphPatternMatcher(Graph* g) : _g(g) {}; + + /** + * @brief Match an edge with 2 predicates for ends of the edge + * @param pattern + * @return Vector of topmost ops of all matches; empty if no mathces are found + */ + std::vector> matchEdge(Predicate p1, Predicate p2); + +private: + Graph* _g; +}; + +#endif //NNCC_GRAPH_PATTERN_MATCHER_H + +} // namespace nnc +} // namespace mir diff --git a/contrib/nnc/include/core/modelIR/Operation.h b/contrib/nnc/include/core/modelIR/Operation.h index 9794ccc..6a96b79 100644 --- a/contrib/nnc/include/core/modelIR/Operation.h +++ b/contrib/nnc/include/core/modelIR/Operation.h @@ -72,6 +72,19 @@ public: const nnc::mir::Shape& getInputShape(std::size_t index) const; const nnc::mir::Shape& getOutputShape(std::size_t index) const; + /// @brief Removes links to this node from it's parents + void removeFromPrev(); + + /// @brief Removes links to this node from it's children + void removeFromNext(); + + /** + * @brief Set `descr` as `i`-th input of this node + * @param descr the tensor to be set as input + * @param i input index + */ + void setInput(const IODescriptor& descr, size_t i); + void accept(IVisitor* v); protected: diff --git a/contrib/nnc/include/option/Options.h b/contrib/nnc/include/option/Options.h index 803fdcc..d49b17e 100644 --- a/contrib/nnc/include/option/Options.h +++ b/contrib/nnc/include/option/Options.h @@ -36,6 +36,9 @@ extern Option caffeFrontend; // frontend for CAFFE AI framework extern Option tflFrontend; // frontend for TensorFlow Lite AI framework extern Option onnxFrontend; // frontend for ONNX AI framework +extern Option doOptimizationPass; // enable optimization pass +extern Option dumpGraph; // enable Dumping graph to .dot files + // valid values for target option #define NNC_TARGET_ARM_CPP "arm-c++" #define NNC_TARGET_X86_CPP "x86-c++" diff --git a/contrib/nnc/include/passes/optimizations/CombineTransposes.h b/contrib/nnc/include/passes/optimizations/CombineTransposes.h new file mode 100644 index 0000000..311624e --- /dev/null +++ b/contrib/nnc/include/passes/optimizations/CombineTransposes.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNCC_COMBINE_TRANSPOSES_H +#define NNCC_COMBINE_TRANSPOSES_H + +#include "pass/Pass.h" +#include "pass/PassData.h" + +namespace nnc { + +/** + * @brief This pass combines sequential transposes and removes identity transposes if + * the combination results in an identity permutation. + */ +class CombineTransposes : public Pass { +public: + PassData run(PassData data) override; + +private: +}; + +} // namespace nnc + + +#endif //NNCC_COMBINE_TRANSPOSES_H diff --git a/contrib/nnc/passes/CMakeLists.txt b/contrib/nnc/passes/CMakeLists.txt index 1fa2195..5a413bc 100644 --- a/contrib/nnc/passes/CMakeLists.txt +++ b/contrib/nnc/passes/CMakeLists.txt @@ -23,6 +23,11 @@ if(NNC_FRONTEND_CAFFE2_ENABLED) endif() # +# MIDDLE PASSES +# +add_subdirectory(optimizations) + +# # BACKENDs # add_subdirectory(interpreter) diff --git a/contrib/nnc/passes/dot_dumper/DumperPass.cpp b/contrib/nnc/passes/dot_dumper/DumperPass.cpp index 9ee1328..9ed90ed 100644 --- a/contrib/nnc/passes/dot_dumper/DumperPass.cpp +++ b/contrib/nnc/passes/dot_dumper/DumperPass.cpp @@ -19,7 +19,6 @@ #include "pass/PassException.h" #include -#include namespace nnc { diff --git a/contrib/nnc/passes/interpreter/interpreter_pass.cpp b/contrib/nnc/passes/interpreter/interpreter_pass.cpp index aab24e9..6205d25 100644 --- a/contrib/nnc/passes/interpreter/interpreter_pass.cpp +++ b/contrib/nnc/passes/interpreter/interpreter_pass.cpp @@ -140,7 +140,6 @@ PassData InterpreterPass::run(PassData data) { #else std::cout << "Result <" << out_node->getName() << "> wasn't saved, due to lack of HDF5" << std::endl; - #endif // NNC_HDF5_SUPPORTED } diff --git a/contrib/nnc/passes/optimizations/CMakeLists.txt b/contrib/nnc/passes/optimizations/CMakeLists.txt new file mode 100644 index 0000000..1a1e13a --- /dev/null +++ b/contrib/nnc/passes/optimizations/CMakeLists.txt @@ -0,0 +1,6 @@ +set(OPTIMIZATIONS_SRC "CombineTransposes.cpp") +nnc_add_library(nnc_optimizations SHARED ${OPTIMIZATIONS_SRC}) +target_link_libraries(nnc_optimizations PRIVATE nnc_core nnc_support) + +# install optimizations library +nnc_install_library(nnc_optimizations) diff --git a/contrib/nnc/passes/optimizations/CombineTransposes.cpp b/contrib/nnc/passes/optimizations/CombineTransposes.cpp new file mode 100644 index 0000000..0a6f2b9 --- /dev/null +++ b/contrib/nnc/passes/optimizations/CombineTransposes.cpp @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "passes/optimizations/CombineTransposes.h" +#include "core/modelIR/operations/TransposeOp.h" +#include "core/modelIR/Graph.h" +#include "core/modelIR/GraphPatternMatcher.h" + +namespace nnc { + +using namespace mir; + +PassData CombineTransposes::run(PassData data) { + auto g = static_cast(data); + assert(g); + return g; +} + +} //namespace nnc diff --git a/contrib/nnc/passes/soft_backend/SBSerializer.cpp b/contrib/nnc/passes/soft_backend/SBSerializer.cpp index 202eca6..008536f 100644 --- a/contrib/nnc/passes/soft_backend/SBSerializer.cpp +++ b/contrib/nnc/passes/soft_backend/SBSerializer.cpp @@ -51,7 +51,6 @@ #include "pass/PassException.h" #include -#include #define UNUSED(x) ((void)(x)) -- 2.7.4