From f5253058144aca1e9fcacdca53accdc975e804cf Mon Sep 17 00:00:00 2001 From: Stanley Winata Date: Tue, 20 Sep 2022 17:16:52 -0400 Subject: [PATCH] [mlir][spirv] Lower max/min vector.reduction for OpenCL Templatizing vector reduction to enable lowering from vector.reduction max/min to CL ops. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D134313 --- .../lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 40 +++++++---- .../Conversion/VectorToSPIRV/vector-to-spirv.mlir | 84 ++++++++++++++++++++++ 2 files changed, 109 insertions(+), 15 deletions(-) diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index 23f664c..7b1fad1 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -272,6 +272,8 @@ struct VectorInsertStridedSliceOpConvert final } }; +template struct VectorReductionPattern final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -317,18 +319,18 @@ struct VectorReductionPattern final #define INT_OR_FLOAT_CASE(kind, fop) \ case vector::CombiningKind::kind: \ - result = rewriter.create(loc, resultType, result, next); \ + result = rewriter.create(loc, resultType, result, next); \ break INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp); INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp); - INT_OR_FLOAT_CASE(MAXF, GLFMaxOp); - INT_OR_FLOAT_CASE(MINF, GLFMinOp); - INT_OR_FLOAT_CASE(MINUI, GLUMinOp); - INT_OR_FLOAT_CASE(MINSI, GLSMinOp); - INT_OR_FLOAT_CASE(MAXUI, GLUMaxOp); - INT_OR_FLOAT_CASE(MAXSI, GLSMaxOp); + INT_OR_FLOAT_CASE(MAXF, SPVFMaxOp); + INT_OR_FLOAT_CASE(MINF, SPVFMinOp); + INT_OR_FLOAT_CASE(MINUI, SPVUMinOp); + INT_OR_FLOAT_CASE(MINSI, SPVSMinOp); + INT_OR_FLOAT_CASE(MAXUI, SPVUMaxOp); + INT_OR_FLOAT_CASE(MAXSI, SPVSMaxOp); case vector::CombiningKind::AND: case vector::CombiningKind::OR: @@ -403,15 +405,23 @@ struct VectorShuffleOpConvert final }; } // namespace +#define CL_MAX_MIN_OPS \ + spirv::CLFMaxOp, spirv::CLFMinOp, spirv::CLUMaxOp, spirv::CLUMinOp, \ + spirv::CLSMaxOp, spirv::CLSMinOp + +#define GL_MAX_MIN_OPS \ + spirv::GLFMaxOp, spirv::GLFMinOp, spirv::GLUMaxOp, spirv::GLUMinOp, \ + spirv::GLSMaxOp, spirv::GLSMinOp void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { - patterns.add, - VectorFmaOpConvert, VectorInsertElementOpConvert, - VectorInsertOpConvert, VectorReductionPattern, - VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert, - VectorSplatPattern>(typeConverter, patterns.getContext()); + patterns.add< + VectorBitcastConvert, VectorBroadcastConvert, + VectorExtractElementOpConvert, VectorExtractOpConvert, + VectorExtractStridedSliceOpConvert, VectorFmaOpConvert, + VectorFmaOpConvert, VectorInsertElementOpConvert, + VectorInsertOpConvert, VectorReductionPattern, + VectorReductionPattern, VectorInsertStridedSliceOpConvert, + VectorShuffleOpConvert, VectorSplatPattern>(typeConverter, + patterns.getContext()); } diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir index f483fc3e..afce049 100644 --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -33,6 +33,90 @@ func.func @cl_fma_size1_vector(%a: vector<1xf32>, %b: vector<1xf32>, %c: vector< return %0 : vector<1xf32> } +// CHECK-LABEL: func @cl_reduction_maxf +// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32) +// CHECK: %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xf32> +// CHECK: %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xf32> +// CHECK: %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xf32> +// CHECK: %[[MAX0:.+]] = spv.CL.fmax %[[S0]], %[[S1]] +// CHECK: %[[MAX1:.+]] = spv.CL.fmax %[[MAX0]], %[[S2]] +// CHECK: %[[MAX2:.+]] = spv.CL.fmax %[[MAX1]], %[[S]] +// CHECK: return %[[MAX2]] +func.func @cl_reduction_maxf(%v : vector<3xf32>, %s: f32) -> f32 { + %reduce = vector.reduction , %v, %s : vector<3xf32> into f32 + return %reduce : f32 +} + +// CHECK-LABEL: func @cl_reduction_minf +// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32) +// CHECK: %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xf32> +// CHECK: %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xf32> +// CHECK: %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xf32> +// CHECK: %[[MIN0:.+]] = spv.CL.fmin %[[S0]], %[[S1]] +// CHECK: %[[MIN1:.+]] = spv.CL.fmin %[[MIN0]], %[[S2]] +// CHECK: %[[MIN2:.+]] = spv.CL.fmin %[[MIN1]], %[[S]] +// CHECK: return %[[MIN2]] +func.func @cl_reduction_minf(%v : vector<3xf32>, %s: f32) -> f32 { + %reduce = vector.reduction , %v, %s : vector<3xf32> into f32 + return %reduce : f32 +} + +// CHECK-LABEL: func @cl_reduction_maxsi +// CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32) +// CHECK: %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xi32> +// CHECK: %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xi32> +// CHECK: %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xi32> +// CHECK: %[[MAX0:.+]] = spv.CL.s_max %[[S0]], %[[S1]] +// CHECK: %[[MAX1:.+]] = spv.CL.s_max %[[MAX0]], %[[S2]] +// CHECK: %[[MAX2:.+]] = spv.CL.s_max %[[MAX1]], %[[S]] +// CHECK: return %[[MAX2]] +func.func @cl_reduction_maxsi(%v : vector<3xi32>, %s: i32) -> i32 { + %reduce = vector.reduction , %v, %s : vector<3xi32> into i32 + return %reduce : i32 +} + +// CHECK-LABEL: func @cl_reduction_minsi +// CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32) +// CHECK: %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xi32> +// CHECK: %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xi32> +// CHECK: %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xi32> +// CHECK: %[[MIN0:.+]] = spv.CL.s_min %[[S0]], %[[S1]] +// CHECK: %[[MIN1:.+]] = spv.CL.s_min %[[MIN0]], %[[S2]] +// CHECK: %[[MIN2:.+]] = spv.CL.s_min %[[MIN1]], %[[S]] +// CHECK: return %[[MIN2]] +func.func @cl_reduction_minsi(%v : vector<3xi32>, %s: i32) -> i32 { + %reduce = vector.reduction , %v, %s : vector<3xi32> into i32 + return %reduce : i32 +} + +// CHECK-LABEL: func @cl_reduction_maxui +// CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32) +// CHECK: %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xi32> +// CHECK: %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xi32> +// CHECK: %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xi32> +// CHECK: %[[MAX0:.+]] = spv.CL.u_max %[[S0]], %[[S1]] +// CHECK: %[[MAX1:.+]] = spv.CL.u_max %[[MAX0]], %[[S2]] +// CHECK: %[[MAX2:.+]] = spv.CL.u_max %[[MAX1]], %[[S]] +// CHECK: return %[[MAX2]] +func.func @cl_reduction_maxui(%v : vector<3xi32>, %s: i32) -> i32 { + %reduce = vector.reduction , %v, %s : vector<3xi32> into i32 + return %reduce : i32 +} + +// CHECK-LABEL: func @cl_reduction_minui +// CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32) +// CHECK: %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xi32> +// CHECK: %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xi32> +// CHECK: %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xi32> +// CHECK: %[[MIN0:.+]] = spv.CL.u_min %[[S0]], %[[S1]] +// CHECK: %[[MIN1:.+]] = spv.CL.u_min %[[MIN0]], %[[S2]] +// CHECK: %[[MIN2:.+]] = spv.CL.u_min %[[MIN1]], %[[S]] +// CHECK: return %[[MIN2]] +func.func @cl_reduction_minui(%v : vector<3xi32>, %s: i32) -> i32 { + %reduce = vector.reduction , %v, %s : vector<3xi32> into i32 + return %reduce : i32 +} + } // end module // ----- -- 2.7.4