Add Broadcast support for PReLU in PACL (#4072)
authorPrasanna R/SNAP /SRI-Bangalore/Engineer/삼성전자 <prasanna.r@samsung.com>
Thu, 17 Jan 2019 09:37:52 +0000 (15:07 +0530)
committer오형석/On-Device Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Thu, 17 Jan 2019 09:37:52 +0000 (18:37 +0900)
This patch adds broadcast support for PReLU in PACL.

Signed-off-by: prasannar <prasanna.r@samsung.com>
runtimes/pure_arm_compute/src/compilation.cc

index 51f4fe4..9c12294 100644 (file)
@@ -3016,10 +3016,20 @@ void Planner::visit(const ::internal::tflite::op::PReLU::Node &node)
   const ::internal::tflite::operand::Index ifm_index{node.param().ifm_index};
   const ::internal::tflite::operand::Index alpha_index{node.param().alpha_index};
 
-  // Set shape constraints
+  // Set Shape Constraints and TensorInfo
   _builder.addShapeConstr(
       ofm_index, asTensorInfo(asTensorShape(_ctx.at(ofm_index).shape()), _ctx.at(ofm_index).type(),
                               _ctx.at(ofm_index).scale(), _ctx.at(ofm_index).zeroPoint()));
+
+  if (!(_ctx.at(ifm_index).shape() == _ctx.at(alpha_index).shape()))
+  {
+    const auto broadcast_rank =
+        std::max(_ctx.at(ifm_index).shape().rank(), _ctx.at(alpha_index).shape().rank());
+    const_cast<::internal::tflite::operand::Shape &>(_ctx.at(ifm_index).shape())
+        .extendRank(broadcast_rank);
+    const_cast<::internal::tflite::operand::Shape &>(_ctx.at(alpha_index).shape())
+        .extendRank(broadcast_rank);
+  }
   _builder.addShapeConstr(
       ifm_index, asTensorInfo(asTensorShape(_ctx.at(ifm_index).shape()), _ctx.at(ifm_index).type(),
                               _ctx.at(ifm_index).scale(), _ctx.at(ifm_index).zeroPoint()));