[flang] Fold transformational bessels when host runtime has bessels
authorJean Perier <jperier@nvidia.com>
Fri, 22 Apr 2022 07:37:08 +0000 (09:37 +0200)
committerJean Perier <jperier@nvidia.com>
Fri, 22 Apr 2022 07:37:49 +0000 (09:37 +0200)
Transformational bessel intrinsic functions require the same math runtime
as elemental bessel intrinsics.

Currently elemental bessels could be folded if f18 was linked with pgmath
(cmake -DLIBPGMATH_DIR option). `j0`, `y0`, ... C libm functions were not
used because they are not standard C functions: they are Posix
extensions.

This patch enable:
- Using the Posix bessel host runtime functions when available.
- folding the transformational bessel using the elemental version.

Differential Revision: https://reviews.llvm.org/D124167

flang/lib/Evaluate/fold-real.cpp
flang/lib/Evaluate/intrinsics-library.cpp
flang/test/Evaluate/folding02.f90

index a1d12c4..e915c55 100644 (file)
 
 namespace Fortran::evaluate {
 
+template <typename T>
+static Expr<T> FoldTransformationalBessel(
+    FunctionRef<T> &&funcRef, FoldingContext &context) {
+  CHECK(funcRef.arguments().size() == 3);
+  /// Bessel runtime functions use `int` integer arguments. Convert integer
+  /// arguments to Int4, any overflow error will be reported during the
+  /// conversion folding.
+  using Int4 = Type<TypeCategory::Integer, 4>;
+  if (auto args{
+          GetConstantArguments<Int4, Int4, T>(context, funcRef.arguments())}) {
+    const std::string &name{std::get<SpecificIntrinsic>(funcRef.proc().u).name};
+    if (auto elementalBessel{GetHostRuntimeWrapper<T, Int4, T>(name)}) {
+      std::vector<Scalar<T>> results;
+      int n1{static_cast<int>(
+          std::get<0>(*args)->GetScalarValue().value().ToInt64())};
+      int n2{static_cast<int>(
+          std::get<1>(*args)->GetScalarValue().value().ToInt64())};
+      Scalar<T> x{std::get<2>(*args)->GetScalarValue().value()};
+      for (int i{n1}; i <= n2; ++i) {
+        results.emplace_back((*elementalBessel)(context, Scalar<Int4>{i}, x));
+      }
+      return Expr<T>{Constant<T>{
+          std::move(results), ConstantSubscripts{std::max(n2 - n1 + 1, 0)}}};
+    } else {
+      context.messages().Say(
+          "%s(integer(kind=4), real(kind=%d)) cannot be folded on host"_warn_en_US,
+          name, T::kind);
+    }
+  }
+  return Expr<T>{std::move(funcRef)};
+}
+
 template <int KIND>
 Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
     FoldingContext &context,
@@ -63,6 +95,8 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
             "%s(integer(kind=4), real(kind=%d)) cannot be folded on host"_warn_en_US,
             name, KIND);
       }
+    } else {
+      return FoldTransformationalBessel<T>(std::move(funcRef), context);
     }
   } else if (name == "abs") { // incl. zabs & cdabs
     // Argument can be complex or real
@@ -245,7 +279,6 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
   // TODO: dim, dot_product, fraction, matmul,
   // modulo, norm2, rrspacing,
   // set_exponent, spacing, transfer,
-  // bessel_jn (transformational) and bessel_yn (transformational)
   return Expr<T>{std::move(funcRef)};
 }
 
index 8230a59..ba3b95f 100644 (file)
@@ -192,7 +192,13 @@ private:
 
 // Define host runtime libraries that can be used for folding and
 // fill their description if they are available.
