[mlir][tensor][bufferize] Support memory_space for tensor.pad
authorMatthias Springer <springerm@google.com>
Thu, 27 Oct 2022 10:24:02 +0000 (12:24 +0200)
committerMatthias Springer <springerm@google.com>
Thu, 27 Oct 2022 10:29:57 +0000 (12:29 +0200)
This change adds memory space support to tensor.pad. (tensor.generate and tensor.from_elements do not support memory spaces yet.)

The memory space is inferred from the buffer of the source tensor.

Instead of lowering tensor.pad to tensor.generate + tensor.insert_slice, it is now lowered to bufferization.alloc_tensor (with the correct memory space) + linalg.map + tensor.insert_slice.

Memory space support for the remaining two tensor ops is left for a later point, as this requires some more design discussions.

Differential Revision: https://reviews.llvm.org/D136265

mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/test/Dialect/Tensor/bufferize.mlir
mlir/test/Dialect/Tensor/one-shot-bufferize.mlir

index 5cddd3c..ea66663 100644 (file)
@@ -779,7 +779,8 @@ struct InsertSliceOpInterface
   }
 };
 
-/// Bufferization of tensor.pad. Replace with tensor.generate + insert_slice.
+/// Bufferization of tensor.pad. Replace with bufferization.alloc_tensor +
+/// linalg.map + insert_slice.
 /// For best performance, vectorize before bufferization (better performance in
 /// case of padding with a constant).
 struct PadOpInterface
@@ -804,6 +805,21 @@ struct PadOpInterface
     return {};
   }
 
+  FailureOr<BaseMemRefType>
+  getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+                const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
+    // Infer memory space from the source tensor.
+    auto padOp = cast<tensor::PadOp>(op);
+    auto maybeSrcBufferType =
+        bufferization::getBufferType(padOp.getSource(), options, fixedTypes);
+    if (failed(maybeSrcBufferType))
+      return failure();
+    MemRefLayoutAttrInterface layout;
+    return MemRefType::get(padOp.getResultType().getShape(),
+                           padOp.getResultType().getElementType(), layout,
+                           maybeSrcBufferType->getMemorySpace());
+  }
+
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationOptions &options) const {
     auto padOp = cast<tensor::PadOp>(op);
@@ -837,17 +853,22 @@ struct PadOpInterface
       dynamicSizes.push_back(sum);
     }
 
-    // Create tensor::GenerateOp.
-    auto generateOp =
-        rewriter.create<tensor::GenerateOp>(loc, resultType, dynamicSizes);
-    // Move over "escape" attribute if present.
-    if (padOp->hasAttr(BufferizationDialect::kEscapeAttrName))
-      generateOp->setAttr(
-          BufferizationDialect::kEscapeAttrName,
-          padOp->getAttr(BufferizationDialect::kEscapeAttrName));
-    // TODO: Memory space
-    rewriter.inlineRegionBefore(padOp.getRegion(), generateOp.getBody(),
-                                generateOp.getBody().begin());
+    // Should the buffer be deallocated?
+    bool dealloc =
+        shouldDeallocateOpResult(padOp.getResult().cast<OpResult>(), options);
+    // Allocate a buffer for the padded result.
+    FailureOr<Value> tensorAlloc =
+        allocateTensorForShapedValue(rewriter, loc, padOp.getResult(),
+                                     /*escape=*/!dealloc, options,
+                                     /*copy=*/false);
+    if (failed(tensorAlloc))
+      return failure();
+
+    // tensor::PadOp is like tensor::GenerateOp: The only difference is that
+    // only a part of the generated tensor is needed. For simplicity, we reuse
+    // the same functionality here.
+    Value filledBuffer = lowerGenerateLikeOpBody(
+        rewriter, loc, *tensorAlloc, dynamicSizes, padOp.getBodyRegion());
 
     // Create tensor::InsertSliceOp.
     SmallVector<OpFoldResult> sliceSizes =
@@ -855,7 +876,7 @@ struct PadOpInterface
     SmallVector<OpFoldResult> sliceStrides(srcType.getRank(),
                                            rewriter.getIndexAttr(1));
     rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
-        padOp, padOp.getSource(), generateOp.getResult(),
+        padOp, padOp.getSource(), filledBuffer,
         /*offsets=*/padOp.getMixedLowPad(), sliceSizes, sliceStrides);
 
     return success();
