Add layout as a condition for merging subgraphs (#6216)
author장지섭/On-Device Lab(SR)/Engineer/삼성전자 <jiseob.jang@samsung.com>
Mon, 5 Aug 2019 11:12:18 +0000 (20:12 +0900)
committer이한종/On-Device Lab(SR)/Engineer/삼성전자 <hanjoung.lee@samsung.com>
Mon, 5 Aug 2019 11:12:18 +0000 (20:12 +0900)
This commit adds layout as a condition for merging subgraph.

Signed-off-by: jiseob.jang <jiseob.jang@samsung.com>
runtimes/neurun/core/include/model/Layout.h
runtimes/neurun/core/src/graph/Graph.cc

index ef9bd1c..db46f42 100644 (file)
@@ -18,6 +18,7 @@
 #define __NEURUN_MODEL_LAYOUT_H__
 
 #include <functional>
+#include <string>
 
 namespace neurun
 {
@@ -31,6 +32,21 @@ enum class Layout
   NCHW
 };
 
+inline std::string to_string(model::Layout layout)
+{
+  switch (layout)
+  {
+    case Layout::NHWC:
+      return std::string{"NHWC"};
+    case model::Layout::NCHW:
+      return std::string{"NCHW"};
+    case model::Layout::UNKNOWN:
+      return std::string{"UNKNOWN"};
+    default:
+      throw std::runtime_error("WRONG LAYOUT");
+  }
+}
+
 } // namespace model
 } // namespace neurun
 
index f5ff189..3e8506d 100644 (file)
@@ -112,18 +112,20 @@ void Graph::lower(void)
     // 1. the same backend id?
     // 2. if 1 is true, the subg and a node are connected?
     auto mergeable = [&](const model::SubgraphIndex &subg_index,
-                         const model::OperationIndex &node_index) {
+                         const model::OperationIndex &node_index, model::Layout layout) {
       const auto &subg = _subgraphs->at(subg_index);
       const auto &node = _model->operations.at(node_index);
 
       // The same backend id?
       {
+        auto subg_layout = subg.getLayout();
         const auto &subg_backend_id = getLowerInfo(subg_index)->backend()->config()->id();
         const auto &node_backend_id = _backend_resolver->getBackend(node_index)->config()->id();
-        VERBOSE(Lower) << "SUBG#" << subg_index.value() << " { " << subg_backend_id << " } "
+        VERBOSE(Lower) << "SUBG#" << subg_index.value() << " { " << subg_backend_id << "("
+                       << model::to_string(subg_layout) << ") } "
                        << " NODE#" << node_index.value() << " (" << node.getName() << ") { "
-                       << node_backend_id << " }" << std::endl;
-        if (subg_backend_id != node_backend_id)
+                       << node_backend_id << "(" << model::to_string(layout) << ") } " << std::endl;
+        if (subg_backend_id != node_backend_id || subg_layout != layout)
           return false;
       }
 
@@ -233,7 +235,7 @@ void Graph::lower(void)
           }
           /*for profiling each subgraph must contain just one node,
             so that we can measure a node separately*/
-          if (!subg || is_profiling || !mergeable(subg_index, node_index))
+          if (!subg || is_profiling || !mergeable(subg_index, node_index, layout))
           {
             // TODO Determines how to set the layout of the subgraph
             auto new_subg_index = append_fresh_single_op_subgraph(node_index, node, layout);
@@ -369,27 +371,12 @@ void Graph::lower(void)
       if (!getLowerInfo(index)->def_factors().empty() ||
           !getLowerInfo(index)->use_factors().empty())
       {
-        auto layout_to_string = [](const model::Layout &layout) {
-          if (layout == model::Layout::NHWC)
-          {
-            return std::string{"NHWC"};
-          }
-          else if (layout == model::Layout::NCHW)
-          {
-            return std::string{"NCHW"};
-          }
-          else if (layout == model::Layout::UNKNOWN)
-          {
-            return std::string{"UNKNOWN"};
-          }
-          return std::string{""};
-        };
-        auto factors_to_string = [&layout_to_string](const operand::PermuteFactorSet &factors) {
+        auto factors_to_string = [](const operand::PermuteFactorSet &factors) {
           std::string str;
           for (auto factor : factors)
           {
             str += factor.backend()->config()->id();
-            str += "(" + layout_to_string(factor.layout()) + ")";
+            str += "(" + model::to_string(factor.layout()) + ")";
             str += " ";
           }
           return "{ " + str + "}";