[neurun] Move Backend info to PermuteNode (#4180)
author이한종/On-Device Lab(SR)/Engineer/삼성전자 <hanjoung.lee@samsung.com>
Wed, 9 Jan 2019 04:32:05 +0000 (13:32 +0900)
committer오형석/On-Device Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Wed, 9 Jan 2019 04:32:05 +0000 (13:32 +0900)
In order for StageGenerator to not have dependency on backends.
This commit moves the backend info for input/output from StageGenerator
to creation of PermuteNode.
(which is created only in PermutationInsertionPass)

Signed-off-by: Hanjoung Lee <hanjoung.lee@samsung.com>
runtimes/neurun/src/backend/cpu/StageGenerator.cc
runtimes/neurun/src/graph/pass/PermutationInsertionPass.cc
runtimes/neurun/src/model/operation/PermuteNode.cc
runtimes/neurun/src/model/operation/PermuteNode.h

index c53b320..66a1e80 100644 (file)
@@ -508,6 +508,8 @@ void StageGenerator::visit(const model::operation::PermuteNode &node)
     model::operand::Shape shape;
 
     PermuteType type{PermuteType::COPY};
+    const backend::Backend *input_backend;
+    const backend::Backend *output_backend;
   };
 
   Param param;
@@ -517,20 +519,14 @@ void StageGenerator::visit(const model::operation::PermuteNode &node)
 
   param.shape = _ctx.at(output_index).shape();
   param.type = node.param().type;
+  param.input_backend = node.param().input_backend;
+  param.output_backend = node.param().output_backend;
 
   //  assert(param.shape == _ctx.at(input_index));
 
-  const auto &input_li = _ctx.at(input_index).lower_info();
-  const auto &output_li = _ctx.at(output_index).lower_info();
-  const auto input_backend = input_li->def_backends().getOnlyElement();
-  const auto output_backend = output_li->def_backends().getOnlyElement();
-
-  const auto input_tensors = input_backend->tensor_builder();
-  const auto output_tensors = output_backend->tensor_builder();
-
-  returnStage([input_tensors, output_tensors, param](IExecutionBuilder &builder) {
-    auto output_object = output_tensors->wrapTensor(param.output_index);
-    auto input_object = input_tensors->wrapTensor(param.input_index);
+  returnStage([param](IExecutionBuilder &builder) {
+    auto output_object = param.output_backend->tensor_builder()->wrapTensor(param.output_index);
+    auto input_object = param.input_backend->tensor_builder()->wrapTensor(param.input_index);
 
     auto fn = nnfw::cpp14::make_unique<::neurun::kernel::cpu::PermuteLayer>();
 
index 9b833b8..de592ce 100644 (file)
@@ -144,12 +144,13 @@ PermutationInsertionPass::insertPermute(const model::operand::Index &operand_ind
 
   using PermuteNode = model::operation::PermuteNode;
 
+  auto input_backend = operand.lower_info()->def_backends().getOnlyElement();
+  auto output_backend = out_operand.lower_info()->def_backends().getOnlyElement();
+
   // Find Permutation Type
   auto type = [&]() {
-    auto input_layout =
-        operand.lower_info()->def_backends().getOnlyElement()->config()->getOperandLayout();
-    auto output_layout =
-        out_operand.lower_info()->def_backends().getOnlyElement()->config()->getOperandLayout();
+    auto input_layout = input_backend->config()->getOperandLayout();
+    auto output_layout = output_backend->config()->getOperandLayout();
 
     if (input_layout == graph::operand::Layout::NHWC &&
         output_layout == graph::operand::Layout::NCHW)
@@ -168,7 +169,8 @@ PermutationInsertionPass::insertPermute(const model::operand::Index &operand_ind
   }();
 
   // Insert permute operation to the graph
-  auto insert_node = nnfw::cpp14::make_unique<PermuteNode>(operand_index, out_operand_index, type);
+  auto insert_node = nnfw::cpp14::make_unique<PermuteNode>(operand_index, out_operand_index, type,
+                                                           input_backend, output_backend);
 
   auto node_index = _graph.operations().append(std::move(insert_node));
   const auto &node = _graph.operations().at(node_index);
index 174d2a8..c894869 100644 (file)
@@ -29,8 +29,11 @@ namespace operation
 
 void PermuteNode::accept(NodeVisitor &&v) const { v.visit(*this); }
 
-PermuteNode::PermuteNode(const operand::Index &input, const operand::Index &output, Type type)
-    : model::operation::Node{OperandConstraint::createExact(1u)}, _param{type}
+PermuteNode::PermuteNode(const operand::Index &input, const operand::Index &output, Type type,
+                         const backend::Backend *input_backend,
+                         const backend::Backend *output_backend)
+    : model::operation::Node{OperandConstraint::createExact(1u)},
+      _param{type, input_backend, output_backend}
 {
   setInputs({input});
   setOutputs({output});
index b589975..d1f2580 100644 (file)
@@ -39,6 +39,8 @@ public:
   struct Param
   {
     Type type;
+    const backend::Backend *input_backend;
+    const backend::Backend *output_backend;
   };
 
 public:
@@ -46,7 +48,8 @@ public:
   virtual std::string getName() const override { return "Permute"; }
 
 public:
-  PermuteNode(const operand::Index &input, const operand::Index &output, Type type);
+  PermuteNode(const operand::Index &input, const operand::Index &output, Type type,
+              const backend::Backend *input_backend, const backend::Backend *output_backend);
 
 public:
   const Param &param() const { return _param; }