From 186796269d4068eb3737dd431a6fca006da2e4b7 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Prasanna=20R/SNAP=20/SRI-Bangalore/Engineer/=EC=82=BC?= =?utf8?q?=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Thu, 17 Jan 2019 15:07:52 +0530 Subject: [PATCH] Add Broadcast support for PReLU in PACL (#4072) This patch adds broadcast support for PReLU in PACL. Signed-off-by: prasannar --- runtimes/pure_arm_compute/src/compilation.cc | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/runtimes/pure_arm_compute/src/compilation.cc b/runtimes/pure_arm_compute/src/compilation.cc index 51f4fe4..9c12294 100644 --- a/runtimes/pure_arm_compute/src/compilation.cc +++ b/runtimes/pure_arm_compute/src/compilation.cc @@ -3016,10 +3016,20 @@ void Planner::visit(const ::internal::tflite::op::PReLU::Node &node) const ::internal::tflite::operand::Index ifm_index{node.param().ifm_index}; const ::internal::tflite::operand::Index alpha_index{node.param().alpha_index}; - // Set shape constraints + // Set Shape Constraints and TensorInfo _builder.addShapeConstr( ofm_index, asTensorInfo(asTensorShape(_ctx.at(ofm_index).shape()), _ctx.at(ofm_index).type(), _ctx.at(ofm_index).scale(), _ctx.at(ofm_index).zeroPoint())); + + if (!(_ctx.at(ifm_index).shape() == _ctx.at(alpha_index).shape())) + { + const auto broadcast_rank = + std::max(_ctx.at(ifm_index).shape().rank(), _ctx.at(alpha_index).shape().rank()); + const_cast<::internal::tflite::operand::Shape &>(_ctx.at(ifm_index).shape()) + .extendRank(broadcast_rank); + const_cast<::internal::tflite::operand::Shape &>(_ctx.at(alpha_index).shape()) + .extendRank(broadcast_rank); + } _builder.addShapeConstr( ifm_index, asTensorInfo(asTensorShape(_ctx.at(ifm_index).shape()), _ctx.at(ifm_index).type(), _ctx.at(ifm_index).scale(), _ctx.at(ifm_index).zeroPoint())); -- 2.7.4