From 4dd5f79f07022dbbff547f4aff13b27134331215 Mon Sep 17 00:00:00 2001 From: Stephan Herhut Date: Thu, 26 Nov 2020 13:26:08 +0100 Subject: [PATCH] [mlir][bufferize] Add argument materialization for bufferization This enables partial bufferization that includes function signatures. To test this, this change also makes the func-bufferize partial and adds a dedicated finalizing-bufferize pass. Differential Revision: https://reviews.llvm.org/D92032 --- .../StandardOps/Transforms/FuncConversions.h | 7 ++ .../mlir/Dialect/StandardOps/Transforms/Passes.td | 30 ++++---- mlir/include/mlir/Transforms/Passes.h | 4 ++ mlir/include/mlir/Transforms/Passes.td | 16 +++++ .../StandardOps/Transforms/FuncBufferize.cpp | 42 +++++++++-- .../StandardOps/Transforms/FuncConversions.cpp | 83 +++++++++++++++++++--- mlir/lib/Transforms/Bufferize.cpp | 51 +++++++++++-- .../Dialect/Standard/func-bufferize-partial.mlir | 59 +++++++++++++++ mlir/test/Dialect/Standard/func-bufferize.mlir | 2 +- 9 files changed, 254 insertions(+), 40 deletions(-) create mode 100644 mlir/test/Dialect/Standard/func-bufferize-partial.mlir diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h index 5a1bc7b..55da3af 100644 --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h @@ -26,6 +26,13 @@ void populateCallOpTypeConversionPattern(OwningRewritePatternList &patterns, MLIRContext *ctx, TypeConverter &converter); +/// Add a pattern to the given pattern list to rewrite branch operations and +/// `return` to use operands that have been legalized by the conversion +/// framework. This can only be done if the branch operation implements the +/// BranchOpInterface. Only needed for partial conversions. +void populateBranchOpInterfaceAndReturnOpTypeConversionPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx, + TypeConverter &converter); } // end namespace mlir #endif // MLIR_DIALECT_STANDARDOPS_TRANSFORMS_FUNCCONVERSIONS_H_ diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td index 3be398f..9623dd1 100644 --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td @@ -25,28 +25,26 @@ def StdExpandOps : FunctionPass<"std-expand"> { def FuncBufferize : Pass<"func-bufferize", "ModuleOp"> { let summary = "Bufferize func/call/return ops"; let description = [{ - A finalizing bufferize pass that bufferizes std.func and std.call ops. + A bufferize pass that bufferizes std.func and std.call ops. Because this pass updates std.func ops, it must be a module pass. It is useful to keep this pass separate from other bufferizations so that the other ones can be run at function-level in parallel. - This pass must be done atomically for two reasons: - 1. This pass changes func op signatures, which requires atomically updating - calls as well throughout the entire module. - 2. This pass changes the type of block arguments, which requires that all - successor arguments of predecessors be converted. Terminators are not - a closed universe (and need not implement BranchOpInterface), and so we - cannot in general rewrite them. + This pass must be done atomically because it changes func op signatures, + which requires atomically updating calls as well throughout the entire + module. - Note, because this is a "finalizing" bufferize step, it can create - invalid IR because it will not create materializations. To avoid this - situation, the pass must only be run when the only SSA values of - tensor type are: - - block arguments - - the result of tensor_load - Other values of tensor type should be eliminated by earlier - bufferization passes. + This pass also changes the type of block arguments, which requires that all + successor arguments of predecessors be converted. This is achieved by + rewriting terminators based on the information provided by the + `BranchOpInterface`. + As this pass rewrites function operations, it also rewrites the + corresponding return operations. Other return-like operations that + implement the `ReturnLike` trait are not rewritten in general, as they + require that the correspondign parent operation is also rewritten. + Finally, this pass fails for unknown terminators, as we cannot decide + whether they need rewriting. }]; let constructor = "mlir::createFuncBufferizePass()"; } diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h index 2e3437a..77d98ce 100644 --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -46,6 +46,10 @@ std::unique_ptr createPromoteBuffersToStackPass(unsigned maxAllocSizeInBytes = 1024, unsigned bitwidthOfIndexType = 64); +/// Creates a pass that finalizes a partial bufferization by removing remaining +/// tensor_load and tensor_to_memref operations. +std::unique_ptr createFinalizingBufferizePass(); + /// Creates a pass that converts memref function results to out-params. std::unique_ptr createBufferResultsToOutParamsPass(); diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td index da4ca24..29fe43f 100644 --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -290,6 +290,22 @@ def Inliner : Pass<"inline"> { ]; } +def FinalizingBufferize : FunctionPass<"finalizing-bufferize"> { + let summary = "Finalize a partial bufferization"; + let description = [{ + A bufferize pass that finalizes a partial bufferization by removing + remaining `tensor_load` and `tensor_to_memref` operations. + + The removal of those operations is only possible if the operations only + exist in pairs, i.e., all uses of `tensor_load` operations are + `tensor_to_memref` operations. + + This pass will fail if not all operations can be removed or if any operation + with tensor typed operands remains. + }]; + let constructor = "mlir::createFinalizingBufferizePass()"; +} + def LocationSnapshot : Pass<"snapshot-op-locations"> { let summary = "Generate new locations from the current IR"; let description = [{ diff --git a/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp index 4aadb72..1aace45 100644 --- a/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp @@ -21,6 +21,8 @@ using namespace mlir; namespace { struct FuncBufferizePass : public FuncBufferizeBase { + using FuncBufferizeBase::FuncBufferizeBase; + void runOnOperation() override { auto module = getOperation(); auto *context = &getContext(); @@ -35,14 +37,42 @@ struct FuncBufferizePass : public FuncBufferizeBase { typeConverter.isLegal(&op.getBody()); }); populateCallOpTypeConversionPattern(patterns, context, typeConverter); - populateEliminateBufferizeMaterializationsPatterns(context, typeConverter, - patterns); - target.addIllegalOp(); + target.addDynamicallyLegalOp( + [&](CallOp op) { return typeConverter.isLegal(op); }); - // If all result types are legal, and all block arguments are legal (ensured - // by func conversion above), then all types in the program are legal. + populateBranchOpInterfaceAndReturnOpTypeConversionPattern(patterns, context, + typeConverter); + target.addLegalOp(); + target.addDynamicallyLegalOp( + [&](ReturnOp op) { return typeConverter.isLegal(op); }); + // Mark terminators as legal if they have the ReturnLike trait or + // implement the BranchOpInterface and have valid types. If they do not + // implement the trait or interface, mark them as illegal no matter what. target.markUnknownOpDynamicallyLegal([&](Operation *op) { - return typeConverter.isLegal(op->getResultTypes()); + // If it is not a terminator, ignore it. + if (op->isKnownNonTerminator()) + return true; + // If it is not the last operation in the block, also ignore it. We do + // this to handle unknown operations, as well. + Block *block = op->getBlock(); + if (!block || &block->back() != op) + return true; + // ReturnLike operations have to be legalized with their parent. For + // return this is handled, for other ops they remain as is. + if (op->hasTrait()) + return true; + // All successor operands of branch like operations must be rewritten. + if (auto branchOp = dyn_cast(op)) { + for (int p = 0, e = op->getBlock()->getNumSuccessors(); p < e; ++p) { + auto successorOperands = branchOp.getSuccessorOperands(p); + if (successorOperands.hasValue() && + !typeConverter.isLegal(successorOperands.getValue().getTypes())) + return false; + } + return true; + } + return false; }); if (failed(applyFullConversion(module, target, std::move(patterns)))) diff --git a/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp index 9d8fceb..07d7c59 100644 --- a/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp @@ -13,21 +13,19 @@ using namespace mlir; namespace { -// Converts the operand and result types of the Standard's CallOp, used together -// with the FuncOpSignatureConversion. +/// Converts the operand and result types of the Standard's CallOp, used +/// together with the FuncOpSignatureConversion. struct CallOpSignatureConversion : public OpConversionPattern { - CallOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) - : OpConversionPattern(ctx), converter(converter) {} + using OpConversionPattern::OpConversionPattern; /// Hook for derived classes to implement combined matching and rewriting. LogicalResult matchAndRewrite(CallOp callOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - FunctionType type = callOp.getCalleeType(); - // Convert the original function results. SmallVector convertedResults; - if (failed(converter.convertTypes(type.getResults(), convertedResults))) + if (failed(typeConverter->convertTypes(callOp.getResultTypes(), + convertedResults))) return failure(); // Substitute with the new result types from the corresponding FuncType @@ -36,14 +34,77 @@ struct CallOpSignatureConversion : public OpConversionPattern { convertedResults, operands); return success(); } - - /// The type converter to use when rewriting the signature. - TypeConverter &converter; }; } // end anonymous namespace void mlir::populateCallOpTypeConversionPattern( OwningRewritePatternList &patterns, MLIRContext *ctx, TypeConverter &converter) { - patterns.insert(ctx, converter); + patterns.insert(converter, ctx); +} + +namespace { +/// Only needed to support partial conversion of functions where this pattern +/// ensures that the branch operation arguments matches up with the succesor +/// block arguments. +class BranchOpInterfaceTypeConversion : public ConversionPattern { +public: + BranchOpInterfaceTypeConversion(TypeConverter &typeConverter, + MLIRContext *ctx) + : ConversionPattern(/*benefit=*/1, typeConverter, MatchAnyOpTypeTag()) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto branchOp = dyn_cast(op); + if (!branchOp) + return failure(); + + // For a branch operation, only some operands go to the target blocks, so + // only rewrite those. + SmallVector newOperands(op->operand_begin(), op->operand_end()); + for (int succIdx = 0, succEnd = op->getBlock()->getNumSuccessors(); + succIdx < succEnd; ++succIdx) { + auto successorOperands = branchOp.getSuccessorOperands(succIdx); + if (!successorOperands) + continue; + for (int idx = successorOperands->getBeginOperandIndex(), + eidx = idx + successorOperands->size(); + idx < eidx; ++idx) { + newOperands[idx] = operands[idx]; + } + } + rewriter.updateRootInPlace( + op, [newOperands, op]() { op->setOperands(newOperands); }); + return success(); + } +}; +} // end anonymous namespace + +namespace { +/// Only needed to support partial conversion of functions where this pattern +/// ensures that the branch operation arguments matches up with the succesor +/// block arguments. +class ReturnOpTypeConversion : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ReturnOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + // For a return, all operands go to the results of the parent, so + // rewrite them all. + Operation *operation = op.getOperation(); + rewriter.updateRootInPlace( + op, [operands, operation]() { operation->setOperands(operands); }); + return success(); + } +}; +} // end anonymous namespace + +void mlir::populateBranchOpInterfaceAndReturnOpTypeConversionPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx, + TypeConverter &typeConverter) { + patterns.insert( + typeConverter, ctx); } diff --git a/mlir/lib/Transforms/Bufferize.cpp b/mlir/lib/Transforms/Bufferize.cpp index ba622335..1811ac8 100644 --- a/mlir/lib/Transforms/Bufferize.cpp +++ b/mlir/lib/Transforms/Bufferize.cpp @@ -7,7 +7,9 @@ //===----------------------------------------------------------------------===// #include "mlir/Transforms/Bufferize.h" +#include "PassDetail.h" #include "mlir/IR/Operation.h" +#include "mlir/Transforms/Passes.h" using namespace mlir; @@ -15,6 +17,13 @@ using namespace mlir; // BufferizeTypeConverter //===----------------------------------------------------------------------===// +static Value materializeTensorLoad(OpBuilder &builder, TensorType type, + ValueRange inputs, Location loc) { + assert(inputs.size() == 1); + assert(inputs[0].getType().isa()); + return builder.create(loc, type, inputs[0]); +} + /// Registers conversions into BufferizeTypeConverter BufferizeTypeConverter::BufferizeTypeConverter() { // Keep all types unchanged. @@ -27,12 +36,8 @@ BufferizeTypeConverter::BufferizeTypeConverter() { addConversion([](UnrankedTensorType type) -> Type { return UnrankedMemRefType::get(type.getElementType(), 0); }); - addSourceMaterialization([](OpBuilder &builder, TensorType type, - ValueRange inputs, Location loc) -> Value { - assert(inputs.size() == 1); - assert(inputs[0].getType().isa()); - return builder.create(loc, type, inputs[0]); - }); + addArgumentMaterialization(materializeTensorLoad); + addSourceMaterialization(materializeTensorLoad); addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); @@ -83,3 +88,37 @@ void mlir::populateEliminateBufferizeMaterializationsPatterns( patterns.insert( typeConverter, context); } + +namespace { +struct FinalizingBufferizePass + : public FinalizingBufferizeBase { + using FinalizingBufferizeBase< + FinalizingBufferizePass>::FinalizingBufferizeBase; + + void runOnFunction() override { + auto func = getFunction(); + auto *context = &getContext(); + + BufferizeTypeConverter typeConverter; + OwningRewritePatternList patterns; + ConversionTarget target(*context); + + populateEliminateBufferizeMaterializationsPatterns(context, typeConverter, + patterns); + target.addIllegalOp(); + + // If all result types are legal, and all block arguments are legal (ensured + // by func conversion above), then all types in the program are legal. + target.markUnknownOpDynamicallyLegal([&](Operation *op) { + return typeConverter.isLegal(op->getResultTypes()); + }); + + if (failed(applyFullConversion(func, target, std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr mlir::createFinalizingBufferizePass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/Standard/func-bufferize-partial.mlir b/mlir/test/Dialect/Standard/func-bufferize-partial.mlir new file mode 100644 index 0000000..2afa532 --- /dev/null +++ b/mlir/test/Dialect/Standard/func-bufferize-partial.mlir @@ -0,0 +1,59 @@ +// RUN: mlir-opt %s -func-bufferize -split-input-file -verify-diagnostics --debug-only=dialect-conversion | FileCheck %s + +// CHECK-LABEL: func @block_arguments( +// CHECK-SAME: %[[ARG:.*]]: memref) -> memref { +// CHECK: %[[T1:.*]] = tensor_load %[[ARG]] : memref +// CHECK: %[[M1:.*]] = tensor_to_memref %[[T1]] : memref +// CHECK: br ^bb1(%[[M1]] : memref) +// CHECK: ^bb1(%[[BBARG:.*]]: memref): +// CHECK: %[[T2:.*]] = tensor_load %[[BBARG]] : memref +// CHECK: %[[M2:.*]] = tensor_to_memref %[[T2]] : memref +// CHECK: return %[[M2]] : memref +func @block_arguments(%arg0: tensor) -> tensor { + br ^bb1(%arg0: tensor) +^bb1(%bbarg: tensor): + return %bbarg : tensor +} + +// CHECK-LABEL: func @partial() +// CHECK-SAME: memref +func @partial() -> tensor { + // CHECK-NEXT: %[[SRC:.*]] = "test.source"() : () -> tensor + // CHECK-NEXT: %[[MEM:.*]] = tensor_to_memref %[[SRC]] : memref + %0 = "test.source"() : () -> tensor + // CHECK-NEXT: return %[[MEM]] : memref + return %0 : tensor +} + +// CHECK-LABEL: func @region_op +// CHECK-SAME: (%[[ARG0:.*]]: i1) -> memref +func @region_op(%arg0: i1) -> tensor { + // CHECK-NEXT: %[[IF:.*]] = scf.if %[[ARG0]] -> (tensor) + %0 = scf.if %arg0 -> (tensor) { + // CHECK-NEXT: %[[SRC:.*]] = "test.source"() : () -> tensor + %1 = "test.source"() : () -> tensor + // CHECK-NEXT: scf.yield %[[SRC]] : tensor + scf.yield %1 : tensor + // CHECK-NEXT: else + } else { + // CHECK-NEXT: %[[OSRC:.*]] = "test.other_source"() : () -> tensor + %1 = "test.other_source"() : () -> tensor + // CHECK-NEXT: scf.yield %[[OSRC]] : tensor + scf.yield %1 : tensor + } + // CHECK: %[[MEM:.*]] = tensor_to_memref %[[IF]] : memref + // CHECK: return %[[MEM]] : memref + return %0 : tensor +} + +// ----- + +func @failed_to_legalize(%arg0: tensor) -> tensor { + %0 = constant true + cond_br %0, ^bb1(%arg0: tensor), ^bb2(%arg0: tensor) + ^bb1(%bbarg0: tensor): + // expected-error @+1 {{failed to legalize operation 'test.terminator'}} + "test.terminator"() : () -> () + ^bb2(%bbarg1: tensor): + return %bbarg1 : tensor +} diff --git a/mlir/test/Dialect/Standard/func-bufferize.mlir b/mlir/test/Dialect/Standard/func-bufferize.mlir index 61c5e18..d02db99 100644 --- a/mlir/test/Dialect/Standard/func-bufferize.mlir +++ b/mlir/test/Dialect/Standard/func-bufferize.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -func-bufferize -split-input-file -verify-diagnostics | FileCheck %s +// RUN: mlir-opt %s -func-bufferize -finalizing-bufferize -split-input-file -verify-diagnostics | FileCheck %s // CHECK-LABEL: func @identity( // CHECK-SAME: %[[ARG:.*]]: memref) -> memref { -- 2.7.4