[mlir][sparse][bufferization] cleanup bufferization attributes after SparsificationAn...
authorPeiming Liu <peiming@google.com>
Fri, 2 Dec 2022 23:00:52 +0000 (23:00 +0000)
committerPeiming Liu <peiming@google.com>
Fri, 2 Dec 2022 23:03:54 +0000 (23:03 +0000)
Reviewed By: aartbik, springerm

Differential Revision: https://reviews.llvm.org/D139218

mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp

index aafa002..f6402b0 100644 (file)
@@ -30,6 +30,9 @@ LogicalResult analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state);
 LogicalResult bufferizeModuleOp(ModuleOp moduleOp,
                                 const OneShotBufferizationOptions &options);
 
+/// Remove bufferization attributes on every FuncOp arguments in the ModuleOp.
+void removeBufferizationAttributesInModule(ModuleOp moduleOp);
+
 /// Run One-Shot Module Bufferization on the given module. Performs a simple
 /// function call analysis to determine which function arguments are
 /// inplaceable. Then analyzes and bufferizes FuncOps one-by-one with One-Shot
index 8b70c2a..e2878b2 100644 (file)
@@ -377,6 +377,14 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
   return success();
 }
 
+void mlir::bufferization::removeBufferizationAttributesInModule(
+    ModuleOp moduleOp) {
+  moduleOp.walk([&](func::FuncOp op) {
+    for (BlockArgument bbArg : op.getArguments())
+      removeBufferizationAttributes(bbArg);
+  });
+}
+
 LogicalResult mlir::bufferization::bufferizeModuleOp(
     ModuleOp moduleOp, const OneShotBufferizationOptions &options) {
   assert(options.bufferizeFunctionBoundaries &&
@@ -405,10 +413,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
   }
 
   // Post-pass cleanup of function argument attributes.
-  moduleOp.walk([&](func::FuncOp op) {
-    for (BlockArgument bbArg : op.getArguments())
-      removeBufferizationAttributes(bbArg);
-  });
+  removeBufferizationAttributesInModule(moduleOp);
 
   return success();
 }
index b3d1081..6066a17 100644 (file)
@@ -11,6 +11,7 @@
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
 #include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
@@ -75,9 +76,14 @@ public:
       }
       return true;
     });
-    return bufferization::bufferizeOp(getOperation(), bufferizationOptions,
-                                      /*copyBeforeWrite=*/false,
-                                      &denseOpFilter);
+
+    if (failed(bufferization::bufferizeOp(getOperation(), bufferizationOptions,
+                                          /*copyBeforeWrite=*/false,
+                                          &denseOpFilter)))
+      return failure();
+
+    bufferization::removeBufferizationAttributesInModule(getOperation());
+    return success();
   }
 
   void runOnOperation() override {