#include "mlir/Support/LLVM.h"
#include "llvm/ADT/SetVector.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizationEnums.h.inc"
+
namespace mlir {
class OpBuilder;
using UnknownTypeConverterFn = std::function<BaseMemRefType(
Value, unsigned, const BufferizationOptions &)>;
- enum class LayoutMapOption : int8_t {
- InferLayoutMap = 0,
- IdentityLayoutMap = 1,
- FullyDynamicLayoutMap = 2
- };
-
BufferizationOptions();
/// Try to cast the given op to BufferizableOpInterface if the op is allow
} // namespace bufferization
} // namespace mlir
+//===----------------------------------------------------------------------===//
+// Bufferization Interfaces
+//===----------------------------------------------------------------------===//
+
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h.inc"
#endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZABLEOPINTERFACE_H_
--- /dev/null
+//===- BufferizationEnums.td - Bufferization enums ---------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the definition file for enums used in Bufferization.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef BUFFERIZATION_ENUMS
+#define BUFFERIZATION_ENUMS
+
+include "mlir/IR/EnumAttr.td"
+
+def LayoutMapOption : I32EnumAttr<"LayoutMapOption",
+ "option for map layout", [
+ I32EnumAttrCase<"InferLayoutMap", 0>,
+ I32EnumAttrCase<"IdentityLayoutMap", 1>,
+ I32EnumAttrCase<"FullyDynamicLayoutMap", 2>
+]> {
+ let cppNamespace = "::mlir::bufferization";
+}
+
+#endif // BUFFERIZATION_ENUMS
add_mlir_doc(BufferizationOps BufferizationOps Dialects/ -gen-dialect-doc)
add_mlir_interface(AllocationOpInterface)
add_mlir_interface(BufferizableOpInterface)
+
+set(LLVM_TARGET_DEFINITIONS BufferizationEnums.td)
+mlir_tablegen(BufferizationEnums.h.inc -gen-enum-decls)
+mlir_tablegen(BufferizationEnums.cpp.inc -gen-enum-defs)
+add_public_tablegen_target(MLIRBufferizationEnumsIncGen)
+add_dependencies(mlir-headers MLIRBufferizationEnumsIncGen)
#ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMOPS_BUFFERIZATIONTRANSFORMOPS_H
#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMOPS_BUFFERIZATIONTRANSFORMOPS_H
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/IR/OpImplementation.h"
#ifndef BUFFERIZATION_TRANSFORM_OPS
#define BUFFERIZATION_TRANSFORM_OPS
+include "mlir/Dialect/Bufferization/IR/BufferizationEnums.td"
include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/IR/TransformEffects.td"
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
let arguments = (
ins PDL_Operation:$target,
+ OptionalAttr<LayoutMapOption>:$function_boundary_type_conversion,
DefaultValuedAttr<BoolAttr, "false">:$allow_return_allocs,
DefaultValuedAttr<BoolAttr, "false">:$allow_unknown_ops,
DefaultValuedAttr<BoolAttr, "false">:$bufferize_function_boundaries,
let results = (outs);
- let assemblyFormat = "$target attr-dict";
+ let assemblyFormat = [{
+ (`layout` `{` $function_boundary_type_conversion^ `}`)?
+ $target attr-dict
+ }];
}
#endif // BUFFERIZATION_TRANSFORM_OPS
DEPENDS
MLIRAllocationOpInterfaceIncGen
MLIRBufferizationOpsIncGen
+ MLIRBufferizationEnumsIncGen
LINK_LIBS PUBLIC
MLIRAffineDialect
options.createDeallocs = getCreateDeallocs();
options.testAnalysisOnly = getTestAnalysisOnly();
options.printConflicts = getPrintConflicts();
+ if (getFunctionBoundaryTypeConversion().has_value())
+ options.functionBoundaryTypeConversion =
+ *getFunctionBoundaryTypeConversion();
ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
for (Operation *target : payloadOps) {
#define GET_OP_CLASSES
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"
+#include "mlir/Dialect/Bufferization/IR/BufferizationEnums.cpp.inc"
+
void mlir::bufferization::registerTransformDialectExtension(
DialectRegistry ®istry) {
registry.addExtensions<BufferizationTransformDialectExtension>();
}
};
-static BufferizationOptions::LayoutMapOption
-parseLayoutMapOption(const std::string &s) {
+static LayoutMapOption parseLayoutMapOption(const std::string &s) {
if (s == "fully-dynamic-layout-map")
- return BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap;
+ return LayoutMapOption::FullyDynamicLayoutMap;
if (s == "identity-layout-map")
- return BufferizationOptions::LayoutMapOption::IdentityLayoutMap;
+ return LayoutMapOption::IdentityLayoutMap;
if (s == "infer-layout-map")
- return BufferizationOptions::LayoutMapOption::InferLayoutMap;
+ return LayoutMapOption::InferLayoutMap;
llvm_unreachable("invalid layout map option");
}
opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
// Configure type converter.
- BufferizationOptions::LayoutMapOption unknownTypeConversionOption =
+ LayoutMapOption unknownTypeConversionOption =
parseLayoutMapOption(unknownTypeConversion);
opt.unknownTypeConverterFn = [=](Value value, unsigned memorySpace,
const BufferizationOptions &options) {
auto tensorType = value.getType().cast<TensorType>();
- if (unknownTypeConversionOption ==
- BufferizationOptions::LayoutMapOption::IdentityLayoutMap)
+ if (unknownTypeConversionOption == LayoutMapOption::IdentityLayoutMap)
return bufferization::getMemRefTypeWithStaticIdentityLayout(
tensorType, memorySpace);
- assert(
- unknownTypeConversionOption ==
- BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap &&
- "invalid layout map option");
+ assert(unknownTypeConversionOption ==
+ LayoutMapOption::FullyDynamicLayoutMap &&
+ "invalid layout map option");
return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
memorySpace);
};
DEPENDS
MLIRBufferizationPassIncGen
+ MLIRBufferizationEnumsIncGen
LINK_LIBS PUBLIC
MLIRBufferizationDialect
BaseMemRefType memrefType;
if (options.functionBoundaryTypeConversion ==
- BufferizationOptions::LayoutMapOption::IdentityLayoutMap) {
+ LayoutMapOption::IdentityLayoutMap) {
memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
} else {
// Note: Layout maps on function parameters cannot be inferred. The best we
BaseMemRefType resultType;
if (options.functionBoundaryTypeConversion ==
- BufferizationOptions::LayoutMapOption::IdentityLayoutMap) {
+ LayoutMapOption::IdentityLayoutMap) {
resultType = getMemRefTypeWithStaticIdentityLayout(tensorType);
} else {
// Note: If `InferLayoutMap`, cast are later folded away.
return failure();
// Change buffer return types to more precise layout maps.
if (options.functionBoundaryTypeConversion ==
- BufferizationOptions::LayoutMapOption::InferLayoutMap)
+ LayoutMapOption::InferLayoutMap)
foldMemRefCasts(funcOp);
}
// TODO(springerm): To spot memory leaks more easily, returning dense allocs
// should be disallowed.
options.allowReturnAllocs = true;
- options.functionBoundaryTypeConversion =
- BufferizationOptions::LayoutMapOption::IdentityLayoutMap;
+ options.functionBoundaryTypeConversion = LayoutMapOption::IdentityLayoutMap;
options.unknownTypeConverterFn = [](Value value, unsigned memorySpace,
const BufferizationOptions &options) {
return getMemRefTypeWithStaticIdentityLayout(
return %0 : tensor<?xf32>
}
}
+
+// -----
+
+// Test we use identity layout at function boundaries.
+
+transform.sequence failures(propagate) {
+ ^bb0(%arg1: !pdl.operation):
+ transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap} %arg1 {
+ target_is_module = true,
+ bufferize_function_boundaries = true }
+}
+
+// CHECK: func.func @matmul(
+// CHECK-SAME: %[[A:.*]]: memref<12x9xf32>,
+// CHECK-SAME: %[[B:.*]]: memref<9x6xf32>,
+// CHECK-SAME: %[[C:.*]]: memref<12x6xf32>) -> memref<12x6xf32> {
+func.func @matmul(%A: tensor<12x9xf32>, %B: tensor<9x6xf32>, %C: tensor<12x6xf32>) -> tensor<12x6xf32> {
+ // CHECK: linalg.matmul ins(%[[A]], %[[B]] : memref<12x9xf32>, memref<9x6xf32>) outs(%[[C]] : memref<12x6xf32>)
+ %D = linalg.matmul ins(%A, %B: tensor<12x9xf32>, tensor<9x6xf32>) outs(%C: tensor<12x6xf32>) -> tensor<12x6xf32>
+ // CHECK: return %[[C]] : memref<12x6xf32>
+ return %D : tensor<12x6xf32>
+}