// Broadcast
if (!(_ctx.at(lhs_index).shape() == _ctx.at(rhs_index).shape()))
{
+ // See issue #8553
+ throw std::runtime_error{"ShapeFixer: NYI for broadcast Add"};
const auto broadcast_rank =
std::max(_ctx.at(lhs_index).shape().rank(), _ctx.at(rhs_index).shape().rank());
const_cast<::neurun::model::Shape &>(_ctx.at(lhs_index).shape()).extendRank(broadcast_rank);
// Broadcast
if (!(_ctx.at(lhs_index).shape() == _ctx.at(rhs_index).shape()))
{
+ // See issue #8553
+ throw std::runtime_error{"ShapeFixer: NYI for broadcast Sub"};
const auto broadcast_rank =
std::max(_ctx.at(lhs_index).shape().rank(), _ctx.at(rhs_index).shape().rank());
const_cast<::neurun::model::Shape &>(_ctx.at(lhs_index).shape()).extendRank(broadcast_rank);
// Broadcast
if (!(_ctx.at(lhs_index).shape() == _ctx.at(rhs_index).shape()))
{
+ // See issue #8553
+ throw std::runtime_error{"ShapeFixer: NYI for broadcast Mul"};
const auto broadcast_rank =
std::max(_ctx.at(lhs_index).shape().rank(), _ctx.at(rhs_index).shape().rank());
const_cast<::neurun::model::Shape &>(_ctx.at(lhs_index).shape()).extendRank(broadcast_rank);
Adding exception in stage doesn't help. Because if there is a record for add without
broadcast, scheduling will select it since it doesn't distinguish broadcast and
non-broadcast like it does for quant non-quantized*/
- if (backend->config()->id() == "cpu" && node.opcode() == model::OpCode::Add)
+ if (backend->config()->id() == "cpu" &&
+ (node.opcode() == model::OpCode::Add || node.opcode() == model::OpCode::Sub ||
+ node.opcode() == model::OpCode::Mul))
{
const auto lhs_index{node.getInputs().at(model::operation::Add::Input::LHS)};
const auto rhs_index{node.getInputs().at(model::operation::Add::Input::RHS)};
- /*Broadcasting isn't supported on CPU: no way to differ the existing exec_time record of
- * Add with and without broadcasting*/
- /*Quant is also unsupported: throws an exception in run(): in case of scheduling without warm-up
- it isn't catched by tryBackend()*/
- if (quant ||
- !(graph.operands().at(lhs_index).shape() == graph.operands().at(rhs_index).shape()))
+ /*Broadcasting isn't supported on CPU: no way to differ the existing exec_time record with and
+ * without broadcasting*/
+ if (!(graph.operands().at(lhs_index).shape() == graph.operands().at(rhs_index).shape()))
{
return true;
}