Replace BackendSet with PermuteFactorSet (#5307)
author장지섭/On-Device Lab(SR)/Engineer/삼성전자 <jiseob.jang@samsung.com>
Thu, 30 May 2019 03:47:03 +0000 (12:47 +0900)
committer오형석/On-Device Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Thu, 30 May 2019 03:47:03 +0000 (12:47 +0900)
This commit replaces BackendSet with PermuteFactorSet.
  - replace BackendSet with PermuteFactorSet in operand::LowerInfo
  - Add the layout as a factor separating sub-graphs

Signed-off-by: jiseob.jang <jiseob.jang@samsung.com>
runtimes/neurun/core/include/graph/operand/LowerInfo.h
runtimes/neurun/core/src/compiler/ExecutorFactory.cc
runtimes/neurun/core/src/dumper/dot/DotOperandInfo.cc
runtimes/neurun/core/src/exec/ExecutorBase.cc
runtimes/neurun/core/src/exec/ExecutorBase.h
runtimes/neurun/core/src/graph/Graph.cc
runtimes/neurun/core/src/graph/pass/PermutationEliminationPass.cc
runtimes/neurun/core/src/graph/pass/PermutationInsertionPass.cc
runtimes/neurun/core/src/graph/pass/PermutationInsertionPass.h
runtimes/neurun/core/src/linear/Linear.cc

index 0b5b703..0c2e7e8 100644 (file)
@@ -20,7 +20,6 @@
 #include <functional>
 #include <stdint.h>
 
-#include "graph/BackendSet.h"
 #include "graph/operand/Layout.h"
 #include "graph/operand/PermuteFactor.h"
 #include "util/Set.h"
@@ -39,6 +38,7 @@ namespace graph
 {
 namespace operand
 {
+using PermuteFactorSet = util::Set<PermuteFactor>;
 
 class LowerInfo
 {
@@ -71,22 +71,11 @@ public:
   }
 
 public:
-  using PermuteFactorSet = util::Set<PermuteFactor>;
-
-public:
   const Shape4D &shape(void) const { return _shape; }
-  const BackendSet &def_backends(void) const { return _def_backends; }
-  const BackendSet &use_backends(void) const { return _use_backends; }
   const PermuteFactorSet &def_factors(void) const { return _def_factors; }
   const PermuteFactorSet &use_factors(void) const { return _use_factors; }
 
 public:
-  void addDefBackend(const backend::Backend *backend) { _def_backends.add(backend); }
-  void addUseBackend(const backend::Backend *backend) { _use_backends.add(backend); }
-  void removeDefBackend(const backend::Backend *backend) { _def_backends.remove(backend); }
-  void removeUseBackend(const backend::Backend *backend) { _use_backends.remove(backend); }
-
-public:
   void setLayout(const Layout &layout) { _layout = layout; }
   Layout layout() const { return _layout; }
 
@@ -98,8 +87,6 @@ public:
 
 private:
   Shape4D _shape;
-  BackendSet _def_backends;
-  BackendSet _use_backends;
   Layout _layout{Layout::NHWC};
   PermuteFactorSet _def_factors;
   PermuteFactorSet _use_factors;
index d75c4d5..882e767 100644 (file)
@@ -128,9 +128,9 @@ exec::IExecutor *ExecutorFactory::createDataflowExecutor(graph::Graph &graph)
 
   graph.operands().iterate([&](const model::OperandIndex &ind, const model::Operand &obj) {
     const auto lower_info = graph.getLowerInfo(ind);
-    for (auto backend : lower_info->def_backends())
+    for (auto factor : lower_info->def_factors())
     {
-      auto tensor_builder = backend->tensor_builder();
+      auto tensor_builder = factor.backend()->tensor_builder();
       const auto info = obj.info();
       const auto layout = lower_info->layout();
       tensor_builder->registerTensorInfo(ind, info, layout);
@@ -217,9 +217,9 @@ exec::IExecutor *ExecutorFactory::createParallelExecutor(graph::Graph &graph)
 
   graph.operands().iterate([&](const model::OperandIndex &ind, const model::Operand &obj) {
     const auto lower_info = graph.getLowerInfo(ind);
-    for (auto backend : lower_info->def_backends())
+    for (auto factor : lower_info->def_factors())
     {
-      auto tensor_builder = backend->tensor_builder();
+      auto tensor_builder = factor.backend()->tensor_builder();
       const auto info = obj.info();
       const auto layout = lower_info->layout();
       tensor_builder->registerTensorInfo(ind, info, layout);
index b14fdad..069efe2 100644 (file)
@@ -91,10 +91,10 @@ std::string DotOperandInfo::bg_color() const
   if (!_lower_info)
     return DEFAULT_BG_COLOR;
 
-  const auto &def_backends = _lower_info->def_backends();
-  assert(def_backends.size() == 1);
+  const auto &def_factors = _lower_info->def_factors();
+  assert(def_factors.size() == 1);
 
-  std::string backend_id = def_backends.getOnlyElement()->config()->id();
+  std::string backend_id = def_factors.getOnlyElement().backend()->config()->id();
   // TODO : This is just workaround it can be made more efficient.
   if (backend_id == "acl_cl")
   {
@@ -116,11 +116,11 @@ void DotOperandInfo::addBackendLabel()
     return;
 
   std::string label;
-  const auto &def_backends = _lower_info->def_backends();
-  assert(def_backends.size() == 1);
+  const auto &def_factors = _lower_info->def_factors();
+  assert(def_factors.size() == 1);
 
   label += "[";
-  label += def_backends.getOnlyElement()->config()->id();
+  label += def_factors.getOnlyElement().backend()->config()->id();
   label += "]";
   _labels.emplace_back(label);
 }
index 7481419..99407bc 100644 (file)
@@ -145,7 +145,7 @@ void ExecutorBase::execute()
 
     ::neurun::model::OperandIndex index{_model->inputs.at(input_index)};
     const auto operand_li = _lower_info->operand.at(index).get();
-    if (operand_li->def_backends().empty())
+    if (operand_li->def_factors().empty())
     {
       // This input is not used (i.e. constant, EX. reshape's axis)
       continue;
index 9ceb58c..796c231 100644 (file)
@@ -66,7 +66,7 @@ private:
     const auto &operand = _model->operands.at(operand_index);
     const auto operand_li = _lower_info->operand.at(operand_index).get();
 
-    if (operand_li->def_backends().empty())
+    if (operand_li->def_factors().empty())
     {
       // This input is not used (i.e. constant, EX. reshape's axis)
       return;
index fe7bd3e..afa9827 100644 (file)
@@ -25,6 +25,7 @@
 #include "linear/Linear.h"
 #include "graph/operation/LowerInfo.h"
 #include "graph/operand/LowerInfo.h"
+#include "graph/operand/PermuteFactor.h"
 #include "operand/Shape4DConvert.h"
 #include "compiler/BackendResolver.h"
 #include "compiler/IScheduler.h"
@@ -203,25 +204,53 @@ void Graph::lower(void)
         *this, [&](const model::OperationIndex &node_index, const model::Operation &node) {
           // LowerInfo for in/output operands
           auto backend = _backend_resolver->getBackend(node_index);
+          // TODO Set layout of this node
+          auto layout = graph::operand::Layout::NHWC;
+          // TODO Change ACL_DEFAULT_LAYOUT configuration to setting the layout forcibly if it
+          // exists
+          const std::string layout_str =
+              config::ConfigManager::instance().get<std::string>(config::ACL_DEFAULT_LAYOUT);
+          if (layout_str == "NHWC")
+          {
+            layout = graph::operand::Layout::NHWC;
+          }
+          else if (layout_str == "NCHW")
+          {
+            layout = graph::operand::Layout::NCHW;
+          }
+          else
+          {
+            throw std::runtime_error("Invalid ACL_DEFAULT_LAYOUT settings");
+          }
+
+          // CPU supports only NHWC now
+          if (backend->config()->id() == "cpu")
+          {
+            layout = graph::operand::Layout::NHWC;
+          }
+
           for (auto operand : node.getInputs())
           {
             auto &&lower_info = operands_lower_info.at(operand);
-            lower_info->addUseBackend(backend);
+            lower_info->addUsePermuteFactor(operand::PermuteFactor{backend, layout});
+            lower_info->setLayout(layout);
           }
           for (auto operand : node.getOutputs())
           {
             auto &&lower_info = operands_lower_info.at(operand);
-            lower_info->addDefBackend(backend);
+            lower_info->addDefPermuteFactor(operand::PermuteFactor{backend, layout});
+            // TODO Remove this
+            lower_info->setLayout(layout);
           }
 
           if (!subg || !mergeable(subg_index, node_index))
           {
             // TODO Determines how to set the layout of the subgraph
-            auto new_subg_index = make_subgraph(node_index, node, operand::Layout::NHWC);
+            auto new_subg_index = make_subgraph(node_index, node, layout);
 
             // Subgraph LowerInfo
-            setLowerInfo(new_subg_index, nnfw::cpp14::make_unique<graph::operation::LowerInfo>(
-                                             _backend_resolver->getBackend(node_index)));
+            setLowerInfo(new_subg_index,
+                         nnfw::cpp14::make_unique<graph::operation::LowerInfo>(backend));
 
             subg_index = new_subg_index;
             subg = &(_subg_ctx->at(new_subg_index));
@@ -289,53 +318,9 @@ void Graph::lower(void)
             subg = nullptr;
         });
 
-    _subg_ctx->iterate([&](const model::SubgraphIndex &ind, model::Subgraph &subg) {
+    _subg_ctx->iterate([&](const model::SubgraphIndex &, model::Subgraph &subg) {
       assert(subg.operations().size() > 0);
       std::reverse(std::begin(subg.operations()), std::end(subg.operations()));
-
-      auto layout = subg.getLayout();
-      // TODO Change ACL_DEFAULT_LAYOUT configuration to setting the layout forcibly if it exists
-      const std::string layout_str =
-          config::ConfigManager::instance().get<std::string>(config::ACL_DEFAULT_LAYOUT);
-      if (layout_str == "NHWC")
-      {
-        layout = graph::operand::Layout::NHWC;
-      }
-      else if (layout_str == "NCHW")
-      {
-        layout = graph::operand::Layout::NCHW;
-      }
-      else
-      {
-        throw std::runtime_error("Invalid ACL_DEFAULT_LAYOUT settings");
-      }
-
-      // CPU supports only NHWC now
-      if (getLowerInfo(ind)->backend()->config()->id() == "cpu")
-      {
-        layout = graph::operand::Layout::NHWC;
-      }
-
-      // TODO Remove This workarounds
-      // This implementation is a workaround
-      // The unit setting layout can be replaced with operation or subgraph
-      const auto operations = subg.operations();
-      for (const auto operation : operations)
-      {
-        const auto inputs = operation.node->getInputs();
-        for (auto it = inputs.begin(); it != inputs.end(); ++it)
-        {
-          // This is a workaround
-          // The unit setting layout can be replaced with `def` and `use` such as backend
-          operands_lower_info.at(*it)->setLayout(layout);
-        }
-
-        const auto outputs = operation.node->getOutputs();
-        for (auto it = outputs.begin(); it != outputs.end(); ++it)
-        {
-          operands_lower_info.at(*it)->setLayout(layout);
-        }
-      }
     });
 
     _subg_ctx->dump("merged and sorted operations without permutation");
@@ -357,29 +342,29 @@ void Graph::lower(void)
     }
 #endif
 
-    // Add DefBackend constants same as UseBackend
+    // Add DefFactor constants same as UseFactor
     // NOTE This assumes a constant operand is used by only one operation
     _model->operations.iterate([&](const model::OperationIndex &, model::Operation &node) {
       // LowerInfo for input operands
       for (auto operand : node.getInputs())
       {
         auto &&lower_info = operands_lower_info.at(operand);
-        if (lower_info->def_backends().empty())
+        if (lower_info->def_factors().empty())
         {
           // NOTE Handling model inputs here is not ideal. See above NOTE comment.
           // If it is a model input, not a constant
           if (_model->inputs.contains(operand))
           {
-            // If one or more elements then any backend is OK so pick first one
-            if (!lower_info->use_backends().empty())
+            // If one or more elements then any PermuteFactor is OK so pick first one
+            if (!lower_info->use_factors().empty())
             {
-              lower_info->addDefBackend(*lower_info->use_backends().begin());
+              lower_info->addDefPermuteFactor(*lower_info->use_factors().begin());
             }
           }
           // If it is a constant
           else
           {
-            lower_info->addDefBackend(lower_info->use_backends().getOnlyElement());
+            lower_info->addDefPermuteFactor(lower_info->use_factors().getOnlyElement());
           }
         }
       }
@@ -391,14 +376,30 @@ void Graph::lower(void)
 
       // Dump operand LowerInfo
       // TODO Extract this dumping procedure to be reusable
-      if (!getLowerInfo(index)->def_backends().empty() ||
-          !getLowerInfo(index)->use_backends().empty())
+      if (!getLowerInfo(index)->def_factors().empty() ||
+          !getLowerInfo(index)->use_factors().empty())
       {
-        auto backends_to_string = [](const BackendSet &backends) {
+        auto layout_to_string = [](const operand::Layout &layout) {
+          if (layout == operand::Layout::NHWC)
+          {
+            return std::string{"NHWC"};
+          }
+          else if (layout == operand::Layout::NCHW)
+          {
+            return std::string{"NCHW"};
+          }
+          else if (layout == operand::Layout::UNKNOWN)
+          {
+            return std::string{"UNKNOWN"};
+          }
+          return std::string{""};
+        };
+        auto factors_to_string = [&layout_to_string](const operand::PermuteFactorSet &factors) {
           std::string str;
-          for (auto backend : backends)
+          for (auto factor : factors)
           {
-            str += backend->config()->id();
+            str += factor.backend()->config()->id();
+            str += "(" + layout_to_string(factor.layout()) + ")";
             str += " ";
           }
           return "{ " + str + "}";
@@ -419,8 +420,8 @@ void Graph::lower(void)
         const auto &lower_shape = lower_info->shape();
         std::string def_ops = operation_index_to_string(object.getDef());
         std::string use_ops = operation_index_to_string(object.getUses());
-        std::string def_layouts = backends_to_string(lower_info->def_backends());
-        std::string use_layouts = backends_to_string(lower_info->use_backends());
+        std::string def_layouts = factors_to_string(lower_info->def_factors());
+        std::string use_layouts = factors_to_string(lower_info->use_factors());
         VERBOSE(Lower) << "* Operand #" << index.value() << " LowerInfo" << std::endl;
         VERBOSE(Lower) << "  - Shape           : { " << shape.dim(0) << " "
                        << (shape.rank() > 1 ? shape.dim(1) : 0) << " "
index 2bc7eec..0cbbb9e 100644 (file)
@@ -150,31 +150,31 @@ bool PermutationEliminationPass::isPermuteLayerToEliminate(
     const model::OperandIndexSequence &inp_indexes, const model::OperandIndexSequence &out_indexes,
     bool is_for_model_input)
 {
-  auto input_def_backends = _graph.getLowerInfo(inp_indexes.at(0))->def_backends();
-  auto output_def_backends = _graph.getLowerInfo(out_indexes.at(0))->def_backends();
+  auto input_def_factors = _graph.getLowerInfo(inp_indexes.at(0))->def_factors();
+  auto output_def_factors = _graph.getLowerInfo(out_indexes.at(0))->def_factors();
 
   auto input_layout = _graph.getLowerInfo(inp_indexes.at(0))->layout();
   auto output_layout = _graph.getLowerInfo(out_indexes.at(0))->layout();
 
-  if (input_def_backends.size() != 1 || output_def_backends.size() != 1)
+  if (input_def_factors.size() != 1 || output_def_factors.size() != 1)
   {
     return false;
   }
 
-  // all operands' backend must be the same
+  // all operands' factor must be the same
   for (auto index : inp_indexes)
   {
-    auto op_backend_set = _graph.getLowerInfo(index)->def_backends();
-    if (op_backend_set.size() != 1 || input_layout != _graph.getLowerInfo(index)->layout())
+    auto op_factor_set = _graph.getLowerInfo(index)->def_factors();
+    if (op_factor_set.size() != 1 || input_layout != _graph.getLowerInfo(index)->layout())
     {
       return false;
     }
   }
-  // all operands' backend must be the same
+  // all operands' factor must be the same
   for (auto index : out_indexes)
   {
-    auto op_backend_set = _graph.getLowerInfo(index)->def_backends();
-    if (op_backend_set.size() != 1 || output_layout != _graph.getLowerInfo(index)->layout())
+    auto op_factor_set = _graph.getLowerInfo(index)->def_factors();
+    if (op_factor_set.size() != 1 || output_layout != _graph.getLowerInfo(index)->layout())
     {
       return false;
     }
index f987d10..24a37ad 100644 (file)
@@ -44,7 +44,7 @@ void PermutationInsertionPass::callback(const model::OperandIndex &index, model:
 
   // NOTE Later, constants also will have Def
   // Ignore constants
-  if (operand_li->def_backends().size() == 0)
+  if (operand_li->def_factors().size() == 0)
   {
     return;
   }
@@ -52,24 +52,24 @@ void PermutationInsertionPass::callback(const model::OperandIndex &index, model:
   std::list<model::OperationIndex> permute_indexes;
 
   // Build a map for all necessary type of operands
-  std::unordered_map<const backend::Backend *, model::OperandIndex> backend_to_index;
+  std::unordered_map<operand::PermuteFactor, model::OperandIndex> factor_to_index;
   {
-    assert(operand_li->def_backends().size() == 1);
-    for (auto backend : operand_li->def_backends())
+    assert(operand_li->def_factors().size() == 1);
+    for (auto factor : operand_li->def_factors())
     {
-      backend_to_index.insert({backend, index});
+      factor_to_index.insert({factor, index});
     }
 
-    auto insert_set = operand_li->use_backends() - operand_li->def_backends();
-    for (auto backend : insert_set)
+    auto insert_set = operand_li->use_factors() - operand_li->def_factors();
+    for (auto factor : insert_set)
     {
-      const auto permute_operation_index = insertPermute(index, backend);
+      const auto permute_operation_index = insertPermute(index, factor);
       permute_indexes.push_back(permute_operation_index);
       VERBOSE(PermutationInsertionPass) << "Insert 'Permute' operation for operand "
                                         << index.value() << std::endl;
       const auto &permute_operation = _graph.operations().at(permute_operation_index);
       const auto permuted_operand_index = permute_operation.getOutputs().at(0);
-      backend_to_index.insert({backend, permuted_operand_index});
+      factor_to_index.insert({factor, permuted_operand_index});
     }
   }
 
@@ -89,12 +89,13 @@ void PermutationInsertionPass::callback(const model::OperandIndex &index, model:
       auto subg_index = _graph.subg_ctx().findNode(use);
       auto subg_li = _graph.getLowerInfo(subg_index);
       assert(subg_li);
+      const auto subg_layout = _graph.subg_ctx().at(subg_index).getLayout();
       const backend::Backend *backend = subg_li->backend();
       assert(backend);
       auto use_node_inputs = operation.getInputs();
       assert(use_node_inputs.contains(index));
 
-      auto new_index = backend_to_index.at(backend);
+      auto new_index = factor_to_index.at({backend, subg_layout});
       if (index != new_index)
       {
         // Update from subgraph
@@ -119,35 +120,12 @@ void PermutationInsertionPass::callback(const model::OperandIndex &index, model:
 
 model::OperationIndex
 PermutationInsertionPass::insertPermute(const model::OperandIndex &operand_index,
-                                        const backend::Backend *backend)
+                                        const operand::PermuteFactor &factor)
 {
   assert(!_graph.isBuildingPhase());
 
   auto &operand = _graph.operands().at(operand_index);
 
-  // TODO It can change to get the layout from def Subgraph
-  const std::string layout_str =
-      config::ConfigManager::instance().get<std::string>(config::ACL_DEFAULT_LAYOUT);
-  graph::operand::Layout input_layout;
-  if (layout_str == "NHWC")
-  {
-    input_layout = graph::operand::Layout::NHWC;
-  }
-  else if (layout_str == "NCHW")
-  {
-    input_layout = graph::operand::Layout::NCHW;
-  }
-  else
-  {
-    throw std::runtime_error("Invalid ACL_DEFAULT_LAYOUT settings");
-  }
-
-  // CPU supports only NHWC now
-  if (_graph.getLowerInfo(operand_index)->def_backends().getOnlyElement()->config()->id() == "cpu")
-  {
-    input_layout = graph::operand::Layout::NHWC;
-  }
-
   // Generate output operand and permute operation
   auto out_operand_index = _graph.addOperand(operand.shape(), operand.typeInfo());
   // change model output if operand_index is model output index
@@ -157,26 +135,39 @@ PermutationInsertionPass::insertPermute(const model::OperandIndex &operand_index
     model_outputs.replace(operand_index, out_operand_index);
   }
 
+  // Find PermuteNode information
+  const auto input_layout =
+      _graph.getLowerInfo(operand_index)->def_factors().getOnlyElement().layout();
+  const auto output_layout = factor.layout();
+  auto input_backend = _graph.getLowerInfo(operand_index)->def_factors().getOnlyElement().backend();
+  auto output_backend = factor.backend();
+  // NOTE PermuteNode may not have specific layout because the layout of input and output may be
+  // different.
+  const auto permute_node_layout =
+      input_layout == output_layout ? output_layout : graph::operand::Layout::UNKNOWN;
+  const auto permute_node_backend = _graph.backend_resolver()->getDefaultBackend();
+  const operand::PermuteFactor permute_node_factor{permute_node_backend, permute_node_layout};
+
   // Update LowerInfo of input operand
   auto operand_lower_info = _graph.getLowerInfo(operand_index);
-  const auto output_layout = operand_lower_info->layout();
-  operand_lower_info->removeUseBackend(backend);
-  operand_lower_info->addUseBackend(operand_lower_info->def_backends().getOnlyElement());
+  operand_lower_info->removeUsePermuteFactor(factor);
+  operand_lower_info->addUsePermuteFactor(permute_node_factor);
   operand_lower_info->setLayout(input_layout);
 
+  // Update LowerInfo of output operand
   auto out_operand_li =
       nnfw::cpp14::make_unique<operand::LowerInfo>(operand::asShape4D(operand.shape()));
-  out_operand_li->addDefBackend(backend);
-  out_operand_li->addUseBackend(backend);
+
+  // The input and output factors of all nodes will be the same except PermuteNode. So Tensor's
+  // allocators allocates memory using only the information of def permutation factor now.
+  // TODO Change param to permute_node_factor
+  out_operand_li->addDefPermuteFactor(factor);
+  out_operand_li->addUsePermuteFactor(factor);
   out_operand_li->setLayout(output_layout);
   _graph.setLowerInfo(out_operand_index, std::move(out_operand_li));
 
-  using PermuteNode = model::operation::PermuteNode;
-
-  auto input_backend = _graph.getLowerInfo(operand_index)->def_backends().getOnlyElement();
-  auto output_backend = _graph.getLowerInfo(out_operand_index)->def_backends().getOnlyElement();
-
   // Insert permute operation to the graph
+  using PermuteNode = model::operation::PermuteNode;
   auto insert_node = nnfw::cpp14::make_unique<PermuteNode>(operand_index, out_operand_index,
                                                            input_backend, output_backend);
 
@@ -185,24 +176,18 @@ PermutationInsertionPass::insertPermute(const model::OperandIndex &operand_index
 
   // Subgraph
   {
-    // NOTE Subgraph for Permutation node may not have specific layout because the layout of input
-    // and output may be different.
-    auto layout = input_layout == output_layout ? output_layout : graph::operand::Layout::UNKNOWN;
-    auto subg_index = _graph.subg_ctx().append(node_index, node, layout);
+    auto subg_index = _graph.subg_ctx().append(node_index, node, permute_node_layout);
     auto &subg = _graph.subg_ctx().at(subg_index);
     subg.setInputs(node.getInputs());
     subg.setOutputs(node.getOutputs());
-    _graph.setLowerInfo(subg_index, nnfw::cpp14::make_unique<graph::operation::LowerInfo>(
-                                        _graph.backend_resolver()->getDefaultBackend()));
+    _graph.setLowerInfo(
+        subg_index, nnfw::cpp14::make_unique<graph::operation::LowerInfo>(permute_node_backend));
   }
 
   // Update Use/Def info
   {
     _graph.operands().at(operand_index).appendUse(node_index);
-
-    auto node_out_indexes = node.getOutputs();
-    auto node_out_index = node_out_indexes.at(model::IOIndex{0});
-    _graph.operands().at(node_out_index).appendDef(node_index);
+    _graph.operands().at(out_operand_index).appendDef(node_index);
   }
   return node_index;
 }
index fb37911..b430be8 100644 (file)
@@ -20,6 +20,7 @@
 #include "OperandPass.h"
 #include "model/Operand.h" //for model::OperationIndex
 #include "backend/BackendManager.h"
+#include "graph/operand/PermuteFactor.h"
 
 namespace neurun
 {
@@ -41,12 +42,12 @@ public:
    * @brief Insert Permute operation that has given operand as input
    *
    * @param operand_index is the target operand index for the insertion
-   * @param backend is the output operand's backend type
+   * @param factor is the output operand's backend type and layout
    *
    * @return model::OperationIndex
    */
   model::OperationIndex insertPermute(const model::OperandIndex &operand_index,
-                                      const backend::Backend *backend);
+                                      const operand::PermuteFactor &factor);
 
 private:
 };
index d0c1120..7609c04 100644 (file)
@@ -160,9 +160,9 @@ backend::TensorBuilderSet Linear::planTensors()
 
   auto iterTensorBuilders = [this](const model::OperandIndex &ind, FnOnTensorBuilder fn) {
     const auto lower_info = getLowerInfo(ind);
-    for (auto backend : lower_info->def_backends())
+    for (auto factor : lower_info->def_factors())
     {
-      auto tensor_builder = backend->tensor_builder();
+      auto tensor_builder = factor.backend()->tensor_builder();
       fn(ind, tensor_builder);
     }
   };
@@ -184,9 +184,10 @@ backend::TensorBuilderSet Linear::planTensors()
       uses_map[ind]++;
     }
 
-    for (auto backend : lower_info->def_backends())
+    for (auto factor : lower_info->def_factors())
     {
       bool isSubTensor = false;
+      auto backend = factor.backend();
       auto tensor_builder = backend->tensor_builder();
 
       if (backend->config()->SupportSubTensorAlloc())