From 0f6103af97e71ef17d2ac3d8bb00bc2f173a2ffb Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Fri, 2 Sep 2022 17:21:57 -0400 Subject: [PATCH] [mlir][spirv] Support more max/min vector.reduction Reviewed By: kuhar Differential Revision: https://reviews.llvm.org/D133168 --- .../lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 25 +++-- .../Conversion/VectorToSPIRV/vector-to-spirv.mlir | 112 +++++++++++++++++++-- 2 files changed, 120 insertions(+), 17 deletions(-) diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index 483619b..5349261 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -283,7 +283,8 @@ struct VectorReductionPattern final Value result = values.front(); for (Value next : llvm::makeArrayRef(values).drop_front()) { switch (reduceOp.getKind()) { -#define INT_FLOAT_CASE(kind, iop, fop) \ + +#define INT_AND_FLOAT_CASE(kind, iop, fop) \ case vector::CombiningKind::kind: \ if (resultType.isa()) { \ result = rewriter.create(loc, resultType, result, next); \ @@ -293,15 +294,21 @@ struct VectorReductionPattern final } \ break - INT_FLOAT_CASE(ADD, IAddOp, FAddOp); - INT_FLOAT_CASE(MUL, IMulOp, FMulOp); +#define INT_OR_FLOAT_CASE(kind, fop) \ + case vector::CombiningKind::kind: \ + result = rewriter.create(loc, resultType, result, next); \ + break + + INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp); + INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp); + + INT_OR_FLOAT_CASE(MAXF, GLFMaxOp); + INT_OR_FLOAT_CASE(MINF, GLFMinOp); + INT_OR_FLOAT_CASE(MINUI, GLUMinOp); + INT_OR_FLOAT_CASE(MINSI, GLSMinOp); + INT_OR_FLOAT_CASE(MAXUI, GLUMaxOp); + INT_OR_FLOAT_CASE(MAXSI, GLSMaxOp); - case vector::CombiningKind::MINUI: - case vector::CombiningKind::MINSI: - case vector::CombiningKind::MINF: - case vector::CombiningKind::MAXUI: - case vector::CombiningKind::MAXSI: - case vector::CombiningKind::MAXF: case vector::CombiningKind::AND: case vector::CombiningKind::OR: case vector::CombiningKind::XOR: diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir index a5af59e..f1de62c 100644 --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -254,7 +254,7 @@ func.func @shuffle(%v0 : vector<2x16xf32>, %v1: vector<1x16xf32>) -> vector<3x16 // ----- -// CHECK-LABEL: func @reduction +// CHECK-LABEL: func @reduction_add // CHECK-SAME: (%[[V:.+]]: vector<4xi32>) // CHECK: %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<4xi32> // CHECK: %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<4xi32> @@ -264,23 +264,119 @@ func.func @shuffle(%v0 : vector<2x16xf32>, %v1: vector<1x16xf32>) -> vector<3x16 // CHECK: %[[ADD1:.+]] = spv.IAdd %[[ADD0]], %[[S2]] // CHECK: %[[ADD2:.+]] = spv.IAdd %[[ADD1]], %[[S3]] // CHECK: return %[[ADD2]] -func.func @reduction(%v : vector<4xi32>) -> i32 { +func.func @reduction_add(%v : vector<4xi32>) -> i32 { %reduce = vector.reduction , %v : vector<4xi32> into i32 return %reduce : i32 } // ----- -// CHECK-LABEL: func @reduction +// CHECK-LABEL: func @reduction_mul // CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32) // CHECK: %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xf32> // CHECK: %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xf32> // CHECK: %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xf32> -// CHECK: %[[ADD0:.+]] = spv.FMul %[[S0]], %[[S1]] -// CHECK: %[[ADD1:.+]] = spv.FMul %[[ADD0]], %[[S2]] -// CHECK: %[[ADD2:.+]] = spv.FMul %[[ADD1]], %[[S]] -// CHECK: return %[[ADD2]] -func.func @reduction(%v : vector<3xf32>, %s: f32) -> f32 { +// CHECK: %[[MUL0:.+]] = spv.FMul %[[S0]], %[[S1]] +// CHECK: %[[MUL1:.+]] = spv.FMul %[[MUL0]], %[[S2]] +// CHECK: %[[MUL2:.+]] = spv.FMul %[[MUL1]], %[[S]] +// CHECK: return %[[MUL2]] +func.func @reduction_mul(%v : vector<3xf32>, %s: f32) -> f32 { %reduce = vector.reduction , %v, %s : vector<3xf32> into f32 return %reduce : f32 } + +// ----- + +// CHECK-LABEL: func @reduction_maxf +// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32) +// CHECK: %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xf32> +// CHECK: %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xf32> +// CHECK: %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xf32> +// CHECK: %[[MAX0:.+]] = spv.GL.FMax %[[S0]], %[[S1]] +// CHECK: %[[MAX1:.+]] = spv.GL.FMax %[[MAX0]], %[[S2]] +// CHECK: %[[MAX2:.+]] = spv.GL.FMax %[[MAX1]], %[[S]] +// CHECK: return %[[MAX2]] +func.func @reduction_maxf(%v : vector<3xf32>, %s: f32) -> f32 { + %reduce = vector.reduction , %v, %s : vector<3xf32> into f32 + return %reduce : f32 +} + +// ----- + +// CHECK-LABEL: func @reduction_minf +// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32) +// CHECK: %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xf32> +// CHECK: %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xf32> +// CHECK: %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xf32> +// CHECK: %[[MIN0:.+]] = spv.GL.FMin %[[S0]], %[[S1]] +// CHECK: %[[MIN1:.+]] = spv.GL.FMin %[[MIN0]], %[[S2]] +// CHECK: %[[MIN2:.+]] = spv.GL.FMin %[[MIN1]], %[[S]] +// CHECK: return %[[MIN2]] +func.func @reduction_minf(%v : vector<3xf32>, %s: f32) -> f32 { + %reduce = vector.reduction , %v, %s : vector<3xf32> into f32 + return %reduce : f32 +} + +// ----- + +// CHECK-LABEL: func @reduction_maxsi +// CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32) +// CHECK: %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xi32> +// CHECK: %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xi32> +// CHECK: %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xi32> +// CHECK: %[[MAX0:.+]] = spv.GL.SMax %[[S0]], %[[S1]] +// CHECK: %[[MAX1:.+]] = spv.GL.SMax %[[MAX0]], %[[S2]] +// CHECK: %[[MAX2:.+]] = spv.GL.SMax %[[MAX1]], %[[S]] +// CHECK: return %[[MAX2]] +func.func @reduction_maxsi(%v : vector<3xi32>, %s: i32) -> i32 { + %reduce = vector.reduction , %v, %s : vector<3xi32> into i32 + return %reduce : i32 +} + +// ----- + +// CHECK-LABEL: func @reduction_minsi +// CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32) +// CHECK: %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xi32> +// CHECK: %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xi32> +// CHECK: %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xi32> +// CHECK: %[[MIN0:.+]] = spv.GL.SMin %[[S0]], %[[S1]] +// CHECK: %[[MIN1:.+]] = spv.GL.SMin %[[MIN0]], %[[S2]] +// CHECK: %[[MIN2:.+]] = spv.GL.SMin %[[MIN1]], %[[S]] +// CHECK: return %[[MIN2]] +func.func @reduction_minsi(%v : vector<3xi32>, %s: i32) -> i32 { + %reduce = vector.reduction , %v, %s : vector<3xi32> into i32 + return %reduce : i32 +} + +// ----- + +// CHECK-LABEL: func @reduction_maxui +// CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32) +// CHECK: %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xi32> +// CHECK: %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xi32> +// CHECK: %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xi32> +// CHECK: %[[MAX0:.+]] = spv.GL.UMax %[[S0]], %[[S1]] +// CHECK: %[[MAX1:.+]] = spv.GL.UMax %[[MAX0]], %[[S2]] +// CHECK: %[[MAX2:.+]] = spv.GL.UMax %[[MAX1]], %[[S]] +// CHECK: return %[[MAX2]] +func.func @reduction_maxui(%v : vector<3xi32>, %s: i32) -> i32 { + %reduce = vector.reduction , %v, %s : vector<3xi32> into i32 + return %reduce : i32 +} + +// ----- + +// CHECK-LABEL: func @reduction_minui +// CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32) +// CHECK: %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xi32> +// CHECK: %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xi32> +// CHECK: %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xi32> +// CHECK: %[[MIN0:.+]] = spv.GL.UMin %[[S0]], %[[S1]] +// CHECK: %[[MIN1:.+]] = spv.GL.UMin %[[MIN0]], %[[S2]] +// CHECK: %[[MIN2:.+]] = spv.GL.UMin %[[MIN1]], %[[S]] +// CHECK: return %[[MIN2]] +func.func @reduction_minui(%v : vector<3xi32>, %s: i32) -> i32 { + %reduce = vector.reduction , %v, %s : vector<3xi32> into i32 + return %reduce : i32 +} -- 2.7.4