"modelIR/Index.cpp"
"modelIR/ir_dot_builder.cpp"
"modelIR/IrDotDumper.cpp"
+ "modelIR/GraphPatternMatcher.cpp"
"modelIR/ir_dot_node_info.cpp"
"modelIR/Operation.cpp"
"modelIR/Shape.cpp"
* limitations under the License.
*/
+#include "core/modelIR/Graph.h"
+
#include <deque>
#include <set>
#include <algorithm>
-#include "core/modelIR/Graph.h"
-
namespace nnc {
namespace mir {
}
void Graph::registerOp(Operation* op) {
- _ops.push_back(op);
+ _ops.emplace(op);
if (auto* input_op = dynamic_cast<ops::InputOp*>(op))
_inputs.emplace_back(input_op);
_outputs.emplace_back(output_op);
}
-void Graph::replaceNode(const Operation* op, Operation* with) {
+void Graph::replaceNode(Operation* op, Operation* with) {
replaceUsages(op, with);
_inputs.erase(std::remove_if(_inputs.begin(), _inputs.end(), [op](ops::InputOp* n) {
return n == op;
}), _outputs.end());
- _ops.erase(std::remove_if(_ops.begin(), _ops.end(), [op](Operation* n) {
- return n == op;
- }), _ops.end());
+ _ops.erase(op);
+
}
-ops::InputOp* Graph::replaceWithInputNode(const Operation* op) {
+ops::InputOp* Graph::replaceWithInputNode(Operation* op) {
assert(op->getNumOutputs() <= 1
&& "Only operations with single output value can be replaced with input node");
assert(op->getNextNodes().size() <= 1
}
}
+void Graph::removeNode(Operation* op) {
+ op->removeFromPrev();
+ op->removeFromNext();
+ _ops.erase(op);
+ delete op;
+}
+
} // namespace mir
} // namespace nnc
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "core/modelIR/GraphPatternMatcher.h"
+
+#include <algorithm>
+#include <vector>
+
+namespace nnc {
+namespace mir {
+
+std::vector<std::pair<Operation*, Operation*>> GraphPatternMatcher::matchEdge(
+ GraphPatternMatcher::Predicate p1,
+ GraphPatternMatcher::Predicate p2) {
+
+ std::vector<std::pair<Operation*, Operation*>> matches;
+ for (auto* start: _g->getNodes()) {
+ if (p1(start)) {
+ const auto& next_nodes = start->getNextNodes();
+ for (auto* end: next_nodes) {
+ if (p2(end)) {
+ matches.emplace_back(std::make_pair(start, end));
+ break;
+ }
+ }
+ }
+ }
+ return matches;
+}
+
+} // namespace nnc
+} // namespace mir
_outputShapes[index] = shape;
}
+void Operation::setInput(const IODescriptor& descr, size_t i) {
+ descr.op->_outputs.emplace_back(this);
+ _inputs[i] = descr;
+}
+
void Operation::accept(IVisitor* v) {
switch (getType()) {
#define HANDLE_OP(OpType, OpClass) \
}
}
+void Operation::removeFromPrev() {
+ for (const auto& prev : _inputs) {
+ auto& mutable_next = prev.op->_outputs;
+ mutable_next.erase(std::find(std::begin(mutable_next), std::end(mutable_next), this));
+ }
+ _inputs.clear();
+}
+
+void Operation::removeFromNext() {
+ for (auto* next : _outputs) {
+ auto& mutable_prev = next->_inputs;
+ mutable_prev.erase(
+ std::remove_if(mutable_prev.begin(), mutable_prev.end(), [this](IODescriptor n) {
+ return n.op == this;
+ }), mutable_prev.end());
+ }
+ _outputs.clear();
+}
+
} // namespace mir
} // namespace nnc
#include "passes/soft_backend/CPPGenerator.h"
#include "passes/dot_dumper/DumperPass.h"
#include "passes/acl_soft_backend/AclCppGenerator.h"
+#include "passes/optimizations/CombineTransposes.h"
#include "support/CommandLine.h"
#include "Definitions.h"
#include "option/Options.h"
#include "Driver.h"
-
namespace nnc {
/**
} // registerBackendPass
+static void registerDumper(PassManager& pass_manager) {
+ if (cli::dumpGraph)
+ pass_manager.registerPass(std::unique_ptr<Pass>(new DumperPass()));
+}
+
+void Driver::registerOptimizationPass() {
+ if (cli::doOptimizationPass) {
+ _passManager.registerPass(std::unique_ptr<Pass>(new CombineTransposes()));
+ }
+} // registerOptimizationPass
+
void Driver::runDriver() {
// register passes
registerFrontendPass();
+ registerOptimizationPass();
registerBackendPass();
// run registered passes
#include "pass/PassManager.h"
-namespace nnc
-{
+namespace nnc {
/**
* @brief exceptions description class for compiler driver
*/
-class DriverException : public std::exception
-{
+class DriverException : public std::exception {
public:
DriverException() = default;
explicit DriverException(const std::string& reason) : _msg(reason) {}
/**
* @brief Compiler Driver manages the whole pipeline compilation process
*/
-class Driver
-{
+class Driver {
public:
/**
* @brief main method to run compiler driver
private:
void registerFrontendPass();
void registerBackendPass();
+ void registerOptimizationPass();
void runPasses();
PassManager _passManager;
checkInFile);
/**
+ * Options for *optimizer*
+ */
+Option<bool> doOptimizationPass(optname("-O"),
+ overview("whether to optimize model or not"),
+ false,
+ optional(true), optvalues(""), nullptr,
+ separators(""),
+ showopt(true));
+
+Option<bool> dumpGraph(optname("--dump, -D"),
+ overview("dump graph to dot files after optimization passes"),
+ false,
+ optional(true), optvalues(""), nullptr,
+ separators(""),
+ showopt(true));
+
+/**
* Options for *backend*
*/
// options for soft backend
#include <string>
#include <vector>
#include <type_traits>
+#include <unordered_set>
#include <unordered_map>
#include <set>
* @brief Returns all graph nodes
* @return vector containing all graph nodes
*/
- const std::vector<Operation*>& getNodes() const { return _ops; }
+ std::unordered_set<Operation*> getNodes() const { return _ops; }
/**
* @brief Returns all graph input nodes
* @returns vector containing all graph input nodes
*/
- const std::vector<ops::InputOp*>& getInputs() const { return _inputs; }
+ std::vector<ops::InputOp*> getInputs() const { return _inputs; }
/**
* @brief Returns all graph output nodes
* @returns vector containing all graph output nodes
*/
- const std::vector<ops::OutputOp*>& getOutputs() const { return _outputs; }
+ std::vector<ops::OutputOp*> getOutputs() const { return _outputs; }
+
+ /**
+ * @brief remove node from graph, along with its links in other nodes
+ * @param op node to be removed
+ */
+ void removeNode(Operation* op);
/**
* @brief Subsitude node in graph with another keeping all edges
* @param op Node to subsitude
* @param with Node to place instead
*/
- void replaceNode(const Operation* op, Operation* with);
+ void replaceNode(Operation* op, Operation* with);
/**
* @brief Replaces referenced node with input(VariableOp) node
* @return Input node which is placed in graph instead of passed node
* @warning deletes passed node
*/
- ops::InputOp* replaceWithInputNode(const Operation* op);
+ ops::InputOp* replaceWithInputNode(Operation* op);
/**
* @brief Change graph inputs to nodes with names in newInputs
private:
void registerOp(Operation* op);
- std::vector<Operation*> _ops;
+ std::unordered_set<Operation*> _ops;
size_t _lastNodeId = 0;
std::vector<ops::InputOp*> _inputs;
std::vector<ops::OutputOp*> _outputs;
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NNCC_GRAPH_PATTERN_MATCHER_H
+#define NNCC_GRAPH_PATTERN_MATCHER_H
+
+#include "Graph.h"
+
+namespace nnc {
+namespace mir {
+
+class Operation;
+
+class GraphPatternMatcher {
+public:
+ using Predicate = bool(const Operation*);
+ explicit GraphPatternMatcher(Graph* g) : _g(g) {};
+
+ /**
+ * @brief Match an edge with 2 predicates for ends of the edge
+ * @param pattern
+ * @return Vector of topmost ops of all matches; empty if no mathces are found
+ */
+ std::vector<std::pair<Operation*, Operation*>> matchEdge(Predicate p1, Predicate p2);
+
+private:
+ Graph* _g;
+};
+
+#endif //NNCC_GRAPH_PATTERN_MATCHER_H
+
+} // namespace nnc
+} // namespace mir
const nnc::mir::Shape& getInputShape(std::size_t index) const;
const nnc::mir::Shape& getOutputShape(std::size_t index) const;
+ /// @brief Removes links to this node from it's parents
+ void removeFromPrev();
+
+ /// @brief Removes links to this node from it's children
+ void removeFromNext();
+
+ /**
+ * @brief Set `descr` as `i`-th input of this node
+ * @param descr the tensor to be set as input
+ * @param i input index
+ */
+ void setInput(const IODescriptor& descr, size_t i);
+
void accept(IVisitor* v);
protected:
extern Option<bool> tflFrontend; // frontend for TensorFlow Lite AI framework
extern Option<bool> onnxFrontend; // frontend for ONNX AI framework
+extern Option<bool> doOptimizationPass; // enable optimization pass
+extern Option<bool> dumpGraph; // enable Dumping graph to .dot files
+
// valid values for target option
#define NNC_TARGET_ARM_CPP "arm-c++"
#define NNC_TARGET_X86_CPP "x86-c++"
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NNCC_COMBINE_TRANSPOSES_H
+#define NNCC_COMBINE_TRANSPOSES_H
+
+#include "pass/Pass.h"
+#include "pass/PassData.h"
+
+namespace nnc {
+
+/**
+ * @brief This pass combines sequential transposes and removes identity transposes if
+ * the combination results in an identity permutation.
+ */
+class CombineTransposes : public Pass {
+public:
+ PassData run(PassData data) override;
+
+private:
+};
+
+} // namespace nnc
+
+
+#endif //NNCC_COMBINE_TRANSPOSES_H
endif()
#
+# MIDDLE PASSES
+#
+add_subdirectory(optimizations)
+
+#
# BACKENDs
#
add_subdirectory(interpreter)
#include "pass/PassException.h"
#include <fstream>
-#include <fstream>
namespace nnc {
#else
std::cout << "Result <" << out_node->getName()
<< "> wasn't saved, due to lack of HDF5" << std::endl;
-
#endif // NNC_HDF5_SUPPORTED
}
--- /dev/null
+set(OPTIMIZATIONS_SRC "CombineTransposes.cpp")
+nnc_add_library(nnc_optimizations SHARED ${OPTIMIZATIONS_SRC})
+target_link_libraries(nnc_optimizations PRIVATE nnc_core nnc_support)
+
+# install optimizations library
+nnc_install_library(nnc_optimizations)
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "passes/optimizations/CombineTransposes.h"
+#include "core/modelIR/operations/TransposeOp.h"
+#include "core/modelIR/Graph.h"
+#include "core/modelIR/GraphPatternMatcher.h"
+
+namespace nnc {
+
+using namespace mir;
+
+PassData CombineTransposes::run(PassData data) {
+ auto g = static_cast<Graph*>(data);
+ assert(g);
+ return g;
+}
+
+} //namespace nnc
#include "pass/PassException.h"
#include <algorithm>
-#include <core/modelIR/TensorUtil.h>
#define UNUSED(x) ((void)(x))