Add support for explicitly marking dialects and operations as illegal.
authorRiver Riddle <riverriddle@google.com>
Wed, 17 Jul 2019 16:26:57 +0000 (09:26 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Fri, 19 Jul 2019 18:38:25 +0000 (11:38 -0700)
This explicit tag is useful is several ways:
*) This simplifies how to mark sub sections of a dialect as explicitly unsupported, e.g. my target supports all operations in the foo dialect except for these select few. This is useful for partial lowerings between dialects.
*) Partial conversions will now verify that operations that were explicitly marked as illegal must be converted. This provides some guarantee that the operations that need to be lowered by a specific pass will be.

PiperOrigin-RevId: 258582879

mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Transforms/DialectConversion.cpp
mlir/test/Transforms/test-legalizer.mlir
mlir/test/lib/TestDialect/TestOps.td
mlir/test/lib/TestDialect/TestPatterns.cpp

index 79f0d38..9944679 100644 (file)
@@ -255,7 +255,10 @@ public:
 
     /// This operation has dynamic legalization constraints that must be checked
     /// by the target.
-    Dynamic
+    Dynamic,
+
+    /// This target explicitly does not support this operation.
+    Illegal,
   };
 
   /// The type used to store operation legality information.
@@ -301,6 +304,16 @@ public:
     addDynamicallyLegalOp<OpT2, OpTs...>();
   }
 
+  /// Register the given operation as illegal, i.e. this operation is known to
+  /// not be supported by this target.
+  template <typename OpT> void addIllegalOp() {
+    setOpAction<OpT>(LegalizationAction::Illegal);
+  }
+  template <typename OpT, typename OpT2, typename... OpTs> void addIllegalOp() {
+    addIllegalOp<OpT>();
+    addIllegalOp<OpT2, OpTs...>();
+  }
+
   /// Register a legality action for the given dialects.
   void setDialectAction(ArrayRef<StringRef> dialectNames,
                         LegalizationAction action);
@@ -328,6 +341,18 @@ public:
     setDialectAction(dialectNames, LegalizationAction::Dynamic);
   }
 
+  /// Register the operations of the given dialects as illegal, i.e.
+  /// operations of this dialect are not supported by the target.
+  template <typename... Names>
+  void addIllegalDialect(StringRef name, Names... names) {
+    SmallVector<StringRef, 2> dialectNames({name, names...});
+    setDialectAction(dialectNames, LegalizationAction::Illegal);
+  }
+  template <typename... Args> void addIllegalDialect() {
+    SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
+    setDialectAction(dialectNames, LegalizationAction::Illegal);
+  }
+
   //===--------------------------------------------------------------------===//
   // Legality Querying
   //===--------------------------------------------------------------------===//
index fd01ad9..806eb64 100644 (file)
@@ -589,6 +589,8 @@ using LegalizationPatterns = SmallVector<RewritePattern *, 1>;
 /// This class defines a recursive operation legalizer.
 class OperationLegalizer {
 public:
+  using LegalizationAction = ConversionTarget::LegalizationAction;
+
   OperationLegalizer(ConversionTarget &targetInfo,
                      OwningRewritePatternList &patterns)
       : target(targetInfo) {
@@ -596,6 +598,9 @@ public:
     computeLegalizationGraphBenefit();
   }
 
+  /// Returns if the given operation is known to be illegal on the target.
+  bool isIllegal(Operation *op) const;
+
   /// Attempt to legalize the given operation. Returns success if the operation
   /// was legalized, failure otherwise.
   LogicalResult legalize(Operation *op, DialectConversionRewriter &rewriter);
@@ -634,6 +639,13 @@ private:
 };
 } // namespace
 
