[mlir][spirv] Lower max/min vector.reduction for OpenCL
authorStanley Winata <stanley@nod-labs.com>
Tue, 20 Sep 2022 21:16:52 +0000 (17:16 -0400)
committerLei Zhang <antiagainst@google.com>
Tue, 20 Sep 2022 21:22:41 +0000 (17:22 -0400)
Templatizing vector reduction to enable lowering from
vector.reduction max/min to CL ops.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D134313

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

index 23f664c..7b1fad1 100644 (file)
@@ -272,6 +272,8 @@ struct VectorInsertStridedSliceOpConvert final
   }
 };
 
+template <class SPVFMaxOp, class SPVFMinOp, class SPVUMaxOp, class SPVUMinOp,
+          class SPVSMaxOp, class SPVSMinOp>
 struct VectorReductionPattern final
     : public OpConversionPattern<vector::ReductionOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -317,18 +319,18 @@ struct VectorReductionPattern final
 
 #define INT_OR_FLOAT_CASE(kind, fop)                                           \
   case vector::CombiningKind::kind:                                            \
-    result = rewriter.create<spirv::fop>(loc, resultType, result, next);       \
+    result = rewriter.create<fop>(loc, resultType, result, next);              \
     break
 
         INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp);
         INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp);
 
-        INT_OR_FLOAT_CASE(MAXF, GLFMaxOp);
-        INT_OR_FLOAT_CASE(MINF, GLFMinOp);
-        INT_OR_FLOAT_CASE(MINUI, GLUMinOp);
-        INT_OR_FLOAT_CASE(MINSI, GLSMinOp);
-        INT_OR_FLOAT_CASE(MAXUI, GLUMaxOp);
-        INT_OR_FLOAT_CASE(MAXSI, GLSMaxOp);
+        INT_OR_FLOAT_CASE(MAXF, SPVFMaxOp);
+        INT_OR_FLOAT_CASE(MINF, SPVFMinOp);
+        INT_OR_FLOAT_CASE(MINUI, SPVUMinOp);
+        INT_OR_FLOAT_CASE(MINSI, SPVSMinOp);
+        INT_OR_FLOAT_CASE(MAXUI, SPVUMaxOp);
+        INT_OR_FLOAT_CASE(MAXSI, SPVSMaxOp);
 
       case vector::CombiningKind::AND:
       case vector::CombiningKind::OR:
@@ -403,15 +405,23 @@ struct VectorShuffleOpConvert final
 };
 
 } // namespace
+#define CL_MAX_MIN_OPS                                                         \
+  spirv::CLFMaxOp, spirv::CLFMinOp, spirv::CLUMaxOp, spirv::CLUMinOp,          \
+      spirv::CLSMaxOp, spirv::CLSMinOp
+
+#define GL_MAX_MIN_OPS                                                         \
+  spirv::GLFMaxOp, spirv::GLFMinOp, spirv::GLUMaxOp, spirv::GLUMinOp,          \
+      spirv::GLSMaxOp, spirv::GLSMinOp
 
 void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                          RewritePatternSet &patterns) {
-  patterns.add<VectorBitcastConvert, VectorBroadcastConvert,
-               VectorExtractElementOpConvert, VectorExtractOpConvert,
-               VectorExtractStridedSliceOpConvert,
-               VectorFmaOpConvert<spirv::GLFmaOp>,
-               VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
-               VectorInsertOpConvert, VectorReductionPattern,
-               VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
-               VectorSplatPattern>(typeConverter, patterns.getContext());
+  patterns.add<
+      VectorBitcastConvert, VectorBroadcastConvert,
+      VectorExtractElementOpConvert, VectorExtractOpConvert,
+      VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
+      VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
+      VectorInsertOpConvert, VectorReductionPattern<GL_MAX_MIN_OPS>,
+      VectorReductionPattern<CL_MAX_MIN_OPS>, VectorInsertStridedSliceOpConvert,
+      VectorShuffleOpConvert, VectorSplatPattern>(typeConverter,
+                                                  patterns.getContext());
 }
index f483fc3..afce049 100644 (file)
@@ -33,6 +33,90 @@ func.func @cl_fma_size1_vector(%a: vector<1xf32>, %b: vector<1xf32>, %c: vector<
   return %0 : vector<1xf32>
 }
 
