From b62b21b98019b46af91365cb6415f8e740cab898 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Fri, 26 Nov 2021 22:13:28 +0900 Subject: [PATCH] [mlir][linalg][bufferize][NFC] InsertSliceOp no-copy detection as PostAnalysis There is special logic for InsertSliceOp to check if a memcpy is needed. This change extracts that piece of code and makes it a PostAnalysisStep. The purpose of this change is to untangle `bufferize` from BufferizationAliasInfo. (Not fully there yet.) Differential Revision: https://reviews.llvm.org/D114513 --- .../ComprehensiveBufferize/TensorInterfaceImpl.h | 7 ++++ .../ComprehensiveBufferize/TensorInterfaceImpl.cpp | 46 +++++++++++++++++----- .../Transforms/ComprehensiveBufferizePass.cpp | 3 ++ 3 files changed, 47 insertions(+), 9 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h index 29355ef..dbda537 100644 --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h @@ -9,6 +9,8 @@ #ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_TENSOR_INTERFACE_IMPL_H #define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_TENSOR_INTERFACE_IMPL_H +#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" + namespace mlir { class DialectRegistry; @@ -17,6 +19,11 @@ namespace linalg { namespace comprehensive_bufferize { namespace tensor_ext { +struct InplaceInsertSliceOpAnalysis : public PostAnalysisStep { + LogicalResult run(FuncOp funcOp, BufferizationState &state, + SmallVector &newOps) override; +}; + void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); } // namespace tensor_ext diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp index 13cc7d7..0cecb1d 100644 --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp @@ -13,6 +13,8 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" +using namespace mlir; + namespace mlir { namespace linalg { namespace comprehensive_bufferize { @@ -21,6 +23,20 @@ namespace tensor_ext { using tensor::ExtractSliceOp; using tensor::InsertSliceOp; +namespace { +/// Extra bufferization state that is required for bufferization of tensor ops. +struct TensorBufferizationState : public DialectBufferizationState { + /// InsertSliceOps that bufferize inplace and do not require a copy. + DenseSet insertSliceOpsWithoutCopy; +}; +} // namespace + +static TensorBufferizationState & +getTensorBufferizationState(BufferizationState &state) { + return state.getDialectState( + tensor::TensorDialect::getDialectNamespace()); +} + struct CastOpInterface : public BufferizableOpInterface::ExternalModel { @@ -374,6 +390,7 @@ struct InsertSliceOpInterface // catastrophically bad scheduling decision. // TODO: be very loud about it or even consider failing the pass. auto insertSliceOp = cast(op); + TensorBufferizationState &tensorState = getTensorBufferizationState(state); // Take a guard before anything else. OpBuilder::InsertionGuard g(b); @@ -385,15 +402,8 @@ struct InsertSliceOpInterface if (!dstMemref) return failure(); - // A copy of the source buffer is needed if either: - // - The producer of `source` is not inplace. This is the case where a - // slice is computed out of place into the inplace full tensor. - // - The result is not inplace. This is the case where the whole tensor is - // cloned and the clone needs to be updated. - // TODO: Is this necessary? - bool needCopy = !isSourceEquivalentToAMatchingInplaceExtractSliceOp( - state.aliasInfo, insertSliceOp) || - !state.aliasInfo.isInPlace(insertSliceOp->getResult(0)); + bool needCopy = + !tensorState.insertSliceOpsWithoutCopy.contains(insertSliceOp); if (needCopy) { // Take a subview of the dst. auto dstMemrefType = dstMemref.getType().cast(); @@ -424,6 +434,24 @@ struct InsertSliceOpInterface } // namespace linalg } // namespace mlir +LogicalResult mlir::linalg::comprehensive_bufferize::tensor_ext:: + InplaceInsertSliceOpAnalysis::run(FuncOp funcOp, BufferizationState &state, + SmallVector &newOps) { + auto &tensorState = getTensorBufferizationState(state); + funcOp.walk([&](InsertSliceOp insertSliceOp) { + // A copy of the source buffer is needed if either: + // - The producer of `source` is not inplace. This is the case where a + // slice is computed out of place into the inplace full tensor. + // - The result is not inplace. This is the case where the whole tensor is + // cloned and the clone needs to be updated. + if (isSourceEquivalentToAMatchingInplaceExtractSliceOp(state.aliasInfo, + insertSliceOp) && + state.aliasInfo.isInPlace(insertSliceOp->getResult(0))) + tensorState.insertSliceOpsWithoutCopy.insert(insertSliceOp); + }); + return success(); +} + void mlir::linalg::comprehensive_bufferize::tensor_ext:: registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) { registry.addOpInterface(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp index 31c7e4c..ca713d1 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -90,6 +90,9 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() { // Enable InitTensorOp elimination. options.addPostAnalysisStep< linalg_ext::InsertSliceAnchoredInitTensorEliminationStep>(); + // TODO: Find a way to enable this step automatically when bufferizing tensor + // dialect ops. + options.addPostAnalysisStep(); ModuleOp moduleOp = getOperation(); applyEnablingTransformations(moduleOp); -- 2.7.4