From 481b254e458bc195af16fef9625cf856ef87fced Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Mon, 22 May 2023 14:13:08 +0200 Subject: [PATCH] [mlir][tensor][bufferize] Bufferize tensor.splat op The op bufferizes similarly to tensor.generate: it is lowered to a linalg.map, which may then lower to a loop nest that fills the buffer. Differential Revision: https://reviews.llvm.org/D150952 --- .../Transforms/BufferizableOpInterfaceImpl.cpp | 49 ++++++++++++++++++++++ mlir/test/Dialect/Tensor/bufferize.mlir | 17 ++++++++ 2 files changed, 66 insertions(+) diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp index 9253bc2..935a1b9 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -1087,6 +1087,54 @@ struct ParallelInsertSliceOpInterface } }; +/// Bufferization of tensor.splat. Bufferizes to a new allocation that is filled +/// with a linalg.map. Similar to tensor.generate. +struct SplatOpInterface + : public BufferizableOpInterface::ExternalModel { + + bool bufferizesToAllocation(Operation *op, OpResult opResult) const { + return true; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationOptions &options) const { + OpBuilder::InsertionGuard g(rewriter); + auto splatOp = cast(op); + + // Should the buffer be deallocated? + bool dealloc = + shouldDeallocateOpResult(cast(splatOp.getResult()), options); + + // TODO: Implement memory space for this op. + if (options.defaultMemorySpace != Attribute()) + return op->emitError("memory space not implemented yet"); + + // Allocate memory. + Location loc = op->getLoc(); + FailureOr tensorAlloc = + allocateTensorForShapedValue(rewriter, loc, splatOp.getResult(), + /*escape=*/!dealloc, options, + /*copy=*/false); + if (failed(tensorAlloc)) + return failure(); + + // Create linalg::MapOp. + auto tensorType = cast(tensorAlloc->getType()); + auto linalgOp = + rewriter.create(loc, tensorType, /*inputs=*/ValueRange(), + /*init=*/*tensorAlloc); + Block &linalgBody = linalgOp.getMapper().emplaceBlock(); + + // Create linalg::IndexOps. + rewriter.setInsertionPointToStart(&linalgBody); + rewriter.create(loc, splatOp.getInput()); + rewriter.replaceOp(splatOp, linalgOp.getResult()[0]); + + return success(); + } +}; + } // namespace } // namespace tensor } // namespace mlir @@ -1110,6 +1158,7 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels( *ctx); RankOp::attachInterface(*ctx); ReshapeOp::attachInterface(*ctx); + SplatOp::attachInterface(*ctx); // Load additional dialects of which ops may get created. ctx->loadDialect(); diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir index fe665a3..b9382b9 100644 --- a/mlir/test/Dialect/Tensor/bufferize.mlir +++ b/mlir/test/Dialect/Tensor/bufferize.mlir @@ -582,3 +582,20 @@ func.func @tensor.pad(%t1: tensor, %l2: index, %h1: index, // CHECK: return %[[r]] : tensor return %0 : tensor } + +// ----- + +// CHECK-LABEL: func @tensor.splat( +// CHECK-SAME: %[[F:.*]]: f32) +// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<10x2x4xf32> +// CHECK: %[[ALLOC_T:.*]] = bufferization.to_tensor %[[ALLOC]] +// CHECK: %[[MAPPED:.*]] = linalg.map +// CHECK: outs(%[[ALLOC_T]] : tensor<10x2x4xf32>) +// CHECK: linalg.yield %[[F]] +// CHECK: } +// CHECK: return %[[MAPPED]] : tensor<10x2x4xf32> +// CHECK: } +func.func @tensor.splat(%f: f32) -> tensor<10x2x4xf32> { + %t = tensor.splat %f : tensor<10x2x4xf32> + return %t : tensor<10x2x4xf32> +} -- 2.7.4