Make to distinguish frontend layout from backend layout in Graph (#6253)
author장지섭/On-Device Lab(SR)/Engineer/삼성전자 <jiseob.jang@samsung.com>
Tue, 6 Aug 2019 05:34:15 +0000 (14:34 +0900)
committer이한종/On-Device Lab(SR)/Engineer/삼성전자 <hanjoung.lee@samsung.com>
Tue, 6 Aug 2019 05:34:15 +0000 (14:34 +0900)
This commit makes to distinguish frontend layout fron backend layout in Gragh.

Signed-off-by: jiseob.jang <jiseob.jang@samsung.com>
runtimes/neurun/core/src/graph/Graph.cc

index 3e8506d..207b909 100644 (file)
@@ -109,23 +109,23 @@ void Graph::lower(void)
     _lower_info_map = nnfw::cpp14::make_unique<LowerInfoMap>();
 
     // Are they mergeable?
-    // 1. the same backend id?
+    // 1. the same backend id and layout?
     // 2. if 1 is true, the subg and a node are connected?
     auto mergeable = [&](const model::SubgraphIndex &subg_index,
                          const model::OperationIndex &node_index, model::Layout layout) {
       const auto &subg = _subgraphs->at(subg_index);
       const auto &node = _model->operations.at(node_index);
 
-      // The same backend id?
+      // The same backend id and layout?
       {
-        auto subg_layout = subg.getLayout();
+        const auto subg_backend_layout = getLowerInfo(subg_index)->layout();
         const auto &subg_backend_id = getLowerInfo(subg_index)->backend()->config()->id();
         const auto &node_backend_id = _backend_resolver->getBackend(node_index)->config()->id();
         VERBOSE(Lower) << "SUBG#" << subg_index.value() << " { " << subg_backend_id << "("
-                       << model::to_string(subg_layout) << ") } "
+                       << model::to_string(subg_backend_layout) << ") } "
                        << " NODE#" << node_index.value() << " (" << node.getName() << ") { "
                        << node_backend_id << "(" << model::to_string(layout) << ") } " << std::endl;
-        if (subg_backend_id != node_backend_id || subg_layout != layout)
+        if (subg_backend_id != node_backend_id || subg_backend_layout != layout)
           return false;
       }
 
@@ -201,114 +201,114 @@ void Graph::lower(void)
     // Make subgraphs while checking whether a node can be merged into a subgraph.
     // NOTE: The below method appends nodes while making one subgraph if needed. If something better
     // ways, happy to update this code.
-    Graph::PostDfsConstIterator().iterate(
-        *this, [&](const model::OperationIndex &node_index, const model::Operation &node) {
-          // LowerInfo for in/output operands
-          auto backend = _backend_resolver->getBackend(node_index);
-          // TODO Set layout of this node
-          auto layout = model::Layout::NHWC;
-          const std::string acl_layout_str = util::getConfigString(util::config::ACL_LAYOUT);
-          if (acl_layout_str == "NHWC")
-          {
-            layout = model::Layout::NHWC;
-          }
-          else if (acl_layout_str == "NCHW")
-          {
-            layout = model::Layout::NCHW;
-          }
+    Graph::PostDfsConstIterator().iterate(*this, [&](const model::OperationIndex &node_index,
+                                                     const model::Operation &node) {
+      // LowerInfo for in/output operands
+      auto backend = _backend_resolver->getBackend(node_index);
+      // TODO How to get layout of this node from IR
+      auto frontend_layout = model::Layout::NHWC;
+      auto backend_layout = frontend_layout;
+      const std::string acl_layout_str = util::getConfigString(util::config::ACL_LAYOUT);
+      if (acl_layout_str == "NHWC")
+      {
+        backend_layout = model::Layout::NHWC;
+      }
+      else if (acl_layout_str == "NCHW")
+      {
+        backend_layout = model::Layout::NCHW;
+      }
 
-          // CPU supports only NHWC now
-          if (backend->config()->id() == "cpu")
-          {
-            layout = model::Layout::NHWC;
-          }
+      // CPU supports only NHWC now
+      if (backend->config()->id() == "cpu")
+      {
+        backend_layout = model::Layout::NHWC;
+      }
 
-          for (auto operand : node.getInputs())
-          {
-            auto &&lower_info = operands_lower_info.at(operand);
-            lower_info->addUsePermuteFactor(operand::PermuteFactor{backend, layout});
-          }
-          for (auto operand : node.getOutputs())
-          {
-            auto &&lower_info = operands_lower_info.at(operand);
-            lower_info->addDefPermuteFactor(operand::PermuteFactor{backend, layout});
-          }
-          /*for profiling each subgraph must contain just one node,
-            so that we can measure a node separately*/
-          if (!subg || is_profiling || !mergeable(subg_index, node_index, layout))
-          {
-            // TODO Determines how to set the layout of the subgraph
-            auto new_subg_index = append_fresh_single_op_subgraph(node_index, node, layout);
+      for (auto operand : node.getInputs())
+      {
+        auto &&lower_info = operands_lower_info.at(operand);
+        lower_info->addUsePermuteFactor(operand::PermuteFactor{backend, backend_layout});
+      }
+      for (auto operand : node.getOutputs())
+      {
+        auto &&lower_info = operands_lower_info.at(operand);
+        lower_info->addDefPermuteFactor(operand::PermuteFactor{backend, backend_layout});
+      }
+      /*for profiling each subgraph must contain just one node,
+        so that we can measure a node separately*/
+      if (!subg || is_profiling || !mergeable(subg_index, node_index, backend_layout))
+      {
+        auto new_subg_index = append_fresh_single_op_subgraph(node_index, node, frontend_layout);
 
-            // Subgraph LowerInfo
-            setLowerInfo(new_subg_index,
-                         nnfw::cpp14::make_unique<graph::operation::LowerInfo>(backend, layout));
+        // Subgraph LowerInfo
+        setLowerInfo(new_subg_index, nnfw::cpp14::make_unique<graph::operation::LowerInfo>(
+                                         backend, backend_layout));
 
-            subg_index = new_subg_index;
-            subg = &(_subgraphs->at(new_subg_index));
+        subg_index = new_subg_index;
+        subg = &(_subgraphs->at(new_subg_index));
 
-            VERBOSE(Lower) << "SUBG#" << subg_index.value() << " is created for "
-                           << "NODE#" << node_index.value() << "(" << node.getName() << ")"
-                           << std::endl;
-          }
-          else
-          {
-            subg->appendOperation(node_index, node);
-            subg->setInputs(node.getInputs());
+        VERBOSE(Lower) << "SUBG#" << subg_index.value() << " is created for "
+                       << "NODE#" << node_index.value() << "(" << node.getName() << ")"
+                       << std::endl;
+      }
+      else
+      {
+        subg->appendOperation(node_index, node);
+        subg->setInputs(node.getInputs());
 
-            VERBOSE(Lower) << "SUBG#" << subg_index.value() << " merges "
-                           << "NODE#" << node_index.value() << "(" << node.getName() << ")"
-                           << std::endl;
-          }
+        VERBOSE(Lower) << "SUBG#" << subg_index.value() << " merges "
+                       << "NODE#" << node_index.value() << "(" << node.getName() << ")"
+                       << std::endl;
+      }
 
-          bool finish = false;
+      bool finish = false;
+      {
+        size_t prev_op_cnt = 0;
+        for (auto input : node.getInputs())
+        {
+          // only valid_inputs
+          const auto &operand = _model->operands.at(input);
+          if (operand.isConstant())
+            continue;
+
+          // This operand is input of operation, not weight or bias
+          if (operand.getDef().list().size() > 0)
+            ++prev_op_cnt;
+
+          // Test the node is Concat or BeginningBranch
+          // About (1)isConcat and (2)isBeginningBranch
+          //   (1) Current node has multiple inputs as concat?
+          //     - Does current node have two or more than previous operation?
+          //
+          //        [CONV] [CONV] [CONV]  [MAX_POOL]
+          //         |      |      |       |
+          //        [0]    [1]    [2]     [3]
+          //         \      |      |      /
+          //          [    C O N C A T   ]  # current node
+          //
+          //   (2) Current node is on the separated branch at the beginning?
+          //     - Does current node's input operand's uses have two or more than?
+          //
+          //       [CONV]
+          //         |
+          //        [0]----.
+          //         |     |
+          //       [CONV] [CONV]  # current node
+          //         |      |
+          //        [1]    [2]
+          //         \      /
+          //         [CONCAT]
+          if (prev_op_cnt > 1 || operand.getUses().list().size() > 1)
           {
-            size_t prev_op_cnt = 0;
-            for (auto input : node.getInputs())
-            {
-              // only valid_inputs
-              const auto &operand = _model->operands.at(input);
-              if (operand.isConstant())
-                continue;
-
-              // This operand is input of operation, not weight or bias
-              if (operand.getDef().list().size() > 0)
-                ++prev_op_cnt;
-
-              // Test the node is Concat or BeginningBranch
-              // About (1)isConcat and (2)isBeginningBranch
-              //   (1) Current node has multiple inputs as concat?
-              //     - Does current node have two or more than previous operation?
-              //
-              //        [CONV] [CONV] [CONV]  [MAX_POOL]
-              //         |      |      |       |
-              //        [0]    [1]    [2]     [3]
-              //         \      |      |      /
-              //          [    C O N C A T   ]  # current node
-              //
-              //   (2) Current node is on the separated branch at the beginning?
-              //     - Does current node's input operand's uses have two or more than?
-              //
-              //       [CONV]
-              //         |
-              //        [0]----.
-              //         |     |
-              //       [CONV] [CONV]  # current node
-              //         |      |
-              //        [1]    [2]
-              //         \      /
-              //         [CONCAT]
-              if (prev_op_cnt > 1 || operand.getUses().list().size() > 1)
-              {
-                finish = true;
-                break;
-              }
-            }
+            finish = true;
+            break;
           }
+        }
+      }
 
-          if (finish)
-            subg = nullptr;
-        });
+      if (finish)
+        subg = nullptr;
+    });
 
     _subgraphs->iterate([&](const model::SubgraphIndex &, model::Subgraph &subg) {
       assert(subg.operations().size() > 0);