}
}
-void Graph::iteratePostDfs(const std::function<void(const operation::Node &)> &fn) const
+std::unique_ptr<linear::Linear> Graph::linearize(void)
{
- std::vector<bool> visited(_operations.size(), false);
+ assert(_phase == Phase::LOWERED);
+
+ auto linear = nnfw::make_unique<linear::Linear>(*this);
+
+ // TODO Move the operations and operands to linear object
+
+ _phase = Phase::LINEARIZED;
+
+ return std::move(linear);
+}
+
+} // namespace graph
+} // namespace neurun
- std::function<void(const operation::Index &index, const operation::Node &)> dfs_recursive =
- [&](const operation::Index &index, const operation::Node &node) -> void {
+namespace neurun
+{
+namespace graph
+{
+
+template class Graph::PostDfsIterator<true>;
+template class Graph::PostDfsIterator<false>;
+
+template <bool is_const>
+void Graph::PostDfsIterator<is_const>::iterate(GraphRef graph, const IterFn &fn) const
+{
+ std::vector<bool> visited(graph._operations.size(), false);
+
+ std::function<void(const operation::Index &, NodeRef)> dfs_recursive =
+ [&](const operation::Index &index, NodeRef node) -> void {
if (visited[index.asInt()])
return;
visited[index.asInt()] = true;
// TODO Fix traversing algorithm
// Every time need to search for operations that has `outgoing` as incoming from all
// operations but we can hold that info cached
- _operations.iterate(
- [&](const operation::Index &cand_index, const operation::Node &cand_node) {
- auto inputs = cand_node.inputs();
- for (auto input : inputs.list())
- {
- if (output == input)
- {
- dfs_recursive(cand_index, cand_node);
- }
- }
- });
+ graph._operations.iterate([&](const operation::Index &cand_index, NodeRef cand_node) -> void {
+ auto inputs = cand_node.inputs();
+ for (auto input : inputs.list())
+ {
+ if (output == input)
+ {
+ dfs_recursive(cand_index, cand_node);
+ }
+ }
+ });
}
fn(node);
};
- _operations.iterate(dfs_recursive);
+ graph._operations.iterate(dfs_recursive);
// All of the operations(nodes) must have been visited.
assert(std::all_of(visited.begin(), visited.end(), [](bool v) { return v; }));
}
-std::unique_ptr<linear::Linear> Graph::linearize(void)
-{
- assert(_phase == Phase::LOWERED);
-
- auto linear = nnfw::make_unique<linear::Linear>(*this);
-
- // TODO Move the operations and operands to linear object
-
- _phase = Phase::LINEARIZED;
-
- return std::move(linear);
-}
-
} // namespace graph
} // namespace neurun
};
public:
+ template <bool is_const> class Iterator
+ {
+ public:
+ using GraphRef = typename std::conditional<is_const, const Graph &, Graph &>::type;
+ using NodeRef =
+ typename std::conditional<is_const, const operation::Node &, operation::Node &>::type;
+ using IterFn = std::function<void(NodeRef)>;
+
+ public:
+ virtual ~Iterator() = default;
+ virtual void iterate(GraphRef graph, const IterFn &fn) const = 0;
+ };
+
+ template <bool is_const = false> class PostDfsIterator final : public Iterator<is_const>
+ {
+ public:
+ using GraphRef = typename Iterator<is_const>::GraphRef;
+ using NodeRef = typename Iterator<is_const>::NodeRef;
+ using IterFn = typename Iterator<is_const>::IterFn;
+
+ public:
+ void iterate(GraphRef graph, const IterFn &fn) const;
+ };
+ using PostDfsConstIterator = PostDfsIterator<true>;
+
+public:
Graph(void) = default;
// Graph Building
operand::Set &operands() { return _operands; } // TODO Remove this non-const accessor
const operation::Set &operations() const { return _operations; }
-public:
- // TODO Introduce Iterator class to support many kinds of interation
- void iteratePostDfs(const std::function<void(const operation::Node &)> &fn) const;
-
private:
Phase _phase{Phase::BUILDING};
operation::Set _operations;
// 2. Append the node to vector when DFS for the node finishes(post order)
// 3. Reverse the order of nodes
- graph.iteratePostDfs([&](const neurun::graph::operation::Node &node) {
- auto op = node.op();
- _operations.emplace_back(op);
- });
+ graph::Graph::PostDfsConstIterator().iterate(graph,
+ [&](const neurun::graph::operation::Node &node) {
+ auto op = node.op();
+ _operations.emplace_back(op);
+ });
std::reverse(std::begin(_operations), std::end(_operations));
}