[mlir][spirv] Support more max/min vector.reduction
authorLei Zhang <antiagainst@google.com>
Fri, 2 Sep 2022 21:21:57 +0000 (17:21 -0400)
committerLei Zhang <antiagainst@google.com>
Fri, 2 Sep 2022 21:21:57 +0000 (17:21 -0400)
Reviewed By: kuhar

Differential Revision: https://reviews.llvm.org/D133168

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

index 483619b..5349261 100644 (file)
@@ -283,7 +283,8 @@ struct VectorReductionPattern final
     Value result = values.front();
     for (Value next : llvm::makeArrayRef(values).drop_front()) {
       switch (reduceOp.getKind()) {
-#define INT_FLOAT_CASE(kind, iop, fop)                                         \
+
+#define INT_AND_FLOAT_CASE(kind, iop, fop)                                     \
   case vector::CombiningKind::kind:                                            \
     if (resultType.isa<IntegerType>()) {                                       \
       result = rewriter.create<spirv::iop>(loc, resultType, result, next);     \
@@ -293,15 +294,21 @@ struct VectorReductionPattern final
     }                                                                          \
     break
 
-        INT_FLOAT_CASE(ADD, IAddOp, FAddOp);
-        INT_FLOAT_CASE(MUL, IMulOp, FMulOp);
+#define INT_OR_FLOAT_CASE(kind, fop)                                           \
+  case vector::CombiningKind::kind:                                            \
+    result = rewriter.create<spirv::fop>(loc, resultType, result, next);       \
+    break
+
+        INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp);
+        INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp);
+
+        INT_OR_FLOAT_CASE(MAXF, GLFMaxOp);
+        INT_OR_FLOAT_CASE(MINF, GLFMinOp);
+        INT_OR_FLOAT_CASE(MINUI, GLUMinOp);
+        INT_OR_FLOAT_CASE(MINSI, GLSMinOp);
+        INT_OR_FLOAT_CASE(MAXUI, GLUMaxOp);
+        INT_OR_FLOAT_CASE(MAXSI, GLSMaxOp);
 
-      case vector::CombiningKind::MINUI:
-      case vector::CombiningKind::MINSI:
-      case vector::CombiningKind::MINF:
-      case vector::CombiningKind::MAXUI:
-      case vector::CombiningKind::MAXSI:
-      case vector::CombiningKind::MAXF:
       case vector::CombiningKind::AND:
       case vector::CombiningKind::OR:
       case vector::CombiningKind::XOR:
index a5af59e..f1de62c 100644 (file)
@@ -254,7 +254,7 @@ func.func @shuffle(%v0 : vector<2x16xf32>, %v1: vector<1x16xf32>) -> vector<3x16
 
 // -----
 
-// CHECK-LABEL: func @reduction
+// CHECK-LABEL: func @reduction_add
 //  CHECK-SAME: (%[[V:.+]]: vector<4xi32>)
 //       CHECK:   %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<4xi32>
 //       CHECK:   %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<4xi32>
@@ -264,23 +264,119 @@ func.func @shuffle(%v0 : vector<2x16xf32>, %v1: vector<1x16xf32>) -> vector<3x16
 //       CHECK:   %[[ADD1:.+]] = spv.IAdd %[[ADD0]], %[[S2]]
 //       CHECK:   %[[ADD2:.+]] = spv.IAdd %[[ADD1]], %[[S3]]
 //       CHECK:   return %[[ADD2]]
