From 76d71f3792b2b1864992446f7b1028b026dccd11 Mon Sep 17 00:00:00 2001 From: Diego Caballero Date: Wed, 31 May 2023 18:07:09 +0000 Subject: [PATCH] Revert "[mlir][Vector] Extend xfer drop unit dim patterns" This reverts commit a53cd03deac5e6272e9dae88a90cd51410d312d5. This commit is exposing some implementation gaps in other patterns. Reverting for now. --- .../Transforms/VectorTransferOpTransforms.cpp | 67 ++++----------- .../vector-transfer-drop-unit-dims-patterns.mlir | 99 ---------------------- 2 files changed, 15 insertions(+), 151 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp index 0e9dcf2..af0fcd0 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -63,7 +63,6 @@ private: std::vector opToErase; }; -} // namespace /// Return true if there is a path from start operation to dest operation, /// otherwise return false. The operations have to be in the same region. bool TransferOptimization::isReachable(Operation *start, Operation *dest) { @@ -289,25 +288,14 @@ static int getReducedRank(ArrayRef shape) { return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; }); } -/// Returns a copy of `shape` without unit dims. -static SmallVector getReducedShape(ArrayRef shape) { - SmallVector reducedShape; - llvm::copy_if(shape, std::back_inserter(reducedShape), - [](int64_t dimSize) { return dimSize != 1; }); - return reducedShape; -} - /// Returns true if all values are `arith.constant 0 : index` static bool isZero(Value v) { auto cst = v.getDefiningOp(); return cst && cst.value() == 0; } -namespace { - -/// Rewrites `vector.transfer_read` ops where the source has unit dims, by -/// inserting a memref.subview dropping those unit dims. The vector shapes are -/// also reduced accordingly. +/// Rewrites vector.transfer_read ops where the source has unit dims, by +/// inserting a memref.subview dropping those unit dims. class TransferReadDropUnitDimsPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -329,15 +317,12 @@ class TransferReadDropUnitDimsPattern return failure(); if (!transferReadOp.getPermutationMap().isMinorIdentity()) return failure(); - // Check if the source shape can be further reduced. int reducedRank = getReducedRank(sourceType.getShape()); if (reducedRank == sourceType.getRank()) - return failure(); - // Check if the reduced vector shape matches the reduced source shape. - // Otherwise, this case is not supported yet. - int vectorReducedRank = getReducedRank(vectorType.getShape()); - if (reducedRank != vectorReducedRank) - return failure(); + return failure(); // The source shape can't be further reduced. + if (reducedRank != vectorType.getRank()) + return failure(); // This pattern requires the vector shape to match the + // reduced source shape. if (llvm::any_of(transferReadOp.getIndices(), [](Value v) { return !isZero(v); })) return failure(); @@ -346,22 +331,14 @@ class TransferReadDropUnitDimsPattern Value c0 = rewriter.create(loc, 0); SmallVector zeros(reducedRank, c0); auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank); - auto reducedVectorType = VectorType::get( - getReducedShape(vectorType.getShape()), vectorType.getElementType()); - - auto newTransferReadOp = rewriter.create( - loc, reducedVectorType, reducedShapeSource, zeros, identityMap); - auto shapeCast = rewriter.createOrFold( - loc, vectorType, newTransferReadOp); - rewriter.replaceOp(transferReadOp, shapeCast); - + rewriter.replaceOpWithNewOp( + transferReadOp, vectorType, reducedShapeSource, zeros, identityMap); return success(); } }; -/// Rewrites `vector.transfer_write` ops where the "source" (i.e. destination) -/// has unit dims, by inserting a `memref.subview` dropping those unit dims. The -/// vector shapes are also reduced accordingly. +/// Rewrites vector.transfer_write ops where the "source" (i.e. destination) has +/// unit dims, by inserting a memref.subview dropping those unit dims. class TransferWriteDropUnitDimsPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -383,15 +360,12 @@ class TransferWriteDropUnitDimsPattern return failure(); if (!transferWriteOp.getPermutationMap().isMinorIdentity()) return failure(); - // Check if the destination shape can be further reduced. int reducedRank = getReducedRank(sourceType.getShape()); if (reducedRank == sourceType.getRank()) - return failure(); - // Check if the reduced vector shape matches the reduced destination shape. - // Otherwise, this case is not supported yet. - int vectorReducedRank = getReducedRank(vectorType.getShape()); - if (reducedRank != vectorReducedRank) - return failure(); + return failure(); // The source shape can't be further reduced. + if (reducedRank != vectorType.getRank()) + return failure(); // This pattern requires the vector shape to match the + // reduced source shape. if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) { return !isZero(v); })) return failure(); @@ -400,20 +374,12 @@ class TransferWriteDropUnitDimsPattern Value c0 = rewriter.create(loc, 0); SmallVector zeros(reducedRank, c0); auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank); - VectorType reducedVectorType = VectorType::get( - getReducedShape(vectorType.getShape()), vectorType.getElementType()); - - auto shapeCast = rewriter.createOrFold( - loc, reducedVectorType, vector); rewriter.replaceOpWithNewOp( - transferWriteOp, shapeCast, reducedShapeSource, zeros, identityMap); - + transferWriteOp, vector, reducedShapeSource, zeros, identityMap); return success(); } }; -} // namespace - /// Return true if the memref type has its inner dimension matching the given /// shape. Otherwise return false. static int64_t hasMatchingInnerContigousShape(MemRefType memrefType, @@ -473,8 +439,6 @@ checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse, return success(); } -namespace { - /// Rewrites contiguous row-major vector.transfer_read ops by inserting /// memref.collapse_shape on the source so that the resulting /// vector.transfer_read has a 1D source. Requires the source shape to be @@ -768,7 +732,6 @@ class RewriteScalarWrite : public OpRewritePattern { return success(); } }; - } // namespace void mlir::vector::transferOpflowOpt(RewriterBase &rewriter, diff --git a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir index 3efa069..e4e2e3b 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir @@ -15,14 +15,6 @@ func.func @transfer_read_rank_reducing( // CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}> // CHECK: vector.transfer_read %[[SUBVIEW]] -transform.sequence failures(propagate) { -^bb1(%module_op: !pdl.operation): - transform.vector.apply_rank_reducing_subview_patterns %module_op - : (!pdl.operation) -> !pdl.operation -} - -// ----- - func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, %vec : vector<3x2xi8>) { %c0 = arith.constant 0 : index vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] : @@ -36,97 +28,6 @@ func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6, // CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}> // CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]] -transform.sequence failures(propagate) { -^bb1(%module_op: !pdl.operation): - transform.vector.apply_rank_reducing_subview_patterns %module_op - : (!pdl.operation) -> !pdl.operation -} - -// ----- - -func.func @transfer_read_and_vector_rank_reducing( - %arg : memref<1x1x3x2x1xf32>) -> vector<3x2x1xf32> { - %c0 = arith.constant 0 : index - %cst = arith.constant 0.0 : f32 - %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0, %c0], %cst : - memref<1x1x3x2x1xf32>, vector<3x2x1xf32> - return %v : vector<3x2x1xf32> -} - -// CHECK-LABEL: func @transfer_read_and_vector_rank_reducing -// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2x1xf32> -// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0] [1, 1, 3, 2, 1] [1, 1, 1, 1, 1] -// CHECK-SAME: memref<1x1x3x2x1xf32> to memref<3x2xf32> -// CHECK: vector.transfer_read %[[SUBVIEW]]{{.*}} {in_bounds = [true, true]} : memref<3x2xf32>, vector<3x2xf32> - -transform.sequence failures(propagate) { -^bb1(%module_op: !pdl.operation): - transform.vector.apply_rank_reducing_subview_patterns %module_op - : (!pdl.operation) -> !pdl.operation -} - -// ----- - -func.func @transfer_write_and_vector_rank_reducing( - %arg : memref<1x1x3x2x1xf32>, - %vec : vector<3x2x1xf32>) { - %c0 = arith.constant 0 : index - vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0, %c0] : - vector<3x2x1xf32>, memref<1x1x3x2x1xf32> - return -} - -// CHECK-LABEL: func @transfer_write_and_vector_rank_reducing -// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2x1xf32> -// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0] [1, 1, 3, 2, 1] [1, 1, 1, 1, 1] -// CHECK-SAME: memref<1x1x3x2x1xf32> to memref<3x2xf32> -// CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]]{{.*}} {in_bounds = [true, true]} : vector<3x2xf32>, memref<3x2xf32> - -transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - transform.vector.apply_rank_reducing_subview_patterns %module_op - : (!transform.any_op) -> !transform.any_op -} - -// ----- - -func.func @transfer_read_and_vector_rank_reducing_to_0d( - %arg : memref<1x1x1x1x1xf32>) -> vector<1x1x1xf32> { - %c0 = arith.constant 0 : index - %cst = arith.constant 0.0 : f32 - %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0, %c0], %cst : - memref<1x1x1x1x1xf32>, vector<1x1x1xf32> - return %v : vector<1x1x1xf32> -} - -// CHECK-LABEL: func @transfer_read_and_vector_rank_reducing_to_0d -// CHECK-SAME: %[[MEMREF:.+]]: memref<1x1x1x1x1xf32> -// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[MEMREF]][0, 0, 0, 0, 0] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1] : memref<1x1x1x1x1xf32> to memref -// CHECK: %[[READ:.+]] = vector.transfer_read %[[SUBVIEW]]{{.*}} : memref, vector -// CHECK: vector.shape_cast %[[READ]] : vector to vector<1x1x1xf32> - -transform.sequence failures(propagate) { -^bb1(%module_op: !pdl.operation): - transform.vector.apply_rank_reducing_subview_patterns %module_op - : (!pdl.operation) -> !pdl.operation -} - -// ----- - -func.func @transfer_write_and_vector_rank_reducing_to_0d( - %arg : memref<1x1x1x1x1xf32>, - %vec : vector<1x1x1xf32>) { - %c0 = arith.constant 0 : index - vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0, %c0] : - vector<1x1x1xf32>, memref<1x1x1x1x1xf32> - return -} - -// CHECK-LABEL: func @transfer_write_and_vector_rank_reducing_to_0d -// CHECK-SAME: %[[MEMREF:.+]]: memref<1x1x1x1x1xf32>, %[[VECTOR:.+]]: vector<1x1x1xf32> -// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[MEMREF]][0, 0, 0, 0, 0] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1] : memref<1x1x1x1x1xf32> to memref -// CHECK: %[[SHCAST:.+]] = vector.shape_cast %[[VECTOR]] : vector<1x1x1xf32> to vector -// CHECK: vector.transfer_write %[[SHCAST]], %[[SUBVIEW]]{{.*}} : vector, memref transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): -- 2.7.4