[Flang][HLFIR] Intrinsics: Propagate fast math flags
authorKiran Chandramohan <kiran.chandramohan@arm.com>
Tue, 18 Jul 2023 09:10:45 +0000 (09:10 +0000)
committerKiran Chandramohan <kiran.chandramohan@arm.com>
Tue, 18 Jul 2023 09:26:59 +0000 (09:26 +0000)
Add a new FirOpBuilder constructor to propagate the fast math flag
from an operation. Use this constructor in the LowerHLFIRIntrinsics
pass.

This fixes the performance issue with the hlfir intrinsics flow
for polyhedron/test_fpu2.

Reviewed By: tblah, vzakhari

Differential Revision: https://reviews.llvm.org/D155438

flang/include/flang/Optimizer/Builder/FIRBuilder.h
flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
flang/test/HLFIR/matmul-lowering.fir
flang/test/HLFIR/product-lowering.fir
flang/test/HLFIR/sum-lowering.fir

index 606e79c..943e313 100644 (file)
@@ -53,6 +53,16 @@ public:
         kindMap{getKindMapping(mod)} {
     setListener(this);
   }
+  explicit FirOpBuilder(mlir::OpBuilder &builder, fir::KindMapping kindMap,
+                        mlir::Operation *op)
+      : OpBuilder(builder), OpBuilder::Listener(), kindMap{std::move(kindMap)} {
+    setListener(this);
+    auto fmi = mlir::dyn_cast<mlir::arith::ArithFastMathInterface>(*op);
+    if (fmi) {
+      // Set the builder with FastMathFlags attached to the operation.
+      setFastMathFlags(fmi.getFastMathFlagsAttr().getValue());
+    }
+  }
 
   // The listener self-reference has to be updated in case of copy-construction.
   FirOpBuilder(const FirOpBuilder &other)
index d3c604e..9dfc6d0 100644 (file)
@@ -61,7 +61,7 @@ protected:
     mlir::Type desiredType;
   };
 
-  /// Lower the arguments to the intrinsic: adding nesecarry boxing and
+  /// Lower the arguments to the intrinsic: adding necessary boxing and
   /// conversion to match the signature of the intrinsic in the runtime library.
   llvm::SmallVector<fir::ExtendedValue, 3>
   lowerArguments(mlir::Operation *op,
@@ -70,7 +70,7 @@ protected:
                  const fir::IntrinsicArgumentLoweringRules *argLowering) const {
     mlir::Location loc = op->getLoc();
     fir::KindMapping kindMapping{rewriter.getContext()};
-    fir::FirOpBuilder builder{rewriter, kindMapping};
+    fir::FirOpBuilder builder{rewriter, kindMapping, op};
 
     llvm::SmallVector<fir::ExtendedValue, 3> ret;
     llvm::SmallVector<std::function<void()>, 2> cleanupFns;
@@ -227,7 +227,7 @@ public:
     }
 
     fir::KindMapping kindMapping{rewriter.getContext()};
-    fir::FirOpBuilder builder{rewriter, kindMapping};
+    fir::FirOpBuilder builder{rewriter, kindMapping, operation};
     const mlir::Location &loc = operation->getLoc();
 
     mlir::Type i32 = builder.getI32Type();
