[flang] add support to fold elemental intrisics over arrays
authorJean Perier <jperier@hsw1.pgi.net>
Mon, 25 Feb 2019 17:33:12 +0000 (09:33 -0800)
committerGitHub <noreply@github.com>
Wed, 27 Mar 2019 17:16:07 +0000 (10:16 -0700)
Original-commit: flang-compiler/f18@c2fec22856b9bae7221036173c785d0732c7a7c3
Tree-same-pre-rewrite: false

flang/lib/evaluate/fold.cc

index 9a42d5e..b7edd21 100644 (file)
@@ -165,6 +165,7 @@ ComplexPart FoldOperation(FoldingContext &context, ComplexPart &&complexPart) {
 // helpers to fold intrinsic function references
 // Define callable types used in a common utility that
 // takes care of array and cast/conversion aspects for elemental intrinsics
+
 template<typename TR, typename... TArgs>
 using ScalarFunc = std::function<Scalar<TR>(const Scalar<TArgs> &...)>;
 template<typename TR, typename... TArgs>
@@ -178,18 +179,69 @@ static inline Expr<TR> FoldElementalIntrinsicHelper(FoldingContext &context,
     std::index_sequence<I...>) {
   static_assert(
       (... && IsSpecificIntrinsicType<TA>));  // TODO derived types for MERGE?
-  std::tuple<const std::optional<Scalar<TA>>...> scalars{
-      GetScalarConstantValue<TA>(*funcRef.arguments()[I]->value)...};
-  if ((... && std::get<I>(scalars).has_value())) {
-    if constexpr (std::is_same_v<WrapperType<TR, TA...>,
-                      ScalarFuncWithContext<TR, TA...>>) {
-      return Expr<TR>{Constant<TR>{func(context, *std::get<I>(scalars)...)}};
-    } else if constexpr (std::is_same_v<WrapperType<TR, TA...>,
-                             ScalarFunc<TR, TA...>>) {
-      return Expr<TR>{Constant<TR>{func(*std::get<I>(scalars)...)}};
+  static_assert(sizeof...(TA) > 0);
+  std::tuple<const Constant<TA> *...> args{
+      UnwrapExpr<Constant<TA>>(*funcRef.arguments()[I]->value)...};
+  if ((... && (std::get<I>(args) != nullptr))) {
+    // Compute the shape of the result based on shapes of arguments
+    std::vector<std::int64_t> shape;
+    int rank;
+    const std::vector<std::int64_t> *shapes[sizeof...(TA)]{
+        &std::get<I>(args)->shape()...};
+    const int ranks[sizeof...(TA)]{std::get<I>(args)->Rank()...};
+    for (unsigned int i{0}; i < sizeof...(TA); ++i) {
+      if (ranks[i] > 0) {
+        if (rank == 0) {
+          rank = ranks[i];
+          shape = *shapes[i];
+        } else {
+          if (shape != *shapes[i]) {
+            // TODO: Rank compatibility was already checked but it seems to be
+            // the first place where the actual shapes are checked to be the
+            // same. Shouldn't this be checked elsewhere so that this is also
+            // checked for non constexpr call to elemental intrinsics function?
+            context.messages().Say(
+                "arguments in elemental intrinsic function are not conformable"_err_en_US);
+            return Expr<TR>{std::move(funcRef)};
+          }
+        }
+      }
+    }
+
+    // Compute all the scalar values of the results
+    std::size_t size{1};
+    for (std::int64_t dim : shape) {
+      size *= dim;
+    }
+    std::vector<Scalar<TR>> results;
+    std::vector<std::int64_t> index(shape.size(), 1);
+    for (std::size_t n{size}; n-- > 0;) {
+      if constexpr (std::is_same_v<WrapperType<TR, TA...>,
+                        ScalarFuncWithContext<TR, TA...>>) {
+        results.emplace_back(func(context,
+            (ranks[I] ? std::get<I>(args)->At(index)
+                      : **std::get<I>(args))...));
+      } else if constexpr (std::is_same_v<WrapperType<TR, TA...>,
+                               ScalarFunc<TR, TA...>>) {
+        results.emplace_back(func((
+            ranks[I] ? std::get<I>(args)->At(index) : **std::get<I>(args))...));
+      }
+      for (int d{0}; d < rank; ++d) {
+        if (++index[d] <= shape[d]) {
+          break;
+        }
+        index[d] = 1;
+      }
+    }
+    // Build and return constant result
+    if constexpr (TR::category == TypeCategory::Character) {
+      std::int64_t len{
+          static_cast<std::int64_t>(results.size() ? results[0].length() : 0)};
+      return Expr<TR>{Constant<TR>{len, std::move(results), std::move(shape)}};
+    } else {
+      return Expr<TR>{Constant<TR>{std::move(results), std::move(shape)}};
     }
   }
-  // TODO: handle Constant<T> that hold arrays
   return Expr<TR>{std::move(funcRef)};
 }
 
@@ -276,7 +328,8 @@ Expr<Type<TypeCategory::Real, KIND>> FoldOperation(FoldingContext &context,
   }
   if (auto *intrinsic{std::get_if<SpecificIntrinsic>(&funcRef.proc().u)}) {
     const std::string name{intrinsic->name};
-    if (name == "acos" || name == "acosh") {
+    if (name == "acos" || name == "acosh" ||
+        (name == "atan" && funcRef.arguments().size() == 1)) {
       if (auto callable{
               context.hostRte().GetHostProcedureWrapper<Scalar, T, T>(name)}) {
         return FoldElementalIntrinsic<T, T>(
@@ -286,6 +339,18 @@ Expr<Type<TypeCategory::Real, KIND>> FoldOperation(FoldingContext &context,
             "%s(real(kind=%d)) cannot be folded on host"_en_US, name.c_str(),
             KIND);
       }
+    }
+    if (name == "atan") {
+      if (auto callable{
+              context.hostRte().GetHostProcedureWrapper<Scalar, T, T, T>(
+                  name)}) {
+        return FoldElementalIntrinsic<T, T, T>(
+            context, std::move(funcRef), *callable);
+      } else {
+        context.messages().Say(
+            "%s(real(kind=%d), real(kind%d)) cannot be folded on host"_en_US,
+            name.c_str(), KIND);
+      }
     } else if (name == "bessel_jn" || name == "bessel_yn") {
       if (funcRef.arguments().size() == 2) {  // elemental
         using Int8 = Type<TypeCategory::Integer, 8>;