[mlir][bufferization] Make function boundary type convertion logic dynamic.
authorOleg Shyshkov <shyshkov@google.com>
Wed, 12 Apr 2023 08:48:14 +0000 (10:48 +0200)
committerOleg Shyshkov <shyshkov@google.com>
Wed, 12 Apr 2023 09:02:43 +0000 (11:02 +0200)
Having to choose from only static or dynamic layout for all function is limiting.

Differential Revision: https://reviews.llvm.org/D148074

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp

index b664496..2dbd113 100644 (file)
@@ -19,6 +19,9 @@
 
 namespace mlir {
 class OpBuilder;
+namespace func {
+class FuncOp;
+}
 
 namespace bufferization {
 
@@ -250,6 +253,11 @@ struct BufferizationOptions {
   /// Initializer function for analysis state.
   using AnalysisStateInitFn = std::function<void(AnalysisState &)>;
   /// Tensor -> MemRef type converter.
+  /// Parameters: Value, memory space, func op, bufferization options
+  using FunctionArgTypeConverterFn =
+      std::function<BaseMemRefType(TensorType, Attribute memorySpace,
+                                   func::FuncOp, const BufferizationOptions &)>;
+  /// Tensor -> MemRef type converter.
   /// Parameters: Value, memory space, bufferization options
   using UnknownTypeConverterFn = std::function<BaseMemRefType(
       Value, Attribute memorySpace, const BufferizationOptions &)>;
@@ -313,7 +321,8 @@ struct BufferizationOptions {
   /// OpOperands out-of-place.
   bool enforceAliasingInvariants = true;
 
-  /// This flag controls buffer types on function signatures.
+  /// This function controls buffer types on function signatures. Sets
+  /// `functionArgTypeConverterFn` and `inferFunctionResultLayout` accordingly.
   ///
   /// * InferLayoutMap: All function parameter types have a fully dynamic layout
   ///   map, but function result types are inferred from the body of the
@@ -326,13 +335,25 @@ struct BufferizationOptions {
   ///   additional buffer allocs and copies because layout maps cannot be casted
   ///   away.
   ///
-  /// If `bufferizeFunctionBoundaries` is not set, this flag has no effect.
-  ///
   /// Note: Inferred layout maps may not be desireable when interacting with
   /// external functions, because the generated function signatures will be less
   /// predictable.
-  LayoutMapOption functionBoundaryTypeConversion =
-      LayoutMapOption::InferLayoutMap;
+  void setFunctionBoundaryTypeConversion(LayoutMapOption layoutMapOption);
+
+  /// Type converter from tensors to memrefs. This type converter is used to
+  /// determine bufferized function argument types. By default, a type
+  /// converter that returns a memref type with a fully dynamic layout map is
+  /// used.
+  ///
+  /// If `bufferizeFunctionBoundaries` is not set, this function isn't used.
+  FunctionArgTypeConverterFn functionArgTypeConverterFn = nullptr;
+
+  /// If true, function result types are inferred from the body of the function.
+  /// Otherwise, function result type is determined by
+  /// `functionArgTypeConverterFn`.
+  ///
+  /// If `bufferizeFunctionBoundaries` is not set, this flag has no effect.
+  bool inferFunctionResultLayout = true;
 
   /// Type converter from tensors to memrefs. This type converter is used if no
   /// memref type could be inferred during bufferization. By default, a type
index 3b965cf..70d857b 100644 (file)
@@ -322,17 +322,29 @@ bool OpFilter::isOpAllowed(Operation *op) const {
 // BufferizationOptions
 //===----------------------------------------------------------------------===//
 
+namespace {
+
+/// Default function arg type converter: Use a fully dynamic layout map.
+BaseMemRefType
+defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace,
+                                func::FuncOp funcOp,
+                                const BufferizationOptions &options) {
+  return getMemRefTypeWithFullyDynamicLayout(type, memorySpace);
+}
 /// Default unknown type converter: Use a fully dynamic layout map.
-static BaseMemRefType
+BaseMemRefType
 defaultUnknownTypeConverter(Value value, Attribute memorySpace,
                             const BufferizationOptions &options) {
   return getMemRefTypeWithFullyDynamicLayout(value.getType().cast<TensorType>(),
                                              memorySpace);
 }
 
+}; // namespace
+
 // Default constructor for BufferizationOptions.
 BufferizationOptions::BufferizationOptions()
-    : unknownTypeConverterFn(defaultUnknownTypeConverter) {}
+    : functionArgTypeConverterFn(defaultFunctionArgTypeConverter),
+      unknownTypeConverterFn(defaultUnknownTypeConverter) {}
 
 bool BufferizationOptions::isOpAllowed(Operation *op) const {
   // Special case: If function boundary bufferization is deactivated, do not
@@ -362,6 +374,21 @@ BufferizationOptions::dynCastBufferizableOp(Value value) const {
   return nullptr;
 }
 
+void BufferizationOptions::setFunctionBoundaryTypeConversion(
+    LayoutMapOption layoutMapOption) {
+  functionArgTypeConverterFn = [=](TensorType tensorType, Attribute memorySpace,
+                                   func::FuncOp funcOp,
+                                   const BufferizationOptions &options) {
+    if (layoutMapOption == LayoutMapOption::IdentityLayoutMap)
+      return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType,
+                                                                  memorySpace);
+    return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
+                                                              memorySpace);
+  };
+  inferFunctionResultLayout =
+      layoutMapOption == LayoutMapOption::InferLayoutMap;
+}
+
 //===----------------------------------------------------------------------===//
 // Helper functions for BufferizableOpInterface
 //===----------------------------------------------------------------------===//
index 58766e8..ed95a62 100644 (file)
@@ -38,8 +38,8 @@ transform::OneShotBufferizeOp::apply(TransformResults &transformResults,
   options.testAnalysisOnly = getTestAnalysisOnly();
   options.printConflicts = getPrintConflicts();
   if (getFunctionBoundaryTypeConversion().has_value())
-    options.functionBoundaryTypeConversion =
-        *getFunctionBoundaryTypeConversion();
+    options.setFunctionBoundaryTypeConversion(
+        *getFunctionBoundaryTypeConversion());
 
   ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
   for (Operation *target : payloadOps) {
index e5e125f..4eabfcc 100644 (file)
@@ -208,8 +208,8 @@ struct OneShotBufferizePass
       opt.analysisHeuristic = parseHeuristicOption(analysisHeuristic);
       opt.copyBeforeWrite = copyBeforeWrite;
       opt.createDeallocs = createDeallocs;
-      opt.functionBoundaryTypeConversion =
-          parseLayoutMapOption(functionBoundaryTypeConversion);
+      opt.setFunctionBoundaryTypeConversion(
+          parseLayoutMapOption(functionBoundaryTypeConversion));
       if (mustInferMemorySpace)
         opt.defaultMemorySpace = std::nullopt;
       opt.printConflicts = printConflicts;
index 1c57679..bf14e46 100644 (file)
@@ -55,8 +55,7 @@ static func::ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
 
 /// Return the index-th bufferized function argument type. This assumes that the
 /// specified argument is a tensor. If the tensor is ranked, a layout map may be
-/// specified by the user. If no layout map is specified, the default layout map
-/// (as per `options.functionBoundaryTypeConversion`) is used.
+/// specified by the user (as per `options.functionArgTypeConverterFn`).
 static BaseMemRefType
 getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
                              const BufferizationOptions &options) {
@@ -64,17 +63,8 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
       funcOp.getFunctionType().getInput(index).dyn_cast<TensorType>();
   assert(tensorType && "expected TensorType");
 
-  BaseMemRefType memrefType;
-  if (options.functionBoundaryTypeConversion ==
-      LayoutMapOption::IdentityLayoutMap) {
-    memrefType = getMemRefTypeWithStaticIdentityLayout(
-        tensorType, *options.defaultMemorySpace);
-  } else {
-    // Note: Layout maps on function parameters cannot be inferred. The best we
-    // can do at the moment is "fully dynamic".
-    memrefType = getMemRefTypeWithFullyDynamicLayout(
-        tensorType, *options.defaultMemorySpace);
-  }
+  BaseMemRefType memrefType = options.functionArgTypeConverterFn(
+      tensorType, *options.defaultMemorySpace, funcOp, options);
 
   auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>(
       index, BufferizationDialect::kBufferLayoutAttrName);
@@ -423,16 +413,10 @@ struct FuncOpInterface
         continue;
       }
 
-      BaseMemRefType resultType;
-      if (options.functionBoundaryTypeConversion ==
-          LayoutMapOption::IdentityLayoutMap) {
-        resultType = getMemRefTypeWithStaticIdentityLayout(
-            tensorType, *options.defaultMemorySpace);
-      } else {
-        // Note: If `InferLayoutMap`, cast are later folded away.
-        resultType = getMemRefTypeWithFullyDynamicLayout(
-            tensorType, *options.defaultMemorySpace);
-      }
+      // Note: If `inferFunctionResultLayout = true`, cast are later folded
+      // away.
+      BaseMemRefType resultType = options.functionArgTypeConverterFn(
+          tensorType, *options.defaultMemorySpace, funcOp, options);
       Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
           loc, resultType, returnVal);
       returnValues.push_back(toMemrefOp);
index c96a507..27b560a 100644 (file)
@@ -433,8 +433,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
                            /*opFilter=*/nullptr, statistics)))
       return failure();
     // Change buffer return types to more precise layout maps.
-    if (options.functionBoundaryTypeConversion ==
-        LayoutMapOption::InferLayoutMap)
+    if (options.inferFunctionResultLayout)
       foldMemRefCasts(funcOp);
   }
 
index a2fa480..99a619c 100644 (file)
@@ -37,7 +37,7 @@ getBufferizationOptions(bool analysisOnly) {
   // TODO(springerm): To spot memory leaks more easily, returning dense allocs
   // should be disallowed.
   options.allowReturnAllocs = true;
-  options.functionBoundaryTypeConversion = LayoutMapOption::IdentityLayoutMap;
+  options.setFunctionBoundaryTypeConversion(LayoutMapOption::IdentityLayoutMap);
   options.unknownTypeConverterFn = [](Value value, Attribute memorySpace,
                                       const BufferizationOptions &options) {
     return getMemRefTypeWithStaticIdentityLayout(