From: River Riddle Date: Mon, 3 Jun 2019 19:49:55 +0000 (-0700) Subject: Refactor the dialect conversion framework to support multi-level conversions. Multi... X-Git-Tag: llvmorg-11-init~1466^2~1516 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=95eaca3e0fa31c7c98b0dd704be06e11e6b1db1f;p=platform%2Fupstream%2Fllvm.git Refactor the dialect conversion framework to support multi-level conversions. Multi-level conversions are those that require multiple patterns to be applied before an operation is completely legalized. This essentially means that conversion patterns do not have to directly generate legal operations, and may be chained together to produce legal code. To accomplish this, moving forward users will need to provide a legalization target that defines what operations are legal for the conversion. A target can mark an operation as legal by providing a specific legalization action. The initial actions are: * Legal - This action signals that every instance of the given operation is legal, i.e. any combination of attributes, operands, types, etc. is valid. * Dynamic - This action signals that only some instances of a given operation are legal. This allows for defining fine-tune constraints, like say std.add is only legal when operating on 32-bit integers. An example target is shown below: struct MyTarget : public ConversionTarget { MyTarget(MLIRContext &ctx) : ConversionTarget(ctx) { // All operations in the LLVM dialect are legal. addLegalDialect(); // std.constant op is always legal on this target. addLegalOp(); // std.return op has dynamic legality constraints. addDynamicallyLegalOp(); } /// Implement the custom legalization handler to handle /// std.return. bool isLegal(Operation *op) override { // Process the dynamic handling for a std.return op (and any others that were // marked "dynamic"). ... } }; PiperOrigin-RevId: 251289374 --- diff --git a/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp index 84234c3..8cd970c 100644 --- a/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp @@ -418,7 +418,10 @@ void linalg::convertToLLVM(mlir::Module &module) { populateStdToLLVMConversionPatterns(converter, patterns); populateLinalg1ToLLVMConversionPatterns(patterns, module.getContext()); - auto r = applyConversionPatterns(module, converter, std::move(patterns)); + ConversionTarget target(*module.getContext()); + target.addLegalDialects(); + auto r = + applyConversionPatterns(module, target, converter, std::move(patterns)); (void)r; assert(succeeded(r) && "conversion failed"); } diff --git a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp index db9f496..60fdf60 100644 --- a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp @@ -29,6 +29,7 @@ #include "mlir/LLVMIR/LLVMLowering.h" #include "mlir/LLVMIR/Transforms.h" #include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/LowerAffine.h" #include "linalg1/ConvertToLLVMDialect.h" #include "linalg1/LLVMIntrinsics.h" @@ -145,12 +146,12 @@ static void populateLinalg3ToLLVMConversionPatterns( } void linalg::convertLinalg3ToLLVM(Module &module) { - // Remove affine constructs if any by using an existing pass. - PassManager pm; - pm.addPass(createLowerAffinePass()); - auto rr = pm.run(&module); - (void)rr; - assert(succeeded(rr) && "affine loop lowering failed"); + // Remove affine constructs. + for (auto &func : module) { + auto rr = lowerAffineConstructs(func); + (void)rr; + assert(succeeded(rr) && "affine loop lowering failed"); + } // Convert Linalg ops to the LLVM IR dialect using the converter defined // above. @@ -160,7 +161,10 @@ void linalg::convertLinalg3ToLLVM(Module &module) { populateLinalg1ToLLVMConversionPatterns(patterns, module.getContext()); populateLinalg3ToLLVMConversionPatterns(patterns, module.getContext()); - auto r = applyConversionPatterns(module, converter, std::move(patterns)); + ConversionTarget target(*module.getContext()); + target.addLegalDialects(); + auto r = + applyConversionPatterns(module, target, converter, std::move(patterns)); (void)r; assert(succeeded(r) && "conversion failed"); } diff --git a/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp b/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp index f4ac522..45d608d 100644 --- a/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp +++ b/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp @@ -27,6 +27,7 @@ #include "toy/Dialect.h" +#include "linalg1/Dialect.h" #include "linalg1/Intrinsics.h" #include "linalg1/ViewOp.h" #include "linalg3/TensorOps.h" @@ -124,9 +125,14 @@ public: /// dialect. struct EarlyLoweringPass : public FunctionPass { void runOnFunction() override { + ConversionTarget target(getContext()); + target.addLegalDialects(); + target.addLegalOp(); + OwningRewritePatternList patterns; RewriteListBuilder::build(patterns, &getContext()); - if (failed(applyConversionPatterns(getFunction(), std::move(patterns)))) { + if (failed(applyConversionPatterns(getFunction(), target, + std::move(patterns)))) { getContext().emitError(mlir::UnknownLoc::get(&getContext()), "Error lowering Toy\n"); signalPassFailure(); diff --git a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp index 611d716..d682d12 100644 --- a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp +++ b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp @@ -24,6 +24,7 @@ #include "toy/Dialect.h" +#include "linalg1/Dialect.h" #include "linalg1/Intrinsics.h" #include "linalg1/ViewOp.h" #include "linalg3/ConvertToLLVMDialect.h" @@ -338,7 +339,11 @@ struct LateLoweringPass : public ModulePass { ReturnOpConversion>::build(toyPatterns, &getContext()); // Perform Toy specific lowering. - if (failed(applyConversionPatterns(getModule(), typeConverter, + ConversionTarget target(getContext()); + target.addLegalDialects(); + target.addLegalOp(); + if (failed(applyConversionPatterns(getModule(), target, typeConverter, std::move(toyPatterns)))) { getModule().getContext()->emitError( UnknownLoc::get(getModule().getContext()), "Error lowering Toy\n"); diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 165e065..ac24252 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -25,6 +25,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/MapVector.h" namespace mlir { @@ -136,18 +137,108 @@ public: SmallVectorImpl &convertedArgAttrs); }; +/// This class describes a specific conversion target. +class ConversionTarget { +public: + /// This enumeration corresponds to the specific action to take when + /// considering an operation legal for this conversion target. + enum class LegalizationAction { + /// The target supports this operation. + Legal, + + /// This operation has dynamic legalization constraints that must be checked + /// by the target. + Dynamic + }; + + /// The type used to store operation legality information. + using LegalityMapTy = llvm::MapVector; + + ConversionTarget(MLIRContext &ctx) : ctx(ctx) {} + virtual ~ConversionTarget() = default; + + /// Runs a custom legalization query for the given operation. This should + /// return true if the given operation is legal, otherwise false. + virtual bool isLegal(Operation *op) const { + llvm_unreachable( + "targets with custom legalization must override 'isLegal'"); + } + + /// Register a legality action for the given operation. + void setOpAction(OperationName op, LegalizationAction action) { + legalOperations[op] = action; + } + template void setOpAction(LegalizationAction action) { + setOpAction(OperationName(OpT::getOperationName(), &ctx), action); + } + + /// Register the given operations as legal. + template void addLegalOp() { + setOpAction(LegalizationAction::Legal); + } + template void addLegalOp() { + addLegalOp(); + addLegalOp(); + } + + /// Register the operations of the given dialects as legal. + void addLegalDialects(ArrayRef dialectNames); + template + void addLegalDialects(StringRef name, Names... names) { + SmallVector dialectNames({name, names...}); + addLegalDialects(dialectNames); + } + template void addLegalDialects() { + SmallVector dialectNames({Args::getDialectNamespace()...}); + addLegalDialects(dialectNames); + } + + /// Register the given operation as dynamically legal, i.e. requiring custom + /// handling by the target via 'isLegal'. + template void addDynamicallyLegalOp() { + setOpAction(LegalizationAction::Dynamic); + } + template + void addDynamicallyLegalOp() { + addDynamicallyLegalOp(); + addDynamicallyLegalOp(); + } + + /// Get the legality action for the given operation. + llvm::Optional getOpAction(OperationName op) const { + auto it = legalOperations.find(op); + if (it != legalOperations.end()) + return it->second; + return llvm::None; + } + + /// Returns a range of operations that this target has defined to be legal in + /// some capacity. + llvm::iterator_range getLegalOps() const { + return llvm::make_range(legalOperations.begin(), legalOperations.end()); + } + +private: + /// A deterministic mapping of operation name to the specific legality action + /// to take. + LegalityMapTy legalOperations; + + /// The current context this target applies to. + MLIRContext &ctx; +}; + /// Convert the given module with the provided conversion patterns and type /// conversion object. If conversion fails for specific functions, those /// functions remains unmodified. -LLVM_NODISCARD -LogicalResult applyConversionPatterns(Module &module, TypeConverter &converter, - OwningRewritePatternList &&patterns); +LLVM_NODISCARD LogicalResult applyConversionPatterns( + Module &module, ConversionTarget &target, TypeConverter &converter, + OwningRewritePatternList &&patterns); /// Convert the given function with the provided conversion patterns. This will /// convert as many of the operations within 'fn' as possible given the set of /// patterns. LLVM_NODISCARD -LogicalResult applyConversionPatterns(Function &fn, +LogicalResult applyConversionPatterns(Function &fn, ConversionTarget &target, OwningRewritePatternList &&patterns); } // end namespace mlir diff --git a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp index a9717e2..0e30a8e 100644 --- a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp +++ b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp @@ -986,7 +986,11 @@ struct LLVMLoweringPass : public ModulePass { LLVMTypeConverter converter(&getContext()); OwningRewritePatternList patterns; populateStdToLLVMConversionPatterns(converter, patterns); - if (failed(applyConversionPatterns(m, converter, std::move(patterns)))) + + ConversionTarget target(getContext()); + target.addLegalDialects(); + if (failed( + applyConversionPatterns(m, target, converter, std::move(patterns)))) signalPassFailure(); } }; diff --git a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp index c686af8..60c0daf 100644 --- a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -677,7 +677,10 @@ void LowerLinalgToLLVMPass::runOnModule() { populateStdToLLVMConversionPatterns(converter, patterns); populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext()); - if (failed(applyConversionPatterns(module, converter, std::move(patterns)))) + ConversionTarget target(getContext()); + target.addLegalDialects(); + if (failed(applyConversionPatterns(module, target, converter, + std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index b6a15f6..4c110d16 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -21,9 +21,14 @@ #include "mlir/IR/Function.h" #include "mlir/IR/Module.h" #include "mlir/Transforms/Utils.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" using namespace mlir; -using namespace mlir::impl; + +#define DEBUG_TYPE "dialect-conversion" //===----------------------------------------------------------------------===// // ArgConverter @@ -136,7 +141,8 @@ struct DialectConversionRewriter final : public PatternRewriter { assert(newValues.size() == op->getNumResults()); // Create mappings for any type changes. for (unsigned i = 0, e = newValues.size(); i < e; ++i) - if (op->getResult(i)->getType() != newValues[i]->getType()) + if (newValues[i] && + op->getResult(i)->getType() != newValues[i]->getType()) mapping.map(op->getResult(i), newValues[i]); // Record the requested operation replacement. @@ -223,17 +229,235 @@ void ConversionPattern::rewrite(Operation *op, } //===----------------------------------------------------------------------===// +// ConversionTarget +//===----------------------------------------------------------------------===// + +/// Register the operations of the given dialects as legal. +void ConversionTarget::addLegalDialects(ArrayRef dialectNames) { + SmallPtrSet dialects; + for (auto dialectName : dialectNames) + if (auto *dialect = ctx.getRegisteredDialect(dialectName)) + dialects.insert(dialect); + + // Set all dialect operations as legal. + for (auto op : ctx.getRegisteredOperations()) + if (dialects.count(&op->dialect)) + setOpAction(op, LegalizationAction::Legal); +} + +//===----------------------------------------------------------------------===// +// OperationLegalizer +//===----------------------------------------------------------------------===// + +namespace { +/// This class represents the information necessary for legalizing an operation +/// kind. +struct OpLegality { + /// This is the legalization action specified by the target, if it provided + /// one. + llvm::Optional targetAction; + + /// The set of patterns to apply to an instance of this operation to legalize + /// it. + SmallVector patterns; +}; + +/// This class defines a recursive operation legalizer. +class OperationLegalizer { +public: + OperationLegalizer(ConversionTarget &targetInfo, + OwningRewritePatternList &patterns) + : target(targetInfo) { + buildLegalizationGraph(patterns); + } + + /// Attempt to legalize the given operation. Returns success if the operation + /// was legalized, failure otherwise. + LogicalResult legalize(Operation *op, DialectConversionRewriter &rewriter); + +private: + /// Attempt to legalize the given operation by applying the provided pattern. + /// Returns success if the operation was legalized, failure otherwise. + LogicalResult legalizePattern(Operation *op, RewritePattern *pattern, + DialectConversionRewriter &rewriter); + + /// Build an optimistic legalization graph given the provided patterns. This + /// function populates 'legalOps' with the operations that are either legal, + /// or transitively legal for the current target given the provided patterns. + void buildLegalizationGraph(OwningRewritePatternList &patterns); + + /// The current set of patterns that have been applied. + llvm::SmallPtrSet appliedPatterns; + + /// The set of legality information for operations transitively supported by + /// the target. + DenseMap legalOps; + + /// The legalization information provided by the target. + ConversionTarget ⌖ +}; +} // namespace + +LogicalResult +OperationLegalizer::legalize(Operation *op, + DialectConversionRewriter &rewriter) { + LLVM_DEBUG(llvm::dbgs() << "Legalizing operation : " << op->getName() + << "\n"); + + auto it = legalOps.find(op->getName()); + if (it == legalOps.end()) { + LLVM_DEBUG(llvm::dbgs() << "-- FAIL : no known legalization path.\n"); + return failure(); + } + + // Check if this was marked legal by the target. + auto &opInfo = it->second; + if (auto action = opInfo.targetAction) { + // Check if this operation is always legal. + if (*action == ConversionTarget::LegalizationAction::Legal) + return success(); + + // Otherwise, handle custom legalization. + LLVM_DEBUG(llvm::dbgs() << "- Trying dynamic legalization.\n"); + if (target.isLegal(op)) + return success(); + + // Fallthough to see if a pattern can convert this into a legal operation. + } + + // Otherwise, we need to apply a legalization pattern to this operation. + // TODO(riverriddle) This currently has no cost model and doesn't prioritize + // specific patterns in any way. + for (auto *pattern : opInfo.patterns) + if (succeeded(legalizePattern(op, pattern, rewriter))) + return success(); + + LLVM_DEBUG(llvm::dbgs() << "-- FAIL : no matched legalization pattern.\n"); + return failure(); +} + +LogicalResult +OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern, + DialectConversionRewriter &rewriter) { + LLVM_DEBUG({ + llvm::dbgs() << "-* Applying rewrite pattern '" << op->getName() << " -> ("; + interleaveComma(pattern->getGeneratedOps(), llvm::dbgs()); + llvm::dbgs() << ")'.\n"; + }); + + // Ensure that we don't cycle by not allowing the same pattern to be + // applied twice in the same recursion stack. + // TODO(riverriddle) We could eventually converge, but that requires more + // complicated analysis. + if (!appliedPatterns.insert(pattern).second) { + LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Pattern was already applied.\n"); + return failure(); + } + + auto curOpCount = rewriter.createdOps.size(); + auto curReplCount = rewriter.replacements.size(); + auto cleanupFailure = [&] { + // Pop all of the newly created operations and replacements. + while (rewriter.createdOps.size() != curOpCount) + rewriter.createdOps.pop_back_val()->erase(); + rewriter.replacements.resize(curReplCount); + appliedPatterns.erase(pattern); + return failure(); + }; + + // Try to rewrite with the given pattern. + rewriter.setInsertionPoint(op); + if (!pattern->matchAndRewrite(op, rewriter)) { + LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Pattern failed to match.\n"); + return cleanupFailure(); + } + + // Recursively legalize each of the new operations. + for (unsigned i = curOpCount, e = rewriter.createdOps.size(); i != e; ++i) { + if (succeeded(legalize(rewriter.createdOps[i], rewriter))) + continue; + LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Generated operation was illegal.\n"); + return cleanupFailure(); + } + + appliedPatterns.erase(pattern); + return success(); +} + +void OperationLegalizer::buildLegalizationGraph( + OwningRewritePatternList &patterns) { + // A mapping between an operation and a set of operations that can be used to + // generate it. + DenseMap> parentOps; + // A mapping between an operation and any currently invalid patterns it has. + DenseMap> invalidPatterns; + // A worklist of patterns to consider for legality. + llvm::SetVector patternWorklist; + + // Collect the initial set of valid target ops. + for (auto &opInfoPair : target.getLegalOps()) + legalOps[opInfoPair.first].targetAction = opInfoPair.second; + + // Build the mapping from operations to the parent ops that may generate them. + for (auto &pattern : patterns) { + auto root = pattern->getRootKind(); + + // Skip operations that are known to always be legal. + auto it = legalOps.find(root); + if (it != legalOps.end() && + it->second.targetAction == ConversionTarget::LegalizationAction::Legal) + continue; + + // Add this pattern to the invalid set for the root op and record this root + // as a parent for any generated operations. + invalidPatterns[root].insert(pattern.get()); + for (auto op : pattern->getGeneratedOps()) + parentOps[op].insert(root); + + // If this pattern doesn't generate any operations, optimistically add it to + // the worklist. + if (pattern->getGeneratedOps().empty()) + patternWorklist.insert(pattern.get()); + } + + // Build the initial worklist with the patterns that generate operations that + // are known to be legal. + for (auto &opInfoPair : target.getLegalOps()) + for (auto &parentOp : parentOps[opInfoPair.first]) + patternWorklist.set_union(invalidPatterns[parentOp]); + + while (!patternWorklist.empty()) { + auto *pattern = patternWorklist.pop_back_val(); + + // Check to see if any of the generated operations are invalid. + if (llvm::any_of(pattern->getGeneratedOps(), + [&](OperationName op) { return !legalOps.count(op); })) + continue; + + // Otherwise, if all of the generated operation are valid, this op is now + // legal so add all of the child patterns to the worklist. + legalOps[pattern->getRootKind()].patterns.push_back(pattern); + invalidPatterns[pattern->getRootKind()].erase(pattern); + + // Add any invalid patterns of the parent operations to see if they have now + // become legal. + for (auto op : parentOps[pattern->getRootKind()]) + patternWorklist.set_union(invalidPatterns[op]); + } +} + +//===----------------------------------------------------------------------===// // FunctionConverter //===----------------------------------------------------------------------===// namespace { // This class converts a single function using the given pattern matcher. If a // TypeConverter object is provided, then the types of block arguments will be // converted using the appropriate 'convertType' calls. -class FunctionConverter { -public: - explicit FunctionConverter(MLIRContext *ctx, RewritePatternMatcher &matcher, +struct FunctionConverter { + explicit FunctionConverter(MLIRContext *ctx, ConversionTarget &target, + OwningRewritePatternList &patterns, TypeConverter *conversion = nullptr) - : typeConverter(conversion), matcher(matcher) {} + : typeConverter(conversion), opLegalizer(target, patterns) {} /// Converts the given function to the dialect using hooks defined in /// `typeConverter`. Returns failure on error, success otherwise. @@ -262,8 +486,8 @@ public: /// Pointer to a specific dialect conversion info. TypeConverter *typeConverter; - /// The matcher to use when converting operations. - RewritePatternMatcher &matcher; + /// The legalizer to use when converting operations. + OperationLegalizer opLegalizer; }; } // end anonymous namespace @@ -293,15 +517,14 @@ FunctionConverter::convertBlock(DialectConversionRewriter &rewriter, // Iterate over ops and convert them. for (Operation &op : llvm::make_early_inc_range(*block)) { - rewriter.setInsertionPoint(&op); - if (matcher.matchAndRewrite(&op, rewriter)) - continue; - // Traverse any held regions. for (auto ®ion : op.getRegions()) if (!region.empty() && failed(convertRegion(rewriter, region, op.getLoc()))) return failure(); + + // Legalize the current operation. + (void)opLegalizer.legalize(&op, rewriter); } // Recurse to children that haven't been visited. @@ -416,12 +639,12 @@ struct ConvertedFunction { /// conversion object. If conversion fails for specific functions, those /// functions remains unmodified. LogicalResult -mlir::applyConversionPatterns(Module &module, TypeConverter &converter, +mlir::applyConversionPatterns(Module &module, ConversionTarget &target, + TypeConverter &converter, OwningRewritePatternList &&patterns) { - // Grab the conversion patterns from the converter and create the pattern - // matcher. - MLIRContext *context = module.getContext(); - RewritePatternMatcher matcher(std::move(patterns)); + // Build the function converter. + FunctionConverter funcConverter(module.getContext(), target, patterns, + &converter); // Try to convert each of the functions within the module. Defer updating the // signatures of the functions until after all of the bodies have been @@ -439,7 +662,6 @@ mlir::applyConversionPatterns(Module &module, TypeConverter &converter, return func.emitError("could not convert function type"); // Convert the body of this function. - FunctionConverter funcConverter(context, matcher, &converter); if (failed(funcConverter.convertFunction(&func))) return failure(); @@ -461,10 +683,9 @@ mlir::applyConversionPatterns(Module &module, TypeConverter &converter, /// convert as many of the operations within 'fn' as possible given the set of /// patterns. LogicalResult -mlir::applyConversionPatterns(Function &fn, +mlir::applyConversionPatterns(Function &fn, ConversionTarget &target, OwningRewritePatternList &&patterns) { // Convert the body of this function. - RewritePatternMatcher matcher(std::move(patterns)); - FunctionConverter converter(fn.getContext(), matcher); + FunctionConverter converter(fn.getContext(), target, patterns); return converter.convertFunction(&fn); }