[MLIR][Transform] Expose map layout option in `OneShotBufferizeOp`
authorLorenzo Chelini <l.chelini@icloud.com>
Fri, 11 Nov 2022 12:35:16 +0000 (13:35 +0100)
committerLorenzo Chelini <l.chelini@icloud.com>
Mon, 14 Nov 2022 17:09:54 +0000 (18:09 +0100)
Expose `function-boundary-type-conversion` in `OneShotBufferizeOp`. To
reuse options between passes and transform operations, create a
`BufferizationEnums.td`.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D137833

13 files changed:
mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
mlir/include/mlir/Dialect/Bufferization/IR/BufferizationEnums.td [new file with mode: 0644]
mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h
mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td
mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir

index 2c9dd66..a5324e1 100644 (file)
@@ -14,6 +14,8 @@
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/SetVector.h"
 
+#include "mlir/Dialect/Bufferization/IR/BufferizationEnums.h.inc"
+
 namespace mlir {
 class OpBuilder;
 
@@ -187,12 +189,6 @@ struct BufferizationOptions {
   using UnknownTypeConverterFn = std::function<BaseMemRefType(
       Value, unsigned, const BufferizationOptions &)>;
 
-  enum class LayoutMapOption : int8_t {
-    InferLayoutMap = 0,
-    IdentityLayoutMap = 1,
-    FullyDynamicLayoutMap = 2
-  };
-
   BufferizationOptions();
 
   /// Try to cast the given op to BufferizableOpInterface if the op is allow
@@ -585,6 +581,10 @@ bool defaultIsRepetitiveRegion(BufferizableOpInterface bufferizableOp,
 } // namespace bufferization
 } // namespace mlir
 
+//===----------------------------------------------------------------------===//
+// Bufferization Interfaces
+//===----------------------------------------------------------------------===//
+
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h.inc"
 
 #endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZABLEOPINTERFACE_H_
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationEnums.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationEnums.td
new file mode 100644 (file)
index 0000000..9242361
--- /dev/null
@@ -0,0 +1,27 @@
+//===- BufferizationEnums.td - Bufferization enums ---------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the definition file for enums used in Bufferization.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef BUFFERIZATION_ENUMS
+#define BUFFERIZATION_ENUMS
+
+include "mlir/IR/EnumAttr.td"
+
+def LayoutMapOption : I32EnumAttr<"LayoutMapOption", 
+                                  "option for map layout", [
+  I32EnumAttrCase<"InferLayoutMap", 0>,
+  I32EnumAttrCase<"IdentityLayoutMap", 1>,
+  I32EnumAttrCase<"FullyDynamicLayoutMap", 2>
+]> {
+  let cppNamespace = "::mlir::bufferization";
+}
+
+#endif // BUFFERIZATION_ENUMS
index 8ddfe5a..aa93534 100644 (file)
@@ -2,3 +2,9 @@ add_mlir_dialect(BufferizationOps bufferization)
 add_mlir_doc(BufferizationOps BufferizationOps Dialects/ -gen-dialect-doc)
 add_mlir_interface(AllocationOpInterface)
 add_mlir_interface(BufferizableOpInterface)
+
+set(LLVM_TARGET_DEFINITIONS BufferizationEnums.td)
+mlir_tablegen(BufferizationEnums.h.inc -gen-enum-decls)
+mlir_tablegen(BufferizationEnums.cpp.inc -gen-enum-defs)
+add_public_tablegen_target(MLIRBufferizationEnumsIncGen)
+add_dependencies(mlir-headers MLIRBufferizationEnumsIncGen)
index bc51845..0aab581 100644 (file)
@@ -9,6 +9,7 @@
 #ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMOPS_BUFFERIZATIONTRANSFORMOPS_H
 #define MLIR_DIALECT_BUFFERIZATION_TRANSFORMOPS_BUFFERIZATIONTRANSFORMOPS_H
 
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/IR/OpImplementation.h"
index 72e6796..e63ecbf 100644 (file)
@@ -9,6 +9,7 @@
 #ifndef BUFFERIZATION_TRANSFORM_OPS
 #define BUFFERIZATION_TRANSFORM_OPS
 
+include "mlir/Dialect/Bufferization/IR/BufferizationEnums.td"
 include "mlir/Dialect/Transform/IR/TransformDialect.td"
 include "mlir/Dialect/Transform/IR/TransformEffects.td"
 include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
@@ -42,6 +43,7 @@ def OneShotBufferizeOp
 
   let arguments = (
       ins PDL_Operation:$target,
+      OptionalAttr<LayoutMapOption>:$function_boundary_type_conversion,
       DefaultValuedAttr<BoolAttr, "false">:$allow_return_allocs,
       DefaultValuedAttr<BoolAttr, "false">:$allow_unknown_ops,
       DefaultValuedAttr<BoolAttr, "false">:$bufferize_function_boundaries,
@@ -52,7 +54,10 @@ def OneShotBufferizeOp
 
   let results = (outs);
 
-  let assemblyFormat = "$target attr-dict";
+  let assemblyFormat = [{
+    (`layout` `{` $function_boundary_type_conversion^ `}`)?
+    $target attr-dict
+  }];
 }
 
 #endif // BUFFERIZATION_TRANSFORM_OPS
index 0c085a4..e774140 100644 (file)
@@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRBufferizationDialect
   DEPENDS
   MLIRAllocationOpInterfaceIncGen
   MLIRBufferizationOpsIncGen
+  MLIRBufferizationEnumsIncGen
 
   LINK_LIBS PUBLIC
   MLIRAffineDialect
index fc3c386..9415bf7 100644 (file)
@@ -34,6 +34,9 @@ transform::OneShotBufferizeOp::apply(TransformResults &transformResults,
   options.createDeallocs = getCreateDeallocs();
   options.testAnalysisOnly = getTestAnalysisOnly();
   options.printConflicts = getPrintConflicts();
+  if (getFunctionBoundaryTypeConversion().has_value())
+    options.functionBoundaryTypeConversion =
+        *getFunctionBoundaryTypeConversion();
 
   ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
   for (Operation *target : payloadOps) {
@@ -94,6 +97,8 @@ public:
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"
 
+#include "mlir/Dialect/Bufferization/IR/BufferizationEnums.cpp.inc"
+
 void mlir::bufferization::registerTransformDialectExtension(
     DialectRegistry &registry) {
   registry.addExtensions<BufferizationTransformDialectExtension>();
index e4355ad..7c3b7c8 100644 (file)
@@ -163,14 +163,13 @@ struct FinalizingBufferizePass
   }
 };
 
-static BufferizationOptions::LayoutMapOption
-parseLayoutMapOption(const std::string &s) {
+static LayoutMapOption parseLayoutMapOption(const std::string &s) {
   if (s == "fully-dynamic-layout-map")
-    return BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap;
+    return LayoutMapOption::FullyDynamicLayoutMap;
   if (s == "identity-layout-map")
-    return BufferizationOptions::LayoutMapOption::IdentityLayoutMap;
+    return LayoutMapOption::IdentityLayoutMap;
   if (s == "infer-layout-map")
-    return BufferizationOptions::LayoutMapOption::InferLayoutMap;
+    return LayoutMapOption::InferLayoutMap;
   llvm_unreachable("invalid layout map option");
 }
 
@@ -216,19 +215,17 @@ struct OneShotBufferizePass
       opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
 
       // Configure type converter.
-      BufferizationOptions::LayoutMapOption unknownTypeConversionOption =
+      LayoutMapOption unknownTypeConversionOption =
           parseLayoutMapOption(unknownTypeConversion);
       opt.unknownTypeConverterFn = [=](Value value, unsigned memorySpace,
                                        const BufferizationOptions &options) {
         auto tensorType = value.getType().cast<TensorType>();
-        if (unknownTypeConversionOption ==
-            BufferizationOptions::LayoutMapOption::IdentityLayoutMap)
+        if (unknownTypeConversionOption == LayoutMapOption::IdentityLayoutMap)
           return bufferization::getMemRefTypeWithStaticIdentityLayout(
               tensorType, memorySpace);
-        assert(
-            unknownTypeConversionOption ==
-                BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap &&
-            "invalid layout map option");
+        assert(unknownTypeConversionOption ==
+                   LayoutMapOption::FullyDynamicLayoutMap &&
+               "invalid layout map option");
         return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
                                                                   memorySpace);
       };
index 453d71f..e23c5c3 100644 (file)
@@ -18,6 +18,7 @@ add_mlir_dialect_library(MLIRBufferizationTransforms
 
   DEPENDS
   MLIRBufferizationPassIncGen
+  MLIRBufferizationEnumsIncGen
 
   LINK_LIBS PUBLIC
   MLIRBufferizationDialect
index 49c57f4..91060dd 100644 (file)
@@ -69,7 +69,7 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
 
   BaseMemRefType memrefType;
   if (options.functionBoundaryTypeConversion ==
-      BufferizationOptions::LayoutMapOption::IdentityLayoutMap) {
+      LayoutMapOption::IdentityLayoutMap) {
     memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
   } else {
     // Note: Layout maps on function parameters cannot be inferred. The best we
@@ -471,7 +471,7 @@ struct FuncOpInterface
 
       BaseMemRefType resultType;
       if (options.functionBoundaryTypeConversion ==
-          BufferizationOptions::LayoutMapOption::IdentityLayoutMap) {
+          LayoutMapOption::IdentityLayoutMap) {
         resultType = getMemRefTypeWithStaticIdentityLayout(tensorType);
       } else {
         // Note: If `InferLayoutMap`, cast are later folded away.
index badcf29..fb1d50c 100644 (file)
@@ -423,7 +423,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
       return failure();
     // Change buffer return types to more precise layout maps.
     if (options.functionBoundaryTypeConversion ==
-        BufferizationOptions::LayoutMapOption::InferLayoutMap)
+        LayoutMapOption::InferLayoutMap)
       foldMemRefCasts(funcOp);
   }
 
index 478bac5..0e1fbaf 100644 (file)
@@ -32,8 +32,7 @@ getBufferizationOptions(bool analysisOnly) {
   // TODO(springerm): To spot memory leaks more easily, returning dense allocs
   // should be disallowed.
   options.allowReturnAllocs = true;
-  options.functionBoundaryTypeConversion =
-      BufferizationOptions::LayoutMapOption::IdentityLayoutMap;
+  options.functionBoundaryTypeConversion = LayoutMapOption::IdentityLayoutMap;
   options.unknownTypeConverterFn = [](Value value, unsigned memorySpace,
                                       const BufferizationOptions &options) {
     return getMemRefTypeWithStaticIdentityLayout(
index 151c8e6..4ff8a23 100644 (file)
@@ -96,3 +96,25 @@ module {
     return %0 : tensor<?xf32>
   }
 }
+
+// -----
+
+// Test we use identity layout at function boundaries.
+
+transform.sequence failures(propagate) {
+  ^bb0(%arg1: !pdl.operation):
+  transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap} %arg1 {
+    target_is_module = true,
+    bufferize_function_boundaries = true }
+}
+
+// CHECK: func.func @matmul(
+// CHECK-SAME:  %[[A:.*]]: memref<12x9xf32>,
+// CHECK-SAME:  %[[B:.*]]: memref<9x6xf32>,
+// CHECK-SAME:  %[[C:.*]]: memref<12x6xf32>) -> memref<12x6xf32> {
+func.func @matmul(%A: tensor<12x9xf32>, %B: tensor<9x6xf32>, %C: tensor<12x6xf32>) -> tensor<12x6xf32> {
+  // CHECK: linalg.matmul ins(%[[A]], %[[B]] : memref<12x9xf32>, memref<9x6xf32>) outs(%[[C]] : memref<12x6xf32>)
+  %D = linalg.matmul ins(%A, %B: tensor<12x9xf32>, tensor<9x6xf32>) outs(%C: tensor<12x6xf32>) -> tensor<12x6xf32>
+  // CHECK: return %[[C]] : memref<12x6xf32>
+  return %D : tensor<12x6xf32>
+}