[mlir][Vector] Add canonicalization pattern for vector.transpose(vector.constant_mask)
authorDiego Caballero <diegocaballero@google.com>
Wed, 29 Mar 2023 19:20:22 +0000 (19:20 +0000)
committerDiego Caballero <diegocaballero@google.com>
Wed, 29 Mar 2023 19:53:29 +0000 (19:53 +0000)
We already had vector.transpose(vector.create_mask) ->
vector.create_mask. This patch adds the constant mask version of it.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D147099

mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir

index abd8962..8ee5965 100644 (file)
@@ -5269,23 +5269,37 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
 public:
   using OpRewritePattern::OpRewritePattern;
 
-  LogicalResult matchAndRewrite(TransposeOp transposeOp,
+  LogicalResult matchAndRewrite(TransposeOp transpOp,
                                 PatternRewriter &rewriter) const override {
-    auto createMaskOp =
-        transposeOp.getVector().getDefiningOp<vector::CreateMaskOp>();
-    if (!createMaskOp)
+    Value transposeSrc = transpOp.getVector();
+    auto createMaskOp = transposeSrc.getDefiningOp<vector::CreateMaskOp>();
+    auto constantMaskOp = transposeSrc.getDefiningOp<vector::ConstantMaskOp>();
+    if (!createMaskOp && !constantMaskOp)
       return failure();
 
-    // Get the transpose permutation and apply it to the vector.create_mask
-    // operands.
-    auto maskOperands = createMaskOp.getOperands();
+    // Get the transpose permutation and apply it to the vector.create_mask or
+    // vector.constant_mask operands.
     SmallVector<int64_t> permutation;
-    transposeOp.getTransp(permutation);
-    SmallVector<Value> newOperands(maskOperands.begin(), maskOperands.end());
-    applyPermutationToVector(newOperands, permutation);
+    transpOp.getTransp(permutation);
+
+    if (createMaskOp) {
+      auto maskOperands = createMaskOp.getOperands();
+      SmallVector<Value> newOperands(maskOperands.begin(), maskOperands.end());
+      applyPermutationToVector(newOperands, permutation);
+
+      rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
+          transpOp, transpOp.getResultVectorType(), newOperands);
+      return success();
+    }
+
+    // ConstantMaskOp case.
+    auto maskDimSizes = constantMaskOp.getMaskDimSizes();
+    SmallVector<Attribute> newMaskDimSizes(maskDimSizes.getValue());
+    applyPermutationToVector(newMaskDimSizes, permutation);
 
-    rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
-        transposeOp, transposeOp.getResultVectorType(), newOperands);
+    rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
+        transpOp, transpOp.getResultVectorType(),
+        ArrayAttr::get(transpOp.getContext(), newMaskDimSizes));
     return success();
   }
 };
index f82540c..88c91ff 100644 (file)
@@ -58,8 +58,9 @@ func.func @create_vector_mask_to_constant_mask_truncation_zero() -> (vector<4x3x
 //  CHECK-SAME: %[[DIM0:.*]]: index, %[[DIM1:.*]]: index, %[[DIM2:.*]]: index
 func.func @create_mask_transpose_to_transposed_create_mask(
   %dim0: index, %dim1: index, %dim2: index) -> (vector<2x3x4xi1>, vector<4x2x3xi1>) {
-  // CHECK: vector.create_mask %[[DIM0]], %[[DIM1]], %[[DIM2]] : vector<2x3x4xi1>
-  // CHECK: vector.create_mask %[[DIM2]], %[[DIM0]], %[[DIM1]] : vector<4x2x3xi1>
+  //     CHECK: vector.create_mask %[[DIM0]], %[[DIM1]], %[[DIM2]] : vector<2x3x4xi1>
+  //     CHECK: vector.create_mask %[[DIM2]], %[[DIM0]], %[[DIM1]] : vector<4x2x3xi1>
+  // CHECK-NOT: vector.transpose
   %0 = vector.create_mask %dim0, %dim1, %dim2 : vector<2x3x4xi1>
   %1 = vector.transpose %0, [2, 0, 1] : vector<2x3x4xi1> to vector<4x2x3xi1>
   return %0, %1 : vector<2x3x4xi1>, vector<4x2x3xi1>
@@ -67,6 +68,18 @@ func.func @create_mask_transpose_to_transposed_create_mask(
 
 // -----
 
+// CHECK-LABEL: constant_mask_transpose_to_transposed_constant_mask
+func.func @constant_mask_transpose_to_transposed_constant_mask() -> (vector<2x3x4xi1>, vector<4x2x3xi1>) {
+  //     CHECK: vector.constant_mask [1, 2, 3] : vector<2x3x4xi1>
+  //     CHECK: vector.constant_mask [3, 1, 2] : vector<4x2x3xi1>
+  // CHECK-NOT: vector.transpose
+  %0 = vector.constant_mask [1, 2, 3] : vector<2x3x4xi1>
+  %1 = vector.transpose %0, [2, 0, 1] : vector<2x3x4xi1> to vector<4x2x3xi1>
+  return %0, %1 : vector<2x3x4xi1>, vector<4x2x3xi1>
+}
+
+// -----
+
 func.func @extract_strided_slice_of_constant_mask() -> (vector<2x2xi1>) {
   %0 = vector.constant_mask [2, 2] : vector<4x3xi1>
   %1 = vector.extract_strided_slice %0