[flang] Lower integer exponentiation into math::IPowI.
authorSlava Zakharin <szakharin@nvidia.com>
Fri, 26 Aug 2022 23:12:25 +0000 (16:12 -0700)
committerSlava Zakharin <szakharin@nvidia.com>
Tue, 30 Aug 2022 21:09:05 +0000 (14:09 -0700)
Differential Revision: https://reviews.llvm.org/D132770

flang/lib/Lower/IntrinsicCall.cpp
flang/lib/Optimizer/CodeGen/CMakeLists.txt
flang/lib/Optimizer/CodeGen/CodeGen.cpp
flang/test/Intrinsics/math-codegen.fir
flang/test/Lower/power-operator.f90

index 42289ff..a40dfe9 100644 (file)
@@ -1153,6 +1153,12 @@ static mlir::FunctionType genF32IntF32FuncType(mlir::MLIRContext *context) {
   return mlir::FunctionType::get(context, {itype, ftype}, {ftype});
 }
 
+template <int Bits>
+static mlir::FunctionType genIntIntIntFuncType(mlir::MLIRContext *context) {
+  auto itype = mlir::IntegerType::get(context, Bits);
+  return mlir::FunctionType::get(context, {itype, itype}, {itype});
+}
+
 /// Callback type for generating lowering for a math operation.
 using MathGeneratorTy = mlir::Value (*)(fir::FirOpBuilder &, mlir::Location,
                                         llvm::StringRef, mlir::FunctionType,
@@ -1220,7 +1226,12 @@ static mlir::Value genMathOp(fir::FirOpBuilder &builder, mlir::Location loc,
   //           can be also lowered to libm calls for "fast" and "relaxed"
   //           modes.
   mlir::Value result;
-  if (mathRuntimeVersion == preciseVersion) {
+  if (mathRuntimeVersion == preciseVersion &&
+      // Some operations do not have to be lowered as conservative
+      // calls, since they do not affect strict FP behavior.
+      // For example, purely integer operations like exponentiation
+      // with integer operands fall into this class.
+      !mathLibFuncName.empty()) {
     result = genLibCall(builder, loc, mathLibFuncName, mathLibFuncType, args);
   } else {
     LLVM_DEBUG(llvm::dbgs() << "Generating '" << mathLibFuncName
@@ -1310,6 +1321,10 @@ static constexpr MathOperation mathOperations[] = {
     {"nint", "llvm.lround.i64.f32", genIntF32FuncType<64>, genLibCall},
     {"nint", "llvm.lround.i32.f64", genIntF64FuncType<32>, genLibCall},
     {"nint", "llvm.lround.i32.f32", genIntF32FuncType<32>, genLibCall},
+    {"pow", {}, genIntIntIntFuncType<8>, genMathOp<mlir::math::IPowIOp>},
+    {"pow", {}, genIntIntIntFuncType<16>, genMathOp<mlir::math::IPowIOp>},
+    {"pow", {}, genIntIntIntFuncType<32>, genMathOp<mlir::math::IPowIOp>},
+    {"pow", {}, genIntIntIntFuncType<64>, genMathOp<mlir::math::IPowIOp>},
     {"pow", "powf", genF32F32F32FuncType, genMathOp<mlir::math::PowFOp>},
     {"pow", "pow", genF64F64F64FuncType, genMathOp<mlir::math::PowFOp>},
     // TODO: add PowIOp in math and complex dialects.
index e4bc0d8..41e7908 100644 (file)
@@ -17,6 +17,7 @@ add_flang_library(FIRCodeGen
   FIRBuilder
   FIRDialect
   FIRSupport
+  MLIRMathToFuncs
   MLIRMathToLLVM
   MLIRMathToLibm
   MLIROpenMPToLLVM
index 00cb71b..a6b313a 100644 (file)
 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
 #include "mlir/Conversion/MathToLibm/MathToLibm.h"
 #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
 #include "llvm/ADT/ArrayRef.h"
 
@@ -3291,6 +3293,18 @@ public:
     if (!forcedTargetTriple.empty())
       fir::setTargetTriple(mod, forcedTargetTriple);
 
+    // Run dynamic pass pipeline for converting Math dialect
+    // operations into other dialects (llvm, func, etc.).
+    // Some conversions of Math operations cannot be done
+    // by just using conversion patterns. This is true for
+    // conversions that affect the ModuleOp, e.g. create new
+    // function operations in it. We have to run such conversions
+    // as passes here.
+    mlir::OpPassManager mathConvertionPM("builtin.module");
+    mathConvertionPM.addPass(mlir::createConvertMathToFuncsPass());
+    if (mlir::failed(runPipeline(mathConvertionPM, mod)))
+      return signalPassFailure();
+
     auto *context = getModule().getContext();
     fir::LLVMTypeConverter typeConverter{getModule()};
     mlir::RewritePatternSet pattern(context);
index c8f99b5..2c658d4 100644 (file)
@@ -1466,6 +1466,19 @@ func.func private @powf(f32, f32) -> f32
 func.func private @llvm.powi.f64.i32(f64, i32) -> f64
 func.func private @pow(f64, f64) -> f64
 
+//--- exponentiation_integer.fir
+// RUN: fir-opt %t/exponentiation_integer.fir --fir-to-llvm-ir="target=x86_64-unknown-linux-gnu" | FileCheck %t/exponentiation_integer.fir
+// CHECK: @_QPtest_int4
+// CHECK: llvm.call @__mlir_math_ipowi_i32({{%[A-Za-z0-9._]+}}, {{%[A-Za-z0-9._]+}}) : (i32, i32) -> i32
+
+func.func @_QPtest_int4(%arg0: !fir.ref<i32> {fir.bindc_name = "x"}, %arg1: !fir.ref<i32> {fir.bindc_name = "y"}, %arg2: !fir.ref<i32> {fir.bindc_name = "z"}) {
+  %0 = fir.load %arg0 : !fir.ref<i32>
+  %1 = fir.load %arg1 : !fir.ref<i32>
+  %2 = math.ipowi %0, %1 : i32
+  fir.store %2 to %arg2 : !fir.ref<i32>
+  return
+}
+
 //--- sign_fast.fir
 // RUN: fir-opt %t/sign_fast.fir --fir-to-llvm-ir="target=x86_64-unknown-linux-gnu" | FileCheck %t/sign_fast.fir
 // CHECK: @_QPtest_real4
index d542935..d99a46b 100644 (file)
@@ -57,18 +57,32 @@ subroutine pow_r4_r8(x, y, z)
   ! CHECK: math.powf %{{.*}}, %{{.*}} : f64
 end subroutine
 
+! CHECK-LABEL: pow_i1_i1
+subroutine pow_i1_i1(x, y, z)
+  integer(1) :: x, y, z
+  z = x ** y
+  ! CHECK: math.ipowi %{{.*}}, %{{.*}} : i8
+end subroutine
+
+! CHECK-LABEL: pow_i2_i2
+subroutine pow_i2_i2(x, y, z)
+  integer(2) :: x, y, z
+  z = x ** y
+  ! CHECK: math.ipowi %{{.*}}, %{{.*}} : i16
+end subroutine
+
 ! CHECK-LABEL: pow_i4_i4
 subroutine pow_i4_i4(x, y, z)
   integer(4) :: x, y, z
   z = x ** y
-  ! CHECK: call @__mth_i_ipowi
+  ! CHECK: math.ipowi %{{.*}}, %{{.*}} : i32
 end subroutine
 
 ! CHECK-LABEL: pow_i8_i8
 subroutine pow_i8_i8(x, y, z)
   integer(8) :: x, y, z
   z = x ** y
-  ! CHECK: call @__mth_i_kpowk
+  ! CHECK: math.ipowi %{{.*}}, %{{.*}} : i64
 end subroutine
 
 ! CHECK-LABEL: pow_c4_i4