const auto lhs_index{node.getInputs().at(model::operation::AddNode::Input::LHS)};
const auto rhs_index{node.getInputs().at(model::operation::AddNode::Input::RHS)};
- // Broadcasting and quantization
- if (!(_ctx.at(lhs_index).shape() == _ctx.at(rhs_index).shape()) ||
- _ctx.at(lhs_index).typeInfo().type() == model::DataType::QUANT8_ASYMM)
+ // Quantization : not supported
+ if (_ctx.at(lhs_index).typeInfo().type() == model::DataType::QUANT8_ASYMM)
{
throw std::runtime_error{"NYI"};
}
+ // Broadcast
+ if (!(_ctx.at(lhs_index).shape() == _ctx.at(rhs_index).shape()))
+ {
+ const auto broadcast_rank =
+ std::max(_ctx.at(lhs_index).shape().rank(), _ctx.at(rhs_index).shape().rank());
+ const_cast<::neurun::model::Shape &>(_ctx.at(lhs_index).shape()).extendRank(broadcast_rank);
+ const_cast<::neurun::model::Shape &>(_ctx.at(rhs_index).shape()).extendRank(broadcast_rank);
+ }
}
void ShapeFixer::visit(const model::operation::PermuteNode &) { /* DO NOTHING */}