From 6fc11d4d3ea08f2a9e6adf1c1a99c8798904f385 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Sat, 5 Mar 2022 05:11:21 +0900 Subject: [PATCH] [mlir][bufferize] Add BufferizationState initializers Such initializer functions can be enqueued in `BufferizationOptions`. They can be used to set up dialect-specific bufferization state. Differential Revision: https://reviews.llvm.org/D120985 --- .../Bufferization/IR/BufferizableOpInterface.h | 21 ++++++++++++++++++++- .../Bufferization/IR/BufferizableOpInterface.cpp | 12 +++++++++++- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h index 593057d..1e75872 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -28,8 +28,8 @@ class FuncOp; namespace bufferization { class BufferizableOpInterface; -struct BufferizationOptions; class BufferizationState; +struct DialectBufferizationState; /// Options for ComprehensiveBufferize. struct BufferizationOptions { @@ -44,6 +44,11 @@ struct BufferizationOptions { /// Memcpy function: Generate a memcpy between two buffers. using MemCpyFn = std::function; + /// Initializer function for bufferization state. + using BufferizationStateInitFn = std::function; + /// Initializer function for dialect-specific bufferization state. + using DialectStateInitFn = + std::function()>; /// An op filter entry. Filters can be used to specify which ops should be /// processed by the bufferization. @@ -228,6 +233,14 @@ struct BufferizationOptions { /// DENY-filtered and have at least one matching ALLOW filter are processed. SmallVector opFilter; + /// Initializer functions for bufferization state. These can be used to + /// initialize dialect-specific bufferization state. + SmallVector stateInitializers; + + /// Add a bufferization state initializer that initializes the specified + /// dialect-specific bufferization state. + void addDialectStateInitializer(StringRef name, DialectStateInitFn fn); + private: /// Allow a dialect. template @@ -362,6 +375,12 @@ public: return static_cast(*dialectState[name]); } + void insertDialectState(StringRef name, + std::unique_ptr state) { + assert(!dialectState.count(name) && "dialect state already initialized"); + dialectState[name] = std::move(state); + } + /// Return a reference to the BufferizationOptions. const BufferizationOptions &getOptions() const { return options; } diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index e5d9487..3e0e89a 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -64,6 +64,12 @@ BufferizationOptions::dynCastBufferizableOp(Value value) const { return nullptr; } +void BufferizationOptions::addDialectStateInitializer(StringRef name, + DialectStateInitFn fn) { + stateInitializers.push_back( + [=](BufferizationState &state) { state.insertDialectState(name, fn()); }); +} + //===----------------------------------------------------------------------===// // Helper functions for BufferizableOpInterface //===----------------------------------------------------------------------===// @@ -200,7 +206,11 @@ BufferizationState::findLastPrecedingWrite(Value value) const { } BufferizationState::BufferizationState(const BufferizationOptions &options) - : options(options) {} + : options(options) { + for (const BufferizationOptions::BufferizationStateInitFn &fn : + options.stateInitializers) + fn(*this); +} // bufferization.to_memref is not allowed to change the rank. static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) { -- 2.7.4