[Pure ACL Runtime] Set Conv2D strides (#608)
author박종현/동작제어Lab(SR)/Senior Engineer/삼성전자 <jh1302.park@samsung.com>
Thu, 12 Apr 2018 04:50:53 +0000 (13:50 +0900)
committer박세희/동작제어Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Thu, 12 Apr 2018 04:50:53 +0000 (13:50 +0900)
This commit set Conv2D stride parameter with values read from operands.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
tools/nnapi_bindings/bindings/pure_arm_compute/src/compilation.cc
tools/nnapi_bindings/bindings/pure_arm_compute/src/internal/Model.h

index 0d75cd0..e3479f9 100644 (file)
@@ -87,6 +87,9 @@ void Planner::visit(const ::internal::tflite::op::Conv2D::implicit::Node &node)
   const ::internal::tflite::operand::Index ker_index{node.param().ker_index};
   const ::internal::tflite::operand::Index bias_index{node.param().bias_index};
 
+  const ::internal::tflite::operand::Index vstride_index{node.param().vstride_index};
+  const ::internal::tflite::operand::Index hstride_index{node.param().hstride_index};
+
   const auto ofm_shape = _ctx.at(ofm_index).shape().asFeature();
   const auto ifm_shape = _ctx.at(ifm_index).shape().asFeature();
   const auto ker_shape = _ctx.at(ker_index).shape().asKernel();
@@ -141,8 +144,8 @@ void Planner::visit(const ::internal::tflite::op::Conv2D::implicit::Node &node)
   param.padding.left = 0;
   param.padding.right = 0;
 
-  param.stride.vertical = 0;
-  param.stride.horizontal = 0;
+  param.stride.vertical = _ctx.at(vstride_index).asScala<int32_t>();
+  param.stride.horizontal = _ctx.at(hstride_index).asScala<int32_t>();
 
   auto stage = [param] (const IAllocationContext &ctx, IExecutionBuilder &builder)
   {
index 9378f2d..357bc4a 100644 (file)
@@ -128,6 +128,7 @@ private:
 } // namespace internal
 
 #include <memory>
+#include <cassert>
 
 namespace internal
 {
@@ -156,6 +157,16 @@ public:
     data(std::unique_ptr<T>(new T{std::forward<Args>(args)...}));
   }
 
+public:
+  template<typename T> T asScala(void) const
+  {
+    assert((_shape.rank() == 0) || ((_shape.rank() == 1) && (_shape.dim(0) == 1)));
+    assert(_data != nullptr);
+    assert((_data->base() != nullptr) && (_data->size() == sizeof(T)));
+
+    return *(reinterpret_cast<const T *>(_data->base()));
+  }
+
 private:
   const Shape _shape;
   std::unique_ptr<Data> _data;