#ifndef __NEURUN_MODEL_SUBGRAPH_CONTEXT_H__
#define __NEURUN_MODEL_SUBGRAPH_CONTEXT_H__
-#include <memory>
-
-#include "Subgraph.h"
-#include "Index.h"
-
-#include <unordered_map>
+#include "model/Index.h"
+#include "model/Subgraph.h"
+#include "util/ObjectManager.h"
namespace neurun
{
namespace model
{
-class SubgraphContext
+/**
+ * @brief Class that manages Subgraph objects
+ */
+class SubgraphContext : public util::ObjectManager<SubgraphIndex, Subgraph>
{
public:
- SubgraphContext() : _index_count(0) {}
-
-public:
- SubgraphIndex append(std::unique_ptr<Subgraph> &&subg);
- SubgraphIndex append(const OperationIndex &index, const Operation &node, Layout layout);
- void remove(const SubgraphIndex &index) { _subgs.erase(index); };
-
-public:
- const Subgraph &at(const SubgraphIndex &) const;
- Subgraph &at(const SubgraphIndex &);
- bool exist(const SubgraphIndex &) const;
- bool hasNode(const OperationIndex &node_index) const;
- SubgraphIndex findNode(const OperationIndex &node_index) const;
-
- uint32_t size() const { return _subgs.size(); }
-
- void iterate(const std::function<void(const SubgraphIndex &, const Subgraph &)> &fn) const;
- void iterate(const std::function<void(const SubgraphIndex &, Subgraph &)> &fn);
-
- // TODO: Extract this into external helper function
+ /**
+ * @brief Create an instance of Subgraph with given op and push it to objects
+ *
+ * @param[in] op_idx Operation index that is emplaced
+ * @param[in] op Operation that is emplaced
+ * @param[in] layout Subgraph's layout
+ * @return SubgraphIndex
+ */
+ SubgraphIndex emplace(const OperationIndex &op_index, const Operation &op, Layout layout);
+ /**
+ * @brief Check if an operation does exist in any subgraphs
+ *
+ * @param operation_index Operation index to find
+ * @return true If such operation exists in any subgraphs otherwise false
+ */
+ bool containsOperation(const OperationIndex &operation_index) const;
+ /**
+ * @brief Find an operation from all subgraphs
+ *
+ * @param operation_index Operation index to find
+ * @return SubgraphIndex Index of Subgraph that contains given operation index
+ */
+ SubgraphIndex findOperation(const OperationIndex &operation_index) const;
+ /**
+ * @brief Dump subgraphs
+ *
+ * @param msg Message that will be displayed
+ */
void dump(const std::string &msg) const;
-
- void clear() { _subgs.clear(); }
-
-private:
- SubgraphIndex generateIndex();
-
-private:
- std::unordered_map<SubgraphIndex, std::unique_ptr<Subgraph>> _subgs;
- uint32_t _index_count;
};
} // namespace model
auto make_subgraph = [&](const model::OperationIndex &node_index, const model::Operation &node,
model::Layout layout) {
- auto subg_index = _subg_ctx->append(node_index, node, layout);
+ auto subg_index = _subg_ctx->emplace(node_index, node, layout);
auto &subg = _subg_ctx->at(subg_index);
subg.setOutputs(node.getOutputs());
subg.setInputs(node.getInputs());
_graph.removeOperand(inp_index);
// remove permutation operation
- assert(_graph.subg_ctx().hasNode(input_use));
- auto subg_idx = _graph.subg_ctx().findNode(input_use);
+ assert(_graph.subg_ctx().containsOperation(input_use));
+ auto subg_idx = _graph.subg_ctx().findOperation(input_use);
_graph.subg_ctx().remove(subg_idx);
_graph.operations().remove(input_use);
_graph.removeOperand(out_index);
// remove permutation operation
- assert(_graph.subg_ctx().hasNode(output_def));
- auto subg_idx = _graph.subg_ctx().findNode(output_def);
+ assert(_graph.subg_ctx().containsOperation(output_def));
+ auto subg_idx = _graph.subg_ctx().findOperation(output_def);
_graph.subg_ctx().remove(subg_idx);
_graph.operations().remove(output_def);
continue;
auto &operation = _graph.operations().at(use);
- assert(_graph.subg_ctx().hasNode(use));
- auto subg_index = _graph.subg_ctx().findNode(use);
+ assert(_graph.subg_ctx().containsOperation(use));
+ auto subg_index = _graph.subg_ctx().findOperation(use);
auto subg_li = _graph.getLowerInfo(subg_index);
assert(subg_li);
const auto subg_layout = _graph.subg_ctx().at(subg_index).getLayout();
// Subgraph
{
- auto subg_index = _graph.subg_ctx().append(node_index, node, permute_node_layout);
+ auto subg_index = _graph.subg_ctx().emplace(node_index, node, permute_node_layout);
auto &subg = _graph.subg_ctx().at(subg_index);
subg.setInputs(node.getInputs());
subg.setOutputs(node.getOutputs());
#include <cassert>
#include <string>
-#include <stdint.h>
namespace neurun
{
namespace model
{
-SubgraphIndex SubgraphContext::generateIndex()
-{
- assert((_index_count) <= UINT32_MAX);
-
- return SubgraphIndex{_index_count++};
-}
-
-SubgraphIndex SubgraphContext::append(std::unique_ptr<Subgraph> &&subgraph)
-{
- auto index = generateIndex();
- _subgs[index] = std::move(subgraph);
- return index;
-}
-
-SubgraphIndex SubgraphContext::append(const OperationIndex &index, const Operation &node,
- Layout layout)
+SubgraphIndex SubgraphContext::emplace(const OperationIndex &index, const Operation &node,
+ Layout layout)
{
std::unique_ptr<Subgraph> subg = nnfw::cpp14::make_unique<model::Subgraph>(layout);
subg->appendOperation(index, node);
- return append(std::move(subg));
-}
-
-const Subgraph &SubgraphContext::at(const SubgraphIndex &index) const
-{
- return *(_subgs.at(index));
-}
-
-Subgraph &SubgraphContext::at(const SubgraphIndex &index) { return *(_subgs.at(index)); }
-
-bool SubgraphContext::exist(const SubgraphIndex &index) const
-{
- return _subgs.find(index) != _subgs.end();
+ return push(std::move(subg));
}
-bool SubgraphContext::hasNode(const OperationIndex &node_index) const
+bool SubgraphContext::containsOperation(const OperationIndex &operation_index) const
{
- for (auto it = _subgs.begin(); it != _subgs.end(); ++it)
- {
- const auto &subg = *it->second;
- for (const auto &elem : subg.operations())
+ bool ret = false;
+ iterate([&](const SubgraphIndex &, const Subgraph &object) {
+ for (const auto &elem : object.operations())
{
- if (elem.index == node_index)
- return true;
+ if (elem.index == operation_index)
+ ret = true;
}
- }
-
- return false;
+ });
+ return ret;
}
-SubgraphIndex SubgraphContext::findNode(const OperationIndex &node_index) const
+SubgraphIndex SubgraphContext::findOperation(const OperationIndex &operation_index) const
{
- for (auto it = _subgs.begin(); it != _subgs.end(); ++it)
- {
- auto &subg_index = it->first;
- const auto &subg = *it->second;
- for (const auto &elem : subg.operations())
+ SubgraphIndex ret;
+ iterate([&](const SubgraphIndex &index, const Subgraph &object) {
+ for (const auto &elem : object.operations())
{
- if (elem.index == node_index)
- {
- return subg_index;
- }
+ if (elem.index == operation_index)
+ ret = index;
}
- }
-
- assert(true && "DO NOT ENTER HERE");
- return SubgraphIndex(UINT32_MAX); // CAN'T ENTER HERE
-}
-
-void SubgraphContext::iterate(
- const std::function<void(const SubgraphIndex &, const Subgraph &)> &fn) const
-{
- for (auto it = _subgs.begin(); it != _subgs.end(); ++it)
- {
- fn(it->first, *it->second);
- }
-}
-
-void SubgraphContext::iterate(const std::function<void(const SubgraphIndex &, Subgraph &)> &fn)
-{
- for (auto it = _subgs.begin(); it != _subgs.end(); ++it)
- {
- fn(it->first, *it->second);
- }
+ });
+ assert(ret.valid());
+ return ret;
}
// TODO: Extract this into external helper function