[mlir][memref] Improve type inference for rank-reducing subviews
authorMatthias Springer <springerm@google.com>
Tue, 5 Jul 2022 14:39:29 +0000 (16:39 +0200)
committerMatthias Springer <springerm@google.com>
Tue, 5 Jul 2022 14:49:07 +0000 (16:49 +0200)
The result shape of a rank-reducing subview cannot be inferred in the general case. Just the result rank is not enough. The only thing that we can infer is the layout map.

This change also improves the bufferization patterns of tensor.extract_slice and tensor.insert_slice to fully support rank-reducing operations.

Differential Revision: https://reviews.llvm.org/D129144

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/lib/Dialect/Bufferization/Transforms/AllocTensorElimination.cpp
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
mlir/test/Dialect/Tensor/bufferize.mlir
mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
mlir/unittests/Dialect/MemRef/InferShapeTest.cpp

index 097ce28..daeb7b8 100644 (file)
@@ -1645,12 +1645,20 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
                                 ArrayRef<OpFoldResult> staticOffsets,
                                 ArrayRef<OpFoldResult> staticSizes,
                                 ArrayRef<OpFoldResult> staticStrides);
-    static Type inferRankReducedResultType(unsigned resultRank,
+
+    /// A rank-reducing result type can be inferred from the desired result
+    /// shape. Only the layout map is inferred.
+    ///
+    /// Note: The result shape cannot be inferred with just the result rank and
+    /// and the desired sizes. In case there are more "ones" among the sizes
+    /// than the difference in source/result rank, it is not clear which dims of
+    /// size one should be dropped.
+    static Type inferRankReducedResultType(ArrayRef<int64_t> resultShape,
                                            MemRefType sourceMemRefType,
                                            ArrayRef<int64_t> staticOffsets,
                                            ArrayRef<int64_t> staticSizes,
                                            ArrayRef<int64_t> staticStrides);
-    static Type inferRankReducedResultType(unsigned resultRank,
+    static Type inferRankReducedResultType(ArrayRef<int64_t> resultShape,
                                            MemRefType sourceMemRefType,
                                            ArrayRef<OpFoldResult> staticOffsets,
                                            ArrayRef<OpFoldResult> staticSizes,
index 6c6bcab..719797a 100644 (file)
@@ -215,25 +215,10 @@ mlir::bufferization::insertSliceAnchoredAllocTensorEliminationStep(
       /*rewriteFunc=*/
       [](OpBuilder &b, Location loc, OpOperand &operand) {
         auto insertOp = cast<tensor::InsertSliceOp>(operand.getOwner());
-        // Expand offsets, sizes and strides to the full rank to handle the
-        // rank-reducing case.
-        SmallVector<OpFoldResult> mixedOffsets = insertOp.getMixedOffsets();
-        SmallVector<OpFoldResult> mixedSizes = insertOp.getMixedSizes();
-        SmallVector<OpFoldResult> mixedStrides = insertOp.getMixedStrides();
-        OffsetSizeAndStrideOpInterface::expandToRank(
-            insertOp.getDest(), mixedOffsets, mixedSizes, mixedStrides,
-            [&](Value target, int64_t dim) -> OpFoldResult {
-              auto shapedType = target.getType().cast<ShapedType>();
-              if (shapedType.isDynamicDim(dim))
-                return b.create<tensor::DimOp>(loc, target, dim).getResult();
-              return b.getIndexAttr(shapedType.getDimSize(dim));
-            });
-        auto t = tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
-            insertOp.getSourceType().getRank(),
-            insertOp.getDest().getType().cast<RankedTensorType>(), mixedOffsets,
-            mixedSizes, mixedStrides);
         auto extractOp = b.create<tensor::ExtractSliceOp>(
-            loc, t, insertOp.getDest(), mixedOffsets, mixedSizes, mixedStrides);
+            loc, insertOp.getSourceType(), insertOp.getDest(),
+            insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
+            insertOp.getMixedStrides());
         return extractOp.getResult();
       });
 }
index 000bac1..8e54936 100644 (file)
@@ -2145,7 +2145,7 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
                                     staticSizes, staticStrides);
 }
 
