[mlir][bufferize][NFC] Add function boundary bufferization flag to BufferizationOptions
authorMatthias Springer <springerm@google.com>
Fri, 22 Apr 2022 15:57:25 +0000 (00:57 +0900)
committerMatthias Springer <springerm@google.com>
Fri, 22 Apr 2022 16:11:37 +0000 (01:11 +0900)
This makes the API easier to use. Also allows us to check for incorrect API usage for easier debugging.

Differential Revision: https://reviews.llvm.org/D124265

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp

index c228e26..94504c1 100644 (file)
@@ -78,23 +78,7 @@ struct BufferizationOptions {
   /// unless they are explicitly marked as DENY. If the filter has at least one
   /// ALLOW rule, ops are ignored by default and only bufferized if they match
   /// an ALLOW rule and no DENY rule.
-  bool isOpAllowed(Operation *op) const {
-    bool isAllowed = !filterHasAllowRule();
-    for (const OpFilterEntry &entry : opFilter) {
-      bool filterResult = entry.fn(op);
-      switch (entry.type) {
-      case OpFilterEntry::ALLOW:
-        isAllowed |= filterResult;
-        break;
-      case OpFilterEntry::DENY:
-        if (filterResult)
-          // DENY filter matches. This op is no allowed. (Even if other ALLOW
-          // filters may match.)
-          return false;
-      };
-    }
-    return isAllowed;
-  }
+  bool isOpAllowed(Operation *op) const;
 
   /// Allow the given dialects in the filter.
   ///
@@ -182,6 +166,10 @@ struct BufferizationOptions {
   /// the boundaries.
   bool allowUnknownOps = false;
 
+  /// Specifies whether function boundaries (ops in the func dialect) should be
+  /// bufferized or not.
+  bool bufferizeFunctionBoundaries = false;
+
   /// Specifies whether dealloc ops should be generated along with alloc ops. If
   /// not, new memory allocations will leak.
   bool createDeallocs = true;
@@ -356,6 +344,12 @@ public:
   /// any given tensor.
   virtual bool isTensorYielded(Value tensor) const = 0;
 
+  /// Return `true` if the given dialect state exists.
+  bool hasDialectState(StringRef name) const {
+    auto it = dialectState.find(name);
+    return it != dialectState.end();
+  }
+
   /// Return dialect-specific bufferization state.
   template <typename StateT>
   Optional<const StateT *> getDialectState(StringRef name) const {
@@ -369,7 +363,7 @@ public:
   template <typename StateT>
   StateT &getOrCreateDialectState(StringRef name) {
     // Create state if it does not exist yet.
-    if (!dialectState.count(name))
+    if (!hasDialectState(name))
       dialectState[name] = std::make_unique<StateT>();
     return static_cast<StateT &>(*dialectState[name]);
   }
index 921c6c8..73da6c8 100644 (file)
@@ -51,6 +51,31 @@ static const char *kSkipDeallocAttr = "bufferization.skip_dealloc";
 // Default constructor for BufferizationOptions.
 BufferizationOptions::BufferizationOptions() = default;
 
+bool BufferizationOptions::isOpAllowed(Operation *op) const {
+  // Special case: If function boundary bufferization is deactivated, do not
+  // allow ops that belong to the `func` dialect.
+  bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(op->getDialect());
+  if (!bufferizeFunctionBoundaries && isFuncBoundaryOp)
+    return false;
+
+  // All other ops: Allow/disallow according to filter.
+  bool isAllowed = !filterHasAllowRule();
+  for (const OpFilterEntry &entry : opFilter) {
+    bool filterResult = entry.fn(op);
+    switch (entry.type) {
+    case OpFilterEntry::ALLOW:
+      isAllowed |= filterResult;
+      break;
+    case OpFilterEntry::DENY:
+      if (filterResult)
+        // DENY filter matches. This op is no allowed. (Even if other ALLOW
+        // filters may match.)
+        return false;
+    };
+  }
+  return isAllowed;
+}
+
 BufferizableOpInterface
 BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
   if (isOpAllowed(op))
index 49aaeb7..0936512 100644 (file)
@@ -175,15 +175,10 @@ struct OneShotBufferizePass
       opt.fullyDynamicLayoutMaps = fullyDynamicLayoutMaps;
       opt.printConflicts = printConflicts;
       opt.testAnalysisOnly = testAnalysisOnly;
+      opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
 
       BufferizationOptions::OpFilterEntry::FilterFn filterFn =
           [&](Operation *op) {
-            // Disallow non-func dialect ops. I.e., no ops related to function
-            // calls. (Unless explicitly activated.)
-            bool isFuncBoundaryOp =
-                isa_and_nonnull<func::FuncDialect>(op->getDialect());
-            if (!this->bufferizeFunctionBoundaries && isFuncBoundaryOp)
-              return false;
             // Filter may be specified via options.
             if (this->dialectFilter.hasValue())
               return llvm::find(this->dialectFilter,
@@ -198,7 +193,7 @@ struct OneShotBufferizePass
     }
 
     ModuleOp moduleOp = getOperation();
-    if (bufferizeFunctionBoundaries) {
+    if (opt.bufferizeFunctionBoundaries) {
       if (failed(runOneShotModuleBufferize(moduleOp, opt))) {
         signalPassFailure();
         return;
@@ -284,6 +279,12 @@ bufferization::finalizeBuffers(Operation *op,
 
 LogicalResult bufferization::bufferizeOp(Operation *op,
                                          const AnalysisState &analysisState) {
+  // Catch incorrect API usage.
+  assert((analysisState.hasDialectState(
+              func::FuncDialect::getDialectNamespace()) ||
+          !analysisState.getOptions().bufferizeFunctionBoundaries) &&
+         "must use ModuleBufferize to bufferize function boundaries");
+
   BufferizationState bufferizationState(analysisState);
   if (failed(bufferizeOp(op, bufferizationState)))
     return failure();
index 8ae5c1c..d1fbb70 100644 (file)
@@ -46,6 +46,7 @@
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/AsmState.h"
 #include "mlir/IR/Dominance.h"
@@ -864,6 +865,11 @@ LogicalResult bufferization::analyzeOp(Operation *op,
   const auto &options =
       static_cast<const OneShotBufferizationOptions &>(state.getOptions());
 
+  // Catch incorrect API usage.
+  assert((state.hasDialectState(func::FuncDialect::getDialectNamespace()) ||
+          !options.bufferizeFunctionBoundaries) &&
+         "must use ModuleBufferize to bufferize function boundaries");
+
   if (failed(checkAliasInfoConsistency(op, domInfo, state, aliasInfo)))
     return failure();
 
index 7bcde9f..6dc3432 100644 (file)
@@ -417,6 +417,8 @@ static void foldMemRefCasts(func::FuncOp funcOp) {
 
 LogicalResult mlir::bufferization::runOneShotModuleBufferize(
     ModuleOp moduleOp, OneShotBufferizationOptions options) {
+  assert(options.bufferizeFunctionBoundaries &&
+         "expected that function boundary bufferization is activated");
   IRRewriter rewriter(moduleOp.getContext());
   OneShotAnalysisState analysisState(moduleOp, options);
   BufferizationState bufferizationState(analysisState);
index b9163c5..13b760c 100644 (file)
@@ -99,6 +99,7 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
     opt.printConflicts = printConflicts;
     opt.testAnalysisOnly = testAnalysisOnly;
     opt.alwaysAliasingWithDest = alwaysAliasingWithDest;
+    opt.bufferizeFunctionBoundaries = true;
     if (initTensorElimination) {
       opt.addPostAnalysisStep(insertSliceAnchoredInitTensorEliminationStep);
     }