[mlir][vector] Disable folding for masked reductions
authorDiego Caballero <diegocaballero@google.com>
Thu, 19 Jan 2023 22:59:02 +0000 (22:59 +0000)
committerDiego Caballero <diegocaballero@google.com>
Thu, 19 Jan 2023 23:06:53 +0000 (23:06 +0000)
Reductions can't be folded into plain arith ops until we can mask
those arith ops.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D141645

mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir

index 777133d..974854e 100644 (file)
@@ -361,6 +361,12 @@ struct ElideUnitDimsInMultiDimReduction
 
   LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
                                 PatternRewriter &rewriter) const override {
+    // Masked reductions can't be folded until we can propagate the mask to the
+    // resulting operation.
+    auto maskableOp = cast<MaskableOpInterface>(reductionOp.getOperation());
+    if (maskableOp.isMasked())
+      return failure();
+
     ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
     for (const auto &dim : enumerate(shape)) {
       if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
@@ -518,6 +524,12 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
 
   LogicalResult matchAndRewrite(ReductionOp reductionOp,
                                 PatternRewriter &rewriter) const override {
+    // Masked reductions can't be folded until we can propagate the mask to the
+    // resulting operation.
+    auto maskableOp = cast<MaskableOpInterface>(reductionOp.getOperation());
+    if (maskableOp.isMasked())
+      return failure();
+
     if (reductionOp.getVectorType().getDimSize(0) != 1)
       return failure();
 
index 081d94e..fb890b6 100644 (file)
@@ -1371,6 +1371,20 @@ func.func @vector_multi_reduction_unit_dimensions(%source: vector<5x1x4x1x20xf32
 
 // -----
 
+// Masked reduction can't be folded.
+
+// CHECK-LABEL: func @masked_vector_multi_reduction_unit_dimensions
+func.func @masked_vector_multi_reduction_unit_dimensions(%source: vector<5x1x4x1x20xf32>,
+                                                         %acc: vector<5x4x20xf32>,
+                                                         %mask: vector<5x1x4x1x20xi1>) -> vector<5x4x20xf32> {
+//       CHECK:   vector.mask %{{.*}} { vector.multi_reduction <mul>
+    %0 = vector.mask %mask { vector.multi_reduction <mul>, %source, %acc [1, 3] : vector<5x1x4x1x20xf32> to vector<5x4x20xf32> } :
+           vector<5x1x4x1x20xi1> -> vector<5x4x20xf32>
+    return %0 : vector<5x4x20xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @vector_multi_reduction_unit_dimensions_fail(
 //  CHECK-SAME: %[[SRC:.+]]: vector<5x1x4x1x20xf32>, %[[ACCUM:.+]]: vector<5x1x20xf32>
 func.func @vector_multi_reduction_unit_dimensions_fail(%source: vector<5x1x4x1x20xf32>, %acc: vector<5x1x20xf32>) -> vector<5x1x20xf32> {
@@ -1921,6 +1935,18 @@ func.func @reduce_one_element_vector_addf(%a : vector<1xf32>, %b: f32) -> f32 {
 
 // -----
 
+// CHECK-LABEL: func @masked_reduce_one_element_vector_addf
+//       CHECK:   vector.mask %{{.*}} { vector.reduction <add>
+func.func @masked_reduce_one_element_vector_addf(%a: vector<1xf32>,
+                                                 %b: f32,
+                                                 %mask: vector<1xi1>) -> f32 {
+  %s = vector.mask %mask { vector.reduction <add>, %a, %b : vector<1xf32> into f32 }
+         : vector<1xi1> -> f32
+  return %s : f32
+}
+
+// -----
+
 // CHECK-LABEL: func @reduce_one_element_vector_mulf
 //  CHECK-SAME: (%[[V:.+]]: vector<1xf32>, %[[B:.+]]: f32)
 //       CHECK:   %[[A:.+]] = vector.extract %[[V]][0] : vector<1xf32>