void KernelGenerator::visit(const model::operation::DivNode &node)
{
- (void)node;
- throw std::runtime_error("Not supported, yet");
+ const auto ofm_index{node.getOutputs().at(0)};
+ const auto lhs_index{node.getInputs().at(model::operation::DivNode::Input::LHS)};
+ const auto rhs_index{node.getInputs().at(model::operation::DivNode::Input::RHS)};
+
+ const auto activation = node.param().activation;
+
+ auto ofm_alloc = _tensor_builder->at(ofm_index).get();
+ auto lhs_alloc = _tensor_builder->at(lhs_index).get();
+ auto rhs_alloc = _tensor_builder->at(rhs_index).get();
+
+ std::unique_ptr<::arm_compute::IFunction> fn;
+
+ auto l = nnfw::cpp14::make_unique<::arm_compute::NEElementwiseDivision>();
+
+ l->configure(lhs_alloc->handle(), rhs_alloc->handle(), ofm_alloc->handle());
+
+ fn = std::move(l);
+
+ auto acl_fn = asAclFunction(std::move(fn));
+
+ _execution_builder->append(std::move(acl_fn));
+
+ ActivationBuilder{*_execution_builder}.append(activation, ofm_alloc->handle());
}
void KernelGenerator::visit(const model::operation::ExpNode &node)
}
}
+void ShapeFixer::visit(const model::operation::DivNode &node)
+{
+ const auto lhs_index{node.getInputs().at(model::operation::DivNode::Input::LHS)};
+ const auto rhs_index{node.getInputs().at(model::operation::DivNode::Input::RHS)};
+
+ if (!(_ctx.at(lhs_index).shape() == _ctx.at(rhs_index).shape()))
+ {
+ const auto broadcast_rank =
+ std::max(_ctx.at(lhs_index).shape().rank(), _ctx.at(rhs_index).shape().rank());
+
+ // TODO remove const_cast later. For example, _ctx may need to be a non const variable or
+ // a node to extend shape may be inserted in front of this operation
+ const_cast<::neurun::model::Shape &>(_ctx.at(lhs_index).shape()).extendRank(broadcast_rank);
+ const_cast<::neurun::model::Shape &>(_ctx.at(rhs_index).shape()).extendRank(broadcast_rank);
+ }
+}
+
} // namespace acl_neon
} // namespace backend
} // namespace neurun
void visit(const model::operation::SquaredDifferenceNode &) override;
void visit(const model::operation::SubNode &) override;
void visit(const model::operation::AddNode &) override;
+ void visit(const model::operation::DivNode &) override;
void visit(const model::operation::ComparisonNode &) override;
private:
GeneratedTests.svdf*
GeneratedTests.tanh_
GeneratedTests.batch_to_space*
-GeneratedTests.div_*
GeneratedTests.space_to_batch*
GeneratedTests.strided_slice*
GeneratedTests.transpose*