// helpers to fold intrinsic function references
// Define callable types used in a common utility that
// takes care of array and cast/conversion aspects for elemental intrinsics
+
template<typename TR, typename... TArgs>
using ScalarFunc = std::function<Scalar<TR>(const Scalar<TArgs> &...)>;
template<typename TR, typename... TArgs>
std::index_sequence<I...>) {
static_assert(
(... && IsSpecificIntrinsicType<TA>)); // TODO derived types for MERGE?
- std::tuple<const std::optional<Scalar<TA>>...> scalars{
- GetScalarConstantValue<TA>(*funcRef.arguments()[I]->value)...};
- if ((... && std::get<I>(scalars).has_value())) {
- if constexpr (std::is_same_v<WrapperType<TR, TA...>,
- ScalarFuncWithContext<TR, TA...>>) {
- return Expr<TR>{Constant<TR>{func(context, *std::get<I>(scalars)...)}};
- } else if constexpr (std::is_same_v<WrapperType<TR, TA...>,
- ScalarFunc<TR, TA...>>) {
- return Expr<TR>{Constant<TR>{func(*std::get<I>(scalars)...)}};
+ static_assert(sizeof...(TA) > 0);
+ std::tuple<const Constant<TA> *...> args{
+ UnwrapExpr<Constant<TA>>(*funcRef.arguments()[I]->value)...};
+ if ((... && (std::get<I>(args) != nullptr))) {
+ // Compute the shape of the result based on shapes of arguments
+ std::vector<std::int64_t> shape;
+ int rank;
+ const std::vector<std::int64_t> *shapes[sizeof...(TA)]{
+ &std::get<I>(args)->shape()...};
+ const int ranks[sizeof...(TA)]{std::get<I>(args)->Rank()...};
+ for (unsigned int i{0}; i < sizeof...(TA); ++i) {
+ if (ranks[i] > 0) {
+ if (rank == 0) {
+ rank = ranks[i];
+ shape = *shapes[i];
+ } else {
+ if (shape != *shapes[i]) {
+ // TODO: Rank compatibility was already checked but it seems to be
+ // the first place where the actual shapes are checked to be the
+ // same. Shouldn't this be checked elsewhere so that this is also
+ // checked for non constexpr call to elemental intrinsics function?
+ context.messages().Say(
+ "arguments in elemental intrinsic function are not conformable"_err_en_US);
+ return Expr<TR>{std::move(funcRef)};
+ }
+ }
+ }
+ }
+
+ // Compute all the scalar values of the results
+ std::size_t size{1};
+ for (std::int64_t dim : shape) {
+ size *= dim;
+ }
+ std::vector<Scalar<TR>> results;
+ std::vector<std::int64_t> index(shape.size(), 1);
+ for (std::size_t n{size}; n-- > 0;) {
+ if constexpr (std::is_same_v<WrapperType<TR, TA...>,
+ ScalarFuncWithContext<TR, TA...>>) {
+ results.emplace_back(func(context,
+ (ranks[I] ? std::get<I>(args)->At(index)
+ : **std::get<I>(args))...));
+ } else if constexpr (std::is_same_v<WrapperType<TR, TA...>,
+ ScalarFunc<TR, TA...>>) {
+ results.emplace_back(func((
+ ranks[I] ? std::get<I>(args)->At(index) : **std::get<I>(args))...));
+ }
+ for (int d{0}; d < rank; ++d) {
+ if (++index[d] <= shape[d]) {
+ break;
+ }
+ index[d] = 1;
+ }
+ }
+ // Build and return constant result
+ if constexpr (TR::category == TypeCategory::Character) {
+ std::int64_t len{
+ static_cast<std::int64_t>(results.size() ? results[0].length() : 0)};
+ return Expr<TR>{Constant<TR>{len, std::move(results), std::move(shape)}};
+ } else {
+ return Expr<TR>{Constant<TR>{std::move(results), std::move(shape)}};
}
}
- // TODO: handle Constant<T> that hold arrays
return Expr<TR>{std::move(funcRef)};
}
}
if (auto *intrinsic{std::get_if<SpecificIntrinsic>(&funcRef.proc().u)}) {
const std::string name{intrinsic->name};
- if (name == "acos" || name == "acosh") {
+ if (name == "acos" || name == "acosh" ||
+ (name == "atan" && funcRef.arguments().size() == 1)) {
if (auto callable{
context.hostRte().GetHostProcedureWrapper<Scalar, T, T>(name)}) {
return FoldElementalIntrinsic<T, T>(
"%s(real(kind=%d)) cannot be folded on host"_en_US, name.c_str(),
KIND);
}
+ }
+ if (name == "atan") {
+ if (auto callable{
+ context.hostRte().GetHostProcedureWrapper<Scalar, T, T, T>(
+ name)}) {
+ return FoldElementalIntrinsic<T, T, T>(
+ context, std::move(funcRef), *callable);
+ } else {
+ context.messages().Say(
+ "%s(real(kind=%d), real(kind%d)) cannot be folded on host"_en_US,
+ name.c_str(), KIND);
+ }
} else if (name == "bessel_jn" || name == "bessel_yn") {
if (funcRef.arguments().size() == 2) { // elemental
using Int8 = Type<TypeCategory::Integer, 8>;