void KernelGenerator::visit(const model::operation::MulNode &node)
{
- // TODO: fix, tests are failing
- throw std::runtime_error("NYI");
-
const auto ofm_index{node.getOutputs().at(0)};
const auto lhs_index{node.getInputs().at(model::operation::MulNode::Input::LHS)};
const auto rhs_index{node.getInputs().at(model::operation::MulNode::Input::RHS)};
auto l = nnfw::cpp14::make_unique<::arm_compute::NEPixelWiseMultiplication>();
+ // RoundingPolicy for scale:1.0 is only allowed RoundingPolicy::TO_ZERO
l->configure(lhs_alloc->handle(), rhs_alloc->handle(), ofm_alloc->handle(), 1.0, // scale
- arm_compute::ConvertPolicy::SATURATE, arm_compute::RoundingPolicy::TO_NEAREST_EVEN);
+ arm_compute::ConvertPolicy::SATURATE, arm_compute::RoundingPolicy::TO_ZERO);
fn = std::move(l);
{
const auto broadcast_rank =
std::max(_ctx.at(lhs_index).shape().rank(), _ctx.at(rhs_index).shape().rank());
- auto lhs_shape = _ctx.at(lhs_index).shape();
- auto rhs_shape = _ctx.at(rhs_index).shape();
- lhs_shape.extendRank(broadcast_rank);
- rhs_shape.extendRank(broadcast_rank);
- }
-
- // Nontrivial broadcasting isn't supported yet
- if (!(_ctx.at(lhs_index).shape() == _ctx.at(rhs_index).shape()))
- {
- throw std::runtime_error("NYI");
+ // TODO remove const_cast later. For example, _ctx may need to be a non const variable or
+ // a node to extend shape may be inserted in front of this operation
+ const_cast<::neurun::model::Shape &>(_ctx.at(lhs_index).shape()).extendRank(broadcast_rank);
+ const_cast<::neurun::model::Shape &>(_ctx.at(rhs_index).shape()).extendRank(broadcast_rank);
}
}
GeneratedTests.lsh_projection*
GeneratedTests.lstm*
GeneratedTests.mobilenet*
-GeneratedTests.mul*
GeneratedTests.neg*
GeneratedTests.notequal*
GeneratedTests.prelu_ex*