index a2a1fb0..82226bc 100644 (file)
@@ -539,7 +539,8 @@ func.func @tensor.reshape(%t1: tensor<?x10xf32>) -> tensor<2x2x5xf32> {
 
 // -----
 
-// CHECK:       #[[$sum_map:.+]] = affine_map<()[s0, s1, s2] -> (s0 + s1 + s2)>
+// CHECK:       #[[$sum_map_1:.+]] = affine_map<()[s0, s1] -> (s1 + s0 + 5)>
+// CHECK:       #[[$sum_map_2:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 10)>
 // CHECK-LABEL: func @tensor.pad(
 //  CHECK-SAME:   %[[t1:.*]]: tensor<?x10xindex>, %[[l2:.*]]: index, %[[h1:.*]]: index, %[[h2:.*]]: index
 func.func @tensor.pad(%t1: tensor<?x10xindex>, %l2: index, %h1: index,
@@ -547,11 +548,10 @@ func.func @tensor.pad(%t1: tensor<?x10xindex>, %l2: index, %h1: index,
   // CHECK-DAG: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x10xindex>
   // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
   // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
-  // CHECK-DAG: %[[c5:.*]] = arith.constant 5 : index
   // CHECK-DAG: %[[dim0:.*]] = memref.dim %[[m1]], %[[c0]]
   // CHECK-DAG: %[[dim1:.*]] = memref.dim %[[m1]], %[[c1]]
-  // CHECK-DAG: %[[size0:.*]] = affine.apply #[[$sum_map]]()[%[[dim0]], %[[c5]], %[[h1]]]
-  // CHECK-DAG: %[[size1:.*]] = affine.apply #[[$sum_map]]()[%[[dim1]], %[[l2]], %[[h2]]]
+  // CHECK-DAG: %[[size0:.*]] = affine.apply #[[$sum_map_1]]()[%[[h1]], %[[dim0]]]
+  // CHECK-DAG: %[[size1:.*]] = affine.apply #[[$sum_map_2]]()[%[[l2]], %[[h2]]]
   // CHECK:     %[[alloc:.*]] = memref.alloc(%[[size0]], %[[size1]]) {{.*}} : memref<?x?xindex>
   // CHECK:     %[[alloc_t:.*]] = bufferization.to_tensor %[[alloc]]
   // CHECK:     %[[mapped:.*]] = linalg.map
index 0d9ade7..ab80f0e 100644 (file)
@@ -251,3 +251,31 @@ func.func @insert_equivalent_tensor(%t: tensor<10xf32>) -> tensor<10xf32> {
   %1 = tensor.insert_slice %0 into %t[0][10][1] : tensor<10xf32> into tensor<10xf32>
   return %1 : tensor<10xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @pad_memory_space(
+//  CHECK-SAME:     %[[t:.*]]: memref<?xf32, strided<[?], offset: ?>>
+func.func @pad_memory_space(%t: tensor<?xf32>, %h1: index, %f: f32, %pos: index) -> f32
+{
+  // CHECK: %[[alloc_tensor:.*]] = memref.alloc{{.*}} : memref<?xf32, 3>
+  // CHECK: memref.copy %[[t]], %[[alloc_tensor]]
+  %0 = bufferization.alloc_tensor() copy(%t)
+      {memory_space = 3 : ui64} : tensor<?xf32>
+  // CHECK: %[[padded_alloc:.*]] = memref.alloc() {{.*}} : memref<15xf32, 3>
+  // CHECK: linalg.map
+  // CHECK:     outs(%[[padded_alloc]] : memref<15xf32, 3>)
+  // CHECK:   linalg.yield %{{.*}}
+  // CHECK: }
+  // CHECK: %[[subview:.*]] = memref.subview {{.*}} : memref<15xf32, 3> to memref<?xf32, strided<[1], offset: 2>, 3>
+  // CHECK: memref.copy %[[alloc_tensor]], %[[subview]]
+  %1 = tensor.pad %0 low[2] high[%h1] {
+  ^bb0(%arg0: index):
+    tensor.yield %f : f32
+  } : tensor<?xf32> to tensor<15xf32>
+  // CHECK: memref.load {{.*}} : memref<15xf32, 3>
+  %2 = tensor.extract %1[%pos] : tensor<15xf32>
+  // CHECK-DAG: memref.dealloc %[[alloc_tensor]]
+  // CHECK-DAG: memref.dealloc %[[padded_alloc]]
+  return %2 : f32
+}