[mlir][bufferize] Add BufferizationState initializers
authorMatthias Springer <springerm@google.com>
Fri, 4 Mar 2022 20:11:21 +0000 (05:11 +0900)
committerMatthias Springer <springerm@google.com>
Fri, 4 Mar 2022 20:20:11 +0000 (05:20 +0900)
Such initializer functions can be enqueued in `BufferizationOptions`. They can be used to set up dialect-specific bufferization state.

Differential Revision: https://reviews.llvm.org/D120985

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

index 593057d..1e75872 100644 (file)
@@ -28,8 +28,8 @@ class FuncOp;
 namespace bufferization {
 
 class BufferizableOpInterface;
-struct BufferizationOptions;
 class BufferizationState;
+struct DialectBufferizationState;
 
 /// Options for ComprehensiveBufferize.
 struct BufferizationOptions {
@@ -44,6 +44,11 @@ struct BufferizationOptions {
   /// Memcpy function: Generate a memcpy between two buffers.
   using MemCpyFn =
       std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
+  /// Initializer function for bufferization state.
+  using BufferizationStateInitFn = std::function<void(BufferizationState &)>;
+  /// Initializer function for dialect-specific bufferization state.
+  using DialectStateInitFn =
+      std::function<std::unique_ptr<DialectBufferizationState>()>;
 
   /// An op filter entry. Filters can be used to specify which ops should be
   /// processed by the bufferization.
@@ -228,6 +233,14 @@ struct BufferizationOptions {
   /// DENY-filtered and have at least one matching ALLOW filter are processed.
   SmallVector<OpFilterEntry> opFilter;
 
+  /// Initializer functions for bufferization state. These can be used to
+  /// initialize dialect-specific bufferization state.
+  SmallVector<BufferizationStateInitFn> stateInitializers;
+
+  /// Add a bufferization state initializer that initializes the specified
+  /// dialect-specific bufferization state.
+  void addDialectStateInitializer(StringRef name, DialectStateInitFn fn);
+
 private:
   /// Allow a dialect.
   template <typename DialectT>
@@ -362,6 +375,12 @@ public:
     return static_cast<StateT &>(*dialectState[name]);
   }
 
+  void insertDialectState(StringRef name,
+                          std::unique_ptr<DialectBufferizationState> state) {
+    assert(!dialectState.count(name) && "dialect state already initialized");
+    dialectState[name] = std::move(state);
+  }
+
   /// Return a reference to the BufferizationOptions.
   const BufferizationOptions &getOptions() const { return options; }
 
index e5d9487..3e0e89a 100644 (file)
@@ -64,6 +64,12 @@ BufferizationOptions::dynCastBufferizableOp(Value value) const {
   return nullptr;
 }
 
+void BufferizationOptions::addDialectStateInitializer(StringRef name,
+                                                      DialectStateInitFn fn) {
+  stateInitializers.push_back(
+      [=](BufferizationState &state) { state.insertDialectState(name, fn()); });
+}
+
 //===----------------------------------------------------------------------===//
 // Helper functions for BufferizableOpInterface
 //===----------------------------------------------------------------------===//
@@ -200,7 +206,11 @@ BufferizationState::findLastPrecedingWrite(Value value) const {
 }
 
 BufferizationState::BufferizationState(const BufferizationOptions &options)
-    : options(options) {}
+    : options(options) {
+  for (const BufferizationOptions::BufferizationStateInitFn &fn :
+       options.stateInitializers)
+    fn(*this);
+}
 
 // bufferization.to_memref is not allowed to change the rank.
 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {