[mlir][linalg] BufferizeToAllocationOp: Add option to materialize buffers for operands
authorMatthias Springer <me@m-sp.org>
Fri, 21 Jul 2023 13:29:16 +0000 (15:29 +0200)
committerMatthias Springer <me@m-sp.org>
Fri, 21 Jul 2023 13:29:59 +0000 (15:29 +0200)
Add an option that does not bufferize the targeted op itself, but just materializes a buffer for the destination operands. This is useful for partial bufferization of complex ops such as `scf.forall`, which need special handling (and an analysis if the region).

Differential Revision: https://reviews.llvm.org/D155946

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir

index 5d1f2d8..f5dfdb2 100644 (file)
@@ -133,6 +133,9 @@ def BufferizeToAllocationOp : Op<Transform_Dialect,
     `alloc_op`. Currently supported are "memref.alloc" and "memref.alloca". In
     case of a "memref.alloca", the buffer is not deallocated.
 
+    If `bufferize_destination_only` is set, only the destination operands of the
+    op are bufferized to a new memory allocation, but not the op itself.
+
     #### Return modes
 
     This operation consumes the `target` handle and produces the
@@ -144,12 +147,13 @@ def BufferizeToAllocationOp : Op<Transform_Dialect,
                        DefaultValuedAttr<StrAttr, "\"memref.tensor_store\"">:
                            $memcpy_op,
                        DefaultValuedAttr<StrAttr, "\"memref.alloc\"">:
-                           $alloc_op);
+                           $alloc_op,
+                       UnitAttr:$bufferize_destination_only);
   let hasVerifier = 1;
   let results = (outs Transform_AnyValue:$allocated_buffer,
                       Transform_AnyOpType:$new_ops);
   let assemblyFormat = "$target attr-dict `:` type($target)";
-  
+
   let builders = [
     OpBuilder<(ins "Value":$target, "Attribute":$memorySpace)>,
     OpBuilder<(ins "Value":$target, "int64_t":$memorySpace)>
index 68fce05..e863cd8 100644 (file)
@@ -52,6 +52,11 @@ struct BufferizeToAllocationOptions {
 
   enum class MemcpyOp { MemrefTensorStore = 0, MemrefCopy = 1, LinalgCopy = 2 };
   MemcpyOp memcpyOp = MemcpyOp::MemrefTensorStore;
+
+  /// If set to "true", only the destination tensor operands are bufferized to
+  /// a new allocation (and wrapped in "bufferization.to_tensor"), but not the
+  /// targeted op itself.
+  bool bufferizeDestinationOnly = false;
 };
 
 /// Materialize a buffer allocation for the given tensor.pad op and lower the
index 859dc78..f653d2d 100644 (file)
@@ -261,6 +261,7 @@ DiagnosedSilenceableFailure transform::BufferizeToAllocationOp::apply(
   } else {
     llvm_unreachable("invalid alloc op");
   }
+  options.bufferizeDestinationOnly = getBufferizeDestinationOnly();
 
   // Bufferize ops.
   Attribute memorySpace =
index 369ff8d..aab4a6f 100644 (file)
@@ -217,6 +217,9 @@ createAllocationForTensor(RewriterBase &rewriter, Location loc, Value value,
 Value linalg::bufferizeToAllocation(
     RewriterBase &rewriter, const linalg::BufferizeToAllocationOptions &options,
     PadOp padOp, Attribute memorySpace, Operation *insertionPoint) {
+  // tensor.pad does not have a destination operand.
+  assert(!options.bufferizeDestinationOnly && "invalid options");
+
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(insertionPoint ? insertionPoint : padOp);
   Location loc = padOp.getLoc();
@@ -266,6 +269,9 @@ Value linalg::bufferizeToAllocation(
       rewriter, options, maskOp.getMaskableOp(), memorySpace,
       /*insertionPoint=*/insertionPoint ? insertionPoint : maskOp);
 
+  if (options.bufferizeDestinationOnly)
+    return alloc;
+
   // Bufferize terminator.
   rewriter.setInsertionPoint(yieldOp);
   if (failed(cast<bufferization::BufferizableOpInterface>(yieldOp).bufferize(
@@ -454,6 +460,21 @@ Value linalg::bufferizeToAllocation(
   BufferizationOptions bufferizationOptions;
   AnalysisState state(bufferizationOptions);
 
+#ifndef NDEBUG
+  // Ops with nested tensor ops are not supported yet. At the moment, this
+  // function just bufferizes the given op itself, but not its body.
+  op->walk([&](Operation *nestedOp) {
+    if (op == nestedOp)
+      return;
+    if (llvm::any_of(nestedOp->getOperands(),
+                     [](Value v) { return v.getType().isa<TensorType>(); }))
+      llvm_unreachable("ops with nested tensor ops are not supported yet");
+    if (llvm::any_of(nestedOp->getResults(),
+                     [](Value v) { return v.getType().isa<TensorType>(); }))
+      llvm_unreachable("ops with nested tensor ops are not supported yet");
+  });
+#endif // NDEBUG
+
   // Gather tensor results.
   SmallVector<OpResult> tensorResults;
   for (OpResult result : op->getResults()) {
@@ -509,10 +530,20 @@ Value linalg::bufferizeToAllocation(
       createMemcpy(rewriter, op->getLoc(), operand->get(), alloc, options);
     }
     rewriter.updateRootInPlace(op, [&]() {
-      operand->set(rewriter.create<ToTensorOp>(op->getLoc(), alloc));
+      auto toTensorOp = rewriter.create<ToTensorOp>(op->getLoc(), alloc);
+      operand->set(toTensorOp);
+      if (options.bufferizeDestinationOnly) {
+        rewriter.updateRootInPlace(toTensorOp, [&]() {
+          toTensorOp.setRestrict(true);
+          toTensorOp.setWritable(true);
+        });
+      }
     });
   }
 
+  if (options.bufferizeDestinationOnly)
+    return allocs.front();
+
   // Bufferize the op.
   rewriter.setInsertionPoint(op);
   if (failed(bufferizableOp.bufferize(rewriter, bufferizationOptions)))
index 36f76d3..7b156a5 100644 (file)
@@ -197,3 +197,24 @@ transform.sequence failures(propagate) {
   %0 = transform.structured.match ops{["vector.mask"]} in %arg1 : (!transform.any_op) -> !transform.any_op
   %2, %new = transform.structured.bufferize_to_allocation %0 {memory_space = 4} : !transform.any_op
 }
+
+// -----
+
+// CHECK-LABEL: func @tensor_insert_destination(
+//  CHECK-SAME:     %[[t:.*]]: tensor<?x10xindex>
+//       CHECK:   %[[alloc:.*]] = memref.alloc(%{{.*}}) : memref<?x10xindex, 4>
+//       CHECK:   memref.tensor_store %[[t]], %[[alloc]]
+//       CHECK:   %[[t2:.*]] = bufferization.to_tensor %[[alloc]] restrict writable
+//       CHECK:   %[[inserted:.*]] = tensor.insert %{{.*}} into %[[t2]]
+//       CHECK:   memref.dealloc %[[alloc]]
+//       CHECK:   return %[[inserted]]
+func.func @tensor_insert_destination(%t: tensor<?x10xindex>, %idx: index, %v: index) -> tensor<?x10xindex> {
+  %r = tensor.insert %v into %t[%idx, %idx] : tensor<?x10xindex>
+  return %r : tensor<?x10xindex>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["tensor.insert"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  %2, %new = transform.structured.bufferize_to_allocation %0 {memory_space = 4, bufferize_destination_only} : !transform.any_op
+}