ACL NEON Strided Slice (#5946)
authorNikita Sizov/AI Tools Lab /SRR/Professional/삼성전자 <n.sizov@samsung.com>
Mon, 5 Aug 2019 13:51:36 +0000 (16:51 +0300)
committerEfimov Alexander/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Mon, 5 Aug 2019 13:51:36 +0000 (16:51 +0300)
Add support of Strided Slice in ACL NEON

Signed-off-by: Sizov Nikita <n.sizov@samsung.com>
runtimes/neurun/backend/acl_neon/KernelGenerator.cc
runtimes/neurun/backend/acl_neon/ShapeFixer.cc
runtimes/neurun/backend/acl_neon/ShapeFixer.h
tests/nnapi/nnapi_gtest.skip.armv7l-linux.acl_neon

index a4bb098..72be6bd 100644 (file)
@@ -28,6 +28,7 @@
 #include <arm_compute/runtime/NEON/functions/NEDepthwiseConvolutionLayer.h>
 #include <arm_compute/runtime/NEON/functions/NEReduceMeanEx.h>
 #include <arm_compute/runtime/NEON/functions/NEReshapeLayer.h>
+#include <arm_compute/runtime/NEON/functions/NEStridedSlice.h>
 #include <arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h>
 #include <arm_compute/runtime/NEON/functions/NEFullyConnectedReshapingLayer.h>
 #include <arm_compute/runtime/NEON/functions/NETransposeConvLayer.h>
@@ -833,8 +834,95 @@ void KernelGenerator::visit(const model::operation::SubNode &node)
 
 void KernelGenerator::visit(const model::operation::StridedSliceNode &node)
 {
-  (void)node;
-  throw std::runtime_error("Not supported, yet");
+  const auto output_index{node.getOutputs().at(0)};
+  const auto input_index{node.getInputs().at(model::operation::StridedSliceNode::Input::INPUT)};
+  const auto startData_index{node.param().startData_index};
+  const auto endData_index{node.param().endData_index};
+  const auto stridesData_index{node.param().stridesData_index};
+  const auto beginMask_index{node.param().beginMask_index};
+  const auto endMask_index{node.param().endMask_index};
+  const auto shrinkAxisMask_index{node.param().shrinkAxisMask_index};
+
+  // Set initializers for indices data such as order of inputData
+  int input_rank = _ctx.at(input_index).shape().rank();
+  std::vector<int32_t> starts;
+  std::vector<int32_t> ends;
+  std::vector<int32_t> strides;
+  starts.resize(input_rank, 0);
+  ends.resize(input_rank, 0);
+  strides.resize(input_rank, 0);
+  {
+    auto input_shape = _ctx.at(input_index).shape();
+    auto startData_base = _ctx.at(startData_index).data().base();
+    auto endData_base = _ctx.at(endData_index).data().base();
+    auto stridesData_base = _ctx.at(stridesData_index).data().base();
+    const int startData_size = _ctx.at(startData_index).shape().num_elements();
+    const int endData_size = _ctx.at(endData_index).shape().num_elements();
+    const int stridesData_size = _ctx.at(stridesData_index).shape().num_elements();
+
+    using neurun::model::DataType;
+
+    UNUSED_RELEASE(startData_size);
+    UNUSED_RELEASE(endData_size);
+    UNUSED_RELEASE(stridesData_size);
+
+    assert(_ctx.at(startData_index).typeInfo().type() == DataType::INT32);
+    assert(_ctx.at(endData_index).typeInfo().type() == DataType::INT32);
+    assert(_ctx.at(stridesData_index).typeInfo().type() == DataType::INT32);
+    assert(startData_size == input_rank);
+    assert(endData_size == input_rank);
+    assert(stridesData_size == input_rank);
+
+    assert(startData_base != nullptr);
+    for (int n = 0; n < input_rank; ++n)
+    {
+      auto axis = ::neurun::backend::acl_common::ToARMComputeAxis(input_rank, n).value();
+
+      int32_t start_value = *(reinterpret_cast<const int32_t *>(startData_base) + n);
+      starts[axis] = start_value;
+
+      int32_t end_value = *(reinterpret_cast<const int32_t *>(endData_base) + n);
+      ends[axis] = end_value;
+
+      int32_t strides_value = *(reinterpret_cast<const int32_t *>(stridesData_base) + n);
+      strides[axis] = strides_value;
+    }
+  }
+
+  // Set mask bits such as order of inputData
+  const auto beginMask = ::neurun::backend::acl_common::ReorderBits<int32_t>(
+      _ctx.at(beginMask_index).asScalar<int32_t>(), input_rank);
+  const auto endMask = ::neurun::backend::acl_common::ReorderBits<int32_t>(
+      _ctx.at(endMask_index).asScalar<int32_t>(), input_rank);
+  const auto shrinkAxisMask = ::neurun::backend::acl_common::ReorderBits<int32_t>(
+      _ctx.at(shrinkAxisMask_index).asScalar<int32_t>(), input_rank);
+
+  auto outputData_alloc = _tensor_builder->at(output_index).get();
+  auto inputData_alloc = _tensor_builder->at(input_index).get();
+
+  ::arm_compute::Coordinates starts_set;
+  ::arm_compute::Coordinates ends_set;
+  ::arm_compute::BiStrides strides_set;
+
+  for (size_t i = 0; i < starts.size(); ++i)
+  {
+    starts_set.set(i, starts[i]);
+    ends_set.set(i, ends[i]);
+    strides_set.set(i, strides[i]);
+  }
+
+  std::unique_ptr<::arm_compute::IFunction> fn;
+
+  auto l = nnfw::cpp14::make_unique<::arm_compute::NEStridedSlice>();
+
+  l->configure(inputData_alloc->handle(), outputData_alloc->handle(), starts_set, ends_set,
+               strides_set, beginMask, endMask, shrinkAxisMask);
+
+  fn = std::move(l);
+
+  auto acl_fn = asAclFunction(std::move(fn));
+
+  _execution_builder->append(std::move(acl_fn));
 }
 
 void KernelGenerator::visit(const model::operation::TransposeConvNode &node)
index 2afc3de..5eda116 100644 (file)
@@ -158,6 +158,8 @@ void ShapeFixer::visit(const model::operation::SqueezeNode &node)
 
 void ShapeFixer::visit(const model::operation::TanhNode &) { /* DO NOTHING */}
 
+void ShapeFixer::visit(const model::operation::StridedSliceNode &) { /* DO NOTHING */}
+
 void ShapeFixer::visit(const model::operation::SoftmaxNode &) { /* DO NOTHING */}
 
 void ShapeFixer::visit(const model::operation::SQRTNode &) { /* DO NOTHING */}
index 60f7a2e..392df91 100644 (file)
@@ -57,6 +57,7 @@ public:
   void visit(const model::operation::SQRTNode &) override;
   void visit(const model::operation::SquaredDifferenceNode &) override;
   void visit(const model::operation::SubNode &) override;
+  void visit(const model::operation::StridedSliceNode &) override;
   void visit(const model::operation::TransposeConvNode &) override;
   void visit(const model::operation::AddNode &) override;
   void visit(const model::operation::DivNode &) override;
index 5678c7d..3f1216c 100644 (file)
@@ -52,7 +52,6 @@ GeneratedTests.svdf*
 GeneratedTests.tanh_
 GeneratedTests.batch_to_space*
 GeneratedTests.space_to_batch*
-GeneratedTests.strided_slice*
 GeneratedTests.transpose*
 GeneratedTests.cast_ex*
 GeneratedTests.gather_ex*