From 90f0a8de6787ae7a971ad94c923014dadd539f5c Mon Sep 17 00:00:00 2001 From: =?utf8?q?Devansh=20Bansal/System=20SW=20/SRI-Bangalore/Engineer/?= =?utf8?q?=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Tue, 20 Nov 2018 09:34:45 +0530 Subject: [PATCH] Minor changes in TransposeConv op (#3619) * Minor changes in TransposeConv op Addressing the comments by jh1302-park in #3423 Signed-off-by: b.devansh * Minor changes in TransposeConv op Addressing the comments by jh1302-park in #3423 Signed-off-by: b.devansh --- runtimes/pure_arm_compute/src/compilation.cc | 2 +- .../src/internal/layers/SimpleTransposeConv.cc | 28 +++++++++++----------- .../src/internal/layers/SimpleTransposeConv.h | 8 ++++--- 3 files changed, 20 insertions(+), 18 deletions(-) diff --git a/runtimes/pure_arm_compute/src/compilation.cc b/runtimes/pure_arm_compute/src/compilation.cc index 3dc9e72..7bc0257 100644 --- a/runtimes/pure_arm_compute/src/compilation.cc +++ b/runtimes/pure_arm_compute/src/compilation.cc @@ -3820,7 +3820,7 @@ void Planner::visit(const ::internal::tflite::op::TransposeConv::Node &node) auto ifm_alloc = ctx.at(::internal::tflite::operand::Index{param.ifm_index}); auto ker_alloc = ctx.at(::internal::tflite::operand::Index{param.ker_index}); - auto fn = nnfw::make_unique(); + auto fn = nnfw::make_unique(); // Only rank 4 is supported const int rank = 4; diff --git a/runtimes/pure_arm_compute/src/internal/layers/SimpleTransposeConv.cc b/runtimes/pure_arm_compute/src/internal/layers/SimpleTransposeConv.cc index 43f16d0..5001446 100644 --- a/runtimes/pure_arm_compute/src/internal/layers/SimpleTransposeConv.cc +++ b/runtimes/pure_arm_compute/src/internal/layers/SimpleTransposeConv.cc @@ -18,9 +18,9 @@ #include #include -void TransposeConv::configure(::arm_compute::ITensor *input, ::arm_compute::ITensor *weights, - ::arm_compute::ITensor *output, int32_t stride_width, - int32_t stride_height, ::arm_compute::Coordinates axises) +void SimpleTransposeConv::configure(::arm_compute::ITensor *input, ::arm_compute::ITensor *weights, + ::arm_compute::ITensor *output, int32_t stride_width, + int32_t stride_height, ::arm_compute::Coordinates axises) { auto rank = axises.num_dimensions(); @@ -35,14 +35,14 @@ void TransposeConv::configure(::arm_compute::ITensor *input, ::arm_compute::ITen } template -inline void SimpleTransposeConv(const ::arm_compute::TensorShape &input_shape, - const ::arm_compute::ITensor *input_data, - const ::arm_compute::TensorShape &filter_shape, - const ::arm_compute::ITensor *filter_data, - const ::arm_compute::TensorShape &output_shape, - const ::arm_compute::ITensor *output_data, - const int32_t stride_width, const int32_t stride_height, - const ::arm_compute::Coordinates axises) +inline void ApplyTransposeConv(const ::arm_compute::TensorShape &input_shape, + const ::arm_compute::ITensor *input_data, + const ::arm_compute::TensorShape &filter_shape, + const ::arm_compute::ITensor *filter_data, + const ::arm_compute::TensorShape &output_shape, + const ::arm_compute::ITensor *output_data, + const int32_t stride_width, const int32_t stride_height, + const ::arm_compute::Coordinates axises) { const int batches = input_shape[axises[0]]; const int input_depth = input_shape[axises[3]]; @@ -113,7 +113,7 @@ inline void SimpleTransposeConv(const ::arm_compute::TensorShape &input_shape, } } -void TransposeConv::run() +void SimpleTransposeConv::run() { if (::internal::arm_compute::isGpuMode()) { @@ -127,12 +127,12 @@ void TransposeConv::run() switch (_input->info()->data_type()) { case ::arm_compute::DataType::S32: - SimpleTransposeConv( + ApplyTransposeConv( _input->info()->tensor_shape(), _input, _weights->info()->tensor_shape(), _weights, _output->info()->tensor_shape(), _output, _stride_width, _stride_height, _axises); break; case ::arm_compute::DataType::F32: - SimpleTransposeConv( + ApplyTransposeConv( _input->info()->tensor_shape(), _input, _weights->info()->tensor_shape(), _weights, _output->info()->tensor_shape(), _output, _stride_width, _stride_height, _axises); break; diff --git a/runtimes/pure_arm_compute/src/internal/layers/SimpleTransposeConv.h b/runtimes/pure_arm_compute/src/internal/layers/SimpleTransposeConv.h index db08667..fd380d4 100644 --- a/runtimes/pure_arm_compute/src/internal/layers/SimpleTransposeConv.h +++ b/runtimes/pure_arm_compute/src/internal/layers/SimpleTransposeConv.h @@ -24,10 +24,10 @@ #include #include "internal/arm_compute/Cast.h" -class TransposeConv : public ::arm_compute::IFunction +class SimpleTransposeConv : public ::arm_compute::IFunction { public: - TransposeConv() + SimpleTransposeConv() : _input(nullptr), _weights(nullptr), _output(nullptr), _stride_width(0), _stride_height(0) { // DO NOTHING @@ -38,7 +38,9 @@ public: * @param[in] input First tensor input. * @param[in] weights Weights * @param[out] output Output tensor. - * @param[in] info Pad and Stride info + * @param[in] stride_width Stride width value + * @param[in] stride_height Stride height value + * @param[in] axises Axises of rank 4 */ void configure(::arm_compute::ITensor *input, ::arm_compute::ITensor *weights, ::arm_compute::ITensor *output, int32_t stride_width, int32_t stride_height, -- 2.7.4