From e7ccfb2ae847249abf230b08638cec7d1a2ee5d9 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Thu, 6 Jun 2019 15:48:14 -0700 Subject: [PATCH] Add support to ConversionTarget for storing legalization actions for entire dialects as opposed to individual operations. This allows for better support of unregistered operations, as well as removing the need to collect all of the operations for a given dialect(which may be very expensive). PiperOrigin-RevId: 251943590 --- mlir/include/mlir/Transforms/DialectConversion.h | 28 ++++++-- mlir/lib/Transforms/DialectConversion.cpp | 85 +++++++----------------- 2 files changed, 44 insertions(+), 69 deletions(-) diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index af08a1f..e958426 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -175,6 +175,10 @@ public: "targets with custom legalization must override 'isLegal'"); } + //===--------------------------------------------------------------------===// + // Legality Registration + //===--------------------------------------------------------------------===// + /// Register a legality action for the given operation. void setOpAction(OperationName op, LegalizationAction action) { legalOperations[op] = action; @@ -193,7 +197,10 @@ public: } /// Register the operations of the given dialects as legal. - void addLegalDialects(ArrayRef dialectNames); + void addLegalDialects(ArrayRef dialectNames) { + for (auto &dialect : dialectNames) + legalDialects[dialect] = LegalizationAction::Legal; + } template void addLegalDialects(StringRef name, Names... names) { SmallVector dialectNames({name, names...}); @@ -215,25 +222,32 @@ public: addDynamicallyLegalOp(); } + //===--------------------------------------------------------------------===// + // Legality Querying + //===--------------------------------------------------------------------===// + /// Get the legality action for the given operation. llvm::Optional getOpAction(OperationName op) const { + // Check for an action for this specific operation. auto it = legalOperations.find(op); if (it != legalOperations.end()) return it->second; + // Otherwise, default to checking for an action on the parent dialect. + auto dialectIt = legalDialects.find(op.getDialect()); + if (dialectIt != legalDialects.end()) + return dialectIt->second; return llvm::None; } - /// Returns a range of operations that this target has defined to be legal in - /// some capacity. - llvm::iterator_range getLegalOps() const { - return llvm::make_range(legalOperations.begin(), legalOperations.end()); - } - private: /// A deterministic mapping of operation name to the specific legality action /// to take. LegalityMapTy legalOperations; + /// A deterministic mapping of dialect name to the specific legality action to + /// take. + llvm::StringMap legalDialects; + /// The current context this target applies to. MLIRContext &ctx; }; diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index 00ae605..7ef8a46 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -263,38 +263,12 @@ ConversionPattern::matchAndRewrite(Operation *op, } //===----------------------------------------------------------------------===// -// ConversionTarget -//===----------------------------------------------------------------------===// - -/// Register the operations of the given dialects as legal. -void ConversionTarget::addLegalDialects(ArrayRef dialectNames) { - SmallPtrSet dialects; - for (auto dialectName : dialectNames) - if (auto *dialect = ctx.getRegisteredDialect(dialectName)) - dialects.insert(dialect); - - // Set all dialect operations as legal. - for (auto op : ctx.getRegisteredOperations()) - if (dialects.count(&op->dialect)) - setOpAction(op, LegalizationAction::Legal); -} - -//===----------------------------------------------------------------------===// // OperationLegalizer //===----------------------------------------------------------------------===// namespace { -/// This class represents the information necessary for legalizing an operation -/// kind. -struct OpLegality { - /// This is the legalization action specified by the target, if it provided - /// one. - llvm::Optional targetAction; - - /// The set of patterns to apply to an instance of this operation to legalize - /// it. - SmallVector patterns; -}; +/// A set of rewrite patterns that can be used to legalize a given operation. +using LegalizationPatterns = SmallVector; /// This class defines a recursive operation legalizer. class OperationLegalizer { @@ -316,8 +290,9 @@ private: DialectConversionRewriter &rewriter); /// Build an optimistic legalization graph given the provided patterns. This - /// function populates 'legalOps' with the operations that are either legal, - /// or transitively legal for the current target given the provided patterns. + /// function populates 'legalizerPatterns' with the operations that are not + /// directly legal, but may be transitively legal for the current target given + /// the provided patterns. void buildLegalizationGraph(OwningRewritePatternList &patterns); /// The current set of patterns that have been applied. @@ -325,7 +300,7 @@ private: /// The set of legality information for operations transitively supported by /// the target. - DenseMap legalOps; + DenseMap legalizerPatterns; /// The legalization information provided by the target. ConversionTarget ⌖ @@ -338,20 +313,13 @@ OperationLegalizer::legalize(Operation *op, LLVM_DEBUG(llvm::dbgs() << "Legalizing operation : " << op->getName() << "\n"); - auto it = legalOps.find(op->getName()); - if (it == legalOps.end()) { - LLVM_DEBUG(llvm::dbgs() << "-- FAIL : no known legalization path.\n"); - return failure(); - } - // Check if this was marked legal by the target. - auto &opInfo = it->second; - if (auto action = opInfo.targetAction) { + if (auto action = target.getOpAction(op->getName())) { // Check if this operation is always legal. if (*action == ConversionTarget::LegalizationAction::Legal) return success(); - // Otherwise, handle custom legalization. + // Otherwise, handle dynamic legalization. LLVM_DEBUG(llvm::dbgs() << "- Trying dynamic legalization.\n"); if (target.isLegal(op)) return success(); @@ -360,9 +328,15 @@ OperationLegalizer::legalize(Operation *op, } // Otherwise, we need to apply a legalization pattern to this operation. + auto it = legalizerPatterns.find(op->getName()); + if (it == legalizerPatterns.end()) { + LLVM_DEBUG(llvm::dbgs() << "-- FAIL : no known legalization path.\n"); + return failure(); + } + // TODO(riverriddle) This currently has no cost model and doesn't prioritize // specific patterns in any way. - for (auto *pattern : opInfo.patterns) + for (auto *pattern : it->second) if (succeeded(legalizePattern(op, pattern, rewriter))) return success(); @@ -427,18 +401,12 @@ void OperationLegalizer::buildLegalizationGraph( // A worklist of patterns to consider for legality. llvm::SetVector patternWorklist; - // Collect the initial set of valid target ops. - for (auto &opInfoPair : target.getLegalOps()) - legalOps[opInfoPair.first].targetAction = opInfoPair.second; - // Build the mapping from operations to the parent ops that may generate them. for (auto &pattern : patterns) { auto root = pattern->getRootKind(); - // Skip operations that are known to always be legal. - auto it = legalOps.find(root); - if (it != legalOps.end() && - it->second.targetAction == ConversionTarget::LegalizationAction::Legal) + // Skip operations that are always known to be legal. + if (target.getOpAction(root) == ConversionTarget::LegalizationAction::Legal) continue; // Add this pattern to the invalid set for the root op and record this root @@ -447,29 +415,22 @@ void OperationLegalizer::buildLegalizationGraph( for (auto op : pattern->getGeneratedOps()) parentOps[op].insert(root); - // If this pattern doesn't generate any operations, optimistically add it to - // the worklist. - if (pattern->getGeneratedOps().empty()) - patternWorklist.insert(pattern.get()); + // Add this pattern to the worklist. + patternWorklist.insert(pattern.get()); } - // Build the initial worklist with the patterns that generate operations that - // are known to be legal. - for (auto &opInfoPair : target.getLegalOps()) - for (auto &parentOp : parentOps[opInfoPair.first]) - patternWorklist.set_union(invalidPatterns[parentOp]); - while (!patternWorklist.empty()) { auto *pattern = patternWorklist.pop_back_val(); // Check to see if any of the generated operations are invalid. - if (llvm::any_of(pattern->getGeneratedOps(), - [&](OperationName op) { return !legalOps.count(op); })) + if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) { + return !legalizerPatterns.count(op) && !target.getOpAction(op); + })) continue; // Otherwise, if all of the generated operation are valid, this op is now // legal so add all of the child patterns to the worklist. - legalOps[pattern->getRootKind()].patterns.push_back(pattern); + legalizerPatterns[pattern->getRootKind()].push_back(pattern); invalidPatterns[pattern->getRootKind()].erase(pattern); // Add any invalid patterns of the parent operations to see if they have now -- 2.7.4