--- /dev/null
+#include "Graph.h"
+
+#include <algorithm>
+#include <bitset>
+
+#include "logging.h"
+
+namespace neurun
+{
+namespace graph
+{
+
+operand::Index Graph::addOperand(const ::internal::tflite::operand::Shape &shape)
+{
+ return _operands.append(shape);
+}
+
+operation::Index Graph::addOperation(std::unique_ptr<operation::Node> &&node)
+{
+
+ return _operations.append(std::move(node));
+}
+
+void Graph::setOperandValue(const operand::Index &ind,
+ std::unique_ptr<::internal::tflite::operand::Data> &&data)
+{
+ assert(_operands.exist(ind));
+ _operands.at(ind).data(std::move(data));
+}
+
+void Graph::addInput(const operand::Index &ind) { _inputs.append(ind); }
+
+void Graph::addOutput(const operand::Index &ind) { _outputs.append(ind); }
+
+void Graph::iteratePostDfs(const std::function<void(const operation::Node &)> &fn) const
+{
+ std::vector<bool> visited(_operations.size(), false);
+
+ std::function<void(const operation::Index &index, const operation::Node &)> dfs_recursive =
+ [&](const operation::Index &index, const operation::Node &node) -> void {
+ if (visited[index.asInt()])
+ return;
+ visited[index.asInt()] = true;
+
+ auto outputs = node.outputs();
+ for (auto output : outputs.list())
+ {
+ // TODO Fix traversing algorithm
+ // Every time need to search for operations that has `outgoing` as incoming from all
+ // operations but we can hold that info cached
+ _operations.iterate(
+ [&](const operation::Index &cand_index, const operation::Node &cand_node) {
+ auto inputs = cand_node.inputs();
+ for (auto input : inputs.list())
+ {
+ if (output == input)
+ {
+ dfs_recursive(cand_index, cand_node);
+ }
+ }
+ });
+ }
+
+ fn(node);
+ };
+
+ _operations.iterate(dfs_recursive);
+
+ // All of the operations(nodes) must have been visited.
+ assert(std::all_of(visited.begin(), visited.end(), [](bool v) { return v; }));
+}
+
+} // namespace graph
+} // namespace neurun
--- /dev/null
+#ifndef __NEURUN_GRAPH_GRAPH_H__
+#define __NEURUN_GRAPH_GRAPH_H__
+
+#include <functional>
+
+#include "graph/operation/Node.h"
+#include "graph/operation/Set.h"
+#include "graph/operand/IndexSet.h"
+#include "graph/operand/Set.h"
+
+namespace neurun
+{
+namespace graph
+{
+
+class Graph
+{
+public:
+ Graph(void) = default;
+
+ // Graph Building
+public:
+ operand::Index addOperand(const ::internal::tflite::operand::Shape &shape);
+ operation::Index addOperation(std::unique_ptr<operation::Node> &&node);
+ void setOperandValue(const operand::Index &ind,
+ std::unique_ptr<::internal::tflite::operand::Data> &&data);
+ void addInput(const operand::Index &ind);
+ void addOutput(const operand::Index &ind);
+
+ // Accessors
+public:
+ const operand::IndexSet &inputs() const { return _inputs; }
+ const operand::IndexSet &outputs() const { return _outputs; }
+ const operand::Set &operands() const { return _operands; }
+
+public:
+ // TODO Introduce Iterator class to support many kinds of interation
+ void iteratePostDfs(const std::function<void(const operation::Node &)> &fn) const;
+
+private:
+ operation::Set _operations;
+ operand::Set _operands;
+ operand::IndexSet _inputs;
+ operand::IndexSet _outputs;
+};
+
+} // namespace graph
+} // namespace neurun
+
+#endif // __NEURUN_GRAPH_GRAPH_H__
bool isModelInput(void) const { return _usage == OperandUsage::MODEL_INPUT; }
private:
- void data(std::unique_ptr<Data> &&data) { _data = std::move(data); }
bool setUsage(OperandUsage usage);
public:
+ void data(std::unique_ptr<Data> &&data) { _data = std::move(data); }
const Data &data(void) const { return *_data; }
public:
--- /dev/null
+#include <gtest/gtest.h>
+
+#include "graph/Graph.h"
+
+TEST(Graph, inputs_and_outputs)
+{
+ ::neurun::graph::Graph graph;
+
+ graph.addInput({0u});
+ graph.addInput({1u});
+
+ graph.addOutput({10u});
+ graph.addOutput({11u});
+ graph.addOutput({12u});
+
+ ASSERT_EQ(graph.inputs().size(), 2);
+ ASSERT_EQ(graph.outputs().size(), 3);
+}