Add support to ConversionTarget for storing legalization actions for entire dialects...
authorRiver Riddle <riverriddle@google.com>
Thu, 6 Jun 2019 22:48:14 +0000 (15:48 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 9 Jun 2019 23:21:32 +0000 (16:21 -0700)
PiperOrigin-RevId: 251943590

mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Transforms/DialectConversion.cpp

index af08a1f..e958426 100644 (file)
@@ -175,6 +175,10 @@ public:
         "targets with custom legalization must override 'isLegal'");
   }
 
+  //===--------------------------------------------------------------------===//
+  // Legality Registration
+  //===--------------------------------------------------------------------===//
+
   /// Register a legality action for the given operation.
   void setOpAction(OperationName op, LegalizationAction action) {
     legalOperations[op] = action;
@@ -193,7 +197,10 @@ public:
   }
 
   /// Register the operations of the given dialects as legal.
-  void addLegalDialects(ArrayRef<StringRef> dialectNames);
+  void addLegalDialects(ArrayRef<StringRef> dialectNames) {
+    for (auto &dialect : dialectNames)
+      legalDialects[dialect] = LegalizationAction::Legal;
+  }
   template <typename... Names>
   void addLegalDialects(StringRef name, Names... names) {
     SmallVector<StringRef, 2> dialectNames({name, names...});
@@ -215,25 +222,32 @@ public:
     addDynamicallyLegalOp<OpT2, OpTs...>();
   }
 
+  //===--------------------------------------------------------------------===//
+  // Legality Querying
+  //===--------------------------------------------------------------------===//
+
   /// Get the legality action for the given operation.
   llvm::Optional<LegalizationAction> getOpAction(OperationName op) const {
+    // Check for an action for this specific operation.
     auto it = legalOperations.find(op);
     if (it != legalOperations.end())
       return it->second;
+    // Otherwise, default to checking for an action on the parent dialect.
+    auto dialectIt = legalDialects.find(op.getDialect());
+    if (dialectIt != legalDialects.end())
+      return dialectIt->second;
     return llvm::None;
   }
 
-  /// Returns a range of operations that this target has defined to be legal in
-  /// some capacity.
-  llvm::iterator_range<LegalityMapTy::const_iterator> getLegalOps() const {
-    return llvm::make_range(legalOperations.begin(), legalOperations.end());
-  }
-
 private:
   /// A deterministic mapping of operation name to the specific legality action
   /// to take.
   LegalityMapTy legalOperations;
 
+  /// A deterministic mapping of dialect name to the specific legality action to
+  /// take.
+  llvm::StringMap<LegalizationAction> legalDialects;
+
   /// The current context this target applies to.
   MLIRContext &ctx;
 };
index 00ae605..7ef8a46 100644 (file)
@@ -263,38 +263,12 @@ ConversionPattern::matchAndRewrite(Operation *op,
 }
 
 //===----------------------------------------------------------------------===//
