Revise BackendResolver to be per operation (#5262)
author이한종/On-Device Lab(SR)/Engineer/삼성전자 <hanjoung.lee@samsung.com>
Mon, 27 May 2019 23:26:48 +0000 (08:26 +0900)
committer박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Mon, 27 May 2019 23:26:48 +0000 (08:26 +0900)
Revise BackendResolver to be per operation, not the type of operation

Part of #5115

Signed-off-by: Hanjoung Lee <hanjoung.lee@samsung.com>
runtimes/neurun/core/include/util/config/Config.lst
runtimes/neurun/core/src/compiler/BackendResolver.h
runtimes/neurun/core/src/graph/Graph.cc

index 3efa289..9254c6a 100644 (file)
@@ -21,7 +21,7 @@
 //     Name                    | Type         | Default
 CONFIG(GRAPH_DOT_DUMP          , int          , "0")
 CONFIG(BACKENDS                , std::string  , "cpu;acl_cl;acl_neon")
-CONFIG(OP_BACKEND_ALLOPS       , std::string  , "none")
+CONFIG(OP_BACKEND_ALLOPS       , std::string  , "acl_cl")
 CONFIG(DISABLE_COMPILE         , bool         , "0")
 CONFIG(NEURUN_LOG_ENABLE       , bool         , "0")
 CONFIG(CPU_MEMORY_PLANNER      , std::string  , "FirstFit")
@@ -32,7 +32,7 @@ CONFIG(ACL_DEFAULT_LAYOUT      , std::string  , "NHWC")
 // Auto-generate all operations
 
 #define OP(InternalName, IsNnApi) \
-    CONFIG(OP_BACKEND_ ## InternalName, std::string, "acl_cl")
+    CONFIG(OP_BACKEND_ ## InternalName, std::string, "")
 #include "model/Operations.lst"
 #undef OP
 
index bb0105a..3058a77 100644 (file)
@@ -34,40 +34,9 @@ namespace compiler
 class BackendResolver
 {
 public:
-  BackendResolver(const neurun::model::Operands &operands)
+  BackendResolver(const std::shared_ptr<backend::BackendManager> &backend_manager)
+      : _backend_manager{backend_manager}
   {
-    _backend_manager = std::make_shared<backend::BackendManager>(operands);
-
-    const auto backend_all_str =
-        config::ConfigManager::instance().get<std::string>(config::OP_BACKEND_ALLOPS);
-    if (backend_all_str.compare("none") != 0)
-    {
-      VERBOSE(BackendResolver) << "Use backend for all ops: " << backend_all_str << std::endl;
-#define OP(InternalName, IsNnApi)                                          \
-  if (IsNnApi)                                                             \
-  {                                                                        \
-    auto backend = _backend_manager->get(backend_all_str);                 \
-    _gen_map_deprecated[typeid(model::operation::InternalName)] = backend; \
-  }
-#include "model/Operations.lst"
-#undef OP
-    }
-    else
-    {
-#define OP(InternalName, IsNnApi)                                                              \
-  if (IsNnApi)                                                                                 \
-  {                                                                                            \
-    const auto &backend_str =                                                                  \
-        config::ConfigManager::instance().get<std::string>(config::OP_BACKEND_##InternalName); \
-    auto backend = _backend_manager->get(backend_str);                                         \
-    VERBOSE(BackendResolver) << "backend for " << #InternalName << ": " << backend_str         \
-                             << std::endl;                                                     \
-    _gen_map_deprecated[typeid(model::operation::InternalName)] = backend;                     \
-  }
-
-#include "model/Operations.lst"
-#undef OP
-    }
   }
 
 public:
@@ -91,6 +60,15 @@ public:
     _gen_map[index] = backend;
   }
 
+  void iterate(
+      const std::function<void(const model::OperationIndex &, const backend::Backend *)> &fn) const
+  {
+    for (const auto &e : _gen_map)
+    {
+      fn(e.first, e.second);
+    }
+  }
+
 private:
   // TODO: remove all usage of _gen_map_deprecated
   std::unordered_map<std::type_index, backend::Backend *> _gen_map_deprecated;
index 5ad84a7..6ce055c 100644 (file)
@@ -104,7 +104,59 @@ void Graph::lower(void)
           nnfw::cpp14::make_unique<operand::LowerInfo>(graph::operand::asShape4D(object.shape()));
     });
 
-    _backend_resolver = nnfw::cpp14::make_unique<compiler::BackendResolver>(_model->operands);
+    auto backend_manager = std::make_shared<backend::BackendManager>(_model->operands);
+    _backend_resolver = nnfw::cpp14::make_unique<compiler::BackendResolver>(backend_manager);
+
+    // BackendResolver building
+    // TODO When IScheduler interface is introduced, this building should be done in a derivative of
+    // IScheduler
+    {
+      // 1. Backend for All operations
+      const auto backend_all_str =
+          config::ConfigManager::instance().get<std::string>(config::OP_BACKEND_ALLOPS);
+      auto backend_all = backend_manager->get(backend_all_str);
+
+      VERBOSE(Lower) << "Use backend for all ops: " << backend_all_str << std::endl;
+
+      _model->operations.iterate([&](const model::OperationIndex &index, model::Operation &) {
+        _backend_resolver->setBackend(index, backend_all);
+      });
+
+      // 2. Backend per operation type
+      std::unordered_map<std::type_index, backend::Backend *> op_type_map;
+#define OP(InternalName, IsNnApi)                                                              \
+  if (IsNnApi)                                                                                 \
+  {                                                                                            \
+    const auto &backend_str =                                                                  \
+        config::ConfigManager::instance().get<std::string>(config::OP_BACKEND_##InternalName); \
+    if (!backend_str.empty())                                                                  \
+    {                                                                                          \
+      auto backend = backend_manager->get(backend_str);                                        \
+      VERBOSE(Lower) << "backend for " << #InternalName << ": " << backend_str << std::endl;   \
+      op_type_map[typeid(model::operation::InternalName)] = backend;                           \
+    }                                                                                          \
+  }
+#include "model/Operations.lst"
+#undef OP
+      _model->operations.iterate([&](const model::OperationIndex &index, model::Operation &) {
+        auto itr = op_type_map.find(typeid(index));
+        if (itr != op_type_map.end())
+        {
+          _backend_resolver->setBackend(index, itr->second);
+        }
+      });
+
+      // 3. Backend per operation
+      // TODO TBD
+
+      // Dump final assignment
+      _backend_resolver->iterate(
+          [&](const model::OperationIndex &index, const backend::Backend *backend) {
+            VERBOSE(Lower) << "backend for operation #" << index.value() << ": "
+                           << backend->config()->id() << std::endl;
+          });
+    }
+
     _lower_info_map = nnfw::cpp14::make_unique<LowerInfoMap>();
 
     // Are they mergeable?
@@ -118,7 +170,7 @@ void Graph::lower(void)
       // The same backend id?
       {
         const auto &subg_backend_id = getLowerInfo(subg_index)->backend()->config()->id();
-        const auto &node_backend_id = _backend_resolver->getBackend(typeid(node))->config()->id();
+        const auto &node_backend_id = _backend_resolver->getBackend(node_index)->config()->id();
         VERBOSE(Lower) << "SUBG#" << subg_index.value() << " { " << subg_backend_id << " } "
                        << " NODE#" << node_index.value() << " (" << node.getName() << ") { "
                        << node_backend_id << " }" << std::endl;
@@ -194,7 +246,7 @@ void Graph::lower(void)
     Graph::PostDfsConstIterator().iterate(
         *this, [&](const model::OperationIndex &node_index, const model::Operation &node) {
           // LowerInfo for in/output operands
-          auto backend = _backend_resolver->getBackend(typeid(node));
+          auto backend = _backend_resolver->getBackend(node_index);
           for (auto operand : node.getInputs())
           {
             auto &&lower_info = operands_lower_info.at(operand);
@@ -213,7 +265,7 @@ void Graph::lower(void)
 
             // Subgraph LowerInfo
             setLowerInfo(new_subg_index, nnfw::cpp14::make_unique<graph::operation::LowerInfo>(
-                                             _backend_resolver->getBackend(typeid(node))));
+                                             _backend_resolver->getBackend(node_index)));
 
             subg_index = new_subg_index;
             subg = &(_subg_ctx->at(new_subg_index));