From f470f8cbcefd4a74c654b0e2a2005a62c94047de Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Sat, 28 May 2022 04:42:47 +0200 Subject: [PATCH] [mlir][bufferize][NFC] Split analysis+bufferization of ModuleBufferization Analysis and bufferization can now be run separately. Differential Revision: https://reviews.llvm.org/D126572 --- .../Transforms/OneShotModuleBufferize.h | 11 ++++ .../Transforms/OneShotModuleBufferize.cpp | 58 +++++++++++++++++----- 2 files changed, 56 insertions(+), 13 deletions(-) diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h index 5f42e14..367edde 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h @@ -15,8 +15,19 @@ struct LogicalResult; class ModuleOp; namespace bufferization { +struct BufferizationState; +class OneShotAnalysisState; struct OneShotBufferizationOptions; +/// Analyze `moduleOp` and its nested ops. Bufferization decisions are stored in +/// `state`. +LogicalResult analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state); + +/// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`. +/// Whether buffer copies are needed or not is queried from the given state. +LogicalResult bufferizeModuleOp(ModuleOp moduleOp, + const OneShotAnalysisState &analysisState); + /// Run One-Shot Module Bufferization on the given module. Performs a simple /// function call analysis to determine which function arguments are /// inplaceable. Then analyzes and bufferizes FuncOps one-by-one with One-Shot diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp index b93022f..e358979 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -380,15 +380,15 @@ static void foldMemRefCasts(func::FuncOp funcOp) { funcOp.setType(newFuncType); } -LogicalResult mlir::bufferization::runOneShotModuleBufferize( - ModuleOp moduleOp, OneShotBufferizationOptions options) { +LogicalResult +mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp, + OneShotAnalysisState &state) { + OneShotBufferizationOptions options = + static_cast(state.getOptions()); assert(options.bufferizeFunctionBoundaries && "expected that function boundary bufferization is activated"); - IRRewriter rewriter(moduleOp.getContext()); - OneShotAnalysisState analysisState(moduleOp, options); - BufferizationState bufferizationState(analysisState); - FuncAnalysisState &funcState = getFuncAnalysisState(analysisState); - BufferizationAliasInfo &aliasInfo = analysisState.getAliasInfo(); + FuncAnalysisState &funcState = getFuncAnalysisState(state); + BufferizationAliasInfo &aliasInfo = state.getAliasInfo(); // A list of functions in the order in which they are analyzed + bufferized. SmallVector orderedFuncOps; @@ -412,12 +412,12 @@ LogicalResult mlir::bufferization::runOneShotModuleBufferize( equivalenceAnalysis(funcOp, aliasInfo, funcState); // Analyze funcOp. - if (failed(analyzeOp(funcOp, analysisState))) + if (failed(analyzeOp(funcOp, state))) return failure(); // Run some extra function analyses. - if (failed(aliasingFuncOpBBArgsAnalysis(funcOp, analysisState)) || - failed(funcOpBbArgReadWriteAnalysis(funcOp, analysisState))) + if (failed(aliasingFuncOpBBArgsAnalysis(funcOp, state)) || + failed(funcOpBbArgReadWriteAnalysis(funcOp, state))) return failure(); // Mark op as fully analyzed. @@ -425,11 +425,29 @@ LogicalResult mlir::bufferization::runOneShotModuleBufferize( // Add annotations to function arguments. if (options.testAnalysisOnly) - annotateOpsWithBufferizationMarkers(funcOp, analysisState); + annotateOpsWithBufferizationMarkers(funcOp, state); } - if (options.testAnalysisOnly) - return success(); + return success(); +} + +LogicalResult mlir::bufferization::bufferizeModuleOp( + ModuleOp moduleOp, const OneShotAnalysisState &analysisState) { + auto const &options = static_cast( + analysisState.getOptions()); + assert(options.bufferizeFunctionBoundaries && + "expected that function boundary bufferization is activated"); + IRRewriter rewriter(moduleOp.getContext()); + BufferizationState bufferizationState(analysisState); + + // A list of functions in the order in which they are analyzed + bufferized. + SmallVector orderedFuncOps; + + // A mapping of FuncOps to their callers. + FuncCallerMap callerMap; + + if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap))) + return failure(); // Bufferize functions. for (func::FuncOp funcOp : orderedFuncOps) { @@ -466,3 +484,17 @@ LogicalResult mlir::bufferization::runOneShotModuleBufferize( return success(); } + +LogicalResult mlir::bufferization::runOneShotModuleBufferize( + ModuleOp moduleOp, OneShotBufferizationOptions options) { + assert(options.bufferizeFunctionBoundaries && + "expected that function boundary bufferization is activated"); + OneShotAnalysisState analysisState(moduleOp, options); + if (failed(analyzeModuleOp(moduleOp, analysisState))) + return failure(); + if (options.testAnalysisOnly) + return success(); + if (failed(bufferizeModuleOp(moduleOp, analysisState))) + return failure(); + return success(); +} -- 2.7.4