Enable Split op for ACL neon (#7226)
author장지섭/On-Device Lab(SR)/Engineer/삼성전자 <jiseob.jang@samsung.com>
Thu, 5 Sep 2019 11:16:37 +0000 (20:16 +0900)
committer오형석/On-Device Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Thu, 5 Sep 2019 11:16:37 +0000 (20:16 +0900)
This commit enables to support Split op for ACL neon.

Signed-off-by: jiseob.jang <jiseob.jang@samsung.com>
runtimes/neurun/backend/acl_neon/KernelGenerator.cc
runtimes/neurun/backend/acl_neon/KernelGenerator.h
runtimes/neurun/backend/acl_neon/ShapeFixer.cc
runtimes/neurun/backend/acl_neon/ShapeFixer.h
tests/nnapi/nnapi_gtest.skip.armv7l-linux.acl_neon

index f5e71e4..f4eb91e 100644 (file)
@@ -1233,6 +1233,45 @@ void KernelGenerator::visit(const model::operation::SoftmaxNode &node)
   _execution_builder->append(std::move(acl_fn));
 }
 
+void KernelGenerator::visit(const model::operation::SplitNode &node)
+{
+  // TODO Support this op by SubTensor
+  const auto ifm_index{node.getInputs().at(model::operation::SplitNode::Input::INPUT)};
+  const auto axis_index{node.param().axis_index};
+  const auto num_of_splits_index{node.param().num_of_splits_index};
+
+  assert(_ctx.at(num_of_splits_index).asScalar<unsigned int>() == node.getOutputs().size());
+
+  const auto ifm_rank = _ctx.at(ifm_index).shape().rank();
+  std::vector<model::OperandIndex> output_indexes;
+  for (const auto &output : node.getOutputs())
+    output_indexes.emplace_back(output);
+
+  auto ifm_alloc = _tensor_builder->at(ifm_index).get();
+  std::vector<arm_compute::ITensor *> output_allocs;
+  for (const auto &ofm_ind : output_indexes)
+    output_allocs.emplace_back(_tensor_builder->at(ofm_ind).get()->handle());
+
+  const auto frontend_layout = _current_subg_layout;
+  const auto backend_layout = ifm_alloc->layout();
+  auto axis = _ctx.at(axis_index).asScalar<int32_t>();
+  if (axis < 0)
+    axis += ifm_rank;
+  axis = acl_common::ToARMComputeAxis(ifm_rank, axis, frontend_layout, backend_layout).value();
+
+  std::unique_ptr<::arm_compute::IFunction> fn;
+
+  auto l = nnfw::cpp14::make_unique<::arm_compute::NESplit>();
+
+  l->configure(ifm_alloc->handle(), output_allocs, axis);
+
+  fn = std::move(l);
+
+  auto acl_fn = asAclFunction(std::move(fn));
+
+  _execution_builder->append(std::move(acl_fn));
+}
+
 void KernelGenerator::visit(const model::operation::SQRTNode &node)
 {
   const auto output_index{node.getOutputs().at(0)};
index 00284d3..4937ae3 100644 (file)
@@ -65,6 +65,7 @@ public:
   void visit(const model::operation::SqueezeNode &) override;
   void visit(const model::operation::TanhNode &) override;
   void visit(const model::operation::SoftmaxNode &) override;
+  void visit(const model::operation::SplitNode &) override;
   void visit(const model::operation::SQRTNode &) override;
   void visit(const model::operation::SquaredDifferenceNode &) override;
   void visit(const model::operation::SubNode &) override;
index 7b0b3fd..4fa59dd 100644 (file)
@@ -219,6 +219,14 @@ void ShapeFixer::visit(const model::operation::StridedSliceNode &) { /* DO NOTHI
 
 void ShapeFixer::visit(const model::operation::SoftmaxNode &) { /* DO NOTHING */}
 
+void ShapeFixer::visit(const model::operation::SplitNode &node)
+{
+  const auto input_index{node.getInputs().at(model::operation::SplitNode::Input::INPUT)};
+  _tensor_builder->dimCorrection(input_index, false);
+  for (const auto &output : node.getOutputs())
+    _tensor_builder->dimCorrection(output, false);
+}
+
 void ShapeFixer::visit(const model::operation::SQRTNode &) { /* DO NOTHING */}
 
 void ShapeFixer::visit(const model::operation::SquaredDifferenceNode &node)
index 161f5a8..386a8f9 100644 (file)
@@ -67,6 +67,7 @@ public:
   void visit(const model::operation::SqueezeNode &) override;
   void visit(const model::operation::TanhNode &) override;
   void visit(const model::operation::SoftmaxNode &) override;
+  void visit(const model::operation::SplitNode &) override;
   void visit(const model::operation::SQRTNode &) override;
   void visit(const model::operation::SquaredDifferenceNode &) override;
   void visit(const model::operation::SubNode &) override;
index 63eb5e3..04e924c 100644 (file)
@@ -29,7 +29,6 @@ GeneratedTests.reduce_max_ex*
 GeneratedTests.reduce_sum_ex*
 GeneratedTests.topk_v2*
 # Unexpected result
-GeneratedTests.split*
 GeneratedTests.pack*
 # Float error
 GeneratedTests.exp_ex_1D_float