/// Creates an instance of std bufferization pass.
std::unique_ptr<Pass> createStdBufferizePass();
+/// Creates an instance of func bufferization pass.
+std::unique_ptr<Pass> createFuncBufferizePass();
+
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
let dependentDialects = ["scf::SCFDialect"];
}
+def FuncBufferize : Pass<"func-bufferize", "ModuleOp"> {
+ let summary = "Bufferize func/call/return ops";
+ let description = [{
+ A finalizing bufferize pass that bufferizes std.func and std.call ops.
+
+ Because this pass updates std.func ops, it must be a module pass. It is
+ useful to keep this pass separate from other bufferizations so that the
+ other ones can be run at function-level in parallel.
+
+ This pass must be done atomically for two reasons:
+ 1. This pass changes func op signatures, which requires atomically updating
+ calls as well throughout the entire module.
+ 2. This pass changes the type of block arguments, which requires that all
+ successor arguments of predecessors be converted. Terminators are not
+ a closed universe (and need not implement BranchOpInterface), and so we
+ cannot in general rewrite them.
+
+ Note, because this is a "finalizing" bufferize step, it can create
+ invalid IR because it will not create materializations. To avoid this
+ situation, the pass must only be run when the only SSA values of
+ tensor type are:
+ - block arguments
+ - the result of tensor_load
+ Other values of tensor type should be eliminated by earlier
+ bufferization passes.
+ }];
+ let constructor = "mlir::createFuncBufferizePass()";
+}
+
#endif // MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES
/// This function should be called by all bufferization passes using
/// BufferizeTypeConverter so that materializations work proprely. One exception
/// is bufferization passes doing "full" conversions, where it can be desirable
-/// for even the materializations to remain illegal so that they are eliminated.
+/// for even the materializations to remain illegal so that they are eliminated,
+/// such as via the patterns in
+/// populateEliminateBufferizeMaterializationsPatterns.
void populateBufferizeMaterializationLegality(ConversionTarget &target);
+/// Populate patterns to eliminate bufferize materializations.
+///
+/// In particular, these are the tensor_load/tensor_to_memref ops.
+void populateEliminateBufferizeMaterializationsPatterns(
+ MLIRContext *context, BufferizeTypeConverter &typeConverter,
+ OwningRewritePatternList &patterns);
+
/// Helper conversion pattern that encapsulates a BufferizeTypeConverter
/// instance.
template <typename SourceOp>
ExpandAtomic.cpp
ExpandMemRefReshape.cpp
ExpandTanh.cpp
+ FuncBufferize.cpp
FuncConversions.cpp
ADDITIONAL_HEADER_DIRS
--- /dev/null
+//===- Bufferize.cpp - Bufferization for std ops --------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements bufferization of std.func's and std.call's.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
+#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
+#include "mlir/Transforms/Bufferize.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace {
+struct FuncBufferizePass : public FuncBufferizeBase<FuncBufferizePass> {
+ void runOnOperation() override {
+ auto module = getOperation();
+ auto *context = &getContext();
+
+ BufferizeTypeConverter typeConverter;
+ OwningRewritePatternList patterns;
+ ConversionTarget target(*context);
+
+ populateFuncOpTypeConversionPattern(patterns, context, typeConverter);
+ target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
+ return typeConverter.isSignatureLegal(op.getType()) &&
+ typeConverter.isLegal(&op.getBody());
+ });
+ populateCallOpTypeConversionPattern(patterns, context, typeConverter);
+ populateEliminateBufferizeMaterializationsPatterns(context, typeConverter,
+ patterns);
+ target.addIllegalOp<TensorLoadOp, TensorToMemrefOp>();
+
+ // If all result types are legal, and all block arguments are legal (ensured
+ // by func conversion above), then all types in the program are legal.
+ target.markUnknownOpDynamicallyLegal([&](Operation *op) {
+ return typeConverter.isLegal(op->getResultTypes());
+ });
+
+ if (failed(applyFullConversion(module, target, std::move(patterns))))
+ signalPassFailure();
+ }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::createFuncBufferizePass() {
+ return std::make_unique<FuncBufferizePass>();
+}
target.addLegalOp<TensorLoadOp, TensorToMemrefOp>();
}
+namespace {
+// In a finalizing bufferize conversion, we know that all tensors have been
+// converted to memrefs, thus, this op becomes an identity.
+class BufferizeTensorLoadOp : public OpConversionPattern<TensorLoadOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(TensorLoadOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ TensorLoadOp::Adaptor adaptor(operands);
+ rewriter.replaceOp(op, adaptor.memref());
+ return success();
+ }
+};
+} // namespace
+
+namespace {
+// In a finalizing bufferize conversion, we know that all tensors have been
+// converted to memrefs, thus, this op becomes an identity.
+class BufferizeTensorToMemrefOp : public OpConversionPattern<TensorToMemrefOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(TensorToMemrefOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ TensorToMemrefOp::Adaptor adaptor(operands);
+ rewriter.replaceOp(op, adaptor.tensor());
+ return success();
+ }
+};
+} // namespace
+
+void mlir::populateEliminateBufferizeMaterializationsPatterns(
+ MLIRContext *context, BufferizeTypeConverter &typeConverter,
+ OwningRewritePatternList &patterns) {
+ patterns.insert<BufferizeTensorLoadOp, BufferizeTensorToMemrefOp>(
+ typeConverter, context);
+}
+
//===----------------------------------------------------------------------===//
// BufferizeFuncOpConverter
//===----------------------------------------------------------------------===//
--- /dev/null
+// RUN: mlir-opt %s -func-bufferize -split-input-file -verify-diagnostics | FileCheck %s
+
+// CHECK-LABEL: func @identity(
+// CHECK-SAME: %[[ARG:.*]]: memref<f32>) -> memref<f32> {
+// CHECK: return %[[ARG]] : memref<f32>
+func @identity(%arg0: tensor<f32>) -> tensor<f32> {
+ return %arg0 : tensor<f32>
+}
+
+// CHECK-LABEL: func @block_arguments(
+// CHECK-SAME: %[[ARG:.*]]: memref<f32>) -> memref<f32> {
+// CHECK: br ^bb1(%[[ARG]] : memref<f32>)
+// CHECK: ^bb1(%[[BBARG:.*]]: memref<f32>):
+// CHECK: return %[[BBARG]] : memref<f32>
+func @block_arguments(%arg0: tensor<f32>) -> tensor<f32> {
+ br ^bb1(%arg0: tensor<f32>)
+^bb1(%bbarg: tensor<f32>):
+ return %bbarg : tensor<f32>
+}
+
+// CHECK-LABEL: func @eliminate_target_materialization(
+// CHECK-SAME: %[[ARG:.*]]: memref<f32>) -> memref<f32> {
+// CHECK: return %[[ARG]] : memref<f32>
+func @eliminate_target_materialization(%arg0: tensor<f32>) -> memref<f32> {
+ %0 = tensor_to_memref %arg0 : memref<f32>
+ return %0 : memref<f32>
+}
+
+// CHECK-LABEL: func @eliminate_source_materialization(
+// CHECK-SAME: %[[ARG:.*]]: memref<f32>) -> memref<f32> {
+// CHECK: return %[[ARG]] : memref<f32>
+func @eliminate_source_materialization(%arg0: memref<f32>) -> tensor<f32> {
+ %0 = tensor_load %arg0 : memref<f32>
+ return %0 : tensor<f32>
+}
+
+// CHECK-LABEL: func @source() -> memref<f32>
+// CHECK-LABEL: func @call_source() -> memref<f32> {
+// CHECK: %[[RET:.*]] = call @source() : () -> memref<f32>
+// CHECK: return %[[RET]] : memref<f32>
+func @source() -> tensor<f32>
+func @call_source() -> tensor<f32> {
+ %0 = call @source() : () -> tensor<f32>
+ return %0 : tensor<f32>
+}
+
+// CHECK-LABEL: func @sink(memref<f32>)
+// CHECK-LABEL: func @call_sink(
+// CHECK-SAME: %[[ARG:.*]]: memref<f32>) {
+// CHECK: call @sink(%[[ARG]]) : (memref<f32>) -> ()
+// CHECK: return
+func @sink(tensor<f32>)
+func @call_sink(%arg0: tensor<f32>) {
+ call @sink(%arg0) : (tensor<f32>) -> ()
+ return
+}
+
+// -----
+
+func @failed_to_legalize() -> tensor<f32> {
+ // expected-error @+1 {{failed to legalize operation 'test.source'}}
+ %0 = "test.source"() : () -> (tensor<f32>)
+ return %0 : tensor<f32>
+}