From 6c3c5f8069d97e635b1887a6f9ac410391b89fae Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Tue, 5 Jul 2022 16:39:29 +0200 Subject: [PATCH] [mlir][memref] Improve type inference for rank-reducing subviews The result shape of a rank-reducing subview cannot be inferred in the general case. Just the result rank is not enough. The only thing that we can infer is the layout map. This change also improves the bufferization patterns of tensor.extract_slice and tensor.insert_slice to fully support rank-reducing operations. Differential Revision: https://reviews.llvm.org/D129144 --- mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td | 12 +++- .../Transforms/AllocTensorElimination.cpp | 21 +------ mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 46 +++++++-------- mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp | 4 +- .../Transforms/BufferizableOpInterfaceImpl.cpp | 67 ++++++++++------------ .../Transforms/VectorTransferOpTransforms.cpp | 4 +- mlir/test/Dialect/Tensor/bufferize.mlir | 36 +++++++++++- mlir/test/Dialect/Tensor/one-shot-bufferize.mlir | 24 ++++++++ mlir/unittests/Dialect/MemRef/InferShapeTest.cpp | 6 +- 9 files changed, 132 insertions(+), 88 deletions(-) diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index 097ce28..daeb7b8 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -1645,12 +1645,20 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [ ArrayRef staticOffsets, ArrayRef staticSizes, ArrayRef staticStrides); - static Type inferRankReducedResultType(unsigned resultRank, + + /// A rank-reducing result type can be inferred from the desired result + /// shape. Only the layout map is inferred. + /// + /// Note: The result shape cannot be inferred with just the result rank and + /// and the desired sizes. In case there are more "ones" among the sizes + /// than the difference in source/result rank, it is not clear which dims of + /// size one should be dropped. + static Type inferRankReducedResultType(ArrayRef resultShape, MemRefType sourceMemRefType, ArrayRef staticOffsets, ArrayRef staticSizes, ArrayRef staticStrides); - static Type inferRankReducedResultType(unsigned resultRank, + static Type inferRankReducedResultType(ArrayRef resultShape, MemRefType sourceMemRefType, ArrayRef staticOffsets, ArrayRef staticSizes, diff --git a/mlir/lib/Dialect/Bufferization/Transforms/AllocTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/AllocTensorElimination.cpp index 6c6bcab..719797a 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/AllocTensorElimination.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/AllocTensorElimination.cpp @@ -215,25 +215,10 @@ mlir::bufferization::insertSliceAnchoredAllocTensorEliminationStep( /*rewriteFunc=*/ [](OpBuilder &b, Location loc, OpOperand &operand) { auto insertOp = cast(operand.getOwner()); - // Expand offsets, sizes and strides to the full rank to handle the - // rank-reducing case. - SmallVector mixedOffsets = insertOp.getMixedOffsets(); - SmallVector mixedSizes = insertOp.getMixedSizes(); - SmallVector mixedStrides = insertOp.getMixedStrides(); - OffsetSizeAndStrideOpInterface::expandToRank( - insertOp.getDest(), mixedOffsets, mixedSizes, mixedStrides, - [&](Value target, int64_t dim) -> OpFoldResult { - auto shapedType = target.getType().cast(); - if (shapedType.isDynamicDim(dim)) - return b.create(loc, target, dim).getResult(); - return b.getIndexAttr(shapedType.getDimSize(dim)); - }); - auto t = tensor::ExtractSliceOp::inferCanonicalRankReducedResultType( - insertOp.getSourceType().getRank(), - insertOp.getDest().getType().cast(), mixedOffsets, - mixedSizes, mixedStrides); auto extractOp = b.create( - loc, t, insertOp.getDest(), mixedOffsets, mixedSizes, mixedStrides); + loc, insertOp.getSourceType(), insertOp.getDest(), + insertOp.getMixedOffsets(), insertOp.getMixedSizes(), + insertOp.getMixedStrides()); return extractOp.getResult(); }); } diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 000bac1..8e54936c 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -2145,7 +2145,7 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType, staticSizes, staticStrides); } -Type SubViewOp::inferRankReducedResultType(unsigned resultRank, +Type SubViewOp::inferRankReducedResultType(ArrayRef resultShape, MemRefType sourceRankedTensorType, ArrayRef offsets, ArrayRef sizes, @@ -2153,27 +2153,26 @@ Type SubViewOp::inferRankReducedResultType(unsigned resultRank, auto inferredType = inferResultType(sourceRankedTensorType, offsets, sizes, strides) .cast(); - assert(inferredType.getRank() >= resultRank && "expected "); - int rankDiff = inferredType.getRank() - resultRank; - if (rankDiff > 0) { - auto shape = inferredType.getShape(); - llvm::SmallBitVector dimsToProject = - getPositionsOfShapeOne(rankDiff, shape); - SmallVector projectedShape; - for (unsigned pos = 0, e = shape.size(); pos < e; ++pos) - if (!dimsToProject.test(pos)) - projectedShape.push_back(shape[pos]); - - AffineMap map = - getProjectedMap(inferredType.getLayout().getAffineMap(), dimsToProject); - inferredType = - MemRefType::get(projectedShape, inferredType.getElementType(), map, - inferredType.getMemorySpace()); - } - return inferredType; -} - -Type SubViewOp::inferRankReducedResultType(unsigned resultRank, + assert(inferredType.getRank() >= resultShape.size() && "expected "); + if (inferredType.getRank() == resultShape.size()) + return inferredType; + + // Compute which dimensions are dropped. + Optional> dimsToProject = + computeRankReductionMask(inferredType.getShape(), resultShape); + assert(dimsToProject.hasValue() && "invalid rank reduction"); + llvm::SmallBitVector dimsToProjectVector(inferredType.getRank()); + for (unsigned dim : *dimsToProject) + dimsToProjectVector.set(dim); + + // Compute layout map and result type. + AffineMap map = getProjectedMap(inferredType.getLayout().getAffineMap(), + dimsToProjectVector); + return MemRefType::get(resultShape, inferredType.getElementType(), map, + inferredType.getMemorySpace()); +} + +Type SubViewOp::inferRankReducedResultType(ArrayRef resultShape, MemRefType sourceRankedTensorType, ArrayRef offsets, ArrayRef sizes, @@ -2187,9 +2186,10 @@ Type SubViewOp::inferRankReducedResultType(unsigned resultRank, dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, ShapedType::kDynamicStrideOrOffset); return SubViewOp::inferRankReducedResultType( - resultRank, sourceRankedTensorType, staticOffsets, staticSizes, + resultShape, sourceRankedTensorType, staticOffsets, staticSizes, staticStrides); } + // Build a SubViewOp with mixed static and dynamic entries and custom result // type. If the type passed is nullptr, it is inferred. void SubViewOp::build(OpBuilder &b, OperationState &result, diff --git a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp index 2c09145..51f6a69 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp @@ -44,7 +44,7 @@ static void replaceUsesAndPropagateType(Operation *oldOp, Value val, } builder.setInsertionPoint(subviewUse); Type newType = memref::SubViewOp::inferRankReducedResultType( - subviewUse.getType().getRank(), val.getType().cast(), + subviewUse.getType().getShape(), val.getType().cast(), extractFromI64ArrayAttr(subviewUse.static_offsets()), extractFromI64ArrayAttr(subviewUse.static_sizes()), extractFromI64ArrayAttr(subviewUse.static_strides())); @@ -136,7 +136,7 @@ LogicalResult mlir::memref::multiBuffer(memref::AllocOp allocOp, sizes.push_back(builder.getIndexAttr(size)); auto dstMemref = memref::SubViewOp::inferRankReducedResultType( - allocOp.getType().getRank(), newMemref, offsets, sizes, strides) + allocOp.getType().getShape(), newMemref, offsets, sizes, strides) .cast(); Value subview = builder.create(loc, dstMemref, newAlloc, offsets, sizes, strides); diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp index 784bd8e..97da596 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -278,36 +278,24 @@ struct ExtractSliceOpInterface LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto extractSliceOp = cast(op); + SmallVector mixedOffsets = extractSliceOp.getMixedOffsets(); + SmallVector mixedSizes = extractSliceOp.getMixedSizes(); + SmallVector mixedStrides = extractSliceOp.getMixedStrides(); Location loc = extractSliceOp.getLoc(); - // Even if this op was decided to bufferize out-of-place, do not insert the - // buffer copy yet. This is done later in this function. + // Get source buffer. FailureOr srcMemref = getBuffer(rewriter, extractSliceOp.getSource(), options); if (failed(srcMemref)) return failure(); auto srcMemrefType = srcMemref->getType().cast(); - auto dstTensorType = - extractSliceOp.getResult().getType().cast(); - // Expand offsets, sizes and strides to the full rank to handle the - // rank-reducing case. - SmallVector mixedOffsets = extractSliceOp.getMixedOffsets(); - SmallVector mixedSizes = extractSliceOp.getMixedSizes(); - SmallVector mixedStrides = extractSliceOp.getMixedStrides(); - OffsetSizeAndStrideOpInterface::expandToRank( - *srcMemref, mixedOffsets, mixedSizes, mixedStrides, - [&](Value target, int64_t dim) -> OpFoldResult { - auto shapedType = target.getType().cast(); - if (shapedType.isDynamicDim(dim)) - return rewriter.create(loc, target, dim).result(); - return rewriter.getIndexAttr(shapedType.getDimSize(dim)); - }); - // Bufferize to subview. - auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType( - dstTensorType.getRank(), srcMemrefType, - mixedOffsets, mixedSizes, mixedStrides) - .cast(); + // Take a subview of the source buffer. + auto subviewMemRefType = + memref::SubViewOp::inferRankReducedResultType( + extractSliceOp.getType().getShape(), srcMemrefType, mixedOffsets, + mixedSizes, mixedStrides) + .cast(); Value subView = rewriter.create( loc, subviewMemRefType, *srcMemref, mixedOffsets, mixedSizes, mixedStrides); @@ -690,30 +678,22 @@ struct InsertSliceOpInterface // catastrophically bad scheduling decision. // TODO: be very loud about it or even consider failing the pass. auto insertSliceOp = cast(op); + SmallVector mixedOffsets = insertSliceOp.getMixedOffsets(); + SmallVector mixedSizes = insertSliceOp.getMixedSizes(); + SmallVector mixedStrides = insertSliceOp.getMixedStrides(); Location loc = insertSliceOp.getLoc(); + + // Get destination buffer. FailureOr dstMemref = getBuffer(rewriter, insertSliceOp.getDest(), options); if (failed(dstMemref)) return failure(); - // Expand offsets, sizes and strides to the full rank to handle the - // rank-reducing case. - SmallVector mixedOffsets = insertSliceOp.getMixedOffsets(); - SmallVector mixedSizes = insertSliceOp.getMixedSizes(); - SmallVector mixedStrides = insertSliceOp.getMixedStrides(); - OffsetSizeAndStrideOpInterface::expandToRank( - *dstMemref, mixedOffsets, mixedSizes, mixedStrides, - [&](Value target, int64_t dim) -> OpFoldResult { - auto shapedType = target.getType().cast(); - if (shapedType.isDynamicDim(dim)) - return rewriter.create(loc, target, dim).result(); - return rewriter.getIndexAttr(shapedType.getDimSize(dim)); - }); - // Take a subview of the dst. + // Take a subview of the destination buffer. auto dstMemrefType = dstMemref->getType().cast(); auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType( - insertSliceOp.getSourceType().getRank(), dstMemrefType, + insertSliceOp.getSourceType().getShape(), dstMemrefType, mixedOffsets, mixedSizes, mixedStrides) .cast(); Value subView = rewriter.create( @@ -946,11 +926,22 @@ struct ParallelInsertSliceOpInterface getBuffer(rewriter, parallelInsertSliceOp.getSource(), options); if (failed(srcBuffer)) return failure(); + + // Take a subview of the destination buffer. + auto destBufferType = destBuffer->getType().cast(); + auto subviewMemRefType = + memref::SubViewOp::inferRankReducedResultType( + parallelInsertSliceOp.getSourceType().getShape(), destBufferType, + parallelInsertSliceOp.getMixedOffsets(), + parallelInsertSliceOp.getMixedSizes(), + parallelInsertSliceOp.getMixedStrides()) + .cast(); Value subview = rewriter.create( - parallelInsertSliceOp.getLoc(), *destBuffer, + parallelInsertSliceOp.getLoc(), subviewMemRefType, *destBuffer, parallelInsertSliceOp.getMixedOffsets(), parallelInsertSliceOp.getMixedSizes(), parallelInsertSliceOp.getMixedStrides()); + // This memcpy will fold away if everything bufferizes in-place. if (failed(options.createMemCpy(rewriter, parallelInsertSliceOp.getLoc(), *srcBuffer, subview))) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp index 198ece1..6cddef2 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -216,8 +216,10 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) { static MemRefType dropUnitDims(MemRefType inputType, ArrayRef offsets, ArrayRef sizes, ArrayRef strides) { + SmallVector targetShape = llvm::to_vector( + llvm::make_filter_range(sizes, [](int64_t sz) { return sz != 1; })); Type rankReducedType = memref::SubViewOp::inferRankReducedResultType( - 0, inputType, offsets, sizes, strides); + targetShape, inputType, offsets, sizes, strides); return canonicalizeStridedLayout(rankReducedType.cast()); } diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir index 6a3c4e1..937588e 100644 --- a/mlir/test/Dialect/Tensor/bufferize.mlir +++ b/mlir/test/Dialect/Tensor/bufferize.mlir @@ -292,7 +292,7 @@ func.func @tensor.extract_slice_rank_reducing( // CHECK-SAME: %[[t1:.*]]: tensor, %[[t2:.*]]: tensor, // CHECK-SAME: %[[idx1:.*]]: index, %[[idx2:.*]]: index func.func @tensor.insert_slice(%t1: tensor, %t2: tensor, - %idx1: index, %idx2: index) -> tensor { + %idx1: index, %idx2: index) -> tensor { // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref @@ -313,6 +313,40 @@ func.func @tensor.insert_slice(%t1: tensor, %t2: tensor, // ----- +// CHECK: #[[$MAP11:.*]] = affine_map<()[s0] -> (s0)> + +// CHECK-LABEL: func @tensor.insert_slice_rank_reducing_1( +func.func @tensor.insert_slice_rank_reducing_1( + %t1: tensor, %f: tensor, %idx1: index, %idx2: index) + -> tensor +{ + // CHECK: %[[alloc:.*]] = memref.alloc{{.*}} : memref + // CHECK: memref.subview %[[alloc]][%{{.*}}, %{{.*}}] [1, 1] [1, 1] : memref to memref + // CHECK: memref.copy {{.*}} : memref to memref + %0 = tensor.insert_slice %f into %t1[%idx1, %idx2][1, 1][1, 1] + : tensor into tensor + return %0 : tensor +} + +// ----- + +// CHECK: #[[$MAP12:.*]] = affine_map<(d0, d1, d2, d3, d4)[s0, s1, s2, s3, s4, s5] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4 + d4 * s5)> + +// CHECK-LABEL: func @tensor.insert_slice_rank_reducing_2( +func.func @tensor.insert_slice_rank_reducing_2( + %t1: tensor, %t2: tensor<2x1x4x1x1xf32>, %i: index) + -> tensor +{ + // CHECK: %[[alloc:.*]] = memref.alloc{{.*}} : memref + // CHECK: memref.subview %[[alloc]][{{.*}}] [1, 2, 1, 4, 1, 1, 1] [1, 1, 1, 1, 1, 1, 1] : memref to memref<2x1x4x1x1xf32, #[[$MAP12]]> + // CHECK: memref.copy {{.*}} : memref<2x1x4x1x1xf32> to memref<2x1x4x1x1xf32, #[[$MAP12]]> + %0 = tensor.insert_slice %t2 into %t1[%i, %i, %i, %i, %i, %i, %i][1, 2, 1, 4, 1, 1, 1][1, 1, 1, 1, 1, 1, 1] + : tensor<2x1x4x1x1xf32> into tensor + return %0 : tensor +} + +// ----- + // CHECK-LABEL: func @tensor.insert( // CHECK-SAME: %[[t1:.*]]: tensor<5xf32>, %[[idx1:.*]]: index, // CHECK-SAME: %[[f:.*]]: f32 diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir index 7249d54..4b462f6 100644 --- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir @@ -193,3 +193,27 @@ func.func @rank_reducing( } return %5: tensor } + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> + +// CHECK-LABEL: func.func @rank_reducing_parallel_insert_slice +func.func @rank_reducing_parallel_insert_slice(%in: tensor<100xf32>, %out: tensor<200x100xf32>) { + %c1 = arith.constant 1 : index + %num_threads = arith.constant 100 : index + + // CHECK: scf.foreach_thread {{.*}} { + %result = scf.foreach_thread (%thread_idx) in (%num_threads) -> tensor<200x100xf32> { + %1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32> + scf.foreach_thread.perform_concurrently { + // CHECK: memref.subview %{{.*}}[%{{.*}}] [1] [1] : memref<100xf32, #[[$MAP0]]> to memref<1xf32, #[[$MAP0]]> + // CHECK: memref.subview %{{.*}}[1, %{{.*}}] [1, 1] [1, 1] : memref<200x100xf32, #[[$MAP1]]> to memref<1xf32, #[[$MAP0]]> + tensor.parallel_insert_slice %1 into %out[1, %thread_idx][1, 1][1, 1] : + tensor<1xf32> into tensor<200x100xf32> + } + } + // CHECK: } + return +} diff --git a/mlir/unittests/Dialect/MemRef/InferShapeTest.cpp b/mlir/unittests/Dialect/MemRef/InferShapeTest.cpp index 1899755..28dc768 100644 --- a/mlir/unittests/Dialect/MemRef/InferShapeTest.cpp +++ b/mlir/unittests/Dialect/MemRef/InferShapeTest.cpp @@ -21,7 +21,7 @@ TEST(InferShapeTest, inferRankReducedShapeIdentity) { OpBuilder b(&ctx); auto sourceMemref = MemRefType::get({10, 5}, b.getIndexType()); auto reducedType = SubViewOp::inferRankReducedResultType( - /*resultRank=*/1, sourceMemref, {2, 3}, {1, 2}, {1, 1}); + /*resultShape=*/{2}, sourceMemref, {2, 3}, {1, 2}, {1, 1}); AffineExpr dim0; bindDims(&ctx, dim0); auto expectedType = @@ -38,7 +38,7 @@ TEST(InferShapeTest, inferRankReducedShapeNonIdentity) { auto sourceMemref = MemRefType::get({10, 5}, b.getIndexType(), AffineMap::get(2, 0, 1000 * dim0 + dim1)); auto reducedType = SubViewOp::inferRankReducedResultType( - /*resultRank=*/1, sourceMemref, {2, 3}, {1, 2}, {1, 1}); + /*resultShape=*/{2}, sourceMemref, {2, 3}, {1, 2}, {1, 1}); auto expectedType = MemRefType::get({2}, b.getIndexType(), AffineMap::get(1, 0, dim0 + 2003)); EXPECT_EQ(reducedType, expectedType); @@ -52,7 +52,7 @@ TEST(InferShapeTest, inferRankReducedShapeToScalar) { auto sourceMemref = MemRefType::get({10, 5}, b.getIndexType(), AffineMap::get(2, 0, 1000 * dim0 + dim1)); auto reducedType = SubViewOp::inferRankReducedResultType( - /*resultRank=*/0, sourceMemref, {2, 3}, {1, 1}, {1, 1}); + /*resultShape=*/{}, sourceMemref, {2, 3}, {1, 1}, {1, 1}); auto expectedType = MemRefType::get({}, b.getIndexType(), AffineMap::get(0, 0, b.getAffineConstantExpr(2003))); -- 2.7.4