+// CHECK-LABEL: func @cl_reduction_maxf
+//  CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
+//       CHECK:   %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
+//       CHECK:   %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
+//       CHECK:   %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
+//       CHECK:   %[[MAX0:.+]] = spv.CL.fmax %[[S0]], %[[S1]]
+//       CHECK:   %[[MAX1:.+]] = spv.CL.fmax %[[MAX0]], %[[S2]]
+//       CHECK:   %[[MAX2:.+]] = spv.CL.fmax %[[MAX1]], %[[S]]
+//       CHECK:   return %[[MAX2]]
+func.func @cl_reduction_maxf(%v : vector<3xf32>, %s: f32) -> f32 {
+  %reduce = vector.reduction <maxf>, %v, %s : vector<3xf32> into f32
+  return %reduce : f32
+}
+
+// CHECK-LABEL: func @cl_reduction_minf
+//  CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
+//       CHECK:   %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
+//       CHECK:   %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
+//       CHECK:   %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
+//       CHECK:   %[[MIN0:.+]] = spv.CL.fmin %[[S0]], %[[S1]]
+//       CHECK:   %[[MIN1:.+]] = spv.CL.fmin %[[MIN0]], %[[S2]]
+//       CHECK:   %[[MIN2:.+]] = spv.CL.fmin %[[MIN1]], %[[S]]
+//       CHECK:   return %[[MIN2]]
+func.func @cl_reduction_minf(%v : vector<3xf32>, %s: f32) -> f32 {
+  %reduce = vector.reduction <minf>, %v, %s : vector<3xf32> into f32
+  return %reduce : f32
+}
+
+// CHECK-LABEL: func @cl_reduction_maxsi
+//  CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32)
+//       CHECK:   %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xi32>
+//       CHECK:   %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xi32>
+//       CHECK:   %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xi32>
+//       CHECK:   %[[MAX0:.+]] = spv.CL.s_max %[[S0]], %[[S1]]
+//       CHECK:   %[[MAX1:.+]] = spv.CL.s_max %[[MAX0]], %[[S2]]
+//       CHECK:   %[[MAX2:.+]] = spv.CL.s_max %[[MAX1]], %[[S]]
+//       CHECK:   return %[[MAX2]]
+func.func @cl_reduction_maxsi(%v : vector<3xi32>, %s: i32) -> i32 {
+  %reduce = vector.reduction <maxsi>, %v, %s : vector<3xi32> into i32
+  return %reduce : i32
+}
+
+// CHECK-LABEL: func @cl_reduction_minsi
+//  CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32)
+//       CHECK:   %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xi32>
+//       CHECK:   %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xi32>
+//       CHECK:   %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xi32>
+//       CHECK:   %[[MIN0:.+]] = spv.CL.s_min %[[S0]], %[[S1]]
+//       CHECK:   %[[MIN1:.+]] = spv.CL.s_min %[[MIN0]], %[[S2]]
+//       CHECK:   %[[MIN2:.+]] = spv.CL.s_min %[[MIN1]], %[[S]]
+//       CHECK:   return %[[MIN2]]
+func.func @cl_reduction_minsi(%v : vector<3xi32>, %s: i32) -> i32 {
+  %reduce = vector.reduction <minsi>, %v, %s : vector<3xi32> into i32
+  return %reduce : i32
+}
+
+// CHECK-LABEL: func @cl_reduction_maxui
+//  CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32)
+//       CHECK:   %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xi32>
+//       CHECK:   %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xi32>
+//       CHECK:   %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xi32>
+//       CHECK:   %[[MAX0:.+]] = spv.CL.u_max %[[S0]], %[[S1]]
+//       CHECK:   %[[MAX1:.+]] = spv.CL.u_max %[[MAX0]], %[[S2]]
+//       CHECK:   %[[MAX2:.+]] = spv.CL.u_max %[[MAX1]], %[[S]]
+//       CHECK:   return %[[MAX2]]
+func.func @cl_reduction_maxui(%v : vector<3xi32>, %s: i32) -> i32 {
+  %reduce = vector.reduction <maxui>, %v, %s : vector<3xi32> into i32
+  return %reduce : i32
+}
+
+// CHECK-LABEL: func @cl_reduction_minui
+//  CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32)
+//       CHECK:   %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xi32>
+//       CHECK:   %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xi32>
+//       CHECK:   %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xi32>
+//       CHECK:   %[[MIN0:.+]] = spv.CL.u_min %[[S0]], %[[S1]]
+//       CHECK:   %[[MIN1:.+]] = spv.CL.u_min %[[MIN0]], %[[S2]]
+//       CHECK:   %[[MIN2:.+]] = spv.CL.u_min %[[MIN1]], %[[S]]
+//       CHECK:   return %[[MIN2]]
+func.func @cl_reduction_minui(%v : vector<3xi32>, %s: i32) -> i32 {
+  %reduce = vector.reduction <minui>, %v, %s : vector<3xi32> into i32
+  return %reduce : i32
+}
+
 } // end module
 
 // -----