Add permutation type member into PermuteNode (#7085)
author장지섭/On-Device Lab(SR)/Engineer/삼성전자 <jiseob.jang@samsung.com>
Wed, 4 Sep 2019 03:39:18 +0000 (12:39 +0900)
committer이한종/On-Device Lab(SR)/Engineer/삼성전자 <hanjoung.lee@samsung.com>
Wed, 4 Sep 2019 03:39:18 +0000 (12:39 +0900)
This comit adds permutation type member into PermuteNode.

Signed-off-by: jiseob.jang <jiseob.jang@samsung.com>
runtimes/neurun/backend/cpu/KernelGenerator.cc
runtimes/neurun/core/include/model/operation/PermuteNode.h
runtimes/neurun/core/src/graph/dumper/Dumper.cc
runtimes/neurun/core/src/graph/pass/PermutationInsertionPass.cc
runtimes/neurun/core/src/model/operation/PermuteNode.cc

index f92e5dc..225428a 100644 (file)
@@ -381,8 +381,9 @@ void KernelGenerator::visit(const model::operation::PermuteNode &node)
     out_shape.dim(3) = shape.dim(2);
   }
 
-  // Find Permutation Type
-  auto permuteType = [&]() {
+  const auto permute_type = node.getPermuteType();
+  // Check Permutation Type
+  const auto inferPermuteType = [&]() {
     if (input_object->ptr()->layout() == model::Layout::NHWC &&
         output_object->ptr()->layout() == model::Layout::NCHW)
     {
@@ -398,8 +399,10 @@ void KernelGenerator::visit(const model::operation::PermuteNode &node)
       return model::operation::PermuteNode::Type::COPY;
     }
   }();
+  UNUSED_RELEASE(inferPermuteType);
+  assert(permute_type == inferPermuteType);
 
-  fn->configure(input_object, output_object, out_shape, permuteType, data_type);
+  fn->configure(input_object, output_object, out_shape, permute_type, data_type);
 
   input_backend_ctx->tensor_builder->postVisit(node);
 
index a536be5..2339f35 100644 (file)
@@ -57,15 +57,17 @@ public:
 public:
   PermuteNode(const OperandIndex &input, const OperandIndex &output,
               const backend::BackendContext *input_backend_ctx,
-              const backend::BackendContext *output_backend_ctx,
+              const backend::BackendContext *output_backend_ctx, Type type,
               model::DataType data_type = model::DataType::FLOAT32);
 
 public:
   const Param &param() const { return _param; }
   model::DataType getDataType() const { return _dataType; }
+  Type getPermuteType() const { return _type; }
 
 private:
   Param _param;
+  Type _type;
   model::DataType _dataType;
 };
 
index 42b1148..315e2ce 100644 (file)
@@ -340,7 +340,21 @@ void Dumper::visit(const NegNode &node)
 
 void Dumper::visit(const PermuteNode &node)
 {
-  VERBOSE(LIR) << "* Permute" << std::endl;
+  std::string permute_type = "Unknown";
+  switch (node.getPermuteType())
+  {
+    case PermuteNode::Type::COPY:
+      permute_type = "Copy";
+      break;
+    case PermuteNode::Type::NHWC_TO_NCHW:
+      permute_type = "NHWC to NCHW";
+      break;
+    case PermuteNode::Type::NCHW_TO_NHWC:
+      permute_type = "NCHW to NHWC";
+      break;
+  }
+
+  VERBOSE(LIR) << "* Permute(" + permute_type + ")" << std::endl;
   VERBOSE(LIR) << "  - Inputs : Input(" << node.getInputs().at(0).value() << ")" << std::endl;
   VERBOSE(LIR) << "  - Output : Output(" << node.getOutputs().at(0).value() << ")" << std::endl;
 }
index 1f54334..0f07b47 100644 (file)
@@ -164,9 +164,26 @@ PermutationInsertionPass::insertPermute(const model::OperandIndex &operand_index
   auto output_backend_ctx = _graph.backend_resolver()->getBackendContext(output_backend);
 
   // Insert permute operation to the graph
+  const auto input_layout =
+      _graph.getLowerInfo(operand_index)->def_factors().getOnlyElement().layout();
+  const auto output_layout = factor.layout();
   using PermuteNode = model::operation::PermuteNode;
-  auto insert_node = nnfw::cpp14::make_unique<PermuteNode>(operand_index, out_operand_index,
-                                                           input_backend_ctx, output_backend_ctx);
+  const auto permute_type = [&]() {
+    if (input_layout == model::Layout::NHWC && output_layout == model::Layout::NCHW)
+    {
+      return PermuteNode::Type::NHWC_TO_NCHW;
+    }
+    else if (input_layout == model::Layout::NCHW && output_layout == model::Layout::NHWC)
+    {
+      return PermuteNode::Type::NCHW_TO_NHWC;
+    }
+    else
+    {
+      return PermuteNode::Type::COPY;
+    }
+  }();
+  auto insert_node = nnfw::cpp14::make_unique<PermuteNode>(
+      operand_index, out_operand_index, input_backend_ctx, output_backend_ctx, permute_type);
 
   auto node_index = _graph.operations().push(std::move(insert_node));
   const auto &node = _graph.operations().at(node_index);
index 3f0b4cf..8affca1 100644 (file)
@@ -31,10 +31,10 @@ void PermuteNode::accept(OperationVisitor &v) const { v.visit(*this); }
 
 PermuteNode::PermuteNode(const OperandIndex &input, const OperandIndex &output,
                          const backend::BackendContext *input_backend_ctx,
-                         const backend::BackendContext *output_backend_ctx,
+                         const backend::BackendContext *output_backend_ctx, Type type,
                          model::DataType data_type)
     : model::Operation{OperandConstraint::createExact(1u)},
-      _param{input_backend_ctx, output_backend_ctx}, _dataType{data_type}
+      _param{input_backend_ctx, output_backend_ctx}, _type{type}, _dataType{data_type}
 {
   setInputs({input});
   setOutputs({output});