+bool OperationLegalizer::isIllegal(Operation *op) const {
+  // Check if the target explicitly marked this operation as illegal.
+  if (auto action = target.getOpAction(op->getName()))
+    return action == LegalizationAction::Illegal;
+  return false;
+}
+
 LogicalResult
 OperationLegalizer::legalize(Operation *op,
                              DialectConversionRewriter &rewriter) {
@@ -643,13 +655,15 @@ OperationLegalizer::legalize(Operation *op,
   // Check if this was marked legal by the target.
   if (auto action = target.getOpAction(op->getName())) {
     // Check if this operation is always legal.
-    if (*action == ConversionTarget::LegalizationAction::Legal)
+    if (*action == LegalizationAction::Legal)
       return success();
 
     // Otherwise, handle dynamic legalization.
-    LLVM_DEBUG(llvm::dbgs() << "- Trying dynamic legalization.\n");
-    if (target.isDynamicallyLegal(op))
-      return success();
+    if (*action == LegalizationAction::Dynamic) {
+      LLVM_DEBUG(llvm::dbgs() << "- Trying dynamic legalization.\n");
+      if (target.isDynamicallyLegal(op))
+        return success();
+    }
 
     // Fallthough to see if a pattern can convert this into a legal operation.
   }
@@ -661,8 +675,7 @@ OperationLegalizer::legalize(Operation *op,
     return failure();
   }
 
-  // TODO(riverriddle) This currently has no cost model and doesn't prioritize
-  // specific patterns in any way.
+  // The patterns are sorted by expected benefit, so try to apply each in-order.
   for (auto *pattern : it->second)
     if (succeeded(legalizePattern(op, pattern, rewriter)))
       return success();
@@ -733,7 +746,7 @@ void OperationLegalizer::buildLegalizationGraph(
     auto root = pattern->getRootKind();
 
     // Skip operations that are always known to be legal.
-    if (target.getOpAction(root) == ConversionTarget::LegalizationAction::Legal)
+    if (target.getOpAction(root) == LegalizationAction::Legal)
       continue;
 
     // Add this pattern to the invalid set for the root op and record this root
@@ -751,7 +764,9 @@ void OperationLegalizer::buildLegalizationGraph(
 
     // Check to see if any of the generated operations are invalid.
     if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) {
-          return !legalizerPatterns.count(op) && !target.getOpAction(op);
+          auto action = target.getOpAction(op);
+          return !legalizerPatterns.count(op) &&
+                 (!action || action == LegalizationAction::Illegal);
         }))
       continue;
 
@@ -950,14 +965,19 @@ OperationConverter::convert(DialectConversionRewriter &rewriter,
 
   // Otherwise, legalize the given operation.
   auto *op = ptr.get<Operation *>();
-  auto result = opLegalizer.legalize(op, rewriter);
-
-  // Failed conversions are only important if this is a full conversion.
-  if (mode == OpConversionMode::Full && failed(result))
-    return op->emitError() << "failed to legalize operation '" << op->getName()
-                           << "'";
-
-  // In any other case, illegal operations are allowed to remain in the IR.
+  if (failed(opLegalizer.legalize(op, rewriter))) {
+    // Handle the case of a failed conversion for each of the different modes.
+    /// Full conversions expect all operations to be converted.
+    if (mode == OpConversionMode::Full)
+      return op->emitError()
+             << "failed to legalize operation '" << op->getName() << "'";
+    /// Partial conversions allow conversions to fail iff the operation was not
+    /// explicitly marked as illegal.
+    if (mode == OpConversionMode::Partial && opLegalizer.isIllegal(op))
+      return op->emitError()
+             << "failed to legalize operation '" << op->getName()
+             << "' that was explicitly marked illegal";
+  }
   return success();
 }
 
index 6c38b6e..dbcca99 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-legalize-patterns %s | FileCheck %s
+// RUN: mlir-opt -split-input-file -test-legalize-patterns -verify-diagnostics %s | FileCheck %s
 
 // CHECK-LABEL: verifyDirectPattern
 func @verifyDirectPattern() -> i32 {
@@ -99,3 +99,11 @@ func @up_to_date_replacement(%arg: i8) -> i8 {
   %repl_2 = "test.rewrite"(%repl_1) : (i8) -> i8
   return %repl_2 : i8
 }
+
+// -----
+
+func @fail_to_convert_illegal_op() -> i32 {
+  // expected-error@+1 {{failed to legalize operation 'test.illegal_op_f'}}
+  %result = "test.illegal_op_f"() : () -> (i32)
+  return %result : i32
+}
index 2640d41..17c170f 100644 (file)
@@ -416,6 +416,7 @@ def ILLegalOpB : TEST_Op<"illegal_op_b">, Results<(outs I32:$res)>;
 def ILLegalOpC : TEST_Op<"illegal_op_c">, Results<(outs I32:$res)>;
 def ILLegalOpD : TEST_Op<"illegal_op_d">, Results<(outs I32:$res)>;
 def ILLegalOpE : TEST_Op<"illegal_op_e">, Results<(outs I32:$res)>;
+def ILLegalOpF : TEST_Op<"illegal_op_f">, Results<(outs I32:$res)>;
 def LegalOpA : TEST_Op<"legal_op_a">,
   Arguments<(ins Test_LegalizerEnum:$status)>, Results<(outs I32:$res)>;
 
index ed5aea8..8b89a3a 100644 (file)
@@ -158,6 +158,7 @@ struct TestConversionTarget : public ConversionTarget {
   TestConversionTarget(MLIRContext &ctx) : ConversionTarget(ctx) {
     addLegalOp<LegalOpA, TestValidOp>();
     addDynamicallyLegalOp<TestReturnOp>();
+    addIllegalOp<ILLegalOpF>();
   }
   bool isDynamicallyLegal(Operation *op) const final {
     // Don't allow F32 operands.