-enum class LibraryVersion { Libm, PgmathFast, PgmathRelaxed, PgmathPrecise };
+enum class LibraryVersion {
+  Libm,
+  LibmExtensions,
+  PgmathFast,
+  PgmathRelaxed,
+  PgmathPrecise
+};
 template <typename HostT, LibraryVersion> struct HostRuntimeLibrary {
   // When specialized, this class holds a static constexpr table containing
   // all the HostRuntimeLibrary for functions of library LibraryVersion
@@ -277,6 +283,64 @@ struct HostRuntimeLibrary<std::complex<HostT>, LibraryVersion::Libm> {
   static constexpr HostRuntimeMap map{table};
   static_assert(map.Verify(), "map must be sorted");
 };
+// Note regarding cmath:
+//  - cmath does not have modulo and erfc_scaled equivalent
+//  - C++17 defined standard Bessel math functions std::cyl_bessel_j
+//    and std::cyl_neumann that can be used for Fortran j and y
+//    bessel functions. However, they are not yet implemented in
+//    clang libc++ (ok in GNU libstdc++). Instead, the Posix libm
+//    extensions are used when available below.
+
+#if _POSIX_C_SOURCE >= 200112L || _XOPEN_SOURCE >= 600
+/// Define libm extensions
+/// Bessel functions are defined in POSIX.1-2001.
+
+template <> struct HostRuntimeLibrary<float, LibraryVersion::LibmExtensions> {
+  using F = FuncPointer<float, float>;
+  using FN = FuncPointer<float, int, float>;
+  static constexpr HostRuntimeFunction table[]{
+      FolderFactory<F, F{::j0f}>::Create("bessel_j0"),
+      FolderFactory<F, F{::j1f}>::Create("bessel_j1"),
+      FolderFactory<FN, FN{::jnf}>::Create("bessel_jn"),
+      FolderFactory<F, F{::y0f}>::Create("bessel_y0"),
+      FolderFactory<F, F{::y1f}>::Create("bessel_y1"),
+      FolderFactory<FN, FN{::ynf}>::Create("bessel_yn"),
+  };
+  static constexpr HostRuntimeMap map{table};
+  static_assert(map.Verify(), "map must be sorted");
+};
+
+template <> struct HostRuntimeLibrary<double, LibraryVersion::LibmExtensions> {
+  using F = FuncPointer<double, double>;
+  using FN = FuncPointer<double, int, double>;
+  static constexpr HostRuntimeFunction table[]{
+      FolderFactory<F, F{::j0}>::Create("bessel_j0"),
+      FolderFactory<F, F{::j1}>::Create("bessel_j1"),
+      FolderFactory<FN, FN{::jn}>::Create("bessel_jn"),
+      FolderFactory<F, F{::y0}>::Create("bessel_y0"),
+      FolderFactory<F, F{::y1}>::Create("bessel_y1"),
+      FolderFactory<FN, FN{::yn}>::Create("bessel_yn"),
+  };
+  static constexpr HostRuntimeMap map{table};
+  static_assert(map.Verify(), "map must be sorted");
+};
+
+template <>
+struct HostRuntimeLibrary<long double, LibraryVersion::LibmExtensions> {
+  using F = FuncPointer<long double, long double>;
+  using FN = FuncPointer<long double, int, long double>;
+  static constexpr HostRuntimeFunction table[]{
+      FolderFactory<F, F{::j0l}>::Create("bessel_j0"),
+      FolderFactory<F, F{::j1l}>::Create("bessel_j1"),
+      FolderFactory<FN, FN{::jnl}>::Create("bessel_jn"),
+      FolderFactory<F, F{::y0l}>::Create("bessel_y0"),
+      FolderFactory<F, F{::y1l}>::Create("bessel_y1"),
+      FolderFactory<FN, FN{::ynl}>::Create("bessel_yn"),
+  };
+  static constexpr HostRuntimeMap map{table};
+  static_assert(map.Verify(), "map must be sorted");
+};
+#endif
 
 /// Define pgmath description
 #if LINK_WITH_LIBPGMATH
@@ -409,6 +473,8 @@ static const HostRuntimeMap *GetHostRuntimeMap(
   switch (version) {
   case LibraryVersion::Libm:
     return GetHostRuntimeMapVersion<LibraryVersion::Libm>(resultType);
+  case LibraryVersion::LibmExtensions:
+    return GetHostRuntimeMapVersion<LibraryVersion::LibmExtensions>(resultType);
   case LibraryVersion::PgmathPrecise:
     return GetHostRuntimeMapVersion<LibraryVersion::PgmathPrecise>(resultType);
   case LibraryVersion::PgmathRelaxed:
@@ -454,6 +520,13 @@ static const HostRuntimeFunction *SearchHostRuntime(const std::string &name,
       return hostFunction;
     }
   }
+  if (const auto *map{
+          GetHostRuntimeMap(LibraryVersion::LibmExtensions, resultType)}) {
+    if (const auto *hostFunction{
+            SearchInHostRuntimeMap(*map, name, resultType, argTypes)}) {
+      return hostFunction;
+    }
+  }
   return nullptr;
 }
 
index 7ee3652..32a4650 100644 (file)
@@ -249,6 +249,22 @@ module m
     (-0.93219375976297402797143831776338629424571990966796875_8))
    TEST_R8(erfc_scaled, erfc_scaled(0.1_8), &
     0.89645697996912654392787089818739332258701324462890625_8)
+
+  real(4), parameter :: bessel_jn_transformational(*) = bessel_jn(1,3, 3.2_4)
+  logical, parameter :: test_bessel_jn_shape = size(bessel_jn_transformational, 1).eq.3
+  logical, parameter :: test_bessel_jn_t1 = bessel_jn_transformational(1).eq.bessel_jn(1, 3.2_4)
+  logical, parameter :: test_bessel_jn_t2 = bessel_jn_transformational(2).eq.bessel_jn(2, 3.2_4)
+  logical, parameter :: test_bessel_jn_t3 = bessel_jn_transformational(3).eq.bessel_jn(3, 3.2_4)
+  real(4), parameter :: bessel_jn_empty(*) = bessel_jn(3,1, 3.2_4)
+  logical, parameter :: test_bessel_jn_empty = size(bessel_jn_empty, 1).eq.0
+
+  real(4), parameter :: bessel_yn_transformational(*) = bessel_yn(1,3, 1.6_4)
+  logical, parameter :: test_bessel_yn_shape = size(bessel_yn_transformational, 1).eq.3
+  logical, parameter :: test_bessel_yn_t1 = bessel_yn_transformational(1).eq.bessel_yn(1, 1.6_4)
+  logical, parameter :: test_bessel_yn_t2 = bessel_yn_transformational(2).eq.bessel_yn(2, 1.6_4)
+  logical, parameter :: test_bessel_yn_t3 = bessel_yn_transformational(3).eq.bessel_yn(3, 1.6_4)
+  real(4), parameter :: bessel_yn_empty(*) = bessel_yn(3,1, 3.2_4)
+  logical, parameter :: test_bessel_yn_empty = size(bessel_yn_empty, 1).eq.0
 #endif
 
 ! Test exponentiation by real or complex folding (it is using host runtime)