From b335356c4c259a5a592ee86650e955df50e2619d Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EC=9E=A5=EC=A7=80=EC=84=AD/On-Device=20Lab=28SR=29/Enginee?= =?utf8?q?r/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Thu, 17 Jan 2019 10:06:59 +0900 Subject: [PATCH] [PACL] Support TransposeConv for GPU (#4223) This commit supports TransposeConv operation for GPU. Signed-off-by: jiseob.jang --- runtimes/pure_arm_compute/src/compilation.cc | 34 ++++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/runtimes/pure_arm_compute/src/compilation.cc b/runtimes/pure_arm_compute/src/compilation.cc index a80ab16..b0da882 100644 --- a/runtimes/pure_arm_compute/src/compilation.cc +++ b/runtimes/pure_arm_compute/src/compilation.cc @@ -50,6 +50,7 @@ #include #include #include +#include #include #include #include @@ -4060,16 +4061,39 @@ void Planner::visit(const ::internal::tflite::op::TransposeConv::Node &node) auto ifm_alloc = ctx.at(::internal::tflite::operand::Index{param.ifm_index}); auto ker_alloc = ctx.at(::internal::tflite::operand::Index{param.ker_index}); - auto fn = nnfw::cpp14::make_unique(); - // Only rank 4 is supported const int rank = 4; - auto tconv_info = asPadStrideInfo(param.padding, param.stride); + if (from_env(std::getenv("USE_SIMPLE_TRANSPOSECONV"))) + { + auto fn = nnfw::cpp14::make_unique(); + + auto tconv_info = asPadStrideInfo(param.padding, param.stride); + fn->configure(ifm_alloc, ker_alloc, ofm_alloc, tconv_info, getARMComputeAxises(rank)); + + builder.append("TransposeConv", std::move(fn)); + } + else if (::internal::arm_compute::isGpuMode()) + { + auto fn = nnfw::cpp14::make_unique<::arm_compute::CLDeconvolutionLayerEx>(); - fn->configure(ifm_alloc, ker_alloc, ofm_alloc, tconv_info, getARMComputeAxises(rank)); + auto padding = param.padding; + auto inner_border_right = padding.right - padding.left; + auto inner_border_top = padding.bottom - padding.top; - builder.append("TransposeConv", std::move(fn)); + padding.left = padding.right; + padding.top = padding.bottom; + auto symmetric_tconv_info = asPadStrideInfo(padding, param.stride); + + // TODO Support WeightInfo in some cases in order to performance improvement + fn->configure(CAST_CL(ifm_alloc), CAST_CL(ker_alloc), nullptr, CAST_CL(ofm_alloc), + symmetric_tconv_info, inner_border_right, inner_border_top); + builder.append("TransposeConv", std::move(fn)); + } + else + { + throw std::runtime_error("Not supported, yet"); + } }; _builder.addStage(stage); } -- 2.7.4