From ceb1b327b53c82203deaa0b0ede3fb07ec8f823a Mon Sep 17 00:00:00 2001 From: aartbik Date: Fri, 26 Jun 2020 11:03:11 -0700 Subject: [PATCH] [mlir] [VectorOps] Add the ability to mark FP reductions with "reassociate" attribute Rationale: In general, passing "fastmath" from MLIR to LLVM backend is not supported, and even just providing such a feature for experimentation is under debate. However, passing fine-grained fastmath related attributes on individual operations is generally accepted. This CL introduces an option to instruct the vector-to-llvm lowering phase to annotate floating-point reductions with the "reassociate" fastmath attribute, which allows the LLVM backend to use SIMD implementations for such constructs. Oher lowering passes can start using this mechanism right away in cases where reassociation is allowed. Benefit: For some microbenchmarks on x86-avx2, speedups over 20 were observed for longer vector (due to cleaner, spill-free and SIMD exploiting code). Usage: mlir-opt --convert-vector-to-llvm="reassociate-fp-reductions" Reviewed By: ftynse, mehdi_amini Differential Revision: https://reviews.llvm.org/D82624 --- mlir/include/mlir/Conversion/Passes.td | 5 +++ .../Conversion/VectorToLLVM/ConvertVectorToLLVM.h | 5 +-- mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td | 28 ++++++++++++--- .../VectorToLLVM/ConvertVectorToLLVM.cpp | 26 +++++++++----- .../VectorToLLVM/vector-reduction-to-llvm.mlir | 42 ++++++++++++++++++++++ .../Conversion/VectorToLLVM/vector-to-llvm.mlir | 2 ++ mlir/test/Target/llvmir-intrinsics.mlir | 4 +++ 7 files changed, 98 insertions(+), 14 deletions(-) create mode 100644 mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 48149ce..89b63e8 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -308,6 +308,11 @@ def ConvertVectorToLLVM : Pass<"convert-vector-to-llvm", "ModuleOp"> { let summary = "Lower the operations from the vector dialect into the LLVM " "dialect"; let constructor = "mlir::createConvertVectorToLLVMPass()"; + let options = [ + Option<"reassociateFPReductions", "reassociate-fp-reductions", + "bool", /*default=*/"false", + "Allows llvm to reassociate floating-point reductions for speed"> + ]; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h index e09a74d..cdff188 100644 --- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h +++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h @@ -23,8 +23,9 @@ void populateVectorToLLVMMatrixConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns); /// Collect a set of patterns to convert from the Vector dialect to LLVM. -void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter, - OwningRewritePatternList &patterns); +void populateVectorToLLVMConversionPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns, + bool reassociateFPReductions = false); /// Create a pass to convert vector operations to the LLVMIR dialect. std::unique_ptr> createConvertVectorToLLVMPass(); diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td index 48eecb4..d88b372 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -214,10 +214,30 @@ class LLVM_VectorReduction : LLVM_OneResultIntrOp<"experimental.vector.reduce." # mnem, [], [0], []>, Arguments<(ins LLVM_Type)>; -// LLVM vector reduction over a single vector, with an initial value. +// LLVM vector reduction over a single vector, with an initial value, +// and with permission to reassociate the reduction operations. class LLVM_VectorReductionV2 - : LLVM_OneResultIntrOp<"experimental.vector.reduce.v2." # mnem, - [0], [1], []>, - Arguments<(ins LLVM_Type, LLVM_Type)>; + : LLVM_OpBase, + Results<(outs LLVM_Type:$res)>, + Arguments<(ins LLVM_Type, LLVM_Type, + DefaultValuedAttr:$reassoc)> { + let llvmBuilder = [{ + llvm::Module *module = builder.GetInsertBlock()->getModule(); + llvm::Function *fn = llvm::Intrinsic::getDeclaration( + module, + llvm::Intrinsic::experimental_vector_reduce_v2_}] # mnem # [{, + { }] # StrJoin.lst, + ListIntSubst.lst)>.result # [{ + }); + auto operands = lookupValues(opInst.getOperands()); + llvm::FastMathFlags origFM = builder.getFastMathFlags(); + llvm::FastMathFlags tempFM = origFM; + tempFM.setAllowReassoc($reassoc); + builder.setFastMathFlags(tempFM); // set fastmath flag + $res = builder.CreateCall(fn, operands); + builder.setFastMathFlags(origFM); // restore fastmath flag + }]; +} #endif // LLVMIR_OP_BASE diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index bd9ec93..6b43a1e 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -255,9 +255,11 @@ public: class VectorReductionOpConversion : public ConvertToLLVMPattern { public: explicit VectorReductionOpConversion(MLIRContext *context, - LLVMTypeConverter &typeConverter) + LLVMTypeConverter &typeConverter, + bool reassociateFP) : ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context, - typeConverter) {} + typeConverter), + reassociateFPReductions(reassociateFP) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -302,7 +304,8 @@ public: op->getLoc(), llvmType, rewriter.getZeroAttr(eltType)); rewriter.replaceOpWithNewOp( - op, llvmType, acc, operands[0]); + op, llvmType, acc, operands[0], + rewriter.getBoolAttr(reassociateFPReductions)); } else if (kind == "mul") { // Optional accumulator (or one). Value acc = operands.size() > 1 @@ -311,7 +314,8 @@ public: op->getLoc(), llvmType, rewriter.getFloatAttr(eltType, 1.0)); rewriter.replaceOpWithNewOp( - op, llvmType, acc, operands[0]); + op, llvmType, acc, operands[0], + rewriter.getBoolAttr(reassociateFPReductions)); } else if (kind == "min") rewriter.replaceOpWithNewOp( op, llvmType, operands[0]); @@ -324,6 +328,9 @@ public: } return failure(); } + +private: + const bool reassociateFPReductions; }; class VectorShuffleOpConversion : public ConvertToLLVMPattern { @@ -1139,16 +1146,18 @@ public: /// Populate the given list with patterns that convert from Vector to LLVM. void mlir::populateVectorToLLVMConversionPatterns( - LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + LLVMTypeConverter &converter, OwningRewritePatternList &patterns, + bool reassociateFPReductions) { MLIRContext *ctx = converter.getDialect()->getContext(); // clang-format off patterns.insert(ctx); + patterns.insert( + ctx, converter, reassociateFPReductions); patterns - .insert">) +// CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float +// CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.v2.fadd"(%[[C]], %[[A]]) +// CHECK-SAME: {reassoc = false} : (!llvm.float, !llvm<"<16 x float>">) -> !llvm.float +// CHECK: llvm.return %[[V]] : !llvm.float +// +// REASSOC-LABEL: llvm.func @reduce_add_f32( +// REASSOC-SAME: %[[A:.*]]: !llvm<"<16 x float>">) +// REASSOC: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float +// REASSOC: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.v2.fadd"(%[[C]], %[[A]]) +// REASSOC-SAME: {reassoc = true} : (!llvm.float, !llvm<"<16 x float>">) -> !llvm.float +// REASSOC: llvm.return %[[V]] : !llvm.float +// +func @reduce_add_f32(%arg0: vector<16xf32>) -> f32 { + %0 = vector.reduction "add", %arg0 : vector<16xf32> into f32 + return %0 : f32 +} + +// +// CHECK-LABEL: llvm.func @reduce_mul_f32( +// CHECK-SAME: %[[A:.*]]: !llvm<"<16 x float>">) +// CHECK: %[[C:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : !llvm.float +// CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.v2.fmul"(%[[C]], %[[A]]) +// CHECK-SAME: {reassoc = false} : (!llvm.float, !llvm<"<16 x float>">) -> !llvm.float +// CHECK: llvm.return %[[V]] : !llvm.float +// +// REASSOC-LABEL: llvm.func @reduce_mul_f32( +// REASSOC-SAME: %[[A:.*]]: !llvm<"<16 x float>">) +// REASSOC: %[[C:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : !llvm.float +// REASSOC: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.v2.fmul"(%[[C]], %[[A]]) +// REASSOC-SAME: {reassoc = true} : (!llvm.float, !llvm<"<16 x float>">) -> !llvm.float +// REASSOC: llvm.return %[[V]] : !llvm.float +// +func @reduce_mul_f32(%arg0: vector<16xf32>) -> f32 { + %0 = vector.reduction "mul", %arg0 : vector<16xf32> into f32 + return %0 : f32 +} diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 8351fb2..2b2adf0 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -721,6 +721,7 @@ func @reduce_f32(%arg0: vector<16xf32>) -> f32 { // CHECK-SAME: %[[A:.*]]: !llvm<"<16 x float>">) // CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float // CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.v2.fadd"(%[[C]], %[[A]]) +// CHECK-SAME: {reassoc = false} : (!llvm.float, !llvm<"<16 x float>">) -> !llvm.float // CHECK: llvm.return %[[V]] : !llvm.float func @reduce_f64(%arg0: vector<16xf64>) -> f64 { @@ -731,6 +732,7 @@ func @reduce_f64(%arg0: vector<16xf64>) -> f64 { // CHECK-SAME: %[[A:.*]]: !llvm<"<16 x double>">) // CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f64) : !llvm.double // CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.v2.fadd"(%[[C]], %[[A]]) +// CHECK-SAME: {reassoc = false} : (!llvm.double, !llvm<"<16 x double>">) -> !llvm.double // CHECK: llvm.return %[[V]] : !llvm.double func @reduce_i32(%arg0: vector<16xi32>) -> i32 { diff --git a/mlir/test/Target/llvmir-intrinsics.mlir b/mlir/test/Target/llvmir-intrinsics.mlir index 4529212..ffbbf35 100644 --- a/mlir/test/Target/llvmir-intrinsics.mlir +++ b/mlir/test/Target/llvmir-intrinsics.mlir @@ -161,6 +161,10 @@ llvm.func @vector_reductions(%arg0: !llvm.float, %arg1: !llvm<"<8 x float>">, %a "llvm.intr.experimental.vector.reduce.v2.fadd"(%arg0, %arg1) : (!llvm.float, !llvm<"<8 x float>">) -> !llvm.float // CHECK: call float @llvm.experimental.vector.reduce.v2.fmul.f32.v8f32 "llvm.intr.experimental.vector.reduce.v2.fmul"(%arg0, %arg1) : (!llvm.float, !llvm<"<8 x float>">) -> !llvm.float + // CHECK: call reassoc float @llvm.experimental.vector.reduce.v2.fadd.f32.v8f32 + "llvm.intr.experimental.vector.reduce.v2.fadd"(%arg0, %arg1) {reassoc = true} : (!llvm.float, !llvm<"<8 x float>">) -> !llvm.float + // CHECK: call reassoc float @llvm.experimental.vector.reduce.v2.fmul.f32.v8f32 + "llvm.intr.experimental.vector.reduce.v2.fmul"(%arg0, %arg1) {reassoc = true} : (!llvm.float, !llvm<"<8 x float>">) -> !llvm.float // CHECK: call i32 @llvm.experimental.vector.reduce.xor.v8i32 "llvm.intr.experimental.vector.reduce.xor"(%arg2) : (!llvm<"<8 x i32>">) -> !llvm.i32 llvm.return -- 2.7.4