From d8256278436bd10e0f5ca59b27db17715f3475d4 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=A2=85=ED=98=84/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Senior=20Engineer/=EC=82=BC=EC=84=B1?= =?utf8?q?=EC=A0=84=EC=9E=90?= Date: Thu, 12 Apr 2018 13:50:53 +0900 Subject: [PATCH] [Pure ACL Runtime] Set Conv2D strides (#608) This commit set Conv2D stride parameter with values read from operands. Signed-off-by: Jonghyun Park --- .../bindings/pure_arm_compute/src/compilation.cc | 7 +++++-- .../bindings/pure_arm_compute/src/internal/Model.h | 11 +++++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/tools/nnapi_bindings/bindings/pure_arm_compute/src/compilation.cc b/tools/nnapi_bindings/bindings/pure_arm_compute/src/compilation.cc index 0d75cd0..e3479f9 100644 --- a/tools/nnapi_bindings/bindings/pure_arm_compute/src/compilation.cc +++ b/tools/nnapi_bindings/bindings/pure_arm_compute/src/compilation.cc @@ -87,6 +87,9 @@ void Planner::visit(const ::internal::tflite::op::Conv2D::implicit::Node &node) const ::internal::tflite::operand::Index ker_index{node.param().ker_index}; const ::internal::tflite::operand::Index bias_index{node.param().bias_index}; + const ::internal::tflite::operand::Index vstride_index{node.param().vstride_index}; + const ::internal::tflite::operand::Index hstride_index{node.param().hstride_index}; + const auto ofm_shape = _ctx.at(ofm_index).shape().asFeature(); const auto ifm_shape = _ctx.at(ifm_index).shape().asFeature(); const auto ker_shape = _ctx.at(ker_index).shape().asKernel(); @@ -141,8 +144,8 @@ void Planner::visit(const ::internal::tflite::op::Conv2D::implicit::Node &node) param.padding.left = 0; param.padding.right = 0; - param.stride.vertical = 0; - param.stride.horizontal = 0; + param.stride.vertical = _ctx.at(vstride_index).asScala(); + param.stride.horizontal = _ctx.at(hstride_index).asScala(); auto stage = [param] (const IAllocationContext &ctx, IExecutionBuilder &builder) { diff --git a/tools/nnapi_bindings/bindings/pure_arm_compute/src/internal/Model.h b/tools/nnapi_bindings/bindings/pure_arm_compute/src/internal/Model.h index 9378f2d..357bc4a 100644 --- a/tools/nnapi_bindings/bindings/pure_arm_compute/src/internal/Model.h +++ b/tools/nnapi_bindings/bindings/pure_arm_compute/src/internal/Model.h @@ -128,6 +128,7 @@ private: } // namespace internal #include +#include namespace internal { @@ -156,6 +157,16 @@ public: data(std::unique_ptr(new T{std::forward(args)...})); } +public: + template T asScala(void) const + { + assert((_shape.rank() == 0) || ((_shape.rank() == 1) && (_shape.dim(0) == 1))); + assert(_data != nullptr); + assert((_data->base() != nullptr) && (_data->size() == sizeof(T))); + + return *(reinterpret_cast(_data->base())); + } + private: const Shape _shape; std::unique_ptr _data; -- 2.7.4