-Type SubViewOp::inferRankReducedResultType(unsigned resultRank,
+Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
                                            MemRefType sourceRankedTensorType,
                                            ArrayRef<int64_t> offsets,
                                            ArrayRef<int64_t> sizes,
@@ -2153,27 +2153,26 @@ Type SubViewOp::inferRankReducedResultType(unsigned resultRank,
   auto inferredType =
       inferResultType(sourceRankedTensorType, offsets, sizes, strides)
           .cast<MemRefType>();
-  assert(inferredType.getRank() >= resultRank && "expected ");
-  int rankDiff = inferredType.getRank() - resultRank;
-  if (rankDiff > 0) {
-    auto shape = inferredType.getShape();
-    llvm::SmallBitVector dimsToProject =
-        getPositionsOfShapeOne(rankDiff, shape);
-    SmallVector<int64_t> projectedShape;
-    for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
-      if (!dimsToProject.test(pos))
-        projectedShape.push_back(shape[pos]);
-
-    AffineMap map =
-        getProjectedMap(inferredType.getLayout().getAffineMap(), dimsToProject);
-    inferredType =
-        MemRefType::get(projectedShape, inferredType.getElementType(), map,
-                        inferredType.getMemorySpace());
-  }
-  return inferredType;
-}
-
-Type SubViewOp::inferRankReducedResultType(unsigned resultRank,
+  assert(inferredType.getRank() >= resultShape.size() && "expected ");
+  if (inferredType.getRank() == resultShape.size())
+    return inferredType;
+
+  // Compute which dimensions are dropped.
+  Optional<llvm::SmallDenseSet<unsigned>> dimsToProject =
+      computeRankReductionMask(inferredType.getShape(), resultShape);
+  assert(dimsToProject.hasValue() && "invalid rank reduction");
+  llvm::SmallBitVector dimsToProjectVector(inferredType.getRank());
+  for (unsigned dim : *dimsToProject)
+    dimsToProjectVector.set(dim);
+
+  // Compute layout map and result type.
+  AffineMap map = getProjectedMap(inferredType.getLayout().getAffineMap(),
+                                  dimsToProjectVector);
+  return MemRefType::get(resultShape, inferredType.getElementType(), map,
+                         inferredType.getMemorySpace());
+}
+
+Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
                                            MemRefType sourceRankedTensorType,
                                            ArrayRef<OpFoldResult> offsets,
                                            ArrayRef<OpFoldResult> sizes,
@@ -2187,9 +2186,10 @@ Type SubViewOp::inferRankReducedResultType(unsigned resultRank,
   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
                              ShapedType::kDynamicStrideOrOffset);
   return SubViewOp::inferRankReducedResultType(
-      resultRank, sourceRankedTensorType, staticOffsets, staticSizes,
+      resultShape, sourceRankedTensorType, staticOffsets, staticSizes,
       staticStrides);
 }
