[mlir][linalg] BufferizeToAllocationOp: Return handle to buffer
authorMatthias Springer <me@m-sp.org>
Tue, 27 Jun 2023 12:29:04 +0000 (14:29 +0200)
committerMatthias Springer <me@m-sp.org>
Tue, 27 Jun 2023 12:55:44 +0000 (14:55 +0200)
Add an additional result handle to the op. This new handle is mapped to the newly allocated buffer.

Differential Revision: https://reviews.llvm.org/D153514

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir

index 0157fc2..1924355 100644 (file)
@@ -85,8 +85,9 @@ def BufferizeToAllocationOp : Op<Transform_Dialect,
   let description = [{
     This transform materializes an allocation for the targeted tensor value. It
     replaces all original uses of the target with the newly allocated buffer,
-    wrapped in a `bufferization.to_tensor` op. It returns a handle to the result
-    of the `to_tensor` op.
+    wrapped in a `bufferization.to_tensor` op. It returns a handle to the newly
+    allocated buffer. Furthermore, it returns a handle to the result of the
+    `to_tensor` op.
 
     Example:
     ```
@@ -116,13 +117,14 @@ def BufferizeToAllocationOp : Op<Transform_Dialect,
 
     #### Return modes
 
-    This operation consumes the `target` handle and produces the `transformed`
-    handle. It always succeeds.
+    This operation consumes the `target` handle and produces the `replacement`
+    and `allocated_buffer` handles. It always succeeds.
   }];
 
   let arguments = (ins Transform_AnyValue:$target,
                        OptionalAttr<AnyAttr>:$memory_space);
-  let results = (outs Transform_AnyValue:$transformed);
+  let results = (outs Transform_AnyValue:$allocated_buffer,
+                      Transform_AnyValue:$replacement);
   let assemblyFormat = "$target attr-dict";
 }
 
index 7ca0c22..86a6fe2 100644 (file)
@@ -321,10 +321,12 @@ using LinalgLoops = SmallVector<Operation *, 4>;
 /// memref.tensor_store %t, %subview
 /// %0 = bufferization.to_tensor %alloc restrict writable
 ///
-/// In addition to rewriting the IR as shown above, the result of the
-/// bufferization.to_tensor op is returned.
+/// In addition to rewriting the IR as shown above, this function returns the
+/// newly allocated buffer. Furthermore, the result of the
+/// bufferization.to_tensor op is optionally returned via `replacement`.
 Value bufferizeToAllocation(RewriterBase &rewriter, tensor::PadOp padOp,
-                            Attribute memorySpace = {});
+                            Attribute memorySpace = {},
+                            Value *replacement = nullptr);
 
 /// Materialize a buffer allocation for the given tensor value. E.g.:
 ///
@@ -334,8 +336,13 @@ Value bufferizeToAllocation(RewriterBase &rewriter, tensor::PadOp padOp,
 ///
 /// In case `value` is a tensor.pad result, the corresponding overload is used
 /// internally to produce a better bufferization.
+///
+/// In addition to rewriting the IR as shown above, this function returns the
+/// newly allocated buffer. Furthermore, the result of the
+/// bufferization.to_tensor op is optionally returned via `replacement`.
 Value bufferizeToAllocation(RewriterBase &rewriter, Value value,
-                            Attribute memorySpace = {});
+                            Attribute memorySpace = {},
+                            Value *replacement = nullptr);
 
 /// Fuse two `linalg.generic` operations that have a producer-consumer
 /// relationship captured through `fusedOperand`. The method expects
index 875b96d..bf70c54 100644 (file)
@@ -174,18 +174,25 @@ DiagnosedSilenceableFailure transform::BufferizeToAllocationOp::apply(
     transform::TransformResults &results, transform::TransformState &state) {
   Attribute memorySpace =
       getMemorySpace().has_value() ? getMemorySpace().value() : Attribute();
-  auto transformed = llvm::to_vector(
-      llvm::map_range(state.getPayloadValues(getTarget()), [&](Value v) {
-        return linalg::bufferizeToAllocation(rewriter, v, memorySpace);
-      }));
-  results.setValues(cast<OpResult>(getTransformed()), transformed);
+  SmallVector<Value> replacements;
+  SmallVector<Value> allocatedBuffers;
+  for (Value value : state.getPayloadValues(getTarget())) {
+    Value replacement;
+    Value buffer = linalg::bufferizeToAllocation(rewriter, value, memorySpace,
+                                                 &replacement);
+    replacements.push_back(replacement);
+    allocatedBuffers.push_back(buffer);
+  }
+  results.setValues(cast<OpResult>(getReplacement()), replacements);
+  results.setValues(cast<OpResult>(getAllocatedBuffer()), allocatedBuffers);
   return DiagnosedSilenceableFailure::success();
 }
 
 void transform::BufferizeToAllocationOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
   consumesHandle(getTarget(), effects);
-  producesHandle(getTransformed(), effects);
+  producesHandle(getReplacement(), effects);
+  producesHandle(getAllocatedBuffer(), effects);
   modifiesPayload(effects);
 }
 
index a81a48d..01a79cc 100644 (file)
@@ -170,7 +170,7 @@ static Value createAllocationForTensor(RewriterBase &rewriter, Location loc,
 }
 
 Value linalg::bufferizeToAllocation(RewriterBase &rewriter, PadOp padOp,
-                                    Attribute memorySpace) {
+                                    Attribute memorySpace, Value *replacement) {
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(padOp);
   Location loc = padOp.getLoc();
@@ -198,7 +198,10 @@ Value linalg::bufferizeToAllocation(RewriterBase &rewriter, PadOp padOp,
   Value toTensorOp = rewriter.create<bufferization::ToTensorOp>(
       loc, alloc, /*restrict=*/true, /*writable=*/true);
   rewriter.replaceOp(padOp, toTensorOp);
-  return toTensorOp;
+
+  if (replacement)
+    *replacement = toTensorOp;
+  return alloc;
 }
 
 /// Lower tensor.from_elements to a sequence of chained tensor.insert.
@@ -329,10 +332,10 @@ mlir::linalg::rewriteInDestinationPassingStyle(RewriterBase &rewriter,
 }
 
 Value linalg::bufferizeToAllocation(RewriterBase &rewriter, Value value,
-                                    Attribute memorySpace) {
+                                    Attribute memorySpace, Value *replacement) {
   // Call specialized overload for certain ops.
   if (auto padOp = value.getDefiningOp<PadOp>())
-    return bufferizeToAllocation(rewriter, padOp, memorySpace);
+    return bufferizeToAllocation(rewriter, padOp, memorySpace, replacement);
 
   // Collect all uses.
   SmallVector<OpOperand *> uses = llvm::to_vector(
@@ -362,7 +365,9 @@ Value linalg::bufferizeToAllocation(RewriterBase &rewriter, Value value,
                                [&]() { use->set(toTensorOp); });
   }
 
-  return toTensorOp;
+  if (replacement)
+    *replacement = toTensorOp;
+  return alloc;
 }
 
 namespace {
index 6a108ba..d2b27c0 100644 (file)
@@ -33,7 +33,7 @@ transform.sequence failures(propagate) {
 ^bb1(%arg1: !transform.any_op):
   %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 : (!transform.any_op) -> !transform.any_op
   %1 = transform.get_result %0[0] : (!transform.any_op) -> !transform.any_value
-  %2 = transform.structured.bufferize_to_allocation %1
+  %2, %3 = transform.structured.bufferize_to_allocation %1
 }
 
 // -----
@@ -59,9 +59,9 @@ transform.sequence failures(propagate) {
 ^bb1(%arg1: !transform.any_op):
   %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 : (!transform.any_op) -> !transform.any_op
   %1 = transform.get_result %0[0] : (!transform.any_op) -> !transform.any_value
-  %2 = transform.structured.bufferize_to_allocation %1
+  %2, %3 = transform.structured.bufferize_to_allocation %1
   // Make sure that One-Shot Bufferize can bufferize the rest.
-  %3 = transform.bufferization.one_shot_bufferize %arg1 : (!transform.any_op) -> !transform.any_op
+  %4 = transform.bufferization.one_shot_bufferize %arg1 : (!transform.any_op) -> !transform.any_op
 }
 
 // -----
@@ -85,7 +85,7 @@ transform.sequence failures(propagate) {
 ^bb1(%arg1: !transform.any_op):
   %0 = transform.structured.match ops{["tensor.extract"]} in %arg1 : (!transform.any_op) -> !transform.any_op
   %1 = test_produce_value_handle_to_argument_of_parent_block %0, 0 : (!transform.any_op) -> !transform.any_value
-  %2 = transform.structured.bufferize_to_allocation %1 {memory_space = 4}
+  %2, %3 = transform.structured.bufferize_to_allocation %1 {memory_space = 4}
 }
 
 // -----
@@ -106,9 +106,9 @@ transform.sequence failures(propagate) {
 ^bb1(%arg1: !transform.any_op):
   %0 = transform.structured.match ops{["tensor.extract"]} in %arg1 : (!transform.any_op) -> !transform.any_op
   %1 = test_produce_value_handle_to_argument_of_parent_block %0, 0 : (!transform.any_op) -> !transform.any_value
-  %2 = transform.structured.bufferize_to_allocation %1 {memory_space = 4}
+  %2, %3 = transform.structured.bufferize_to_allocation %1 {memory_space = 4}
   // Make sure that One-Shot Bufferize can bufferize the rest.
-  %3 = transform.bufferization.one_shot_bufferize %arg1 : (!transform.any_op) -> !transform.any_op
+  %4 = transform.bufferization.one_shot_bufferize %arg1 : (!transform.any_op) -> !transform.any_op
 }
 
 // -----
@@ -128,7 +128,7 @@ transform.sequence failures(propagate) {
 ^bb1(%arg1: !transform.any_op):
   %0 = transform.structured.match ops{["dummy.some_op"]} in %arg1 : (!transform.any_op) -> !transform.any_op
   %1 = transform.get_result %0[0] : (!transform.any_op) -> !transform.any_value
-  %2 = transform.structured.bufferize_to_allocation %1 {memory_space = 4}
+  %2, %3 = transform.structured.bufferize_to_allocation %1 {memory_space = 4}
 }