[mlir][arith] Convert fastmath to LLVM dialect for some arith ops.
authorSlava Zakharin <szakharin@nvidia.com>
Tue, 8 Nov 2022 03:38:36 +0000 (19:38 -0800)
committerSlava Zakharin <szakharin@nvidia.com>
Tue, 8 Nov 2022 03:39:51 +0000 (19:39 -0800)
This is a follow-up on D126305 and D136225.
We can now preserve fastmath for arith::MaxFOp,MinFOp,RemFOp during
ArithToLLVM conversion.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D137456

mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir

index 1409b7f..3ad0155 100644 (file)
@@ -52,16 +52,16 @@ using FPToSIOpLowering =
     VectorConvertToLLVMPattern<arith::FPToSIOp, LLVM::FPToSIOp>;
 using FPToUIOpLowering =
     VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp>;
-// TODO: Add LLVM intrinsic support for fastmath
-using MaxFOpLowering = VectorConvertToLLVMPattern<arith::MaxFOp, LLVM::MaxNumOp,
-                                                  arith::AttrDropFastMath>;
+using MaxFOpLowering =
+    VectorConvertToLLVMPattern<arith::MaxFOp, LLVM::MaxNumOp,
+                               arith::AttrConvertFastMathToLLVM>;
 using MaxSIOpLowering =
     VectorConvertToLLVMPattern<arith::MaxSIOp, LLVM::SMaxOp>;
 using MaxUIOpLowering =
     VectorConvertToLLVMPattern<arith::MaxUIOp, LLVM::UMaxOp>;
-// TODO: Add LLVM intrinsic support for fastmath
-using MinFOpLowering = VectorConvertToLLVMPattern<arith::MinFOp, LLVM::MinNumOp,
-                                                  arith::AttrDropFastMath>;
+using MinFOpLowering =
+    VectorConvertToLLVMPattern<arith::MinFOp, LLVM::MinNumOp,
+                               arith::AttrConvertFastMathToLLVM>;
 using MinSIOpLowering =
     VectorConvertToLLVMPattern<arith::MinSIOp, LLVM::SMinOp>;
 using MinUIOpLowering =
@@ -74,9 +74,9 @@ using NegFOpLowering =
     VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp,
                                arith::AttrConvertFastMathToLLVM>;
 using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>;
-// TODO: Add LLVM intrinsic support for fastmath
-using RemFOpLowering = VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
-                                                  arith::AttrDropFastMath>;
+using RemFOpLowering =
+    VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
+                               arith::AttrConvertFastMathToLLVM>;
 using RemSIOpLowering =
     VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>;
 using RemUIOpLowering =
index eccd875..d8e49a5 100644 (file)
@@ -453,11 +453,11 @@ func.func @minmaxf(%arg0 : f32, %arg1 : f32) -> f32 {
 
 // CHECK-LABEL: @fastmath
 func.func @fastmath(%arg0: f32, %arg1: f32, %arg2: i32) {
-// CHECK: {{.*}} = llvm.fadd %arg0, %arg1  {fastmathFlags = #llvm.fastmath<fast>} : f32
-// CHECK: {{.*}} = llvm.fmul %arg0, %arg1  {fastmathFlags = #llvm.fastmath<fast>} : f32
-// CHECK: {{.*}} = llvm.fneg %arg0  {fastmathFlags = #llvm.fastmath<fast>} : f32
-// CHECK: {{.*}} = llvm.fadd %arg0, %arg1  : f32
-// CHECK: {{.*}} = llvm.fadd %arg0, %arg1  {fastmathFlags = #llvm.fastmath<nnan, ninf>} : f32
+// CHECK: llvm.fadd %arg0, %arg1  {fastmathFlags = #llvm.fastmath<fast>} : f32
+// CHECK: llvm.fmul %arg0, %arg1  {fastmathFlags = #llvm.fastmath<fast>} : f32
+// CHECK: llvm.fneg %arg0  {fastmathFlags = #llvm.fastmath<fast>} : f32
+// CHECK: llvm.fadd %arg0, %arg1  : f32
+// CHECK: llvm.fadd %arg0, %arg1  {fastmathFlags = #llvm.fastmath<nnan, ninf>} : f32
   %0 = arith.addf %arg0, %arg1 fastmath<fast> : f32
   %1 = arith.mulf %arg0, %arg1 fastmath<fast> : f32
   %2 = arith.negf %arg0 fastmath<fast> : f32
@@ -465,3 +465,26 @@ func.func @fastmath(%arg0: f32, %arg1: f32, %arg2: i32) {
   %4 = arith.addf %arg0, %arg1 fastmath<nnan,ninf> : f32
   return
 }
+
+// -----
+
+// CHECK-LABEL: @ops_supporting_fastmath
+func.func @ops_supporting_fastmath(%arg0: f32, %arg1: f32, %arg2: i32) {
+// CHECK: llvm.fadd %arg0, %arg1  {fastmathFlags = #llvm.fastmath<fast>} : f32
+  %0 = arith.addf %arg0, %arg1 fastmath<fast> : f32
+// CHECK: llvm.fdiv %arg0, %arg1  {fastmathFlags = #llvm.fastmath<fast>} : f32
+  %1 = arith.divf %arg0, %arg1 fastmath<fast> : f32
+// CHECK: llvm.intr.maxnum(%arg0, %arg1) {fastmathFlags = #llvm.fastmath<fast>} : (f32, f32) -> f32
+  %2 = arith.maxf %arg0, %arg1 fastmath<fast> : f32
+// CHECK: llvm.intr.minnum(%arg0, %arg1) {fastmathFlags = #llvm.fastmath<fast>} : (f32, f32) -> f32
+  %3 = arith.minf %arg0, %arg1 fastmath<fast> : f32
+// CHECK: llvm.fmul %arg0, %arg1  {fastmathFlags = #llvm.fastmath<fast>} : f32
+  %4 = arith.mulf %arg0, %arg1 fastmath<fast> : f32
+// CHECK: llvm.fneg %arg0  {fastmathFlags = #llvm.fastmath<fast>} : f32
+  %5 = arith.negf %arg0 fastmath<fast> : f32
+// CHECK: llvm.frem %arg0, %arg1  {fastmathFlags = #llvm.fastmath<fast>} : f32
+  %6 = arith.remf %arg0, %arg1 fastmath<fast> : f32
+// CHECK: llvm.fsub %arg0, %arg1  {fastmathFlags = #llvm.fastmath<fast>} : f32
+  %7 = arith.subf %arg0, %arg1 fastmath<fast> : f32
+  return
+}