+
 // Build a SubViewOp with mixed static and dynamic entries and custom result
 // type. If the type passed is nullptr, it is inferred.
 void SubViewOp::build(OpBuilder &b, OperationState &result,
index 2c09145..51f6a69 100644 (file)
@@ -44,7 +44,7 @@ static void replaceUsesAndPropagateType(Operation *oldOp, Value val,
     }
     builder.setInsertionPoint(subviewUse);
     Type newType = memref::SubViewOp::inferRankReducedResultType(
-        subviewUse.getType().getRank(), val.getType().cast<MemRefType>(),
+        subviewUse.getType().getShape(), val.getType().cast<MemRefType>(),
         extractFromI64ArrayAttr(subviewUse.static_offsets()),
         extractFromI64ArrayAttr(subviewUse.static_sizes()),
         extractFromI64ArrayAttr(subviewUse.static_strides()));
@@ -136,7 +136,7 @@ LogicalResult mlir::memref::multiBuffer(memref::AllocOp allocOp,
     sizes.push_back(builder.getIndexAttr(size));
   auto dstMemref =
       memref::SubViewOp::inferRankReducedResultType(
-          allocOp.getType().getRank(), newMemref, offsets, sizes, strides)
+          allocOp.getType().getShape(), newMemref, offsets, sizes, strides)
           .cast<MemRefType>();
   Value subview = builder.create<memref::SubViewOp>(loc, dstMemref, newAlloc,
                                                     offsets, sizes, strides);
index 784bd8e..97da596 100644 (file)
@@ -278,36 +278,24 @@ struct ExtractSliceOpInterface
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationOptions &options) const {
     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
+    SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
+    SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
+    SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
     Location loc = extractSliceOp.getLoc();
 
-    // Even if this op was decided to bufferize out-of-place, do not insert the
-    // buffer copy yet. This is done later in this function.
+    // Get source buffer.
     FailureOr<Value> srcMemref =
         getBuffer(rewriter, extractSliceOp.getSource(), options);
     if (failed(srcMemref))
       return failure();
     auto srcMemrefType = srcMemref->getType().cast<MemRefType>();
-    auto dstTensorType =
-        extractSliceOp.getResult().getType().cast<RankedTensorType>();
 
-    // Expand offsets, sizes and strides to the full rank to handle the
-    // rank-reducing case.
-    SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
-    SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
-    SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
-    OffsetSizeAndStrideOpInterface::expandToRank(
-        *srcMemref, mixedOffsets, mixedSizes, mixedStrides,
-        [&](Value target, int64_t dim) -> OpFoldResult {
-          auto shapedType = target.getType().cast<ShapedType>();
-          if (shapedType.isDynamicDim(dim))
-            return rewriter.create<memref::DimOp>(loc, target, dim).result();
-          return rewriter.getIndexAttr(shapedType.getDimSize(dim));
-        });
-    // Bufferize to subview.
-    auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType(
-                                 dstTensorType.getRank(), srcMemrefType,
-                                 mixedOffsets, mixedSizes, mixedStrides)
-                                 .cast<MemRefType>();
+    // Take a subview of the source buffer.
+    auto subviewMemRefType =
+        memref::SubViewOp::inferRankReducedResultType(
+            extractSliceOp.getType().getShape(), srcMemrefType, mixedOffsets,
+            mixedSizes, mixedStrides)
+            .cast<MemRefType>();
     Value subView = rewriter.create<memref::SubViewOp>(
         loc, subviewMemRefType, *srcMemref, mixedOffsets, mixedSizes,
         mixedStrides);
@@ -690,30 +678,22 @@ struct InsertSliceOpInterface
     // catastrophically bad scheduling decision.
     // TODO: be very loud about it or even consider failing the pass.
     auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
+    SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
+    SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
+    SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
     Location loc = insertSliceOp.getLoc();
+
+    // Get destination buffer.
     FailureOr<Value> dstMemref =
         getBuffer(rewriter, insertSliceOp.getDest(), options);
     if (failed(dstMemref))
       return failure();
 
-    // Expand offsets, sizes and strides to the full rank to handle the
-    // rank-reducing case.
-    SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
-    SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
-    SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
-    OffsetSizeAndStrideOpInterface::expandToRank(
-        *dstMemref, mixedOffsets, mixedSizes, mixedStrides,
-        [&](Value target, int64_t dim) -> OpFoldResult {
-          auto shapedType = target.getType().cast<ShapedType>();
-          if (shapedType.isDynamicDim(dim))
-            return rewriter.create<memref::DimOp>(loc, target, dim).result();
-          return rewriter.getIndexAttr(shapedType.getDimSize(dim));
-        });
-    // Take a subview of the dst.
+    // Take a subview of the destination buffer.
     auto dstMemrefType = dstMemref->getType().cast<MemRefType>();
     auto subviewMemRefType =
         memref::SubViewOp::inferRankReducedResultType(
-            insertSliceOp.getSourceType().getRank(), dstMemrefType,
+            insertSliceOp.getSourceType().getShape(), dstMemrefType,
             mixedOffsets, mixedSizes, mixedStrides)
             .cast<MemRefType>();
     Value subView = rewriter.create<memref::SubViewOp>(
@@ -946,11 +926,22 @@ struct ParallelInsertSliceOpInterface
         getBuffer(rewriter, parallelInsertSliceOp.getSource(), options);
     if (failed(srcBuffer))
       return failure();
+
+    // Take a subview of the destination buffer.
+    auto destBufferType = destBuffer->getType().cast<MemRefType>();
+    auto subviewMemRefType =
+        memref::SubViewOp::inferRankReducedResultType(
+            parallelInsertSliceOp.getSourceType().getShape(), destBufferType,
+            parallelInsertSliceOp.getMixedOffsets(),
+            parallelInsertSliceOp.getMixedSizes(),
+            parallelInsertSliceOp.getMixedStrides())
+            .cast<MemRefType>();
     Value subview = rewriter.create<memref::SubViewOp>(
-        parallelInsertSliceOp.getLoc(), *destBuffer,
+        parallelInsertSliceOp.getLoc(), subviewMemRefType, *destBuffer,
         parallelInsertSliceOp.getMixedOffsets(),
         parallelInsertSliceOp.getMixedSizes(),
         parallelInsertSliceOp.getMixedStrides());
+
     // This memcpy will fold away if everything bufferizes in-place.
     if (failed(options.createMemCpy(rewriter, parallelInsertSliceOp.getLoc(),
                                     *srcBuffer, subview)))
index 198ece1..6cddef2 100644 (file)
@@ -216,8 +216,10 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
 static MemRefType dropUnitDims(MemRefType inputType, ArrayRef<int64_t> offsets,
                                ArrayRef<int64_t> sizes,
                                ArrayRef<int64_t> strides) {
+  SmallVector<int64_t> targetShape = llvm::to_vector(
+      llvm::make_filter_range(sizes, [](int64_t sz) { return sz != 1; }));
   Type rankReducedType = memref::SubViewOp::inferRankReducedResultType(
-      0, inputType, offsets, sizes, strides);
+      targetShape, inputType, offsets, sizes, strides);
   return canonicalizeStridedLayout(rankReducedType.cast<MemRefType>());
 }
 
index 6a3c4e1..937588e 100644 (file)
@@ -292,7 +292,7 @@ func.func @tensor.extract_slice_rank_reducing(
 //  CHECK-SAME:     %[[t1:.*]]: tensor<?x?xf32>, %[[t2:.*]]: tensor<?x10xf32>,
 //  CHECK-SAME:     %[[idx1:.*]]: index, %[[idx2:.*]]: index
 func.func @tensor.insert_slice(%t1: tensor<?x?xf32>, %t2: tensor<?x10xf32>,
-                          %idx1: index, %idx2: index) -> tensor<?x?xf32> {
+                               %idx1: index, %idx2: index) -> tensor<?x?xf32> {
   // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
   // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
   // CHECK-DAG: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x?xf32>
@@ -313,6 +313,40 @@ func.func @tensor.insert_slice(%t1: tensor<?x?xf32>, %t2: tensor<?x10xf32>,
 
 // -----
 
+// CHECK: #[[$MAP11:.*]] = affine_map<()[s0] -> (s0)>
+
+// CHECK-LABEL: func @tensor.insert_slice_rank_reducing_1(
+func.func @tensor.insert_slice_rank_reducing_1(
+    %t1: tensor<?x?xf32>, %f: tensor<f32>, %idx1: index, %idx2: index)
+  -> tensor<?x?xf32>
+{
+  // CHECK: %[[alloc:.*]] = memref.alloc{{.*}} : memref<?x?xf32>
+  // CHECK: memref.subview %[[alloc]][%{{.*}}, %{{.*}}] [1, 1] [1, 1] : memref<?x?xf32> to memref<f32, #[[$MAP11]]>
+  // CHECK: memref.copy {{.*}} : memref<f32> to memref<f32, #[[$MAP11]]>
+  %0 = tensor.insert_slice %f into %t1[%idx1, %idx2][1, 1][1, 1]
+      : tensor<f32> into tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// -----
+
+// CHECK: #[[$MAP12:.*]] = affine_map<(d0, d1, d2, d3, d4)[s0, s1, s2, s3, s4, s5] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4 + d4 * s5)>
+
+// CHECK-LABEL: func @tensor.insert_slice_rank_reducing_2(
+func.func @tensor.insert_slice_rank_reducing_2(
+    %t1: tensor<?x?x?x?x?x?x?xf32>, %t2: tensor<2x1x4x1x1xf32>, %i: index)
+  -> tensor<?x?x?x?x?x?x?xf32>
+{
+  // CHECK: %[[alloc:.*]] = memref.alloc{{.*}} : memref<?x?x?x?x?x?x?xf32>
+  // CHECK: memref.subview %[[alloc]][{{.*}}] [1, 2, 1, 4, 1, 1, 1] [1, 1, 1, 1, 1, 1, 1] : memref<?x?x?x?x?x?x?xf32> to memref<2x1x4x1x1xf32, #[[$MAP12]]>
+  // CHECK: memref.copy {{.*}} : memref<2x1x4x1x1xf32> to memref<2x1x4x1x1xf32, #[[$MAP12]]>
+  %0 = tensor.insert_slice %t2 into %t1[%i, %i, %i, %i, %i, %i, %i][1, 2, 1, 4, 1, 1, 1][1, 1, 1, 1, 1, 1, 1]
+      : tensor<2x1x4x1x1xf32> into tensor<?x?x?x?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?x?x?x?xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @tensor.insert(
 //  CHECK-SAME:     %[[t1:.*]]: tensor<5xf32>, %[[idx1:.*]]: index,
 //  CHECK-SAME:     %[[f:.*]]: f32
index 7249d54..4b462f6 100644 (file)
@@ -193,3 +193,27 @@ func.func @rank_reducing(
   }
   return %5: tensor<?x1x6x8xf32>
 }
+
+// -----
+
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
+
+// CHECK-LABEL: func.func @rank_reducing_parallel_insert_slice
+func.func @rank_reducing_parallel_insert_slice(%in: tensor<100xf32>, %out: tensor<200x100xf32>) {
+  %c1 = arith.constant 1 : index
+  %num_threads = arith.constant 100 : index
+
+  // CHECK: scf.foreach_thread {{.*}} {
+  %result = scf.foreach_thread (%thread_idx) in (%num_threads) -> tensor<200x100xf32> {
+      %1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32>
+      scf.foreach_thread.perform_concurrently {
+        // CHECK: memref.subview %{{.*}}[%{{.*}}] [1] [1] : memref<100xf32, #[[$MAP0]]> to memref<1xf32, #[[$MAP0]]>
+        // CHECK: memref.subview %{{.*}}[1, %{{.*}}] [1, 1] [1, 1] : memref<200x100xf32, #[[$MAP1]]> to memref<1xf32, #[[$MAP0]]>
+        tensor.parallel_insert_slice %1 into %out[1, %thread_idx][1, 1][1, 1] :
+          tensor<1xf32> into tensor<200x100xf32>
+      }
+  }
+  // CHECK: }
+  return
+}
index 1899755..28dc768 100644 (file)
@@ -21,7 +21,7 @@ TEST(InferShapeTest, inferRankReducedShapeIdentity) {
   OpBuilder b(&ctx);
   auto sourceMemref = MemRefType::get({10, 5}, b.getIndexType());
   auto reducedType = SubViewOp::inferRankReducedResultType(
-      /*resultRank=*/1, sourceMemref, {2, 3}, {1, 2}, {1, 1});
+      /*resultShape=*/{2}, sourceMemref, {2, 3}, {1, 2}, {1, 1});
   AffineExpr dim0;
   bindDims(&ctx, dim0);
   auto expectedType =
@@ -38,7 +38,7 @@ TEST(InferShapeTest, inferRankReducedShapeNonIdentity) {
   auto sourceMemref = MemRefType::get({10, 5}, b.getIndexType(),
                                       AffineMap::get(2, 0, 1000 * dim0 + dim1));
   auto reducedType = SubViewOp::inferRankReducedResultType(
-      /*resultRank=*/1, sourceMemref, {2, 3}, {1, 2}, {1, 1});
+      /*resultShape=*/{2}, sourceMemref, {2, 3}, {1, 2}, {1, 1});
   auto expectedType =
       MemRefType::get({2}, b.getIndexType(), AffineMap::get(1, 0, dim0 + 2003));
   EXPECT_EQ(reducedType, expectedType);
@@ -52,7 +52,7 @@ TEST(InferShapeTest, inferRankReducedShapeToScalar) {
   auto sourceMemref = MemRefType::get({10, 5}, b.getIndexType(),
                                       AffineMap::get(2, 0, 1000 * dim0 + dim1));
   auto reducedType = SubViewOp::inferRankReducedResultType(
-      /*resultRank=*/0, sourceMemref, {2, 3}, {1, 1}, {1, 1});
+      /*resultShape=*/{}, sourceMemref, {2, 3}, {1, 1}, {1, 1});
   auto expectedType =
       MemRefType::get({}, b.getIndexType(),
                       AffineMap::get(0, 0, b.getAffineConstantExpr(2003)));