From a5bba98a58b7406f81629d3942e03b1eff1e2b33 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Fri, 21 Jul 2023 15:29:16 +0200 Subject: [PATCH] [mlir][linalg] BufferizeToAllocationOp: Add option to materialize buffers for operands Add an option that does not bufferize the targeted op itself, but just materializes a buffer for the destination operands. This is useful for partial bufferization of complex ops such as `scf.forall`, which need special handling (and an analysis if the region). Differential Revision: https://reviews.llvm.org/D155946 --- .../Linalg/TransformOps/LinalgTransformOps.td | 8 ++++-- .../mlir/Dialect/Linalg/Transforms/Transforms.h | 5 ++++ .../Linalg/TransformOps/LinalgTransformOps.cpp | 1 + .../Transforms/ConvertToDestinationStyle.cpp | 33 +++++++++++++++++++++- .../transform-op-bufferize-to-allocation.mlir | 21 ++++++++++++++ 5 files changed, 65 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 5d1f2d8..f5dfdb2 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -133,6 +133,9 @@ def BufferizeToAllocationOp : Op: $memcpy_op, DefaultValuedAttr: - $alloc_op); + $alloc_op, + UnitAttr:$bufferize_destination_only); let hasVerifier = 1; let results = (outs Transform_AnyValue:$allocated_buffer, Transform_AnyOpType:$new_ops); let assemblyFormat = "$target attr-dict `:` type($target)"; - + let builders = [ OpBuilder<(ins "Value":$target, "Attribute":$memorySpace)>, OpBuilder<(ins "Value":$target, "int64_t":$memorySpace)> diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 68fce05..e863cd8 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -52,6 +52,11 @@ struct BufferizeToAllocationOptions { enum class MemcpyOp { MemrefTensorStore = 0, MemrefCopy = 1, LinalgCopy = 2 }; MemcpyOp memcpyOp = MemcpyOp::MemrefTensorStore; + + /// If set to "true", only the destination tensor operands are bufferized to + /// a new allocation (and wrapped in "bufferization.to_tensor"), but not the + /// targeted op itself. + bool bufferizeDestinationOnly = false; }; /// Materialize a buffer allocation for the given tensor.pad op and lower the diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 859dc78..f653d2d 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -261,6 +261,7 @@ DiagnosedSilenceableFailure transform::BufferizeToAllocationOp::apply( } else { llvm_unreachable("invalid alloc op"); } + options.bufferizeDestinationOnly = getBufferizeDestinationOnly(); // Bufferize ops. Attribute memorySpace = diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp index 369ff8d..aab4a6f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp @@ -217,6 +217,9 @@ createAllocationForTensor(RewriterBase &rewriter, Location loc, Value value, Value linalg::bufferizeToAllocation( RewriterBase &rewriter, const linalg::BufferizeToAllocationOptions &options, PadOp padOp, Attribute memorySpace, Operation *insertionPoint) { + // tensor.pad does not have a destination operand. + assert(!options.bufferizeDestinationOnly && "invalid options"); + OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(insertionPoint ? insertionPoint : padOp); Location loc = padOp.getLoc(); @@ -266,6 +269,9 @@ Value linalg::bufferizeToAllocation( rewriter, options, maskOp.getMaskableOp(), memorySpace, /*insertionPoint=*/insertionPoint ? insertionPoint : maskOp); + if (options.bufferizeDestinationOnly) + return alloc; + // Bufferize terminator. rewriter.setInsertionPoint(yieldOp); if (failed(cast(yieldOp).bufferize( @@ -454,6 +460,21 @@ Value linalg::bufferizeToAllocation( BufferizationOptions bufferizationOptions; AnalysisState state(bufferizationOptions); +#ifndef NDEBUG + // Ops with nested tensor ops are not supported yet. At the moment, this + // function just bufferizes the given op itself, but not its body. + op->walk([&](Operation *nestedOp) { + if (op == nestedOp) + return; + if (llvm::any_of(nestedOp->getOperands(), + [](Value v) { return v.getType().isa(); })) + llvm_unreachable("ops with nested tensor ops are not supported yet"); + if (llvm::any_of(nestedOp->getResults(), + [](Value v) { return v.getType().isa(); })) + llvm_unreachable("ops with nested tensor ops are not supported yet"); + }); +#endif // NDEBUG + // Gather tensor results. SmallVector tensorResults; for (OpResult result : op->getResults()) { @@ -509,10 +530,20 @@ Value linalg::bufferizeToAllocation( createMemcpy(rewriter, op->getLoc(), operand->get(), alloc, options); } rewriter.updateRootInPlace(op, [&]() { - operand->set(rewriter.create(op->getLoc(), alloc)); + auto toTensorOp = rewriter.create(op->getLoc(), alloc); + operand->set(toTensorOp); + if (options.bufferizeDestinationOnly) { + rewriter.updateRootInPlace(toTensorOp, [&]() { + toTensorOp.setRestrict(true); + toTensorOp.setWritable(true); + }); + } }); } + if (options.bufferizeDestinationOnly) + return allocs.front(); + // Bufferize the op. rewriter.setInsertionPoint(op); if (failed(bufferizableOp.bufferize(rewriter, bufferizationOptions))) diff --git a/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir b/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir index 36f76d3..7b156a5 100644 --- a/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir @@ -197,3 +197,24 @@ transform.sequence failures(propagate) { %0 = transform.structured.match ops{["vector.mask"]} in %arg1 : (!transform.any_op) -> !transform.any_op %2, %new = transform.structured.bufferize_to_allocation %0 {memory_space = 4} : !transform.any_op } + +// ----- + +// CHECK-LABEL: func @tensor_insert_destination( +// CHECK-SAME: %[[t:.*]]: tensor +// CHECK: %[[alloc:.*]] = memref.alloc(%{{.*}}) : memref +// CHECK: memref.tensor_store %[[t]], %[[alloc]] +// CHECK: %[[t2:.*]] = bufferization.to_tensor %[[alloc]] restrict writable +// CHECK: %[[inserted:.*]] = tensor.insert %{{.*}} into %[[t2]] +// CHECK: memref.dealloc %[[alloc]] +// CHECK: return %[[inserted]] +func.func @tensor_insert_destination(%t: tensor, %idx: index, %v: index) -> tensor { + %r = tensor.insert %v into %t[%idx, %idx] : tensor + return %r : tensor +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["tensor.insert"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %2, %new = transform.structured.bufferize_to_allocation %0 {memory_space = 4, bufferize_destination_only} : !transform.any_op +} -- 2.7.4