From: Peiming Liu Date: Fri, 2 Dec 2022 23:00:52 +0000 (+0000) Subject: [mlir][sparse][bufferization] cleanup bufferization attributes after SparsificationAn... X-Git-Tag: upstream/17.0.6~25509 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=c7a9e5e5d28ded8cd0f7162956d0df679c341392;p=platform%2Fupstream%2Fllvm.git [mlir][sparse][bufferization] cleanup bufferization attributes after SparsificationAndBufferizationPass Reviewed By: aartbik, springerm Differential Revision: https://reviews.llvm.org/D139218 --- diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h index aafa002..f6402b0 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h @@ -30,6 +30,9 @@ LogicalResult analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state); LogicalResult bufferizeModuleOp(ModuleOp moduleOp, const OneShotBufferizationOptions &options); +/// Remove bufferization attributes on every FuncOp arguments in the ModuleOp. +void removeBufferizationAttributesInModule(ModuleOp moduleOp); + /// Run One-Shot Module Bufferization on the given module. Performs a simple /// function call analysis to determine which function arguments are /// inplaceable. Then analyzes and bufferizes FuncOps one-by-one with One-Shot diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp index 8b70c2a..e2878b2e 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -377,6 +377,14 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp, return success(); } +void mlir::bufferization::removeBufferizationAttributesInModule( + ModuleOp moduleOp) { + moduleOp.walk([&](func::FuncOp op) { + for (BlockArgument bbArg : op.getArguments()) + removeBufferizationAttributes(bbArg); + }); +} + LogicalResult mlir::bufferization::bufferizeModuleOp( ModuleOp moduleOp, const OneShotBufferizationOptions &options) { assert(options.bufferizeFunctionBoundaries && @@ -405,10 +413,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp( } // Post-pass cleanup of function argument attributes. - moduleOp.walk([&](func::FuncOp op) { - for (BlockArgument bbArg : op.getArguments()) - removeBufferizationAttributes(bbArg); - }); + removeBufferizationAttributesInModule(moduleOp); return success(); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp index b3d1081..6066a17 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" +#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" @@ -75,9 +76,14 @@ public: } return true; }); - return bufferization::bufferizeOp(getOperation(), bufferizationOptions, - /*copyBeforeWrite=*/false, - &denseOpFilter); + + if (failed(bufferization::bufferizeOp(getOperation(), bufferizationOptions, + /*copyBeforeWrite=*/false, + &denseOpFilter))) + return failure(); + + bufferization::removeBufferizationAttributesInModule(getOperation()); + return success(); } void runOnOperation() override {