From 8d67d187ba1bdb201f83ce25725e9be59b0141a7 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Thu, 18 Jun 2020 15:45:43 -0700 Subject: [PATCH] [mlir][DialectConversion] Refactor how block argument types get converted This revision removes the TypeConverter parameter passed to the apply* methods, and instead moves the responsibility of region type conversion to patterns. The types of a region can be converted using the 'convertRegionTypes' method, which acts similarly to the existing 'applySignatureConversion'. This method ensures that all blocks within, and including those moved into, a region will have the block argument types converted using the provided converter. This has the benefit of making more of the legalization logic controlled by patterns, instead of being handled explicitly by the driver. It also opens up the possibility to support multiple type conversions at some point in the future. This revision also adds a new utility class `FailureOr` that provides a LogicalResult friendly facility for returning a failure or a valid result value. Differential Revision: https://reviews.llvm.org/D81681 --- mlir/docs/DialectConversion.md | 28 +- mlir/docs/Tutorials/Toy/Ch-6.md | 3 +- mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp | 2 +- mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp | 2 +- mlir/include/mlir/IR/PatternMatch.h | 31 +- mlir/include/mlir/Support/LogicalResult.h | 27 +- mlir/include/mlir/Transforms/BufferPlacement.h | 4 +- mlir/include/mlir/Transforms/DialectConversion.h | 97 +++-- .../AVX512ToLLVM/ConvertAVX512ToLLVM.cpp | 5 +- mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h | 5 +- .../Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 2 +- .../GPUToROCDL/LowerGpuOpsToROCDLOps.cpp | 2 +- .../Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp | 12 +- .../GPUToSPIRV/ConvertGPUToSPIRVPass.cpp | 4 +- mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp | 4 +- .../Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp | 6 +- .../SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp | 3 +- .../Conversion/ShapeToStandard/ShapeToStandard.cpp | 5 +- .../Conversion/StandardToLLVM/StandardToLLVM.cpp | 14 +- .../VectorToLLVM/ConvertVectorToLLVM.cpp | 5 +- .../lib/Conversion/VectorToROCDL/VectorToROCDL.cpp | 4 +- .../Dialect/Linalg/Transforms/TensorsToBuffers.cpp | 8 +- mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp | 4 +- .../SPIRV/Transforms/LowerABIAttributesPass.cpp | 8 +- mlir/lib/Transforms/DialectConversion.cpp | 430 +++++++++++++-------- mlir/test/Transforms/test-legalizer.mlir | 29 +- mlir/test/lib/Dialect/Test/TestPatterns.cpp | 25 +- mlir/test/lib/Transforms/TestBufferPlacement.cpp | 7 +- 28 files changed, 463 insertions(+), 313 deletions(-) diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md index 7995099..c717414 100644 --- a/mlir/docs/DialectConversion.md +++ b/mlir/docs/DialectConversion.md @@ -262,14 +262,21 @@ patterns used in dialect conversion. ### Region Signature Conversion -From the perspective of type conversion, the entry block to a region is often -special. The types of the entry block arguments are often tied semantically to -details on the operation, e.g. FuncOp, AffineForOp, etc. Given this, the -conversion of the types for this block must be done explicitly via a conversion -pattern. To convert the signature of a region entry block, a custom hook on the -ConversionPatternRewriter must be invoked `applySignatureConversion`. A -signature conversion, `TypeConverter::SignatureConversion`, can be built -programmatically: +From the perspective of type conversion, the types of block arguments are a bit +special. Throughout the conversion process, blocks may move between regions of +different operations. Given this, the conversion of the types for blocks must be +done explicitly via a conversion pattern. To convert the types of block +arguments within a Region, a custom hook on the `ConversionPatternRewriter` must +be invoked; `convertRegionTypes`. This hook uses a provided type converter to +apply type conversions to all blocks within the region, and all blocks that move +into that region. This hook also takes an optional +`TypeConverter::SignatureConversion` parameter that applies a custom conversion +to the entry block of the region. The types of the entry block arguments are +often tied semantically to details on the operation, e.g. FuncOp, AffineForOp, +etc. To convert the signature of just the region entry block, and not any other +blocks within the region, the `applySignatureConversion` hook may be used +instead. A signature conversion, `TypeConverter::SignatureConversion`, can be +built programmatically: ```c++ class SignatureConversion { @@ -293,5 +300,6 @@ public: }; ``` -The `TypeConverter` provides several default utilities for signature conversion: -`convertSignatureArg`/`convertBlockSignature`. +The `TypeConverter` provides several default utilities for signature conversion +and legality checking: +`convertSignatureArgs`/`convertBlockSignature`/`isLegal(Region *|Type)`. diff --git a/mlir/docs/Tutorials/Toy/Ch-6.md b/mlir/docs/Tutorials/Toy/Ch-6.md index 734eafb..06f5cd6 100644 --- a/mlir/docs/Tutorials/Toy/Ch-6.md +++ b/mlir/docs/Tutorials/Toy/Ch-6.md @@ -106,8 +106,7 @@ that only legal operations will remain after the conversion. ```c++ mlir::ModuleOp module = getOperation(); - if (mlir::failed(mlir::applyFullConversion(module, target, patterns, - &typeConverter))) + if (mlir::failed(mlir::applyFullConversion(module, target, patterns))) signalPassFailure(); ``` diff --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp index 43b4c10..af4130c 100644 --- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp @@ -203,7 +203,7 @@ void ToyToLLVMLoweringPass::runOnOperation() { // We want to completely lower to LLVM, so we use a `FullConversion`. This // ensures that only legal operations will remain after the conversion. auto module = getOperation(); - if (failed(applyFullConversion(module, target, patterns, &typeConverter))) + if (failed(applyFullConversion(module, target, patterns))) signalPassFailure(); } diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp index 43b4c10..af4130c 100644 --- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp @@ -203,7 +203,7 @@ void ToyToLLVMLoweringPass::runOnOperation() { // We want to completely lower to LLVM, so we use a `FullConversion`. This // ensures that only legal operations will remain after the conversion. auto module = getOperation(); - if (failed(applyFullConversion(module, target, patterns, &typeConverter))) + if (failed(applyFullConversion(module, target, patterns))) signalPassFailure(); } diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 0f0228a..f1c7c39 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -99,7 +99,7 @@ protected: /// pattern. Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context); - /// This contructor is used when a pattern may match against multiple + /// This constructor is used when a pattern may match against multiple /// different types of operations. The `benefit` is the expected benefit of /// matching this pattern. `MatchAnyOpTypeTag` is just a tag to ensure that /// the "match any" behavior is what the user actually desired, @@ -163,28 +163,27 @@ public: ArrayRef getGeneratedOps() const { return generatedOps; } protected: - /// Patterns must specify the root operation name they match against, and can - /// also specify the benefit of the pattern matching. + /// Construct a rewrite pattern with a certain benefit that matches the + /// operation with the given root name. RewritePattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context) : Pattern(rootName, benefit, context) {} - /// Patterns must specify the root operation name they match against, and can - /// also specify the benefit of the pattern matching. `MatchAnyOpTypeTag` - /// is just a tag to ensure that the "match any" behavior is what the user - /// actually desired, `MatchAnyOpTypeTag()` should always be supplied here. + /// Construct a rewrite pattern with a certain benefit that matches any + /// operation type. `MatchAnyOpTypeTag` is just a tag to ensure that the + /// "match any" behavior is what the user actually desired, + /// `MatchAnyOpTypeTag()` should always be supplied here. RewritePattern(PatternBenefit benefit, MatchAnyOpTypeTag tag) : Pattern(benefit, tag) {} - /// Patterns must specify the root operation name they match against, and can - /// also specify the benefit of the pattern matching. They can also specify - /// the names of operations that may be generated during a successful rewrite. + /// Construct a rewrite pattern with a certain benefit that matches the + /// operation with the given root name. `generatedNames` contains the names of + /// operations that may be generated during a successful rewrite. RewritePattern(StringRef rootName, ArrayRef generatedNames, PatternBenefit benefit, MLIRContext *context); - /// Patterns must specify the root operation name they match against, and can - /// also specify the benefit of the pattern matching. They can also specify - /// the names of operations that may be generated during a successful rewrite. - /// `MatchAnyOpTypeTag` is just a tag to ensure that the "match any" - /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should - /// always be supplied here. + /// Construct a rewrite pattern that may match any operation type. + /// `generatedNames` contains the names of operations that may be generated + /// during a successful rewrite. `MatchAnyOpTypeTag` is just a tag to ensure + /// that the "match any" behavior is what the user actually desired, + /// `MatchAnyOpTypeTag()` should always be supplied here. RewritePattern(ArrayRef generatedNames, PatternBenefit benefit, MLIRContext *context, MatchAnyOpTypeTag tag); diff --git a/mlir/include/mlir/Support/LogicalResult.h b/mlir/include/mlir/Support/LogicalResult.h index 216e056..3e30e9e 100644 --- a/mlir/include/mlir/Support/LogicalResult.h +++ b/mlir/include/mlir/Support/LogicalResult.h @@ -10,11 +10,12 @@ #define MLIR_SUPPORT_LOGICAL_RESULT_H #include "mlir/Support/LLVM.h" +#include "llvm/ADT/Optional.h" namespace mlir { -// Values that can be used to signal success/failure. This should be used in -// conjunction with the utility functions below. +/// Values that can be used to signal success/failure. This should be used in +/// conjunction with the utility functions below. struct LogicalResult { enum ResultEnum { Success, Failure } value; LogicalResult(ResultEnum v) : value(v) {} @@ -46,6 +47,28 @@ inline bool failed(LogicalResult result) { return result.value == LogicalResult::Failure; } +/// This class provides support for representing a failure result, or a valid +/// value of type `T`. This allows for integrating with LogicalResult, while +/// also providing a value on the success path. +template class LLVM_NODISCARD FailureOr : public Optional { +public: + /// Allow constructing from a LogicalResult. The result *must* be a failure. + /// Success results should use a proper instance of type `T`. + FailureOr(LogicalResult result) { + assert(failed(result) && + "success should be constructed with an instance of 'T'"); + } + FailureOr() : FailureOr(failure()) {} + FailureOr(T &&y) : Optional(std::forward(y)) {} + + operator LogicalResult() const { return success(this->hasValue()); } + +private: + /// Hide the bool conversion as it easily creates confusion. + using Optional::operator bool; + using Optional::hasValue; +}; + } // namespace mlir #endif // MLIR_SUPPORT_LOGICAL_RESULT_H diff --git a/mlir/include/mlir/Transforms/BufferPlacement.h b/mlir/include/mlir/Transforms/BufferPlacement.h index 547db48..f8559a9 100644 --- a/mlir/include/mlir/Transforms/BufferPlacement.h +++ b/mlir/include/mlir/Transforms/BufferPlacement.h @@ -141,12 +141,14 @@ public: else newResultTypes.push_back(convertedType); } + if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), *converter, + &conversion))) + return failure(); // Update the signature of the function. rewriter.updateRootInPlace(funcOp, [&] { funcOp.setType(rewriter.getFunctionType(conversion.getConvertedTypes(), newResultTypes)); - rewriter.applySignatureConversion(&funcOp.getBody(), conversion); }); return success(); } diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 2ce95b1..d862823 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -160,6 +160,9 @@ public: /// Return true if the given operation has legal operand and result types. bool isLegal(Operation *op); + /// Return true if the types of block arguments within the region are legal. + bool isLegal(Region *region); + /// Return true if the inputs and outputs of the given function type are /// legal. bool isSignatureLegal(FunctionType ty); @@ -268,16 +271,15 @@ private: // Conversion Patterns //===----------------------------------------------------------------------===// -/// Base class for the conversion patterns that require type changes. Specific -/// conversions must derive this class and implement least one `rewrite` method. -/// NOTE: These conversion patterns can only be used with the 'apply*' methods -/// below. +/// Base class for the conversion patterns. This pattern class enables type +/// conversions, and other uses specific to the conversion framework. As such, +/// patterns of this type can only be used with the 'apply*' methods below. class ConversionPattern : public RewritePattern { public: /// Hook for derived classes to implement rewriting. `op` is the (first) - /// operation matched by the pattern, `operands` is a list of rewritten values - /// that are passed to this operation, `rewriter` can be used to emit the new - /// operations. This function should not fail. If some specific cases of + /// operation matched by the pattern, `operands` is a list of the rewritten + /// operand values that are passed to `op`, `rewriter` can be used to emit the + /// new operations. This function should not fail. If some specific cases of /// the operation are not supported, these cases should not be matched. virtual void rewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { @@ -298,8 +300,32 @@ public: LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final; + /// Return the type converter held by this pattern, or nullptr if the pattern + /// does not require type conversion. + TypeConverter *getTypeConverter() const { return typeConverter; } + protected: + /// See `RewritePattern::RewritePattern` for information on the other + /// available constructors. using RewritePattern::RewritePattern; + /// Construct a conversion pattern that matches an operation with the given + /// root name. This constructor allows for providing a type converter to use + /// within the pattern. + ConversionPattern(StringRef rootName, PatternBenefit benefit, + TypeConverter &typeConverter, MLIRContext *ctx) + : RewritePattern(rootName, benefit, ctx), typeConverter(&typeConverter) {} + /// Construct a conversion pattern that matches any operation type. This + /// constructor allows for providing a type converter to use within the + /// pattern. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any" + /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should + /// always be supplied here. + ConversionPattern(PatternBenefit benefit, TypeConverter &typeConverter, + MatchAnyOpTypeTag tag) + : RewritePattern(benefit, tag), typeConverter(&typeConverter) {} + +protected: + /// An optional type converter for use by this pattern. + TypeConverter *typeConverter; private: using RewritePattern::rewrite; @@ -312,6 +338,10 @@ template struct OpConversionPattern : public ConversionPattern { OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1) : ConversionPattern(SourceOp::getOperationName(), benefit, context) {} + OpConversionPattern(TypeConverter &typeConverter, MLIRContext *context, + PatternBenefit benefit = 1) + : ConversionPattern(SourceOp::getOperationName(), benefit, typeConverter, + context) {} /// Wrappers around the ConversionPattern methods that pass the derived op /// type. @@ -367,7 +397,7 @@ struct ConversionPatternRewriterImpl; /// hooks. class ConversionPatternRewriter final : public PatternRewriter { public: - ConversionPatternRewriter(MLIRContext *ctx, TypeConverter *converter); + ConversionPatternRewriter(MLIRContext *ctx); ~ConversionPatternRewriter() override; /// Apply a signature conversion to the entry block of the given region. This @@ -377,6 +407,15 @@ public: applySignatureConversion(Region *region, TypeConverter::SignatureConversion &conversion); + /// Convert the types of block arguments within the given region. This + /// replaces each block with a new block containing the updated signature. The + /// entry block may have a special conversion if `entryConversion` is + /// provided. On success, the new entry block to the region is returned for + /// convenience. Otherwise, failure is returned. + FailureOr convertRegionTypes( + Region *region, TypeConverter &converter, + TypeConverter::SignatureConversion *entryConversion = nullptr); + /// Replace all the uses of the block argument `from` with value `to`. void replaceUsesOfBlockArgument(BlockArgument from, Value to); @@ -721,36 +760,30 @@ private: /// Apply a partial conversion on the given operations and all nested /// operations. This method converts as many operations to the target as /// possible, ignoring operations that failed to legalize. This method only -/// returns failure if there ops explicitly marked as illegal. If `converter` is -/// provided, the signatures of blocks and regions are also converted. -/// If an `unconvertedOps` set is provided, all operations that are found not -/// to be legalizable to the given `target` are placed within that set. (Note -/// that if there is an op explicitly marked as illegal, the conversion -/// terminates and the `unconvertedOps` set will not necessarily be complete.) +/// returns failure if there ops explicitly marked as illegal. If an +/// `unconvertedOps` set is provided, all operations that are found not to be +/// legalizable to the given `target` are placed within that set. (Note that if +/// there is an op explicitly marked as illegal, the conversion terminates and +/// the `unconvertedOps` set will not necessarily be complete.) LLVM_NODISCARD LogicalResult applyPartialConversion(ArrayRef ops, ConversionTarget &target, const OwningRewritePatternList &patterns, - TypeConverter *converter = nullptr, DenseSet *unconvertedOps = nullptr); LLVM_NODISCARD LogicalResult applyPartialConversion(Operation *op, ConversionTarget &target, const OwningRewritePatternList &patterns, - TypeConverter *converter = nullptr, DenseSet *unconvertedOps = nullptr); /// Apply a complete conversion on the given operations, and all nested /// operations. This method returns failure if the conversion of any operation /// fails, or if there are unreachable blocks in any of the regions nested -/// within 'ops'. If 'converter' is provided, the signatures of blocks and -/// regions are also converted. +/// within 'ops'. LLVM_NODISCARD LogicalResult applyFullConversion(ArrayRef ops, ConversionTarget &target, - const OwningRewritePatternList &patterns, - TypeConverter *converter = nullptr); + const OwningRewritePatternList &patterns); LLVM_NODISCARD LogicalResult applyFullConversion(Operation *op, ConversionTarget &target, - const OwningRewritePatternList &patterns, - TypeConverter *converter = nullptr); + const OwningRewritePatternList &patterns); /// Apply an analysis conversion on the given operations, and all nested /// operations. This method analyzes which operations would be successfully @@ -759,17 +792,15 @@ applyFullConversion(Operation *op, ConversionTarget &target, /// provided 'convertedOps' set; note that no actual rewrites are applied to the /// operations on success and only pre-existing operations are added to the set. /// This method only returns failure if there are unreachable blocks in any of -/// the regions nested within 'ops', or if a type conversion failed. If -/// 'converter' is provided, the signatures of blocks and regions are also -/// considered for conversion. -LLVM_NODISCARD LogicalResult applyAnalysisConversion( - ArrayRef ops, ConversionTarget &target, - const OwningRewritePatternList &patterns, - DenseSet &convertedOps, TypeConverter *converter = nullptr); -LLVM_NODISCARD LogicalResult applyAnalysisConversion( - Operation *op, ConversionTarget &target, - const OwningRewritePatternList &patterns, - DenseSet &convertedOps, TypeConverter *converter = nullptr); +/// the regions nested within 'ops'. +LLVM_NODISCARD LogicalResult +applyAnalysisConversion(ArrayRef ops, ConversionTarget &target, + const OwningRewritePatternList &patterns, + DenseSet &convertedOps); +LLVM_NODISCARD LogicalResult +applyAnalysisConversion(Operation *op, ConversionTarget &target, + const OwningRewritePatternList &patterns, + DenseSet &convertedOps); } // end namespace mlir #endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_ diff --git a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp index 0753e38..b65118b 100644 --- a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp +++ b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp @@ -179,10 +179,7 @@ void ConvertAVX512ToLLVMPass::runOnOperation() { target.addLegalDialect(); target.addLegalDialect(); target.addIllegalDialect(); - target.addDynamicallyLegalOp( - [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); - if (failed(applyPartialConversion(getOperation(), target, patterns, - &converter))) { + if (failed(applyPartialConversion(getOperation(), target, patterns))) { signalPassFailure(); } } diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h index 4a1fe1a..f6aede4 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h @@ -145,8 +145,9 @@ struct GPUFuncOpLowering : ConvertToLLVMPattern { // Move the region to the new function, update the entry block signature. rewriter.inlineRegionBefore(gpuFuncOp.getBody(), llvmFuncOp.getBody(), llvmFuncOp.end()); - rewriter.applySignatureConversion(&llvmFuncOp.getBody(), - signatureConversion); + if (failed(rewriter.convertRegionTypes(&llvmFuncOp.getBody(), typeConverter, + &signatureConversion))) + return failure(); rewriter.eraseOp(gpuFuncOp); return success(); diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index 0fe767a..e4fabe4 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -133,7 +133,7 @@ public: target.addLegalDialect(); // TODO(csigg): Remove once we support replacing non-root ops. target.addLegalOp(); - if (failed(applyPartialConversion(m, target, patterns, &converter))) + if (failed(applyPartialConversion(m, target, patterns))) signalPassFailure(); } }; diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp index 5707075..2381d61 100644 --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -67,7 +67,7 @@ public: target.addLegalDialect(); // TODO(whchung): Remove once we support replacing non-root ops. target.addLegalOp(); - if (failed(applyPartialConversion(m, target, patterns, &converter))) + if (failed(applyPartialConversion(m, target, patterns))) signalPassFailure(); } }; diff --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp index 322d08a..2b4829a 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp @@ -164,8 +164,11 @@ ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef operands, TypeConverter::SignatureConversion signatureConverter( body->getNumArguments()); signatureConverter.remapInput(0, newIndVar); - body = rewriter.applySignatureConversion(&forOp.getLoopBody(), - signatureConverter); + FailureOr newBody = rewriter.convertRegionTypes( + &forOp.getLoopBody(), typeConverter, &signatureConverter); + if (failed(newBody)) + return failure(); + body = *newBody; // Delete the loop terminator. rewriter.eraseOp(body->getTerminator()); @@ -356,9 +359,12 @@ lowerAsEntryFunction(gpu::GPUFuncOp funcOp, SPIRVTypeConverter &typeConverter, continue; newFuncOp.setAttr(namedAttr.first, namedAttr.second); } + rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), newFuncOp.end()); - rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter); + if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter, + &signatureConverter))) + return nullptr; rewriter.eraseOp(funcOp); spirv::setABIAttrs(newFuncOp, entryPointInfo, argABIInfo); diff --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp index 3147eed..1f486b9 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp @@ -61,10 +61,8 @@ void GPUToSPIRVPass::runOnOperation() { populateGPUToSPIRVPatterns(context, typeConverter, patterns); populateStandardToSPIRVPatterns(context, typeConverter, patterns); - if (failed(applyFullConversion(kernelModules, *target, patterns, - &typeConverter))) { + if (failed(applyFullConversion(kernelModules, *target, patterns))) return signalPassFailure(); - } } std::unique_ptr> mlir::createConvertGPUToSPIRVPass() { diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp index f603510..b92ab13 100644 --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -383,10 +383,8 @@ void ConvertLinalgToLLVMPass::runOnOperation() { populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext()); LLVMConversionTarget target(getContext()); - target.addDynamicallyLegalOp( - [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); target.addLegalOp(); - if (failed(applyFullConversion(module, target, patterns, &converter))) + if (failed(applyFullConversion(module, target, patterns))) signalPassFailure(); } diff --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp index d81e269..cc938c8 100644 --- a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp +++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp @@ -36,8 +36,10 @@ void LinalgToSPIRVPass::runOnOperation() { // Allow builtin ops. target->addLegalOp(); - target->addDynamicallyLegalOp( - [&](FuncOp op) { return typeConverter.isSignatureLegal(op.getType()); }); + target->addDynamicallyLegalOp([&](FuncOp op) { + return typeConverter.isSignatureLegal(op.getType()) && + typeConverter.isLegal(&op.getBody()); + }); if (failed(applyFullConversion(module, *target, patterns))) return signalPassFailure(); diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp index c8e2d73..8f30054 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp @@ -44,8 +44,7 @@ void ConvertSPIRVToLLVMPass::runOnOperation() { ConversionTarget target(getContext()); target.addIllegalDialect(); target.addLegalDialect(); - - if (failed(applyPartialConversion(module, target, patterns, &converter))) + if (failed(applyPartialConversion(module, target, patterns))) signalPassFailure(); } diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp index e774114..d02f5e3 100644 --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -82,7 +82,8 @@ class ConvertShapeToStandardPass target.addLegalDialect(); target.addLegalOp(); target.addDynamicallyLegalOp([&](FuncOp op) { - return typeConverter.isSignatureLegal(op.getType()); + return typeConverter.isSignatureLegal(op.getType()) && + typeConverter.isLegal(&op.getBody()); }); // Setup conversion patterns. @@ -92,7 +93,7 @@ class ConvertShapeToStandardPass // Apply conversion. auto module = getOperation(); - if (failed(applyFullConversion(module, target, patterns, &typeConverter))) + if (failed(applyFullConversion(module, target, patterns))) signalPassFailure(); } }; diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp index a316f2e..19c451f 100644 --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -398,7 +398,7 @@ ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context, LLVMTypeConverter &typeConverter_, PatternBenefit benefit) - : ConversionPattern(rootOpName, benefit, context), + : ConversionPattern(rootOpName, benefit, typeConverter_, context), typeConverter(typeConverter_) {} /*============================================================================*/ @@ -1038,8 +1038,9 @@ protected: attributes); rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), newFuncOp.end()); - // Tell the rewriter to convert the region signature. - rewriter.applySignatureConversion(&newFuncOp.getBody(), result); + if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter, + &result))) + return nullptr; return newFuncOp; } @@ -1059,6 +1060,9 @@ struct FuncOpConversion : public FuncOpConversionBase { auto funcOp = cast(op); auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter); + if (!newFuncOp) + return failure(); + if (emitWrappers || funcOp.getAttrOfType(kEmitIfaceAttrName)) { if (newFuncOp.isExternal()) wrapExternalFunction(rewriter, op->getLoc(), typeConverter, funcOp, @@ -1095,6 +1099,8 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase { getMemRefArgIndicesAndTypes(funcOp.getType(), promotedArgsInfo); auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter); + if (!newFuncOp) + return failure(); if (newFuncOp.getBody().empty()) { rewriter.eraseOp(op); return success(); @@ -3172,7 +3178,7 @@ struct LLVMLoweringPass : public ConvertStandardToLLVMBase { emitCWrappers, useAlignedAlloc); LLVMConversionTarget target(getContext()); - if (failed(applyPartialConversion(m, target, patterns, &typeConverter))) + if (failed(applyPartialConversion(m, target, patterns))) signalPassFailure(); } }; diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 4d9734a..bd9ec93 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1195,10 +1195,7 @@ void LowerVectorToLLVMPass::runOnOperation() { populateStdToLLVMConversionPatterns(converter, patterns); LLVMConversionTarget target(getContext()); - target.addDynamicallyLegalOp( - [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); - if (failed(applyPartialConversion(getOperation(), target, patterns, - &converter))) { + if (failed(applyPartialConversion(getOperation(), target, patterns))) { signalPassFailure(); } } diff --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp index 665a32c..37e314f2 100644 --- a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp +++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp @@ -173,10 +173,8 @@ void LowerVectorToROCDLPass::runOnOperation() { LLVMConversionTarget target(getContext()); target.addLegalDialect(); - if (failed(applyPartialConversion(getOperation(), target, patterns, - &converter))) { + if (failed(applyPartialConversion(getOperation(), target, patterns))) signalPassFailure(); - } } std::unique_ptr> diff --git a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp index 2357062..afd94cc 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp @@ -136,19 +136,19 @@ struct ConvertLinalgOnTensorsToBuffers target.addDynamicallyLegalOp([&](FuncOp funcOp) { return converter.isSignatureLegal(funcOp.getType()) && llvm::none_of(funcOp.getType().getResults(), - [&](Type type) { return type.isa(); }); + [&](Type type) { return type.isa(); }) && + converter.isLegal(&funcOp.getBody()); }); // Walk over all the functions to apply buffer assignment. - getOperation().walk([&](FuncOp function) { + getOperation().walk([&](FuncOp function) -> WalkResult { OwningRewritePatternList patterns; BufferAssignmentPlacer placer(function); populateConvertLinalgOnTensorsToBuffersPattern(&context, &placer, &converter, &patterns); // Applying full conversion - return WalkResult( - applyFullConversion(function, target, patterns, &converter)); + return applyFullConversion(function, target, patterns); }); } }; diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp index 7df2be9..6bb07b2 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp @@ -489,7 +489,9 @@ FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef operands, rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), newFuncOp.end()); - rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter); + if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter, + &signatureConverter))) + return failure(); rewriter.eraseOp(funcOp); return success(); } diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp index 139b6bc..5bd425a 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -201,12 +201,14 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite( } signatureConverter.remapInput(argType.index(), replacement); } + if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), typeConverter, + &signatureConverter))) + return failure(); // Creates a new function with the update signature. rewriter.updateRootInPlace(funcOp, [&] { funcOp.setType(rewriter.getFunctionType( signatureConverter.getConvertedTypes(), llvm::None)); - rewriter.applySignatureConversion(&funcOp.getBody(), signatureConverter); }); return success(); } @@ -237,10 +239,8 @@ void LowerABIAttributesPass::runOnOperation() { return op->getDialect()->getNamespace() == spirv::SPIRVDialect::getDialectNamespace(); }); - if (failed( - applyPartialConversion(module, target, patterns, &typeConverter))) { + if (failed(applyPartialConversion(module, target, patterns))) return signalPassFailure(); - } // Walks over all the FuncOps in spirv::ModuleOp to lower the entry point // attributes. diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index b065247..ecebe61 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -98,7 +98,7 @@ static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, } //===----------------------------------------------------------------------===// -// Multi-Level Value Mapper +// ConversionValueMapping //===----------------------------------------------------------------------===// namespace { @@ -140,9 +140,7 @@ namespace { /// types and extracting the block that contains the old illegal types to allow /// for undoing pending rewrites in the case of failure. struct ArgConverter { - ArgConverter(TypeConverter *typeConverter, PatternRewriter &rewriter) - : loc(rewriter.getUnknownLoc()), typeConverter(typeConverter), - rewriter(rewriter) {} + ArgConverter(PatternRewriter &rewriter) : rewriter(rewriter) {} /// This structure contains the information pertaining to an argument that has /// been converted. @@ -166,7 +164,8 @@ struct ArgConverter { /// This structure contains information pertaining to a block that has had its /// signature converted. struct ConvertedBlockInfo { - ConvertedBlockInfo(Block *origBlock) : origBlock(origBlock) {} + ConvertedBlockInfo(Block *origBlock, TypeConverter &converter) + : origBlock(origBlock), converter(&converter) {} /// The original block that was requested to have its signature converted. Block *origBlock; @@ -174,11 +173,26 @@ struct ArgConverter { /// The conversion information for each of the arguments. The information is /// None if the argument was dropped during conversion. SmallVector, 1> argInfo; + + /// The type converter used to convert the arguments. + TypeConverter *converter; }; /// Return if the signature of the given block has already been converted. bool hasBeenConverted(Block *block) const { - return conversionInfo.count(block); + return conversionInfo.count(block) || convertedBlocks.count(block); + } + + /// Set the type converter to use for the given region. + void setConverter(Region *region, TypeConverter *typeConverter) { + assert(typeConverter && "expected valid type converter"); + regionToConverter[region] = typeConverter; + } + + /// Return the type converter to use for the given region, or null if there + /// isn't one. + TypeConverter *getConverter(Region *region) { + return regionToConverter.lookup(region); } //===--------------------------------------------------------------------===// @@ -204,32 +218,39 @@ struct ArgConverter { //===--------------------------------------------------------------------===// /// Attempt to convert the signature of the given block, if successful a new - /// block is returned containing the new arguments. On failure, nullptr is - /// returned. - Block *convertSignature(Block *block, ConversionValueMapping &mapping); + /// block is returned containing the new arguments. Returns `block` if it did + /// not require conversion. + FailureOr convertSignature(Block *block, TypeConverter &converter, + ConversionValueMapping &mapping); /// Apply the given signature conversion on the given block. The new block - /// containing the updated signature is returned. + /// containing the updated signature is returned. If no conversions were + /// necessary, e.g. if the block has no arguments, `block` is returned. + /// `converter` is used to generate any necessary cast operations that + /// translate between the origin argument types and those specified in the + /// signature conversion. Block *applySignatureConversion( - Block *block, TypeConverter::SignatureConversion &signatureConversion, + Block *block, TypeConverter &converter, + TypeConverter::SignatureConversion &signatureConversion, ConversionValueMapping &mapping); /// Insert a new conversion into the cache. void insertConversion(Block *newBlock, ConvertedBlockInfo &&info); - /// A collection of blocks that have had their arguments converted. + /// A collection of blocks that have had their arguments converted. This is a + /// map from the new replacement block, back to the original block. llvm::MapVector conversionInfo; + /// The set of original blocks that were converted. + DenseSet convertedBlocks; + /// A mapping from valid regions, to those containing the original blocks of a /// conversion. DenseMap> regionMapping; - /// An instance of the unknown location that is used when materializing - /// conversions. - Location loc; - - /// The type converter to use when changing types. - TypeConverter *typeConverter; + /// A mapping of regions to type converters that should be used when + /// converting the arguments of blocks within that region. + DenseMap regionToConverter; /// The pattern rewriter to use when materializing conversions. PatternRewriter &rewriter; @@ -240,6 +261,9 @@ struct ArgConverter { // Rewrite Application void ArgConverter::notifyOpRemoved(Operation *op) { + if (conversionInfo.empty()) + return; + for (Region ®ion : op->getRegions()) { for (Block &block : region) { // Drop any rewrites from within. @@ -277,6 +301,7 @@ void ArgConverter::discardRewrites(Block *block) { origBlock->moveBefore(block); block->erase(); + convertedBlocks.erase(origBlock); conversionInfo.erase(it); } @@ -305,8 +330,8 @@ void ArgConverter::applyRewrites(ConversionValueMapping &mapping) { // persist in the IR after conversion. if (!origArg.use_empty()) { rewriter.setInsertionPointToStart(newBlock); - Value newArg = typeConverter->materializeConversion( - rewriter, loc, origArg.getType(), llvm::None); + Value newArg = blockInfo.converter->materializeConversion( + rewriter, origArg.getLoc(), origArg.getType(), llvm::None); assert(newArg && "Couldn't materialize a block argument after 1->0 conversion"); origArg.replaceAllUsesWith(newArg); @@ -333,15 +358,23 @@ void ArgConverter::applyRewrites(ConversionValueMapping &mapping) { //===----------------------------------------------------------------------===// // Conversion -Block *ArgConverter::convertSignature(Block *block, - ConversionValueMapping &mapping) { - if (auto conversion = typeConverter->convertBlockSignature(block)) - return applySignatureConversion(block, *conversion, mapping); - return nullptr; +FailureOr +ArgConverter::convertSignature(Block *block, TypeConverter &converter, + ConversionValueMapping &mapping) { + // Check if the block was already converted. If the block is detached, + // conservatively assume it is going to be deleted. + if (hasBeenConverted(block) || !block->getParent()) + return block; + + // Try to convert the signature for the block with the provided converter. + if (auto conversion = converter.convertBlockSignature(block)) + return applySignatureConversion(block, converter, *conversion, mapping); + return failure(); } Block *ArgConverter::applySignatureConversion( - Block *block, TypeConverter::SignatureConversion &signatureConversion, + Block *block, TypeConverter &converter, + TypeConverter::SignatureConversion &signatureConversion, ConversionValueMapping &mapping) { // If no arguments are being changed or added, there is nothing to do. unsigned origArgCount = block->getNumArguments(); @@ -359,7 +392,7 @@ Block *ArgConverter::applySignatureConversion( // Remap each of the original arguments as determined by the signature // conversion. - ConvertedBlockInfo info(block); + ConvertedBlockInfo info(block, converter); info.argInfo.resize(origArgCount); OpBuilder::InsertionGuard guard(rewriter); @@ -384,10 +417,8 @@ Block *ArgConverter::applySignatureConversion( // to pack the new values. For 1->1 mappings, if there is no materialization // provided, use the argument directly instead. auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size); - Value newArg; - if (typeConverter) - newArg = typeConverter->materializeConversion( - rewriter, loc, origArg.getType(), replArgs); + Value newArg = converter.materializeConversion(rewriter, origArg.getLoc(), + origArg.getType(), replArgs); if (!newArg) { assert(replArgs.size() == 1 && "couldn't materialize the result of 1->N conversion"); @@ -414,6 +445,7 @@ void ArgConverter::insertConversion(Block *newBlock, // Move the original block to the mapped region and emplace the conversion. mappedRegion->getBlocks().splice(mappedRegion->end(), region->getBlocks(), info.origBlock->getIterator()); + convertedBlocks.insert(info.origBlock); conversionInfo.insert({newBlock, std::move(info)}); } @@ -548,9 +580,8 @@ struct ConversionPatternRewriterImpl { }; }; - ConversionPatternRewriterImpl(PatternRewriter &rewriter, - TypeConverter *converter) - : argConverter(converter, rewriter) {} + ConversionPatternRewriterImpl(PatternRewriter &rewriter) + : argConverter(rewriter) {} /// Return the current state of the rewriter. RewriterState getCurrentState(); @@ -575,13 +606,20 @@ struct ConversionPatternRewriterImpl { void applyRewrites(); /// Convert the signature of the given block. - LogicalResult convertBlockSignature(Block *block); + FailureOr convertBlockSignature( + Block *block, TypeConverter &converter, + TypeConverter::SignatureConversion *conversion = nullptr); /// Apply a signature conversion on the given region. Block * applySignatureConversion(Region *region, TypeConverter::SignatureConversion &conversion); + /// Convert the types of block arguments within the given region. + FailureOr + convertRegionTypes(Region *region, TypeConverter &converter, + TypeConverter::SignatureConversion *entryConversion); + /// PatternRewriter hook for replacing the results of an operation. void replaceOp(Operation *op, ValueRange newValues); @@ -654,6 +692,10 @@ struct ConversionPatternRewriterImpl { /// A logger used to emit diagnostics during the conversion process. llvm::ScopedPrinter logger{llvm::dbgs()}; #endif + + /// A default type converter, used when block conversions do not have one + /// explicitly provided. + TypeConverter defaultTypeConverter; }; } // end namespace detail } // end namespace mlir @@ -791,7 +833,7 @@ void ConversionPatternRewriterImpl::applyRewrites() { // If this operation defines any regions, drop any pending argument // rewrites. - if (argConverter.typeConverter && repl.op->getNumRegions()) + if (repl.op->getNumRegions()) argConverter.notifyOpRemoved(repl.op); } @@ -826,34 +868,45 @@ void ConversionPatternRewriterImpl::applyRewrites() { eraseDanglingBlocks(); } -LogicalResult -ConversionPatternRewriterImpl::convertBlockSignature(Block *block) { - // Check to see if this block should not be converted: - // * There is no type converter. - // * The block has already been converted. - // * This is an entry block, these are converted explicitly via patterns. - if (!argConverter.typeConverter || argConverter.hasBeenConverted(block) || - !block->getParent() || block->isEntryBlock()) - return success(); - - // Otherwise, try to convert the block signature. - Block *newBlock = argConverter.convertSignature(block, mapping); - if (newBlock) - blockActions.push_back(BlockAction::getTypeConversion(newBlock)); - return success(newBlock); +FailureOr ConversionPatternRewriterImpl::convertBlockSignature( + Block *block, TypeConverter &converter, + TypeConverter::SignatureConversion *conversion) { + FailureOr result = + conversion ? argConverter.applySignatureConversion(block, converter, + *conversion, mapping) + : argConverter.convertSignature(block, converter, mapping); + if (Block *newBlock = result.getValue()) { + if (newBlock != block) + blockActions.push_back(BlockAction::getTypeConversion(newBlock)); + } + return result; } Block *ConversionPatternRewriterImpl::applySignatureConversion( Region *region, TypeConverter::SignatureConversion &conversion) { if (!region->empty()) { - Block *newEntry = argConverter.applySignatureConversion( - ®ion->front(), conversion, mapping); - blockActions.push_back(BlockAction::getTypeConversion(newEntry)); - return newEntry; + return *convertBlockSignature(®ion->front(), defaultTypeConverter, + &conversion); } return nullptr; } +FailureOr ConversionPatternRewriterImpl::convertRegionTypes( + Region *region, TypeConverter &converter, + TypeConverter::SignatureConversion *entryConversion) { + argConverter.setConverter(region, &converter); + if (region->empty()) + return nullptr; + + // Convert the arguments of each block within the region. + FailureOr newEntry = + convertBlockSignature(®ion->front(), converter, entryConversion); + for (Block &block : llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) + if (failed(convertBlockSignature(&block, converter))) + return failure(); + return newEntry; +} + void ConversionPatternRewriterImpl::replaceOp(Operation *op, ValueRange newValues) { assert(newValues.size() == op->getNumResults()); @@ -938,10 +991,9 @@ void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) { // ConversionPatternRewriter //===----------------------------------------------------------------------===// -ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx, - TypeConverter *converter) +ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx) : PatternRewriter(ctx), - impl(new detail::ConversionPatternRewriterImpl(*this, converter)) {} + impl(new detail::ConversionPatternRewriterImpl(*this)) {} ConversionPatternRewriter::~ConversionPatternRewriter() {} /// PatternRewriter hook for replacing the results of an operation. @@ -979,12 +1031,17 @@ void ConversionPatternRewriter::eraseBlock(Block *block) { block->getParent()->getBlocks().remove(block); } -/// Apply a signature conversion to the entry block of the given region. Block *ConversionPatternRewriter::applySignatureConversion( Region *region, TypeConverter::SignatureConversion &conversion) { return impl->applySignatureConversion(region, conversion); } +FailureOr ConversionPatternRewriter::convertRegionTypes( + Region *region, TypeConverter &converter, + TypeConverter::SignatureConversion *entryConversion) { + return impl->convertRegionTypes(region, converter, entryConversion); +} + void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, Value to) { LLVM_DEBUG({ @@ -1163,6 +1220,20 @@ private: ConversionPatternRewriter &rewriter, RewriterState &curState); + /// Legalizes the actions registered during the execution of a pattern. + LogicalResult legalizePatternBlockActions(Operation *op, + ConversionPatternRewriter &rewriter, + ConversionPatternRewriterImpl &impl, + RewriterState &state, + RewriterState &newState); + LogicalResult legalizePatternCreatedOperations( + ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl, + RewriterState &state, RewriterState &newState); + LogicalResult legalizePatternRootUpdates(ConversionPatternRewriter &rewriter, + ConversionPatternRewriterImpl &impl, + RewriterState &state, + RewriterState &newState); + /// Build an optimistic legalization graph given the provided patterns. This /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with /// patterns for operations that are not directly legal, but may be @@ -1402,50 +1473,29 @@ bool OperationLegalizer::canApplyPattern(Operation *op, LogicalResult OperationLegalizer::legalizePatternResult( Operation *op, const RewritePattern &pattern, ConversionPatternRewriter &rewriter, RewriterState &curState) { - auto &rewriterImpl = rewriter.getImpl(); + auto &impl = rewriter.getImpl(); #ifndef NDEBUG - assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates"); + assert(impl.pendingRootUpdates.empty() && "dangling root updates"); #endif - // If the pattern moved or created any blocks, try to legalize their types. - // This ensures that the types of the block arguments are legal for the region - // they were moved into. - for (unsigned i = curState.numBlockActions, - e = rewriterImpl.blockActions.size(); - i != e; ++i) { - auto &action = rewriterImpl.blockActions[i]; - if (action.kind == - ConversionPatternRewriterImpl::BlockActionKind::TypeConversion || - action.kind == ConversionPatternRewriterImpl::BlockActionKind::Erase) - continue; - - // Convert the block signature. - if (failed(rewriterImpl.convertBlockSignature(action.block))) { - LLVM_DEBUG(logFailure(rewriterImpl.logger, - "failed to convert types of moved block")); - return failure(); - } - } - // Check all of the replacements to ensure that the pattern actually replaced // the root operation. We also mark any other replaced ops as 'dead' so that // we don't try to legalize them later. bool replacedRoot = false; - for (unsigned i = curState.numReplacements, - e = rewriterImpl.replacements.size(); + for (unsigned i = curState.numReplacements, e = impl.replacements.size(); i != e; ++i) { - Operation *replacedOp = rewriterImpl.replacements[i].op; + Operation *replacedOp = impl.replacements[i].op; if (replacedOp == op) replacedRoot = true; else - rewriterImpl.ignoredOps.insert(replacedOp); + impl.ignoredOps.insert(replacedOp); } // Check that the root was either updated or replace. auto updatedRootInPlace = [&] { return llvm::any_of( - llvm::drop_begin(rewriterImpl.rootUpdates, curState.numRootUpdates), + llvm::drop_begin(impl.rootUpdates, curState.numRootUpdates), [op](auto &state) { return state.getOperation() == op; }); }; (void)replacedRoot; @@ -1453,32 +1503,99 @@ LogicalResult OperationLegalizer::legalizePatternResult( assert((replacedRoot || updatedRootInPlace()) && "expected pattern to replace the root operation"); - // Recursively legalize each of the operations updated in place. - for (unsigned i = curState.numRootUpdates, - e = rewriterImpl.rootUpdates.size(); - i != e; ++i) { - auto &state = rewriterImpl.rootUpdates[i]; - if (failed(legalize(state.getOperation(), rewriter))) { - LLVM_DEBUG(logFailure(rewriterImpl.logger, - "operation updated in-place '{0}' was illegal", - op->getName())); + // Legalize each of the actions registered during application. + RewriterState newState = impl.getCurrentState(); + if (failed(legalizePatternBlockActions(op, rewriter, impl, curState, + newState)) || + failed(legalizePatternRootUpdates(rewriter, impl, curState, newState)) || + failed(legalizePatternCreatedOperations(rewriter, impl, curState, + newState))) { + return failure(); + } + + LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully")); + return success(); +} + +LogicalResult OperationLegalizer::legalizePatternBlockActions( + Operation *op, ConversionPatternRewriter &rewriter, + ConversionPatternRewriterImpl &impl, RewriterState &state, + RewriterState &newState) { + SmallPtrSet operationsToIgnore; + + // If the pattern moved or created any blocks, make sure the types of block + // arguments get legalized. + for (int i = state.numBlockActions, e = newState.numBlockActions; i != e; + ++i) { + auto &action = impl.blockActions[i]; + if (action.kind == + ConversionPatternRewriterImpl::BlockActionKind::TypeConversion || + action.kind == ConversionPatternRewriterImpl::BlockActionKind::Erase) + continue; + // Only check blocks outside of the current operation. + Operation *parentOp = action.block->getParentOp(); + if (!parentOp || parentOp == op || action.block->getNumArguments() == 0) + continue; + + // If the region of the block has a type converter, try to convert the block + // directly. + if (auto *converter = + impl.argConverter.getConverter(action.block->getParent())) { + if (failed(impl.convertBlockSignature(action.block, *converter))) { + LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved " + "block")); + return failure(); + } + continue; + } + + // Otherwise, check that this operation isn't one generated by this pattern. + // This is because we will attempt to legalize the parent operation, and + // blocks in regions created by this pattern will already be legalized later + // on. If we haven't built the set yet, build it now. + if (operationsToIgnore.empty()) { + auto createdOps = ArrayRef(impl.createdOps) + .drop_front(state.numCreatedOps); + operationsToIgnore.insert(createdOps.begin(), createdOps.end()); + } + + // If this operation should be considered for re-legalization, try it. + if (operationsToIgnore.insert(parentOp).second && + failed(legalize(parentOp, rewriter))) { + LLVM_DEBUG(logFailure( + impl.logger, "operation '{0}'({1}) became illegal after block action", + parentOp->getName(), parentOp)); return failure(); } } - - // Recursively legalize each of the new operations. - for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size(); - i != e; ++i) { - Operation *op = rewriterImpl.createdOps[i]; + return success(); +} +LogicalResult OperationLegalizer::legalizePatternCreatedOperations( + ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl, + RewriterState &state, RewriterState &newState) { + for (int i = state.numCreatedOps, e = newState.numCreatedOps; i != e; ++i) { + Operation *op = impl.createdOps[i]; if (failed(legalize(op, rewriter))) { - LLVM_DEBUG(logFailure(rewriterImpl.logger, + LLVM_DEBUG(logFailure(impl.logger, "generated operation '{0}'({1}) was illegal", op->getName(), op)); return failure(); } } - - LLVM_DEBUG(logSuccess(rewriterImpl.logger, "pattern applied successfully")); + return success(); +} +LogicalResult OperationLegalizer::legalizePatternRootUpdates( + ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl, + RewriterState &state, RewriterState &newState) { + for (int i = state.numRootUpdates, e = newState.numRootUpdates; i != e; ++i) { + Operation *op = impl.rootUpdates[i].getOperation(); + if (failed(legalize(op, rewriter))) { + LLVM_DEBUG(logFailure(impl.logger, + "operation updated in-place '{0}' was illegal", + op->getName())); + return failure(); + } + } return success(); } @@ -1699,17 +1816,12 @@ struct OperationConverter { : opLegalizer(target, patterns), mode(mode), trackedOps(trackedOps) {} /// Converts the given operations to the conversion target. - LogicalResult convertOperations(ArrayRef ops, - TypeConverter *typeConverter); + LogicalResult convertOperations(ArrayRef ops); private: /// Converts an operation with the given rewriter. LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op); - /// Converts the type signatures of the blocks nested within 'op'. - LogicalResult convertBlockSignatures(ConversionPatternRewriter &rewriter, - Operation *op); - /// The legalizer to use when converting operations. OperationLegalizer opLegalizer; @@ -1724,21 +1836,6 @@ private: }; } // end anonymous namespace -LogicalResult -OperationConverter::convertBlockSignatures(ConversionPatternRewriter &rewriter, - Operation *op) { - // Check to see if type signatures need to be converted. - if (!rewriter.getImpl().argConverter.typeConverter) - return success(); - - for (auto ®ion : op->getRegions()) { - for (auto &block : llvm::make_early_inc_range(region)) - if (failed(rewriter.getImpl().convertBlockSignature(&block))) - return failure(); - } - return success(); -} - LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter, Operation *op) { // Legalize the given operation. @@ -1759,24 +1856,16 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter, if (trackedOps) trackedOps->insert(op); } - } else { + } else if (mode == OpConversionMode::Analysis) { // Analysis conversions don't fail if any operations fail to legalize, // they are only interested in the operations that were successfully // legalized. - if (mode == OpConversionMode::Analysis) - trackedOps->insert(op); - - // If legalization succeeded, convert the types any of the blocks within - // this operation. - if (failed(convertBlockSignatures(rewriter, op))) - return failure(); + trackedOps->insert(op); } return success(); } -LogicalResult -OperationConverter::convertOperations(ArrayRef ops, - TypeConverter *typeConverter) { +LogicalResult OperationConverter::convertOperations(ArrayRef ops) { if (ops.empty()) return success(); ConversionTarget &target = opLegalizer.getTarget(); @@ -1792,7 +1881,7 @@ OperationConverter::convertOperations(ArrayRef ops, } // Convert each operation and discard rewrites on failure. - ConversionPatternRewriter rewriter(ops.front()->getContext(), typeConverter); + ConversionPatternRewriter rewriter(ops.front()->getContext()); for (auto *op : toConvert) if (failed(convert(rewriter, op))) return rewriter.getImpl().discardRewrites(), failure(); @@ -1913,6 +2002,13 @@ bool TypeConverter::isLegal(Operation *op) { return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes()); } +/// Return true if the types of block arguments within the region are legal. +bool TypeConverter::isLegal(Region *region) { + return llvm::all_of(*region, [this](Block &block) { + return isLegal(block.getArgumentTypes()); + }); +} + /// Return true if the inputs and outputs of the given function type are /// legal. bool TypeConverter::isSignatureLegal(FunctionType ty) { @@ -1969,7 +2065,7 @@ auto TypeConverter::convertBlockSignature(Block *block) namespace { struct FuncOpSignatureConversion : public OpConversionPattern { FuncOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) - : OpConversionPattern(ctx), converter(converter) {} + : OpConversionPattern(converter, ctx) {} /// Hook for derived classes to implement combined matching and rewriting. LogicalResult @@ -1979,22 +2075,20 @@ struct FuncOpSignatureConversion : public OpConversionPattern { // Convert the original function types. TypeConverter::SignatureConversion result(type.getNumInputs()); - SmallVector convertedResults; - if (failed(converter.convertSignatureArgs(type.getInputs(), result)) || - failed(converter.convertTypes(type.getResults(), convertedResults))) + SmallVector newResults; + if (failed(typeConverter->convertSignatureArgs(type.getInputs(), result)) || + failed(typeConverter->convertTypes(type.getResults(), newResults)) || + failed(rewriter.convertRegionTypes(&funcOp.getBody(), *typeConverter, + &result))) return failure(); // Update the function signature in-place. rewriter.updateRootInPlace(funcOp, [&] { - funcOp.setType(FunctionType::get(result.getConvertedTypes(), - convertedResults, funcOp.getContext())); - rewriter.applySignatureConversion(&funcOp.getBody(), result); + funcOp.setType(FunctionType::get(result.getConvertedTypes(), newResults, + funcOp.getContext())); }); return success(); } - - /// The type converter to use when rewriting the signature. - TypeConverter &converter; }; } // end anonymous namespace @@ -2128,27 +2222,26 @@ auto ConversionTarget::getOpInfo(OperationName op) const /// Apply a partial conversion on the given operations and all nested /// operations. This method converts as many operations to the target as /// possible, ignoring operations that failed to legalize. This method only -/// returns failure if there ops explicitly marked as illegal. If `converter` is -/// provided, the signatures of blocks and regions are also converted. +/// returns failure if there ops explicitly marked as illegal. /// If an `unconvertedOps` set is provided, all operations that are found not /// to be legalizable to the given `target` are placed within that set. (Note /// that if there is an op explicitly marked as illegal, the conversion /// terminates and the `unconvertedOps` set will not necessarily be complete.) -LogicalResult mlir::applyPartialConversion( - ArrayRef ops, ConversionTarget &target, - const OwningRewritePatternList &patterns, TypeConverter *converter, - DenseSet *unconvertedOps) { +LogicalResult +mlir::applyPartialConversion(ArrayRef ops, + ConversionTarget &target, + const OwningRewritePatternList &patterns, + DenseSet *unconvertedOps) { OperationConverter opConverter(target, patterns, OpConversionMode::Partial, unconvertedOps); - return opConverter.convertOperations(ops, converter); + return opConverter.convertOperations(ops); } LogicalResult mlir::applyPartialConversion(Operation *op, ConversionTarget &target, const OwningRewritePatternList &patterns, - TypeConverter *converter, DenseSet *unconvertedOps) { return applyPartialConversion(llvm::makeArrayRef(op), target, patterns, - converter, unconvertedOps); + unconvertedOps); } /// Apply a complete conversion on the given operations, and all nested @@ -2156,17 +2249,14 @@ mlir::applyPartialConversion(Operation *op, ConversionTarget &target, /// operation fails. LogicalResult mlir::applyFullConversion(ArrayRef ops, ConversionTarget &target, - const OwningRewritePatternList &patterns, - TypeConverter *converter) { + const OwningRewritePatternList &patterns) { OperationConverter opConverter(target, patterns, OpConversionMode::Full); - return opConverter.convertOperations(ops, converter); + return opConverter.convertOperations(ops); } LogicalResult mlir::applyFullConversion(Operation *op, ConversionTarget &target, - const OwningRewritePatternList &patterns, - TypeConverter *converter) { - return applyFullConversion(llvm::makeArrayRef(op), target, patterns, - converter); + const OwningRewritePatternList &patterns) { + return applyFullConversion(llvm::makeArrayRef(op), target, patterns); } /// Apply an analysis conversion on the given operations, and all nested @@ -2175,19 +2265,19 @@ mlir::applyFullConversion(Operation *op, ConversionTarget &target, /// were found to be legalizable to the given 'target' are placed within the /// provided 'convertedOps' set; note that no actual rewrites are applied to the /// operations on success and only pre-existing operations are added to the set. -LogicalResult mlir::applyAnalysisConversion( - ArrayRef ops, ConversionTarget &target, - const OwningRewritePatternList &patterns, - DenseSet &convertedOps, TypeConverter *converter) { +LogicalResult +mlir::applyAnalysisConversion(ArrayRef ops, + ConversionTarget &target, + const OwningRewritePatternList &patterns, + DenseSet &convertedOps) { OperationConverter opConverter(target, patterns, OpConversionMode::Analysis, &convertedOps); - return opConverter.convertOperations(ops, converter); + return opConverter.convertOperations(ops); } LogicalResult mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target, const OwningRewritePatternList &patterns, - DenseSet &convertedOps, - TypeConverter *converter) { + DenseSet &convertedOps) { return applyAnalysisConversion(llvm::makeArrayRef(op), target, patterns, - convertedOps, converter); + convertedOps); } diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir index 284c38b..5637fa8 100644 --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -153,15 +153,11 @@ func @remove_foldable_op(%arg0 : i32) -> (i32) { // CHECK-LABEL: @create_block func @create_block() { - // expected-remark@+1 {{op 'test.container' is not legalizable}} - "test.container"() ({ - // Check that we created a block with arguments. - // CHECK-NOT: test.create_block - // CHECK: ^{{.*}}(%{{.*}}: i32, %{{.*}}: i32): - // CHECK: test.finish - "test.create_block"() : () -> () - "test.finish"() : () -> () - }) : () -> () + // Check that we created a block with arguments. + // CHECK-NOT: test.create_block + // CHECK: ^{{.*}}(%{{.*}}: i32, %{{.*}}: i32): + "test.create_block"() : () -> () + // expected-remark@+1 {{op 'std.return' is not legalizable}} return } @@ -212,15 +208,12 @@ func @fail_to_convert_region() { // CHECK-LABEL: @create_illegal_block func @create_illegal_block() { - // expected-remark@+1 {{op 'test.container' is not legalizable}} - "test.container"() ({ - // Check that we can undo block creation, i.e. that the block was removed. - // CHECK: test.create_illegal_block - // CHECK-NOT: ^{{.*}}(%{{.*}}: i32, %{{.*}}: i32): - // expected-remark@+1 {{op 'test.create_illegal_block' is not legalizable}} - "test.create_illegal_block"() : () -> () - "test.finish"() : () -> () - }) : () -> () + // Check that we can undo block creation, i.e. that the block was removed. + // CHECK: test.create_illegal_block + // CHECK-NOT: ^{{.*}}(%{{.*}}: i32, %{{.*}}: i32): + // expected-remark@+1 {{op 'test.create_illegal_block' is not legalizable}} + "test.create_illegal_block"() : () -> () + // expected-remark@+1 {{op 'std.return' is not legalizable}} return } diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index cbab7d7..60f663f 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -304,8 +304,7 @@ struct TestUndoBlockErase : public ConversionPattern { /// This patterns erases a region operation that has had a type conversion. struct TestDropOpSignatureConversion : public ConversionPattern { TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) - : ConversionPattern("test.drop_region_op", 1, ctx), converter(converter) { - } + : ConversionPattern("test.drop_region_op", 1, converter, ctx) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { @@ -313,19 +312,17 @@ struct TestDropOpSignatureConversion : public ConversionPattern { Block *entry = ®ion.front(); // Convert the original entry arguments. + TypeConverter &converter = *getTypeConverter(); TypeConverter::SignatureConversion result(entry->getNumArguments()); - if (failed( - converter.convertSignatureArgs(entry->getArgumentTypes(), result))) + if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(), + result)) || + failed(rewriter.convertRegionTypes(®ion, converter, &result))) return failure(); // Convert the region signature and just drop the operation. - rewriter.applySignatureConversion(®ion, result); rewriter.eraseOp(op); return success(); } - - /// The type converter to use when rewriting the signature. - TypeConverter &converter; }; /// This pattern simply updates the operands of the given operation. struct TestPassthroughInvalidOp : public ConversionPattern { @@ -568,8 +565,10 @@ struct TestLegalizePatternDriver return llvm::none_of(op.getOperandTypes(), [](Type type) { return type.isF32(); }); }); - target.addDynamicallyLegalOp( - [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); + target.addDynamicallyLegalOp([&](FuncOp op) { + return converter.isSignatureLegal(op.getType()) && + converter.isLegal(&op.getBody()); + }); // Expect the type_producer/type_consumer operations to only operate on f64. target.addDynamicallyLegalOp( @@ -591,7 +590,7 @@ struct TestLegalizePatternDriver // Handle a partial conversion. if (mode == ConversionMode::Partial) { DenseSet unlegalizedOps; - (void)applyPartialConversion(getOperation(), target, patterns, &converter, + (void)applyPartialConversion(getOperation(), target, patterns, &unlegalizedOps); // Emit remarks for each legalizable operation. for (auto *op : unlegalizedOps) @@ -606,7 +605,7 @@ struct TestLegalizePatternDriver return (bool)op->getAttrOfType("test.dynamically_legal"); }); - (void)applyFullConversion(getOperation(), target, patterns, &converter); + (void)applyFullConversion(getOperation(), target, patterns); return; } @@ -616,7 +615,7 @@ struct TestLegalizePatternDriver // Analyze the convertible operations. DenseSet legalizedOps; if (failed(applyAnalysisConversion(getOperation(), target, patterns, - legalizedOps, &converter))) + legalizedOps))) return signalPassFailure(); // Emit remarks for each legalizable operation. diff --git a/mlir/test/lib/Transforms/TestBufferPlacement.cpp b/mlir/test/lib/Transforms/TestBufferPlacement.cpp index 0976f71..2fbdfe9 100644 --- a/mlir/test/lib/Transforms/TestBufferPlacement.cpp +++ b/mlir/test/lib/Transforms/TestBufferPlacement.cpp @@ -1,4 +1,4 @@ -//===- TestBufferPlacement.cpp - Test for buffer placement 0----*- C++ -*-===// +//===- TestBufferPlacement.cpp - Test for buffer placement ------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -140,7 +140,8 @@ struct TestBufferPlacementPreparationPass // Mark the function whose arguments are in tensor-type illegal. target.addDynamicallyLegalOp([&](FuncOp funcOp) { - return converter.isSignatureLegal(funcOp.getType()); + return converter.isSignatureLegal(funcOp.getType()) && + converter.isLegal(&funcOp.getBody()); }); // Walk over all the functions to apply buffer assignment. @@ -151,7 +152,7 @@ struct TestBufferPlacementPreparationPass &context, &placer, &converter, &patterns); // Applying full conversion - return applyFullConversion(function, target, patterns, &converter); + return applyFullConversion(function, target, patterns); }); }; }; -- 2.7.4