const ::internal::tflite::operand::Index ifm_index{node.param().ifm_index};
const ::internal::tflite::operand::Index alpha_index{node.param().alpha_index};
- // Set shape constraints
+ // Set Shape Constraints and TensorInfo
_builder.addShapeConstr(
ofm_index, asTensorInfo(asTensorShape(_ctx.at(ofm_index).shape()), _ctx.at(ofm_index).type(),
_ctx.at(ofm_index).scale(), _ctx.at(ofm_index).zeroPoint()));
+
+ if (!(_ctx.at(ifm_index).shape() == _ctx.at(alpha_index).shape()))
+ {
+ const auto broadcast_rank =
+ std::max(_ctx.at(ifm_index).shape().rank(), _ctx.at(alpha_index).shape().rank());
+ const_cast<::internal::tflite::operand::Shape &>(_ctx.at(ifm_index).shape())
+ .extendRank(broadcast_rank);
+ const_cast<::internal::tflite::operand::Shape &>(_ctx.at(alpha_index).shape())
+ .extendRank(broadcast_rank);
+ }
_builder.addShapeConstr(
ifm_index, asTensorInfo(asTensorShape(_ctx.at(ifm_index).shape()), _ctx.at(ifm_index).type(),
_ctx.at(ifm_index).scale(), _ctx.at(ifm_index).zeroPoint()));