-// ConversionTarget
-//===----------------------------------------------------------------------===//
-
-/// Register the operations of the given dialects as legal.
-void ConversionTarget::addLegalDialects(ArrayRef<StringRef> dialectNames) {
-  SmallPtrSet<Dialect *, 2> dialects;
-  for (auto dialectName : dialectNames)
-    if (auto *dialect = ctx.getRegisteredDialect(dialectName))
-      dialects.insert(dialect);
-
-  // Set all dialect operations as legal.
-  for (auto op : ctx.getRegisteredOperations())
-    if (dialects.count(&op->dialect))
-      setOpAction(op, LegalizationAction::Legal);
-}
-
-//===----------------------------------------------------------------------===//
 // OperationLegalizer
 //===----------------------------------------------------------------------===//
 
 namespace {
-/// This class represents the information necessary for legalizing an operation
-/// kind.
-struct OpLegality {
-  /// This is the legalization action specified by the target, if it provided
-  /// one.
-  llvm::Optional<ConversionTarget::LegalizationAction> targetAction;
-
-  /// The set of patterns to apply to an instance of this operation to legalize
-  /// it.
-  SmallVector<RewritePattern *, 1> patterns;
-};
+/// A set of rewrite patterns that can be used to legalize a given operation.
+using LegalizationPatterns = SmallVector<RewritePattern *, 1>;
 
 /// This class defines a recursive operation legalizer.
 class OperationLegalizer {
@@ -316,8 +290,9 @@ private:
                                 DialectConversionRewriter &rewriter);
 
   /// Build an optimistic legalization graph given the provided patterns. This
-  /// function populates 'legalOps' with the operations that are either legal,
-  /// or transitively legal for the current target given the provided patterns.
+  /// function populates 'legalizerPatterns' with the operations that are not
+  /// directly legal, but may be transitively legal for the current target given
+  /// the provided patterns.
   void buildLegalizationGraph(OwningRewritePatternList &patterns);
 
   /// The current set of patterns that have been applied.
@@ -325,7 +300,7 @@ private:
 
   /// The set of legality information for operations transitively supported by
   /// the target.
-  DenseMap<OperationName, OpLegality> legalOps;
+  DenseMap<OperationName, LegalizationPatterns> legalizerPatterns;
 
   /// The legalization information provided by the target.
   ConversionTarget &target;
@@ -338,20 +313,13 @@ OperationLegalizer::legalize(Operation *op,
   LLVM_DEBUG(llvm::dbgs() << "Legalizing operation : " << op->getName()
                           << "\n");
 
-  auto it = legalOps.find(op->getName());
-  if (it == legalOps.end()) {
-    LLVM_DEBUG(llvm::dbgs() << "-- FAIL : no known legalization path.\n");
-    return failure();
-  }
-
   // Check if this was marked legal by the target.
-  auto &opInfo = it->second;
-  if (auto action = opInfo.targetAction) {
+  if (auto action = target.getOpAction(op->getName())) {
     // Check if this operation is always legal.
     if (*action == ConversionTarget::LegalizationAction::Legal)
       return success();
 
-    // Otherwise, handle custom legalization.
+    // Otherwise, handle dynamic legalization.
     LLVM_DEBUG(llvm::dbgs() << "- Trying dynamic legalization.\n");
     if (target.isLegal(op))
       return success();
@@ -360,9 +328,15 @@ OperationLegalizer::legalize(Operation *op,
   }
 
   // Otherwise, we need to apply a legalization pattern to this operation.
+  auto it = legalizerPatterns.find(op->getName());
+  if (it == legalizerPatterns.end()) {
+    LLVM_DEBUG(llvm::dbgs() << "-- FAIL : no known legalization path.\n");
+    return failure();
+  }
+
   // TODO(riverriddle) This currently has no cost model and doesn't prioritize
   // specific patterns in any way.
-  for (auto *pattern : opInfo.patterns)
+  for (auto *pattern : it->second)
     if (succeeded(legalizePattern(op, pattern, rewriter)))
       return success();
 
@@ -427,18 +401,12 @@ void OperationLegalizer::buildLegalizationGraph(
   // A worklist of patterns to consider for legality.
   llvm::SetVector<RewritePattern *> patternWorklist;
 
-  // Collect the initial set of valid target ops.
-  for (auto &opInfoPair : target.getLegalOps())
-    legalOps[opInfoPair.first].targetAction = opInfoPair.second;
-
   // Build the mapping from operations to the parent ops that may generate them.
   for (auto &pattern : patterns) {
     auto root = pattern->getRootKind();
 
-    // Skip operations that are known to always be legal.
-    auto it = legalOps.find(root);
-    if (it != legalOps.end() &&
-        it->second.targetAction == ConversionTarget::LegalizationAction::Legal)
+    // Skip operations that are always known to be legal.
+    if (target.getOpAction(root) == ConversionTarget::LegalizationAction::Legal)
       continue;
 
     // Add this pattern to the invalid set for the root op and record this root
@@ -447,29 +415,22 @@ void OperationLegalizer::buildLegalizationGraph(
     for (auto op : pattern->getGeneratedOps())
       parentOps[op].insert(root);
 
-    // If this pattern doesn't generate any operations, optimistically add it to
-    // the worklist.
-    if (pattern->getGeneratedOps().empty())
-      patternWorklist.insert(pattern.get());
+    // Add this pattern to the worklist.
+    patternWorklist.insert(pattern.get());
   }
 
-  // Build the initial worklist with the patterns that generate operations that
-  // are known to be legal.
-  for (auto &opInfoPair : target.getLegalOps())
-    for (auto &parentOp : parentOps[opInfoPair.first])
-      patternWorklist.set_union(invalidPatterns[parentOp]);
-
   while (!patternWorklist.empty()) {
     auto *pattern = patternWorklist.pop_back_val();
 
     // Check to see if any of the generated operations are invalid.
-    if (llvm::any_of(pattern->getGeneratedOps(),
-                     [&](OperationName op) { return !legalOps.count(op); }))
+    if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) {
+          return !legalizerPatterns.count(op) && !target.getOpAction(op);
+        }))
       continue;
 
     // Otherwise, if all of the generated operation are valid, this op is now
     // legal so add all of the child patterns to the worklist.
-    legalOps[pattern->getRootKind()].patterns.push_back(pattern);
+    legalizerPatterns[pattern->getRootKind()].push_back(pattern);
     invalidPatterns[pattern->getRootKind()].erase(pattern);
 
     // Add any invalid patterns of the parent operations to see if they have now