Move Graph::insertPermute to PermutationInsertionPass (#3408)
authorДилшоджон Умронхонович Пошшоев/AI Tools Lab /SRR/Engineer/삼성전자 <d.poshshoev@samsung.com>
Wed, 31 Oct 2018 10:03:50 +0000 (13:03 +0300)
committer오형석/동작제어Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Wed, 31 Oct 2018 10:03:50 +0000 (19:03 +0900)
Since this method is used just in `PermutationInsertionPass` it's
reasonable to move into this class
Related issue is https://github.sec.samsung.net/STAR/nnfw/issues/3388

Signed-off-by: Poshshoev Dilshodzhon <d.poshshoev@samsung.com>
runtimes/neurun/src/codegen/BackendResolver.h
runtimes/neurun/src/graph/Graph.cc
runtimes/neurun/src/graph/Graph.h
runtimes/neurun/src/graph/pass/PermutationInsertionPass.cc
runtimes/neurun/src/graph/pass/PermutationInsertionPass.h

index e3191f4..d410690 100644 (file)
@@ -70,7 +70,7 @@ public:
 
 public:
   const backend::Backend *getBackend(const std::type_index &type) { return _gen_map[type]; }
-  const backend::Backend *getDefaultBackend()
+  const backend::Backend *getDefaultBackend() const
   {
     backend::Backend *default_backend = _backend_manager->get("cpu");
     return default_backend;
index dc6b604..ce51ebe 100644 (file)
@@ -47,87 +47,32 @@ operand::Index Graph::addOperand(const operand::Shape &shape, const operand::Typ
 
 operation::Index Graph::addOperation(std::unique_ptr<operation::Node> &&node)
 {
-  assert(_phase == Phase::BUILDING);
+  assert(isBuildingPhase());
   return _operations.append(std::move(node));
 }
 
-/**
- * @brief Insert Permute operation that has given operand as input
- *
- * @param operand_index is the target operand index for the insertion
- * @param backend is the output operand's backend type
- *
- * @return operation::Index
- */
-operation::Index Graph::insertPermute(const operand::Index &operand_index,
-                                      const backend::Backend *backend)
-{
-  assert(_phase == Phase::LOWERED);
-
-  auto &operand = _operands.at(operand_index);
-
-  // Generate output operand and permute operation
-  auto out_operand_index = addOperand(operand.shape(), operand.typeInfo());
-  auto &out_operand = _operands.at(out_operand_index);
-  out_operand.setAsOperationOutput();
-  // change model output if operand_index is model output index
-  auto &model_outputs = getOutputs();
-  if (model_outputs.contains(operand_index))
-  {
-    model_outputs.replace(operand_index, out_operand_index);
-  }
-  out_operand.setAsOperationOutput();
-  auto out_operand_li = nnfw::make_unique<operand::LowerInfo>(operand::asShape4D(operand.shape()));
-  out_operand_li->addDefBackend(backend);
-  out_operand_li->addUseBackend(backend);
-  out_operand.lower_info(std::move(out_operand_li));
-
-  // Update LowerInfo of input operand
-  operand.lower_info()->removeUseBackend(backend);
-  operand.lower_info()->addUseBackend(operand.lower_info()->def_backends().getOnlyElement());
-
-  // Insert permute operation to the graph
-  auto insert_node = nnfw::make_unique<operation::Permute::Node>(operand_index, out_operand_index);
-  insert_node->lower_info(
-      nnfw::make_unique<operation::LowerInfo>(_backend_resolver->getDefaultBackend()));
-
-  auto node_index = _operations.append(std::move(insert_node));
-  auto &node = _operations.at(node_index);
-
-  // Update Use/Def info
-  {
-    _operands.at(operand_index).appendUse(node_index);
-
-    auto node_out_indexes = node.getOutputs();
-    auto node_out_index = node_out_indexes.at(operand::IO::Index{0});
-    _operands.at(node_out_index).appendDef(node_index);
-  }
-
-  return node_index;
-}
-
 void Graph::setOperandValue(const operand::Index &ind, std::unique_ptr<operand::Data> &&data)
 {
-  assert(_phase == Phase::BUILDING);
+  assert(isBuildingPhase());
   assert(_operands.exist(ind));
   _operands.at(ind).data(std::move(data));
 }
 
 void Graph::addInput(const operand::Index &ind)
 {
-  assert(_phase == Phase::BUILDING);
+  assert(isBuildingPhase());
   _inputs.append(ind);
 }
 
 void Graph::addOutput(const operand::Index &ind)
 {
-  assert(_phase == Phase::BUILDING);
+  assert(isBuildingPhase());
   _outputs.append(ind);
 }
 
 void Graph::finishBuilding(void)
 {
-  assert(_phase == Phase::BUILDING);
+  assert(isBuildingPhase());
   _phase = Phase::MODEL;
 
   // Initialize operand use-def
@@ -273,7 +218,7 @@ void Graph::lower(void)
 
 std::unique_ptr<linear::Linear> Graph::linearize(void)
 {
-  assert(_phase == Phase::LOWERED);
+  assert(isLowered());
 
   auto linear = nnfw::make_unique<linear::Linear>(*this);
 
index 3632f7b..de7ee34 100644 (file)
@@ -102,8 +102,6 @@ public:
 public:
   operand::Index addOperand(const operand::Shape &shape, const operand::TypeInfo &type);
   operation::Index addOperation(std::unique_ptr<operation::Node> &&node);
-  operation::Index insertPermute(const operand::Index &operand_index,
-                                 const backend::Backend *backend);
   void setOperandValue(const operand::Index &ind, std::unique_ptr<operand::Data> &&data);
   void addInput(const operand::Index &ind);
   void addOutput(const operand::Index &ind);
@@ -111,6 +109,7 @@ public:
   void lower(void);
   std::unique_ptr<linear::Linear> linearize(void);
   bool isBuildingPhase(void) const { return _phase == Phase::BUILDING; }
+  bool isLowered(void) const { return _phase == Phase::LOWERED; }
 
 private:
   void initializeUseDef();
@@ -124,6 +123,7 @@ public:
   operand::Set &operands() { return _operands; } // TODO Remove this non-const accessor
   const operation::Set &operations() const { return _operations; }
   operation::Set &operations() { return _operations; }
+  const codegen::BackendResolver *backend_resolver() const { return _backend_resolver.get(); }
 
 private:
   Phase _phase{Phase::BUILDING};
index 17100e3..329f53d 100644 (file)
 #include "graph/Graph.h"
 #include "backend/interface/IConfig.h"
 #include "logging.h"
+#include "nnfw/std/memory.h"
+#include "graph/operation/Permute.h"
+#include "graph/operand/Shape4DConvert.h"
+#include "codegen/BackendResolver.h"
 
 namespace neurun
 {
@@ -59,7 +63,7 @@ void PermutationInsertionPass::callback(const operand::Index &index, operand::Ob
     auto insert_set = operand_li->use_backends() - operand_li->def_backends();
     for (auto backend : insert_set)
     {
-      const auto permute_operation_index = _graph.insertPermute(index, backend);
+      const auto permute_operation_index = insertPermute(index, backend);
       permute_indexes.push_back(permute_operation_index);
       VERBOSE(PermutationInsertionPass) << "Insert 'Permute' operation for operand "
                                         << index.value() << std::endl;
@@ -108,6 +112,51 @@ void PermutationInsertionPass::callback(const operand::Index &index, operand::Ob
   }
 }
 
+operation::Index PermutationInsertionPass::insertPermute(const operand::Index &operand_index,
+                                                         const backend::Backend *backend)
+{
+  assert(_graph.isLowered());
+
+  auto &operand = _graph.operands().at(operand_index);
+
+  // Generate output operand and permute operation
+  auto out_operand_index = _graph.addOperand(operand.shape(), operand.typeInfo());
+  auto &out_operand = _graph.operands().at(out_operand_index);
+  out_operand.setAsOperationOutput();
+  // change model output if operand_index is model output index
+  auto &model_outputs = _graph.getOutputs();
+  if (model_outputs.contains(operand_index))
+  {
+    model_outputs.replace(operand_index, out_operand_index);
+  }
+  out_operand.setAsOperationOutput();
+  auto out_operand_li = nnfw::make_unique<operand::LowerInfo>(operand::asShape4D(operand.shape()));
+  out_operand_li->addDefBackend(backend);
+  out_operand_li->addUseBackend(backend);
+  out_operand.lower_info(std::move(out_operand_li));
+
+  // Update LowerInfo of input operand
+  operand.lower_info()->removeUseBackend(backend);
+  operand.lower_info()->addUseBackend(operand.lower_info()->def_backends().getOnlyElement());
+
+  // Insert permute operation to the graph
+  auto insert_node = nnfw::make_unique<operation::Permute::Node>(operand_index, out_operand_index);
+  insert_node->lower_info(
+      nnfw::make_unique<operation::LowerInfo>(_graph.backend_resolver()->getDefaultBackend()));
+
+  auto node_index = _graph.operations().append(std::move(insert_node));
+  const auto &node = _graph.operations().at(node_index);
+
+  // Update Use/Def info
+  {
+    _graph.operands().at(operand_index).appendUse(node_index);
+
+    auto node_out_indexes = node.getOutputs();
+    auto node_out_index = node_out_indexes.at(operand::IO::Index{0});
+    _graph.operands().at(node_out_index).appendDef(node_index);
+  }
+  return node_index;
+}
 } // namespace pass
 } // namespace graph
 } // namespace neurun
index 1654f68..d68485d 100644 (file)
@@ -18,6 +18,7 @@
 #define __NEURUN_GRAPH_PASS_PERMUTATION_INSERTION_PASS_H__
 
 #include "OperandPass.h"
+#include "graph/operand/Object.h" //for operation::Index
 
 namespace neurun
 {
@@ -35,6 +36,17 @@ public:
   virtual std::string id() override { return "PermutationInsertionPass"; }
   virtual void callback(const operand::Index &index, operand::Object &object);
 
+  /**
+   * @brief Insert Permute operation that has given operand as input
+   *
+   * @param operand_index is the target operand index for the insertion
+   * @param backend is the output operand's backend type
+   *
+   * @return operation::Index
+   */
+  operation::Index insertPermute(const operand::Index &operand_index,
+                                 const backend::Backend *backend);
+
 private:
 };