-func.func @reduction(%v : vector<4xi32>) -> i32 {
+func.func @reduction_add(%v : vector<4xi32>) -> i32 {
   %reduce = vector.reduction <add>, %v : vector<4xi32> into i32
   return %reduce : i32
 }
 
 // -----
 
-// CHECK-LABEL: func @reduction
+// CHECK-LABEL: func @reduction_mul
 //  CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
 //       CHECK:   %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
 //       CHECK:   %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
 //       CHECK:   %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
-//       CHECK:   %[[ADD0:.+]] = spv.FMul %[[S0]], %[[S1]]
-//       CHECK:   %[[ADD1:.+]] = spv.FMul %[[ADD0]], %[[S2]]
-//       CHECK:   %[[ADD2:.+]] = spv.FMul %[[ADD1]], %[[S]]
-//       CHECK:   return %[[ADD2]]
-func.func @reduction(%v : vector<3xf32>, %s: f32) -> f32 {
+//       CHECK:   %[[MUL0:.+]] = spv.FMul %[[S0]], %[[S1]]
+//       CHECK:   %[[MUL1:.+]] = spv.FMul %[[MUL0]], %[[S2]]
+//       CHECK:   %[[MUL2:.+]] = spv.FMul %[[MUL1]], %[[S]]
+//       CHECK:   return %[[MUL2]]
+func.func @reduction_mul(%v : vector<3xf32>, %s: f32) -> f32 {
   %reduce = vector.reduction <mul>, %v, %s : vector<3xf32> into f32
   return %reduce : f32
 }
+
+// -----
+
+// CHECK-LABEL: func @reduction_maxf
+//  CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
+//       CHECK:   %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
+//       CHECK:   %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
+//       CHECK:   %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
+//       CHECK:   %[[MAX0:.+]] = spv.GL.FMax %[[S0]], %[[S1]]
+//       CHECK:   %[[MAX1:.+]] = spv.GL.FMax %[[MAX0]], %[[S2]]
+//       CHECK:   %[[MAX2:.+]] = spv.GL.FMax %[[MAX1]], %[[S]]
+//       CHECK:   return %[[MAX2]]
+func.func @reduction_maxf(%v : vector<3xf32>, %s: f32) -> f32 {
+  %reduce = vector.reduction <maxf>, %v, %s : vector<3xf32> into f32
+  return %reduce : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_minf
+//  CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
+//       CHECK:   %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
+//       CHECK:   %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
+//       CHECK:   %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
+//       CHECK:   %[[MIN0:.+]] = spv.GL.FMin %[[S0]], %[[S1]]
+//       CHECK:   %[[MIN1:.+]] = spv.GL.FMin %[[MIN0]], %[[S2]]
+//       CHECK:   %[[MIN2:.+]] = spv.GL.FMin %[[MIN1]], %[[S]]
+//       CHECK:   return %[[MIN2]]
+func.func @reduction_minf(%v : vector<3xf32>, %s: f32) -> f32 {
+  %reduce = vector.reduction <minf>, %v, %s : vector<3xf32> into f32
+  return %reduce : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_maxsi
+//  CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32)
+//       CHECK:   %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xi32>
+//       CHECK:   %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xi32>
+//       CHECK:   %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xi32>
+//       CHECK:   %[[MAX0:.+]] = spv.GL.SMax %[[S0]], %[[S1]]
+//       CHECK:   %[[MAX1:.+]] = spv.GL.SMax %[[MAX0]], %[[S2]]
+//       CHECK:   %[[MAX2:.+]] = spv.GL.SMax %[[MAX1]], %[[S]]
+//       CHECK:   return %[[MAX2]]
+func.func @reduction_maxsi(%v : vector<3xi32>, %s: i32) -> i32 {
+  %reduce = vector.reduction <maxsi>, %v, %s : vector<3xi32> into i32
+  return %reduce : i32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_minsi
+//  CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32)
+//       CHECK:   %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xi32>
+//       CHECK:   %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xi32>
+//       CHECK:   %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xi32>
+//       CHECK:   %[[MIN0:.+]] = spv.GL.SMin %[[S0]], %[[S1]]
+//       CHECK:   %[[MIN1:.+]] = spv.GL.SMin %[[MIN0]], %[[S2]]
+//       CHECK:   %[[MIN2:.+]] = spv.GL.SMin %[[MIN1]], %[[S]]
+//       CHECK:   return %[[MIN2]]
+func.func @reduction_minsi(%v : vector<3xi32>, %s: i32) -> i32 {
+  %reduce = vector.reduction <minsi>, %v, %s : vector<3xi32> into i32
+  return %reduce : i32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_maxui
+//  CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32)
+//       CHECK:   %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xi32>
+//       CHECK:   %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xi32>
+//       CHECK:   %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xi32>
+//       CHECK:   %[[MAX0:.+]] = spv.GL.UMax %[[S0]], %[[S1]]
+//       CHECK:   %[[MAX1:.+]] = spv.GL.UMax %[[MAX0]], %[[S2]]
+//       CHECK:   %[[MAX2:.+]] = spv.GL.UMax %[[MAX1]], %[[S]]
+//       CHECK:   return %[[MAX2]]
+func.func @reduction_maxui(%v : vector<3xi32>, %s: i32) -> i32 {
+  %reduce = vector.reduction <maxui>, %v, %s : vector<3xi32> into i32
+  return %reduce : i32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction_minui
+//  CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32)
+//       CHECK:   %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xi32>
+//       CHECK:   %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xi32>
+//       CHECK:   %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xi32>
+//       CHECK:   %[[MIN0:.+]] = spv.GL.UMin %[[S0]], %[[S1]]
+//       CHECK:   %[[MIN1:.+]] = spv.GL.UMin %[[MIN0]], %[[S2]]
+//       CHECK:   %[[MIN2:.+]] = spv.GL.UMin %[[MIN1]], %[[S]]
+//       CHECK:   return %[[MIN2]]
+func.func @reduction_minui(%v : vector<3xi32>, %s: i32) -> i32 {
+  %reduce = vector.reduction <minui>, %v, %s : vector<3xi32> into i32
+  return %reduce : i32
+}