[flang] Enable folding of some more intrinsic functions
authorJean Perier <jperier@nvidia.com>
Fri, 22 Mar 2019 16:22:00 +0000 (09:22 -0700)
committerGitHub <noreply@github.com>
Wed, 27 Mar 2019 17:16:07 +0000 (10:16 -0700)
Enable folding of the following 80 intrinsic functions:

+ Without runtime:

++ Integer:
abs, dim, dshiftl, dshiftr, exponent, iand, ibclr, ibset, ieor, int,
ior, ishft, kind, len, leadz, maskl, maskr, merge_bits, popcnt, poppar,
shifta, shiftl, shiftr, trailz

++ Real:
abs, aimag, aint, dprod, real

+ Complex:
cmplx, conjg

++ Logical:
bge, bgt, ble, blt

+ With Runtime :

+ Real:
acos, acosh, asinh, atan, atan2, atanh, bessel_j0, bessel_j1,
bessel_jn (elemental), bessel_y0, bessel_y1, bessel_yn (elemental),
cos, cosh, erf, erfc, erfc_scaled, exp, gamma, hypot, log, log10,
log_gamma, mod, sin, sqrt, sinh, sqrt, tan, tanh

++ Complex:
acos, acosh, asin, asinh, atan, atanh, cos, cosh, exp, log, sin,
sinh, sqrt, tan, tanh

Original-commit: flang-compiler/f18@7e7d1920f882e7ca22c1320dd9b7e0a3d6eaec28
Tree-same-pre-rewrite: false

flang/lib/evaluate/fold.cc

