Refactor SubgraphContext (#5658)
author이한종/On-Device Lab(SR)/Engineer/삼성전자 <hanjoung.lee@samsung.com>
Wed, 17 Jul 2019 05:13:52 +0000 (14:13 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Wed, 17 Jul 2019 05:13:52 +0000 (14:13 +0900)
- `SubgraphContext` inherits `ObjectManager`
- Rename method `append` to `emplace`
- Rename method `hasNode` to `containsOperation`
- Rename method `findNode` to `findOperation`

Signed-off-by: Hanjoung Lee <hanjoung.lee@samsung.com>
runtimes/neurun/core/include/model/SubgraphContext.h
runtimes/neurun/core/src/graph/Graph.cc
runtimes/neurun/core/src/graph/pass/PermutationEliminationPass.cc
runtimes/neurun/core/src/graph/pass/PermutationInsertionPass.cc
runtimes/neurun/core/src/model/SubgraphContext.cc

index 773f719..582bd19 100644 (file)
 #ifndef __NEURUN_MODEL_SUBGRAPH_CONTEXT_H__
 #define __NEURUN_MODEL_SUBGRAPH_CONTEXT_H__
 
-#include <memory>
-
-#include "Subgraph.h"
-#include "Index.h"
-
-#include <unordered_map>
+#include "model/Index.h"
+#include "model/Subgraph.h"
+#include "util/ObjectManager.h"
 
 namespace neurun
 {
 namespace model
 {
 
-class SubgraphContext
+/**
+ * @brief Class that manages Subgraph objects
+ */
+class SubgraphContext : public util::ObjectManager<SubgraphIndex, Subgraph>
 {
 public:
-  SubgraphContext() : _index_count(0) {}
-
-public:
-  SubgraphIndex append(std::unique_ptr<Subgraph> &&subg);
-  SubgraphIndex append(const OperationIndex &index, const Operation &node, Layout layout);
-  void remove(const SubgraphIndex &index) { _subgs.erase(index); };
-
-public:
-  const Subgraph &at(const SubgraphIndex &) const;
-  Subgraph &at(const SubgraphIndex &);
-  bool exist(const SubgraphIndex &) const;
-  bool hasNode(const OperationIndex &node_index) const;
-  SubgraphIndex findNode(const OperationIndex &node_index) const;
-
-  uint32_t size() const { return _subgs.size(); }
-
-  void iterate(const std::function<void(const SubgraphIndex &, const Subgraph &)> &fn) const;
-  void iterate(const std::function<void(const SubgraphIndex &, Subgraph &)> &fn);
-
-  // TODO: Extract this into external helper function
+  /**
+   * @brief Create an instance of Subgraph with given op and push it to objects
+   *
+   * @param[in] op_idx Operation index that is emplaced
+   * @param[in] op Operation that is emplaced
+   * @param[in] layout Subgraph's layout
+   * @return SubgraphIndex
+   */
+  SubgraphIndex emplace(const OperationIndex &op_index, const Operation &op, Layout layout);
+  /**
+   * @brief Check if an operation does exist in any subgraphs
+   *
+   * @param operation_index Operation index to find
+   * @return true If such operation exists in any subgraphs otherwise false
+   */
+  bool containsOperation(const OperationIndex &operation_index) const;
+  /**
+   * @brief Find an operation from all subgraphs
+   *
+   * @param operation_index Operation index to find
+   * @return SubgraphIndex Index of Subgraph that contains given operation index
+   */
+  SubgraphIndex findOperation(const OperationIndex &operation_index) const;
+  /**
+   * @brief Dump subgraphs
+   *
+   * @param msg Message that will be displayed
+   */
   void dump(const std::string &msg) const;
-
-  void clear() { _subgs.clear(); }
-
-private:
-  SubgraphIndex generateIndex();
-
-private:
-  std::unordered_map<SubgraphIndex, std::unique_ptr<Subgraph>> _subgs;
-  uint32_t _index_count;
 };
 
 } // namespace model
index 6dc7d32..bc31d74 100644 (file)
@@ -179,7 +179,7 @@ void Graph::lower(void)
 
     auto make_subgraph = [&](const model::OperationIndex &node_index, const model::Operation &node,
                              model::Layout layout) {
-      auto subg_index = _subg_ctx->append(node_index, node, layout);
+      auto subg_index = _subg_ctx->emplace(node_index, node, layout);
       auto &subg = _subg_ctx->at(subg_index);
       subg.setOutputs(node.getOutputs());
       subg.setInputs(node.getInputs());
index 8760bde..3ad7324 100644 (file)
@@ -79,8 +79,8 @@ void PermutationEliminationPass::eliminateInput(const model::OperandIndex &inp_i
     _graph.removeOperand(inp_index);
 
     // remove permutation operation
-    assert(_graph.subg_ctx().hasNode(input_use));
-    auto subg_idx = _graph.subg_ctx().findNode(input_use);
+    assert(_graph.subg_ctx().containsOperation(input_use));
+    auto subg_idx = _graph.subg_ctx().findOperation(input_use);
     _graph.subg_ctx().remove(subg_idx);
     _graph.operations().remove(input_use);
 
@@ -134,8 +134,8 @@ void PermutationEliminationPass::eliminateOutput(const model::OperandIndex &out_
     _graph.removeOperand(out_index);
 
     // remove permutation operation
-    assert(_graph.subg_ctx().hasNode(output_def));
-    auto subg_idx = _graph.subg_ctx().findNode(output_def);
+    assert(_graph.subg_ctx().containsOperation(output_def));
+    auto subg_idx = _graph.subg_ctx().findOperation(output_def);
     _graph.subg_ctx().remove(subg_idx);
     _graph.operations().remove(output_def);
 
index 9d64609..2f65dfa 100644 (file)
@@ -85,8 +85,8 @@ void PermutationInsertionPass::callback(const model::OperandIndex &index, model:
         continue;
 
       auto &operation = _graph.operations().at(use);
-      assert(_graph.subg_ctx().hasNode(use));
-      auto subg_index = _graph.subg_ctx().findNode(use);
+      assert(_graph.subg_ctx().containsOperation(use));
+      auto subg_index = _graph.subg_ctx().findOperation(use);
       auto subg_li = _graph.getLowerInfo(subg_index);
       assert(subg_li);
       const auto subg_layout = _graph.subg_ctx().at(subg_index).getLayout();
@@ -174,7 +174,7 @@ PermutationInsertionPass::insertPermute(const model::OperandIndex &operand_index
 
   // Subgraph
   {
-    auto subg_index = _graph.subg_ctx().append(node_index, node, permute_node_layout);
+    auto subg_index = _graph.subg_ctx().emplace(node_index, node, permute_node_layout);
     auto &subg = _graph.subg_ctx().at(subg_index);
     subg.setInputs(node.getInputs());
     subg.setOutputs(node.getOutputs());
index 79b01dd..12842e9 100644 (file)
 
 #include <cassert>
 #include <string>
-#include <stdint.h>
 
 namespace neurun
 {
 namespace model
 {
 
-SubgraphIndex SubgraphContext::generateIndex()
-{
-  assert((_index_count) <= UINT32_MAX);
-
-  return SubgraphIndex{_index_count++};
-}
-
-SubgraphIndex SubgraphContext::append(std::unique_ptr<Subgraph> &&subgraph)
-{
-  auto index = generateIndex();
-  _subgs[index] = std::move(subgraph);
-  return index;
-}
-
-SubgraphIndex SubgraphContext::append(const OperationIndex &index, const Operation &node,
-                                      Layout layout)
+SubgraphIndex SubgraphContext::emplace(const OperationIndex &index, const Operation &node,
+                                       Layout layout)
 {
   std::unique_ptr<Subgraph> subg = nnfw::cpp14::make_unique<model::Subgraph>(layout);
   subg->appendOperation(index, node);
-  return append(std::move(subg));
-}
-
-const Subgraph &SubgraphContext::at(const SubgraphIndex &index) const
-{
-  return *(_subgs.at(index));
-}
-
-Subgraph &SubgraphContext::at(const SubgraphIndex &index) { return *(_subgs.at(index)); }
-
-bool SubgraphContext::exist(const SubgraphIndex &index) const
-{
-  return _subgs.find(index) != _subgs.end();
+  return push(std::move(subg));
 }
 
-bool SubgraphContext::hasNode(const OperationIndex &node_index) const
+bool SubgraphContext::containsOperation(const OperationIndex &operation_index) const
 {
-  for (auto it = _subgs.begin(); it != _subgs.end(); ++it)
-  {
-    const auto &subg = *it->second;
-    for (const auto &elem : subg.operations())
+  bool ret = false;
+  iterate([&](const SubgraphIndex &, const Subgraph &object) {
+    for (const auto &elem : object.operations())
     {
-      if (elem.index == node_index)
-        return true;
+      if (elem.index == operation_index)
+        ret = true;
     }
-  }
-
-  return false;
+  });
+  return ret;
 }
 
-SubgraphIndex SubgraphContext::findNode(const OperationIndex &node_index) const
+SubgraphIndex SubgraphContext::findOperation(const OperationIndex &operation_index) const
 {
-  for (auto it = _subgs.begin(); it != _subgs.end(); ++it)
-  {
-    auto &subg_index = it->first;
-    const auto &subg = *it->second;
-    for (const auto &elem : subg.operations())
+  SubgraphIndex ret;
+  iterate([&](const SubgraphIndex &index, const Subgraph &object) {
+    for (const auto &elem : object.operations())
     {
-      if (elem.index == node_index)
-      {
-        return subg_index;
-      }
+      if (elem.index == operation_index)
+        ret = index;
     }
-  }
-
-  assert(true && "DO NOT ENTER HERE");
-  return SubgraphIndex(UINT32_MAX); // CAN'T ENTER HERE
-}
-
-void SubgraphContext::iterate(
-    const std::function<void(const SubgraphIndex &, const Subgraph &)> &fn) const
-{
-  for (auto it = _subgs.begin(); it != _subgs.end(); ++it)
-  {
-    fn(it->first, *it->second);
-  }
-}
-
-void SubgraphContext::iterate(const std::function<void(const SubgraphIndex &, Subgraph &)> &fn)
-{
-  for (auto it = _subgs.begin(); it != _subgs.end(); ++it)
-  {
-    fn(it->first, *it->second);
-  }
+  });
+  assert(ret.valid());
+  return ret;
 }
 
 // TODO: Extract this into external helper function