@@ -269,7 +269,7 @@ struct CountOpConversion : public HlfirIntrinsicConversion<hlfir::CountOp> {
   matchAndRewrite(hlfir::CountOp count,
                   mlir::PatternRewriter &rewriter) const override {
     fir::KindMapping kindMapping{rewriter.getContext()};
-    fir::FirOpBuilder builder{rewriter, kindMapping};
+    fir::FirOpBuilder builder{rewriter, kindMapping, count};
     const mlir::Location &loc = count->getLoc();
 
     mlir::Type i32 = builder.getI32Type();
@@ -302,7 +302,7 @@ struct MatmulOpConversion : public HlfirIntrinsicConversion<hlfir::MatmulOp> {
   matchAndRewrite(hlfir::MatmulOp matmul,
                   mlir::PatternRewriter &rewriter) const override {
     fir::KindMapping kindMapping{rewriter.getContext()};
-    fir::FirOpBuilder builder{rewriter, kindMapping};
+    fir::FirOpBuilder builder{rewriter, kindMapping, matmul};
     const mlir::Location &loc = matmul->getLoc();
 
     mlir::Value lhs = matmul.getLhs();
@@ -334,7 +334,7 @@ struct DotProductOpConversion
   matchAndRewrite(hlfir::DotProductOp dotProduct,
                   mlir::PatternRewriter &rewriter) const override {
     fir::KindMapping kindMapping{rewriter.getContext()};
-    fir::FirOpBuilder builder{rewriter, kindMapping};
+    fir::FirOpBuilder builder{rewriter, kindMapping, dotProduct};
     const mlir::Location &loc = dotProduct->getLoc();
 
     mlir::Value lhs = dotProduct.getLhs();
@@ -366,7 +366,7 @@ class TransposeOpConversion
   matchAndRewrite(hlfir::TransposeOp transpose,
                   mlir::PatternRewriter &rewriter) const override {
     fir::KindMapping kindMapping{rewriter.getContext()};
-    fir::FirOpBuilder builder{rewriter, kindMapping};
+    fir::FirOpBuilder builder{rewriter, kindMapping, transpose};
     const mlir::Location &loc = transpose->getLoc();
 
     mlir::Value arg = transpose.getArray();
@@ -397,7 +397,7 @@ struct MatmulTransposeOpConversion
   matchAndRewrite(hlfir::MatmulTransposeOp multranspose,
                   mlir::PatternRewriter &rewriter) const override {
     fir::KindMapping kindMapping{rewriter.getContext()};
-    fir::FirOpBuilder builder{rewriter, kindMapping};
+    fir::FirOpBuilder builder{rewriter, kindMapping, multranspose};
     const mlir::Location &loc = multranspose->getLoc();
 
     mlir::Value lhs = multranspose.getLhs();
index 752afce..ee921da 100644 (file)
@@ -28,7 +28,7 @@ func.func @_QPmatmul1(%arg0: !fir.box<!fir.array<?x?xi32>> {fir.bindc_name = "lh
 // CHECK:         %[[RET_ARG:.*]] = fir.convert %[[RET_BOX]] : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>) -> !fir.ref<!fir.box<none>>
 // CHECK-DAG:     %[[LHS_ARG:.*]] = fir.convert %[[LHS_VAR]]#1 : (!fir.box<!fir.array<?x?xi32>>) -> !fir.box<none>
 // CHECK-DAG:     %[[RHS_ARG:.*]] = fir.convert %[[RHS_VAR]]#1 : (!fir.box<!fir.array<?x?xi32>>) -> !fir.box<none>
-// CHECK:         %[[NONE:.*]] = fir.call @_FortranAMatmul(%[[RET_ARG]], %[[LHS_ARG]], %[[RHS_ARG]], %[[LOC_STR:.*]], %[[LOC_N:.*]])
+// CHECK:         %[[NONE:.*]] = fir.call @_FortranAMatmul(%[[RET_ARG]], %[[LHS_ARG]], %[[RHS_ARG]], %[[LOC_STR:.*]], %[[LOC_N:.*]]) fastmath<contract>
 
 // CHECK:         %[[RET:.*]] = fir.load %[[RET_BOX]]
 // CHECK-DAG:     %[[BOX_DIMS:.*]]:3 = fir.box_dims %[[RET]]
@@ -71,11 +71,7 @@ func.func @_QPtest(%arg0: !fir.ref<!fir.array<3x3xf32>> {fir.bindc_name = "a"},
 }
 // just check that we apply the patterns successfully. The details are checked above
 // CHECK-LABEL: func.func @_QPtest(
-// CHECK:           %arg0: !fir.ref<!fir.array<3x3xf32>> {fir.bindc_name = "a"},
-// CHECK-SAME:      %arg1: !fir.ref<!fir.array<3x3xf32>> {fir.bindc_name = "b"},
-// CHECK-SAME:      %arg2: !fir.ref<!fir.array<3x3xf32>> {fir.bindc_name = "c"},
-// CHECK-SAME:      %arg3: !fir.ref<!fir.array<3x3xf32>> {fir.bindc_name = "out"}) {
-// CHECK:         fir.call @_FortranAMatmul(
-// CHECK;         fir.call @_FortranAMatmul(%40, %41, %42, %43, %c20_i32) : (!fir.ref<!fir.box<none>>, !fir.box<none>, !fir.box<none>, !fir.ref<i8>, i32) -> none
+// CHECK:         fir.call @_FortranAMatmul({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}) fastmath<contract> : (!fir.ref<!fir.box<none>>, !fir.box<none>, !fir.box<none>, !fir.ref<i8>, i32) -> none
+// CHECK:         fir.call @_FortranAMatmul({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}) fastmath<contract> : (!fir.ref<!fir.box<none>>, !fir.box<none>, !fir.box<none>, !fir.ref<i8>, i32) -> none
 // CHECK:         return
 // CHECK-NEXT:  }
index aedc499..337b5fc 100644 (file)
@@ -18,7 +18,7 @@ func.func @_QPproduct1(%arg0: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "a"}
 // CHECK-DAG:     %[[MASK:.*]] = fir.absent !fir.box<i1>
 // CHECK-DAG:     %[[ARRAY_ARG:.*]] = fir.convert %[[ARRAY]]#1 :  (!fir.box<!fir.array<?xi32>>) -> !fir.box<none>
 // CHECK-DAG:     %[[MASK_ARG:.*]] = fir.convert %[[MASK]] : (!fir.box<i1>) -> !fir.box<none>
-// CHECK:         %[[RET:.*]] = fir.call @_FortranAProductInteger4(%[[ARRAY_ARG]], %[[LOC_STR:.*]], %[[LOC_N:.*]], %[[INT:.*]], %[[MASK_ARG]]) : (!fir.box<none>, !fir.ref<i8>, i32, i32, !fir.box<none>) -> i32
+// CHECK:         %[[RET:.*]] = fir.call @_FortranAProductInteger4(%[[ARRAY_ARG]], %[[LOC_STR:.*]], %[[LOC_N:.*]], %[[INT:.*]], %[[MASK_ARG]]) fastmath<contract> : (!fir.box<none>, !fir.ref<i8>, i32, i32, !fir.box<none>) -> i32
 // CHECK-NEXT:    hlfir.assign %[[RET]] to %[[RES]]#0 : i32, !fir.ref<i32>
 // CHECK-NEXT:    return
 // CHECK-NEXT:  }
@@ -58,7 +58,7 @@ func.func @_QPproduct2(%arg0: !fir.box<!fir.array<?x?xi32>> {fir.bindc_name = "a
 // CHECK-DAG:     %[[ARRAY_ARG:.*]] = fir.convert %[[ARRAY]]
 // CHECK-DAG:     %[[MASK_ARG:.*]] = fir.convert %[[MASK]]
 
-// CHECK:         %[[NONE:.*]] = fir.call @_FortranAProductDim(%[[RET_ARG]], %[[ARRAY_ARG]], %[[DIM]], %[[LOC_STR:.*]], %[[LOC_N:.*]], %[[MASK_ARG]])  : (!fir.ref<!fir.box<none>>, !fir.box<none>, i32, !fir.ref<i8>, i32, !fir.box<none>) -> none
+// CHECK:         %[[NONE:.*]] = fir.call @_FortranAProductDim(%[[RET_ARG]], %[[ARRAY_ARG]], %[[DIM]], %[[LOC_STR:.*]], %[[LOC_N:.*]], %[[MASK_ARG]]) fastmath<contract> : (!fir.ref<!fir.box<none>>, !fir.box<none>, i32, !fir.ref<i8>, i32, !fir.box<none>) -> none
 // CHECK:         %[[RET:.*]] = fir.load %[[RET_BOX]]
 // CHECK:         %[[BOX_DIMS:.*]]:3 = fir.box_dims %[[RET]]
 // CHECK-NEXT:    %[[ADDR:.*]] = fir.box_addr %[[RET]]
@@ -91,7 +91,7 @@ func.func @_QPproduct3(%arg0: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "a"}
 // CHECK-DAG:      %[[MASK_BOX:.*]] = fir.embox %[[MASK]]#1 : (!fir.ref<!fir.logical<4>>) -> !fir.box<!fir.logical<4>>
 // CHECK-DAG:      %[[ARRAY_ARG:.*]] = fir.convert %[[ARRAY]]#1 : (!fir.box<!fir.array<?xi32>>) -> !fir.box<none>
 // CHECK-DAG:      %[[MASK_ARG:.*]] = fir.convert %[[MASK_BOX]]  : (!fir.box<!fir.logical<4>>) -> !fir.box<none>
-// CHECK:          %[[RET:.*]] = fir.call @_FortranAProductInteger4(%[[ARRAY_ARG]], %[[LOC_STR:.*]], %[[LOC_N:.*]], %[[INT:.*]], %[[MASK_ARG]]) : (!fir.box<none>, !fir.ref<i8>, i32, i32, !fir.box<none>) -> i32
+// CHECK:          %[[RET:.*]] = fir.call @_FortranAProductInteger4(%[[ARRAY_ARG]], %[[LOC_STR:.*]], %[[LOC_N:.*]], %[[INT:.*]], %[[MASK_ARG]]) fastmath<contract> : (!fir.box<none>, !fir.ref<i8>, i32, i32, !fir.box<none>) -> i32
 // CHECK-NEXT:     hlfir.assign %[[RET]] to %[[RES]]#0 : i32, !fir.ref<i32>
 // CHECK-NEXT:     return
 // CHECK-NEXT:   }
@@ -115,7 +115,7 @@ func.func @_QPproduct4(%arg0: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "a"}
 // CHECK-DAG:     %[[MASK]]:2 = hlfir.declare %[[ARG2]]
 // CHECK-DAG:     %[[ARRAY_ARG:.*]] = fir.convert %[[ARRAY]]#1  : (!fir.box<!fir.array<?xi32>>) -> !fir.box<none>
 // CHECK-DAG:     %[[MASK_ARG:.*]] = fir.convert %[[MASK]]#1 : (!fir.box<!fir.array<?x!fir.logical<4>>>) -> !fir.box<none>
-// CHECK:         %[[RET:.*]] = fir.call @_FortranAProductInteger4(%[[ARRAY_ARG]], %[[LOC_STR:.*]], %[[LOC_N:.*]], %[[INT:.*]], %[[MASK_ARG]]) : (!fir.box<none>, !fir.ref<i8>, i32, i32, !fir.box<none>) -> i32
+// CHECK:         %[[RET:.*]] = fir.call @_FortranAProductInteger4(%[[ARRAY_ARG]], %[[LOC_STR:.*]], %[[LOC_N:.*]], %[[INT:.*]], %[[MASK_ARG]]) fastmath<contract> : (!fir.box<none>, !fir.ref<i8>, i32, i32, !fir.box<none>) -> i32
 // CHECK-NEXT:    hlfir.assign %[[RET]] to %[[RES]]#0 : i32, !fir.ref<i32>
 // CHECK-NEXT:    return
 // CHECK-NEXT:  }
@@ -165,4 +165,4 @@ func.func @_QPproduct5(%arg0: !fir.ref<!fir.array<2xi32>> {fir.bindc_name = "s"}
 // CHECK-DAG:     %[[RET_ARG:.*]] = fir.convert %[[RET_BOX]]
 // CHECK-DAG:     %[[ARRAY_ARG:.*]] = fir.convert %[[ARRAY_BOX]] : (!fir.box<!fir.array<2x2xi32>>) -> !fir.box<none>
 // CHECK-DAG:     %[[MASK_ARG:.*]] = fir.convert %[[MASK_BOX]] : (!fir.box<!fir.logical<4>>) -> !fir.box<none>
-// CHECK:         %[[NONE:.*]] = fir.call @_FortranAProductDim(%[[RET_ARG]], %[[ARRAY_ARG]], %[[DIM]], %[[LOC_STR:.*]], %[[LOC_N:.*]], %[[MASK_ARG]]) : (!fir.ref<!fir.box<none>>, !fir.box<none>, i32, !fir.ref<i8>, i32, !fir.box<none>) -> none
+// CHECK:         %[[NONE:.*]] = fir.call @_FortranAProductDim(%[[RET_ARG]], %[[ARRAY_ARG]], %[[DIM]], %[[LOC_STR:.*]], %[[LOC_N:.*]], %[[MASK_ARG]]) fastmath<contract> : (!fir.ref<!fir.box<none>>, !fir.box<none>, i32, !fir.ref<i8>, i32, !fir.box<none>) -> none
index 8f5ddd7..e33b9bc 100644 (file)
@@ -17,7 +17,7 @@ func.func @_QPsum1(%arg0: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "a"}, %a
 // CHECK-DAG:     %[[MASK:.*]] = fir.absent !fir.box<i1>
 // CHECK-DAG:     %[[ARRAY_ARG:.*]] = fir.convert %[[ARRAY]]#1 : (!fir.box<!fir.array<?xi32>>) -> !fir.box<none>
 // CHECK-DAG:     %[[MASK_ARG:.*]] = fir.convert %[[MASK]] : (!fir.box<i1>) -> !fir.box<none>
-// CHECK:         %[[RET:.*]] = fir.call @_FortranASumInteger4(%[[ARRAY_ARG]], %[[LOC_STR:.*]], %[[LOC_N:.*]], %[[INT:.*]], %[[MASK_ARG]]) : (!fir.box<none>, !fir.ref<i8>, i32, i32, !fir.box<none>) -> i32
+// CHECK:         %[[RET:.*]] = fir.call @_FortranASumInteger4(%[[ARRAY_ARG]], %[[LOC_STR:.*]], %[[LOC_N:.*]], %[[INT:.*]], %[[MASK_ARG]]) fastmath<contract> : (!fir.box<none>, !fir.ref<i8>, i32, i32, !fir.box<none>) -> i32
 // CHECK-NEXT:    hlfir.assign %[[RET]] to %[[RES]]#0 : i32, !fir.ref<i32>
 // CHECK-NEXT:    return
 // CHECK-NEXT:  }
@@ -55,7 +55,7 @@ func.func @_QPsum2(%arg0: !fir.box<!fir.array<?x?xi32>> {fir.bindc_name = "a"},
 // CHECK-DAG:     %[[RET_ARG:.*]] = fir.convert %[[RET_BOX]]
 // CHECK-DAG:     %[[ARRAY_ARG:.*]] = fir.convert %[[ARRAY]]#1 : (!fir.box<!fir.array<?x?xi32>>) -> !fir.box<none>
 // CHECK-DAG:     %[[MASK_ARG:.*]] = fir.convert %[[MASK]] : (!fir.box<i1>) -> !fir.box<none>
-// CHECK:         %[[NONE:.*]] = fir.call @_FortranASumDim(%[[RET_ARG]], %[[ARRAY_ARG]], %[[DIM]], %[[LOC_STR:.*]], %[[LOC_N:.*]], %[[MASK_ARG]]) : (!fir.ref<!fir.box<none>>, !fir.box<none>, i32, !fir.ref<i8>, i32, !fir.box<none>) -> none
+// CHECK:         %[[NONE:.*]] = fir.call @_FortranASumDim(%[[RET_ARG]], %[[ARRAY_ARG]], %[[DIM]], %[[LOC_STR:.*]], %[[LOC_N:.*]], %[[MASK_ARG]]) fastmath<contract> : (!fir.ref<!fir.box<none>>, !fir.box<none>, i32, !fir.ref<i8>, i32, !fir.box<none>) -> none
 
 // CHECK:         %[[RET:.*]] = fir.load %[[RET_BOX]]
 // CHECK:         %[[BOX_DIMS:.*]]:3 = fir.box_dims %[[RET]]
@@ -89,7 +89,7 @@ func.func @_QPsum3(%arg0: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "a"}, %a
 // CHECK-DAG:     %[[MASK_BOX:.*]] = fir.embox %[[MASK]]#1 : (!fir.ref<!fir.logical<4>>) -> !fir.box<!fir.logical<4>>
 // CHECK-DAG:     %[[ARRAY_ARG:.*]] = fir.convert %[[ARRAY]]#1 : (!fir.box<!fir.array<?xi32>>) -> !fir.box<none>
 // CHECK-DAG:     %[[MASK_ARG:.*]] = fir.convert %[[MASK_BOX]] : (!fir.box<!fir.logical<4>>) -> !fir.box<none>
-// CHECK:         %[[RET:.*]] = fir.call @_FortranASumInteger4(%[[ARRAY_ARG]], %[[LOC_STR:.*]], %[[LOC_N:.*]], %[[INT:.*]], %[[MASK_ARG]]) : (!fir.box<none>, !fir.ref<i8>, i32, i32, !fir.box<none>) -> i32
+// CHECK:         %[[RET:.*]] = fir.call @_FortranASumInteger4(%[[ARRAY_ARG]], %[[LOC_STR:.*]], %[[LOC_N:.*]], %[[INT:.*]], %[[MASK_ARG]]) fastmath<contract> : (!fir.box<none>, !fir.ref<i8>, i32, i32, !fir.box<none>) -> i32
 // CHECK-NEXT:    hlfir.assign %[[RET]] to %[[RES]]#0 : i32, !fir.ref<i32>
 // CHECK-NEXT:    return
 // CHECK-NEXT:  }
@@ -112,7 +112,7 @@ func.func @_QPsum4(%arg0: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "a"}, %a
 // CHECK-DAG:     %[[MASK:.*]]:2 = hlfir.declare %[[ARG2]]
 // CHECK-DAG:     %[[ARRAY_ARG:.*]] = fir.convert %[[ARRAY]]#1 : (!fir.box<!fir.array<?xi32>>) -> !fir.box<none>
 // CHECK-DAG:     %[[MASK_ARG:.*]] = fir.convert %[[MASK]]#1 : (!fir.box<!fir.array<?x!fir.logical<4>>>) -> !fir.box<none>
-// CHECK:         %[[RET:.*]] = fir.call @_FortranASumInteger4(%[[ARRAY_ARG]], %[[LOC_STR:.*]], %[[LOC_N:.*]], %[[INT:.*]], %[[MASK_ARG]]) : (!fir.box<none>, !fir.ref<i8>, i32, i32, !fir.box<none>) -> i32
+// CHECK:         %[[RET:.*]] = fir.call @_FortranASumInteger4(%[[ARRAY_ARG]], %[[LOC_STR:.*]], %[[LOC_N:.*]], %[[INT:.*]], %[[MASK_ARG]]) fastmath<contract> : (!fir.box<none>, !fir.ref<i8>, i32, i32, !fir.box<none>) -> i32
 // CHECK-NEXT:    hlfir.assign %[[RET]] to %[[RES]]#0 : i32, !fir.ref<i32>
 // CHECK-NEXT:    return
 // CHECK-NEXT:  }
@@ -175,4 +175,4 @@ func.func @_QPsum5(%arg0: !fir.ref<!fir.array<2xi32>> {fir.bindc_name = "s"}) {
 // CHECK-DAG:     %[[RET_ARG:.*]] = fir.convert %[[RET_BOX]]
 // CHECK-DAG:     %[[ARRAY_ARG:.*]] = fir.convert %[[ARRAY_BOX]] : (!fir.box<!fir.array<2x2xi32>>) -> !fir.box<none>
 // CHECK-DAG:     %[[MASK_ARG:.*]] = fir.convert %[[MASK_BOX]] : (!fir.box<!fir.logical<4>>) -> !fir.box<none>
-// CHECK:         %[[NONE:.*]] = fir.call @_FortranASumDim(%[[RET_ARG]], %[[ARRAY_ARG]], %[[DIM]], %[[LOC_STR:.*]], %[[LOC_N:.*]], %[[MASK_ARG]]) : (!fir.ref<!fir.box<none>>, !fir.box<none>, i32, !fir.ref<i8>, i32, !fir.box<none>) -> none
+// CHECK:         %[[NONE:.*]] = fir.call @_FortranASumDim(%[[RET_ARG]], %[[ARRAY_ARG]], %[[DIM]], %[[LOC_STR:.*]], %[[LOC_N:.*]], %[[MASK_ARG]]) fastmath<contract> : (!fir.ref<!fir.box<none>>, !fir.box<none>, i32, !fir.ref<i8>, i32, !fir.box<none>) -> none