index 742740c..b94d25a 100644 (file)
@@ -263,7 +263,8 @@ template<int KIND>
 Expr<Type<TypeCategory::Integer, KIND>> FoldOperation(FoldingContext &context,
     FunctionRef<Type<TypeCategory::Integer, KIND>> &&funcRef) {
   using T = Type<TypeCategory::Integer, KIND>;
-  for (std::optional<ActualArgument> &arg : funcRef.arguments()) {
+  ActualArguments &args{funcRef.arguments()};
+  for (std::optional<ActualArgument> &arg : args) {
     if (arg.has_value()) {
       arg.value().value() =
           FoldOperation(context, std::move(arg.value().value()));
@@ -271,32 +272,106 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldOperation(FoldingContext &context,
   }
   if (auto *intrinsic{std::get_if<SpecificIntrinsic>(&funcRef.proc().u)}) {
     const std::string name{intrinsic->name};
-    if (name == "kind") {
-      if constexpr (common::HasMember<T, IntegerTypes>) {
-        return Expr<T>{funcRef.arguments()[0]->value()->GetType()->kind};
-      } else {
-        common::die("kind() result not integral");
+    // abs, dim, dshiftl, dshiftr, exponent, iand, ibclr, ibset, ieor, int, ior,
+    // ishft, kind, len, leadz, maskl, maskr, merge_bits shifta, shiftl, shiftr,
+    // trailz
+    if (name == "abs") {
+      return FoldElementalIntrinsic<T, T>(context, std::move(funcRef),
+          ScalarFunc<T, T>([&context](const Scalar<T> &i) -> Scalar<T> {
+            typename Scalar<T>::ValueWithOverflow j{i.ABS()};
+            if (j.overflow) {
+              context.messages().Say(
+                  "abs(integer(kind=%n)) folding overflowed"_en_US, KIND);
+            }
+            return j.value;
+          }));
+    } else if (name == "dim") {
+      return FoldElementalIntrinsic<T, T, T>(
+          context, std::move(funcRef), &Scalar<T>::DIM);
+    } else if (name == "dshiftl" || name == "dshiftr") {
+      // convert boz
+      for (int i{0}; i <= 1; ++i) {
+        if (auto *x{std::get_if<BOZLiteralConstant>(&args[i]->value->u)}) {
+          *args[i]->value = Fold(context, ConvertToType<T>(std::move(*x)));
+        }
       }
-    } else if (name == "len") {
-      if constexpr (std::is_same_v<T, SubscriptInteger>) {
-        if (auto *charExpr{UnwrapExpr<Expr<SomeCharacter>>(
-                *funcRef.arguments()[0]->value())}) {
-          return std::visit([](auto &kx) { return kx.LEN(); }, charExpr->u);
+      // Third argument can be of any kind. However, it must be smaller or equal
+      // than BIT_SIZE. It can be converted to Int4 to simplify.
+      using Int4 = Type<TypeCategory::Integer, 4>;
+      if (auto *n{std::get_if<Expr<SomeInteger>>(&args[2]->value->u)}) {
+        if (n->GetType()->kind != 4) {
+          *args[2]->value = Fold(context, ConvertToType<Int4>(std::move(*n)));
         }
+      }
+      const auto fptr{
+          name == "dshiftl" ? &Scalar<T>::DSHIFTL : &Scalar<T>::DSHIFTR};
+      return FoldElementalIntrinsic<T, T, T, Int4>(context, std::move(funcRef),
+          ScalarFunc<T, T, T, Int4>(
+              [&fptr](const Scalar<T> &i, const Scalar<T> &j,
+                  const Scalar<Int4> &shift) -> Scalar<T> {
+                return std::invoke(
+                    fptr, i, j, static_cast<int>(shift.ToInt64()));
+              }));
+    } else if (name == "exponent") {
+      if (auto *sx{std::get_if<Expr<SomeReal>>(&args[0]->value->u)}) {
+        return std::visit(
+            [&funcRef, &context](const auto &x) -> Expr<T> {
+              using TR = typename std::decay_t<decltype(x)>::Result;
+              return FoldElementalIntrinsic<T, TR>(context, std::move(funcRef),
+                  &Scalar<TR>::template EXPONENT<Scalar<T>>);
+            },
+            sx->u);
       } else {
-        common::die("len() result not SubscriptInteger");
+        common::die("exponent argument must be real");
       }
-    } else if (name == "iand") {
-      // TODO change intrinsic.cc so that it already has handled BOZ conversions
+    } else if (name == "iand" || name == "ior" || name == "ieor") {
+      // convert boz
       for (int i{0}; i <= 1; ++i) {
-        if (auto *x{std::get_if<BOZLiteralConstant>(
-                &funcRef.arguments()[i]->value->u)}) {
-          *funcRef.arguments()[i]->value =
-              Fold(context, ConvertToType<T>(std::move(*x)));
+        if (auto *x{std::get_if<BOZLiteralConstant>(&args[i]->value->u)}) {
+          *args[i]->value = Fold(context, ConvertToType<T>(std::move(*x)));
         }
       }
+      auto fptr{&Scalar<T>::IAND};
+      if (name == "iand") {  // done in fptr declaration
+      } else if (name == "ior") {
+        fptr = &Scalar<T>::IOR;
+      } else if (name == "ieor") {
+        fptr = &Scalar<T>::IEOR;
+      } else {
+        common::die("missing case to fold intrinsic function %s", name);
+      }
       return FoldElementalIntrinsic<T, T, T>(
-          context, std::move(funcRef), ScalarFunc<T, T, T>(&Scalar<T>::IAND));
+          context, std::move(funcRef), ScalarFunc<T, T, T>(fptr));
+    } else if (name == "ibclr" || name == "ibset" || name == "ishft" ||
+        name == "shifta" || name == "shiftr" || name == "shiftl") {
+      // Second argument can be of any kind. However, it must be smaller or
+      // equal than BIT_SIZE. It can be converted to Int4 to simplify.
+      using Int4 = Type<TypeCategory::Integer, 4>;
+      if (auto *n{std::get_if<Expr<SomeInteger>>(&args[1]->value->u)}) {
+        if (n->GetType()->kind != 4) {
+          *args[1]->value = Fold(context, ConvertToType<Int4>(std::move(*n)));
+        }
+      }
+      auto fptr{&Scalar<T>::IBCLR};
+      if (name == "ibclr") {  // done in fprt definition
+      } else if (name == "ibset") {
+        fptr = &Scalar<T>::IBSET;
+      } else if (name == "ibshft") {
+        fptr = &Scalar<T>::ISHFT;
+      } else if (name == "shifta") {
+        fptr = &Scalar<T>::SHIFTA;
+      } else if (name == "shiftr") {
+        fptr = &Scalar<T>::SHIFTR;
+      } else if (name == "shiftl") {
+        fptr = &Scalar<T>::SHIFTL;
+      } else {
+        common::die("missing case to fold intrinsic function %s", name);
+      }
+      return FoldElementalIntrinsic<T, T, Int4>(context, std::move(funcRef),
+          ScalarFunc<T, T, Int4>([&fptr](const Scalar<T> &i,
+                                     const Scalar<Int4> &pos) -> Scalar<T> {
+            return std::invoke(fptr, i, static_cast<int>(pos.ToInt64()));
+          }));
     } else if (name == "int") {
       return std::visit(
           [&](auto &&x) -> Expr<T> {
@@ -311,9 +386,89 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldOperation(FoldingContext &context,
               return Expr<T>{std::move(funcRef)};  // unreachable
             }
           },
-          std::move(funcRef.arguments()[0]->value->u));
+          std::move(args[0]->value->u));
+    } else if (name == "kind") {
+      if constexpr (common::HasMember<T, IntegerTypes>) {
+        return Expr<T>{args[0]->value->GetType()->kind};
+      } else {
+        common::die("kind() result not integral");
+      }
+    } else if (name == "leadz" || name == "trailz" || name == "poppar" ||
+        name == "popcnt") {
+      if (auto *sn{std::get_if<Expr<SomeInteger>>(&args[0]->value->u)}) {
+        return std::visit(
+            [&funcRef, &context, &name](const auto &n) -> Expr<T> {
+              using TI = typename std::decay_t<decltype(n)>::Result;
+              if (name == "poppar") {
+                return FoldElementalIntrinsic<T, TI>(context,
+                    std::move(funcRef),
+                    ScalarFunc<T, TI>([](const Scalar<TI> &i) -> Scalar<T> {
+                      return Scalar<T>{i.POPPAR() ? 1 : 0};
+                    }));
+              }
+              auto fptr{&Scalar<TI>::LEADZ};
+              if (name == "leadz") {  // done in fprt definition
+              } else if (name == "trailz") {
+                fptr = &Scalar<TI>::TRAILZ;
+              } else if (name == "popcnt") {
+                fptr = &Scalar<TI>::POPCNT;
+              } else {
+                common::die("missing case to fold intrinsic function %s", name);
+              }
+              return FoldElementalIntrinsic<T, TI>(context, std::move(funcRef),
+                  ScalarFunc<T, TI>([&fptr](const Scalar<TI> &i) -> Scalar<T> {
+                    return Scalar<T>{std::invoke(fptr, i)};
+                  }));
+            },
+            sn->u);
+      } else {
+        common::die("leadz argument must be integer");
+      }
+    } else if (name == "len") {
+      if (auto *charExpr{UnwrapExpr<Expr<SomeCharacter>>(*args[0]->value)}) {
+        return std::visit(
+            [&context](auto &kx) {
+              if constexpr (std::is_same_v<T, SubscriptInteger>) {
+                return kx.LEN();
+              } else {
+                return Fold(context, ConvertToType<T>(kx.LEN()));
+              }
+            },
+            charExpr->u);
+      } else {
+        common::die("len() result not SubscriptInteger");
+      }
+    } else if (name == "maskl" || name == "maskr") {
+      // Argument can be of any kind but value has to be smaller than bit_size.
+      // It can be safely converted to Int4 to simplify.
+      using Int4 = Type<TypeCategory::Integer, 4>;
+      if (auto *n{std::get_if<Expr<SomeInteger>>(&args[0]->value->u)}) {
+        if (n->GetType()->kind != 4) {
+          *args[0]->value = Fold(context, ConvertToType<Int4>(std::move(*n)));
+        }
+      }
+      const auto fptr{name == "maskl" ? &Scalar<T>::MASKL : &Scalar<T>::MASKR};
+      return FoldElementalIntrinsic<T, Int4>(context, std::move(funcRef),
+          ScalarFunc<T, Int4>([&fptr](const Scalar<Int4> &places) -> Scalar<T> {
+            return fptr(static_cast<int>(places.ToInt64()));
+          }));
+    } else if (name == "merge_bits") {
+      // convert boz
+      for (int i{0}; i <= 2; ++i) {
+        if (auto *x{std::get_if<BOZLiteralConstant>(&args[i]->value->u)}) {
+          *args[i]->value = Fold(context, ConvertToType<T>(std::move(*x)));
+        }
+      }
+      return FoldElementalIntrinsic<T, T, T, T>(
+          context, std::move(funcRef), &Scalar<T>::MERGE_BITS);
     }
-    // TODO: many more intrinsic functions
+    // TODO:
+    // ceiling, command_argument_count, count, cshift, dot_product, eoshift,
+    // findloc, floor, iachar, iall, iany, iparity, ibits, ichar, image_status,
+    // index, ishftc, lbound, len_trim, matmul, max, maxloc, maxval, merge, min,
+    // minloc, minval, mod, modulo, nint, not, pack, product, reduce, reshape,
+    // scan, selected_char_kind, selected_int_kind, selected_real_kind, shape,
+    // sign, size, spread, sum, transfer, transpose, ubound, unpack, verify
   }
   return Expr<T>{std::move(funcRef)};
 }
@@ -322,15 +477,25 @@ template<int KIND>
 Expr<Type<TypeCategory::Real, KIND>> FoldOperation(FoldingContext &context,
     FunctionRef<Type<TypeCategory::Real, KIND>> &&funcRef) {
   using T = Type<TypeCategory::Real, KIND>;
-  for (std::optional<ActualArgument> &arg : funcRef.arguments()) {
+  using ComplexT = Type<TypeCategory::Complex, KIND>;
+  ActualArguments &args{funcRef.arguments()};
+  for (std::optional<ActualArgument> &arg : args) {
     if (arg.has_value()) {
       *arg->value = FoldOperation(context, std::move(*arg->value));
     }
   }
   if (auto *intrinsic{std::get_if<SpecificIntrinsic>(&funcRef.proc().u)}) {
     const std::string name{intrinsic->name};
-    if (name == "acos" || name == "acosh" ||
-        (name == "atan" && funcRef.arguments().size() == 1)) {
+    if (name == "acos" || name == "acosh" || name == "asin" ||
+        name == "asinh" || (name == "atan" && args.size() == 1) ||
+        name == "atanh" || name == "bessel_j0" || name == "bessel_j1" ||
+        name == "bessel_y0" || name == "bessel_y1" || name == "cos" ||
+        name == "cosh" || name == "erf" || name == "erfc" ||
+        name == "erfc_scaled" || name == "exp" || name == "gamma" ||
+        name == "log" || name == "log10" || name == "log_gamma" ||
+        name == "sin" || name == "sinh" || name == "sqrt" || name == "tan" ||
+        name == "tanh") {
+      CHECK(args.size() == 1);
       if (auto callable{context.hostIntrinsicsLibrary()
                             .GetHostProcedureWrapper<Scalar, T, T>(name)}) {
         return FoldElementalIntrinsic<T, T>(
@@ -341,46 +506,87 @@ Expr<Type<TypeCategory::Real, KIND>> FoldOperation(FoldingContext &context,
             KIND);
       }
     }
-    if (name == "atan") {
-      if (auto callable{context.hostIntrinsicsLibrary()
-                            .GetHostProcedureWrapper<Scalar, T, T, T>(name)}) {
+    if (name == "atan" || name == "atan2" || name == "hypot" || name == "mod") {
+      std::string localName{name == "atan2" ? "atan" : name};
+      CHECK(args.size() == 2);
+      if (auto callable{
+              context.hostIntrinsicsLibrary()
+                  .GetHostProcedureWrapper<Scalar, T, T, T>(localName)}) {
         return FoldElementalIntrinsic<T, T, T>(
             context, std::move(funcRef), *callable);
       } else {
         context.messages().Say(
             "%s(real(kind=%d), real(kind%d)) cannot be folded on host"_en_US,
-            name.c_str(), KIND);
+            name.c_str(), KIND, KIND);
       }
     } else if (name == "bessel_jn" || name == "bessel_yn") {
-      if (funcRef.arguments().size() == 2) {  // elemental
-        using Int8 = Type<TypeCategory::Integer, 8>;
-        if (auto *n{std::get_if<Expr<SomeInteger>>(
-                &funcRef.arguments()[0]->value->u)}) {
-          *funcRef.arguments()[0]->value =
-              Fold(context, ConvertToType<Int8>(std::move(*n)));
+      if (args.size() == 2) {  // elemental
+        // runtime functions use int arg
+        using Int4 = Type<TypeCategory::Integer, 4>;
+        if (auto *n{std::get_if<Expr<SomeInteger>>(&args[0]->value->u)}) {
+          if (n->GetType()->kind != 4) {
+            *args[0]->value = Fold(context, ConvertToType<Int4>(std::move(*n)));
+          }
         }
         if (auto callable{
                 context.hostIntrinsicsLibrary()
-                    .GetHostProcedureWrapper<Scalar, T, Int8, T>(name)}) {
-          return FoldElementalIntrinsic<T, Int8, T>(
+                    .GetHostProcedureWrapper<Scalar, T, Int4, T>(name)}) {
+          return FoldElementalIntrinsic<T, Int4, T>(
               context, std::move(funcRef), *callable);
         } else {
           context.messages().Say(
-              "%s(integer(kind=8), real(kind=%d)) cannot be folded on host"_en_US,
+              "%s(integer(kind=4), real(kind=%d)) cannot be folded on host"_en_US,
               name.c_str(), KIND);
         }
       }
+    } else if (name == "abs") {
+      // Argument can be complex or real
+      if (auto *x{std::get_if<Expr<SomeReal>>(&args[0]->value->u)}) {
+        return FoldElementalIntrinsic<T, T>(
+            context, std::move(funcRef), &Scalar<T>::ABS);
+      } else if (auto *z{std::get_if<Expr<SomeComplex>>(&args[0]->value->u)}) {
+        if (auto callable{
+                context.hostIntrinsicsLibrary()
+                    .GetHostProcedureWrapper<Scalar, T, ComplexT>("abs")}) {
+          return FoldElementalIntrinsic<T, ComplexT>(
+              context, std::move(funcRef), *callable);
+        } else {
+          context.messages().Say(
+              "abs(complex(kind=%d)) cannot be folded on host"_en_US, KIND);
+        }
+      } else {
+        common::die(" unexpected argument type inside abs");
+      }
+    } else if (name == "aimag") {
+      return FoldElementalIntrinsic<T, ComplexT>(
+          context, std::move(funcRef), &Scalar<ComplexT>::AIMAG);
+    } else if (name == "aint") {
+      // Convert argument to the requested kind before calling aint
+      if (auto *x{std::get_if<Expr<SomeReal>>(&args[0]->value->u)}) {
+        if (!(x->GetType()->kind == T::kind)) {
+          *args[0]->value = Fold(context, ConvertToType<T>(std::move(*x)));
+        }
+      }
+      return FoldElementalIntrinsic<T, T>(context, std::move(funcRef),
+          ScalarFunc<T, T>([&name, &context](const Scalar<T> &x) -> Scalar<T> {
+            ValueWithRealFlags<Scalar<T>> y{x.AINT()};
+            if (y.flags.test(RealFlag::Overflow)) {
+              context.messages().Say(
+                  "%s intrinsic folding overflow"_en_US, name.c_str());
+            }
+            return y.value;
+          }));
     } else if (name == "dprod") {
-      if (auto *x{
-              std::get_if<Expr<SomeReal>>(&funcRef.arguments()[0]->value->u)}) {
-        if (auto *y{std::get_if<Expr<SomeReal>>(
-                &funcRef.arguments()[1]->value->u)}) {
+      if (auto *x{std::get_if<Expr<SomeReal>>(&args[0]->value->u)}) {
+        if (auto *y{std::get_if<Expr<SomeReal>>(&args[1]->value->u)}) {
           return Fold(context,
               Expr<T>{Multiply<T>{ConvertToType<T>(std::move(*x)),
                   ConvertToType<T>(std::move(*y))}});
         }
       }
       common::die("Wrong argument type in dprod()");
+    } else if (name == "epsilon") {
+      return Expr<T>{Constant<T>{Scalar<T>::EPSILON()}};
     } else if (name == "real") {
       return std::visit(
           [&](auto &&x) -> Expr<T> {
@@ -402,9 +608,13 @@ Expr<Type<TypeCategory::Real, KIND>> FoldOperation(FoldingContext &context,
               return Expr<T>{std::move(funcRef)};  // unreachable
             }
           },
-          std::move(funcRef.arguments()[0]->value->u));
+          std::move(args[0]->value->u));
     }
-    // TODO: many more intrinsic functions
+    // TODO: anint, cshift, dim, dot_product, eoshift, fraction, huge, matmul,
+    // max, maxval, merge, min, minval, modulo, nearest, norm2, pack, product,
+    // reduce, reshape, rrspacing, scale, set_exponent, sign, spacing, spread,
+    // sum, tiny, transfer, transpose, unpack, bessel_jn (transformational) and
+    // bessel_yn (transformational)
   }
   return Expr<T>{std::move(funcRef)};
 }
@@ -413,15 +623,18 @@ template<int KIND>
 Expr<Type<TypeCategory::Complex, KIND>> FoldOperation(FoldingContext &context,
     FunctionRef<Type<TypeCategory::Complex, KIND>> &&funcRef) {
   using T = Type<TypeCategory::Complex, KIND>;
-  for (std::optional<ActualArgument> &arg : funcRef.arguments()) {
+  ActualArguments &args{funcRef.arguments()};
+  for (std::optional<ActualArgument> &arg : args) {
     if (arg.has_value()) {
       *arg->value = FoldOperation(context, std::move(*arg->value));
     }
   }
   if (auto *intrinsic{std::get_if<SpecificIntrinsic>(&funcRef.proc().u)}) {
     const std::string name{intrinsic->name};
-    if (name == "acos" || name == "acosh" || name == "asin" || name == "atan" ||
-        name == "atanh") {
+    if (name == "acos" || name == "acosh" || name == "asin" ||
+        name == "asinh" || name == "atan" || name == "atanh" || name == "cos" ||
+        name == "cosh" || name == "exp" || name == "log" || name == "sin" ||
+        name == "sinh" || name == "sqrt" || name == "tan" || name == "tanh") {
       if (auto callable{context.hostIntrinsicsLibrary()
                             .GetHostProcedureWrapper<Scalar, T, T>(name)}) {
         return FoldElementalIntrinsic<T, T>(
@@ -431,22 +644,23 @@ Expr<Type<TypeCategory::Complex, KIND>> FoldOperation(FoldingContext &context,
             "%s(complex(kind=%d)) cannot be folded on host"_en_US, name.c_str(),
             KIND);
       }
-    }
-    if (name == "cmplx") {
-      if (funcRef.arguments().size() == 2) {
-        if (auto *x{std::get_if<Expr<SomeComplex>>(
-                &funcRef.arguments()[0]->value->u)}) {
+    } else if (name == "conjg") {
+      return FoldElementalIntrinsic<T, T>(
+          context, std::move(funcRef), &Scalar<T>::CONJG);
+    } else if (name == "cmplx") {
+      if (args.size() == 2) {
+        if (auto *x{std::get_if<Expr<SomeComplex>>(&args[0]->value->u)}) {
           return Fold(context, ConvertToType<T>(std::move(*x)));
         } else {
           common::die("x must be complex in cmplx(x[, kind])");
         }
       } else {
-        CHECK(funcRef.arguments().size() == 3);
+        CHECK(args.size() == 3);
         using Part = typename T::Part;
-        Expr<SomeType> im{funcRef.arguments()[1].has_value()
-                ? std::move(*funcRef.arguments()[1]->value)
+        Expr<SomeType> im{args[1].has_value()
+                ? std::move(*args[1]->value)
                 : AsGenericExpr(Constant<Part>{Scalar<Part>{}})};
-        Expr<SomeType> re{std::move(*funcRef.arguments()[0]->value)};
+        Expr<SomeType> re{std::move(*args[0]->value)};
         int reRank{re.Rank()};
         int imRank{im.Rank()};
         semantics::Attrs attrs;
@@ -464,7 +678,8 @@ Expr<Type<TypeCategory::Complex, KIND>> FoldOperation(FoldingContext &context,
                 Expr<Part>{std::move(reReal)}, Expr<Part>{std::move(imReal)}}});
       }
     }
-    // TODO: many more intrinsic functions
+    // TODO: cshift, dot_product, eoshift, matmul, merge, pack, product,
+    // reduce, reshape, spread, sum, transfer, transpose, unpack
   }
   return Expr<T>{std::move(funcRef)};
 }
@@ -473,42 +688,51 @@ template<int KIND>
 Expr<Type<TypeCategory::Logical, KIND>> FoldOperation(FoldingContext &context,
     FunctionRef<Type<TypeCategory::Logical, KIND>> &&funcRef) {
   using T = Type<TypeCategory::Logical, KIND>;
-  for (std::optional<ActualArgument> &arg : funcRef.arguments()) {
+  ActualArguments &args{funcRef.arguments()};
+  for (std::optional<ActualArgument> &arg : args) {
     if (arg.has_value()) {
       *arg->value = FoldOperation(context, std::move(*arg->value));
     }
   }
   if (auto *intrinsic{std::get_if<SpecificIntrinsic>(&funcRef.proc().u)}) {
     std::string name{intrinsic->name};
-    if (name == "bge") {
+    if (name == "bge" || name == "bgt" || name == "ble" || name == "blt") {
       using LargestInt = Type<TypeCategory::Integer, 16>;
       static_assert(std::is_same_v<Scalar<LargestInt>, BOZLiteralConstant>);
-      if (auto *x{std::get_if<Expr<SomeInteger>>(
-              &funcRef.arguments()[0]->value->u)}) {
-        *funcRef.arguments()[0]->value =
-            Fold(context, ConvertToType<LargestInt>(std::move(*x)));
-      } else if (auto *x{std::get_if<BOZLiteralConstant>(
-                     &funcRef.arguments()[0]->value->u)}) {
-        *funcRef.arguments()[0]->value =
-            AsGenericExpr(Constant<LargestInt>{std::move(*x)});
+      // Arguments do not have to be of the same integer type. Convert all
+      // arguments to the biggest integer type before comparing them to
+      // simplify.
+      for (int i{0}; i <= 1; ++i) {
+        if (auto *x{std::get_if<Expr<SomeInteger>>(&args[i]->value->u)}) {
+          *args[i]->value =
+              Fold(context, ConvertToType<LargestInt>(std::move(*x)));
+        } else if (auto *x{
+                       std::get_if<BOZLiteralConstant>(&args[i]->value->u)}) {
+          *args[i]->value = AsGenericExpr(Constant<LargestInt>{std::move(*x)});
+        }
       }
-      if (auto *x{std::get_if<Expr<SomeInteger>>(
-              &funcRef.arguments()[1]->value->u)}) {
-        *funcRef.arguments()[1]->value =
-            Fold(context, ConvertToType<LargestInt>(std::move(*x)));
-      } else if (auto *x{std::get_if<BOZLiteralConstant>(
-                     &funcRef.arguments()[1]->value->u)}) {
-        *funcRef.arguments()[1]->value =
-            AsGenericExpr(Constant<LargestInt>{std::move(*x)});
+      auto fptr{&Scalar<LargestInt>::BGE};
+      if (name == "bge") {  // done in fptr declaration
+      } else if (name == "bgt") {
+        fptr = &Scalar<LargestInt>::BGT;
+      } else if (name == "ble") {
+        fptr = &Scalar<LargestInt>::BLE;
+      } else if (name == "blt") {
+        fptr = &Scalar<LargestInt>::BLT;
+      } else {
+        common::die("missing case to fold intrinsic function %s", name);
       }
       return FoldElementalIntrinsic<T, LargestInt, LargestInt>(context,
           std::move(funcRef),
           ScalarFunc<T, LargestInt, LargestInt>(
-              [](const Scalar<LargestInt> &i, const Scalar<LargestInt> &j) {
-                return Scalar<T>{i.BGE(j)};
+              [&fptr](
+                  const Scalar<LargestInt> &i, const Scalar<LargestInt> &j) {
+                return Scalar<T>{std::invoke(fptr, i, j)};
               }));
     }
-    // TODO: many more intrinsic functions
+    // TODO: all, any, btest, cshift, dot_product, eoshift, is_iostat_end,
+    // is_iostat_eor, lge, lgt, lle, llt, logical, matmul, merge, out_of_range,
+    // pack, parity, reduce, reshape, spread, transfer, transpose, unpack
   }
   return Expr<T>{std::move(funcRef)};
 }