From 491ef84dc445ffdda80d8eaa6f92e8951567b183 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Wed, 17 Jul 2019 09:26:57 -0700 Subject: [PATCH] Add support for explicitly marking dialects and operations as illegal. This explicit tag is useful is several ways: *) This simplifies how to mark sub sections of a dialect as explicitly unsupported, e.g. my target supports all operations in the foo dialect except for these select few. This is useful for partial lowerings between dialects. *) Partial conversions will now verify that operations that were explicitly marked as illegal must be converted. This provides some guarantee that the operations that need to be lowered by a specific pass will be. PiperOrigin-RevId: 258582879 --- mlir/include/mlir/Transforms/DialectConversion.h | 27 +++++++++++- mlir/lib/Transforms/DialectConversion.cpp | 52 ++++++++++++++++-------- mlir/test/Transforms/test-legalizer.mlir | 10 ++++- mlir/test/lib/TestDialect/TestOps.td | 1 + mlir/test/lib/TestDialect/TestPatterns.cpp | 1 + 5 files changed, 73 insertions(+), 18 deletions(-) diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 79f0d38..9944679 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -255,7 +255,10 @@ public: /// This operation has dynamic legalization constraints that must be checked /// by the target. - Dynamic + Dynamic, + + /// This target explicitly does not support this operation. + Illegal, }; /// The type used to store operation legality information. @@ -301,6 +304,16 @@ public: addDynamicallyLegalOp(); } + /// Register the given operation as illegal, i.e. this operation is known to + /// not be supported by this target. + template void addIllegalOp() { + setOpAction(LegalizationAction::Illegal); + } + template void addIllegalOp() { + addIllegalOp(); + addIllegalOp(); + } + /// Register a legality action for the given dialects. void setDialectAction(ArrayRef dialectNames, LegalizationAction action); @@ -328,6 +341,18 @@ public: setDialectAction(dialectNames, LegalizationAction::Dynamic); } + /// Register the operations of the given dialects as illegal, i.e. + /// operations of this dialect are not supported by the target. + template + void addIllegalDialect(StringRef name, Names... names) { + SmallVector dialectNames({name, names...}); + setDialectAction(dialectNames, LegalizationAction::Illegal); + } + template void addIllegalDialect() { + SmallVector dialectNames({Args::getDialectNamespace()...}); + setDialectAction(dialectNames, LegalizationAction::Illegal); + } + //===--------------------------------------------------------------------===// // Legality Querying //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index fd01ad9..806eb64 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -589,6 +589,8 @@ using LegalizationPatterns = SmallVector; /// This class defines a recursive operation legalizer. class OperationLegalizer { public: + using LegalizationAction = ConversionTarget::LegalizationAction; + OperationLegalizer(ConversionTarget &targetInfo, OwningRewritePatternList &patterns) : target(targetInfo) { @@ -596,6 +598,9 @@ public: computeLegalizationGraphBenefit(); } + /// Returns if the given operation is known to be illegal on the target. + bool isIllegal(Operation *op) const; + /// Attempt to legalize the given operation. Returns success if the operation /// was legalized, failure otherwise. LogicalResult legalize(Operation *op, DialectConversionRewriter &rewriter); @@ -634,6 +639,13 @@ private: }; } // namespace +bool OperationLegalizer::isIllegal(Operation *op) const { + // Check if the target explicitly marked this operation as illegal. + if (auto action = target.getOpAction(op->getName())) + return action == LegalizationAction::Illegal; + return false; +} + LogicalResult OperationLegalizer::legalize(Operation *op, DialectConversionRewriter &rewriter) { @@ -643,13 +655,15 @@ OperationLegalizer::legalize(Operation *op, // Check if this was marked legal by the target. if (auto action = target.getOpAction(op->getName())) { // Check if this operation is always legal. - if (*action == ConversionTarget::LegalizationAction::Legal) + if (*action == LegalizationAction::Legal) return success(); // Otherwise, handle dynamic legalization. - LLVM_DEBUG(llvm::dbgs() << "- Trying dynamic legalization.\n"); - if (target.isDynamicallyLegal(op)) - return success(); + if (*action == LegalizationAction::Dynamic) { + LLVM_DEBUG(llvm::dbgs() << "- Trying dynamic legalization.\n"); + if (target.isDynamicallyLegal(op)) + return success(); + } // Fallthough to see if a pattern can convert this into a legal operation. } @@ -661,8 +675,7 @@ OperationLegalizer::legalize(Operation *op, return failure(); } - // TODO(riverriddle) This currently has no cost model and doesn't prioritize - // specific patterns in any way. + // The patterns are sorted by expected benefit, so try to apply each in-order. for (auto *pattern : it->second) if (succeeded(legalizePattern(op, pattern, rewriter))) return success(); @@ -733,7 +746,7 @@ void OperationLegalizer::buildLegalizationGraph( auto root = pattern->getRootKind(); // Skip operations that are always known to be legal. - if (target.getOpAction(root) == ConversionTarget::LegalizationAction::Legal) + if (target.getOpAction(root) == LegalizationAction::Legal) continue; // Add this pattern to the invalid set for the root op and record this root @@ -751,7 +764,9 @@ void OperationLegalizer::buildLegalizationGraph( // Check to see if any of the generated operations are invalid. if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) { - return !legalizerPatterns.count(op) && !target.getOpAction(op); + auto action = target.getOpAction(op); + return !legalizerPatterns.count(op) && + (!action || action == LegalizationAction::Illegal); })) continue; @@ -950,14 +965,19 @@ OperationConverter::convert(DialectConversionRewriter &rewriter, // Otherwise, legalize the given operation. auto *op = ptr.get(); - auto result = opLegalizer.legalize(op, rewriter); - - // Failed conversions are only important if this is a full conversion. - if (mode == OpConversionMode::Full && failed(result)) - return op->emitError() << "failed to legalize operation '" << op->getName() - << "'"; - - // In any other case, illegal operations are allowed to remain in the IR. + if (failed(opLegalizer.legalize(op, rewriter))) { + // Handle the case of a failed conversion for each of the different modes. + /// Full conversions expect all operations to be converted. + if (mode == OpConversionMode::Full) + return op->emitError() + << "failed to legalize operation '" << op->getName() << "'"; + /// Partial conversions allow conversions to fail iff the operation was not + /// explicitly marked as illegal. + if (mode == OpConversionMode::Partial && opLegalizer.isIllegal(op)) + return op->emitError() + << "failed to legalize operation '" << op->getName() + << "' that was explicitly marked illegal"; + } return success(); } diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir index 6c38b6e..dbcca99 100644 --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -test-legalize-patterns %s | FileCheck %s +// RUN: mlir-opt -split-input-file -test-legalize-patterns -verify-diagnostics %s | FileCheck %s // CHECK-LABEL: verifyDirectPattern func @verifyDirectPattern() -> i32 { @@ -99,3 +99,11 @@ func @up_to_date_replacement(%arg: i8) -> i8 { %repl_2 = "test.rewrite"(%repl_1) : (i8) -> i8 return %repl_2 : i8 } + +// ----- + +func @fail_to_convert_illegal_op() -> i32 { + // expected-error@+1 {{failed to legalize operation 'test.illegal_op_f'}} + %result = "test.illegal_op_f"() : () -> (i32) + return %result : i32 +} diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td index 2640d41..17c170f 100644 --- a/mlir/test/lib/TestDialect/TestOps.td +++ b/mlir/test/lib/TestDialect/TestOps.td @@ -416,6 +416,7 @@ def ILLegalOpB : TEST_Op<"illegal_op_b">, Results<(outs I32:$res)>; def ILLegalOpC : TEST_Op<"illegal_op_c">, Results<(outs I32:$res)>; def ILLegalOpD : TEST_Op<"illegal_op_d">, Results<(outs I32:$res)>; def ILLegalOpE : TEST_Op<"illegal_op_e">, Results<(outs I32:$res)>; +def ILLegalOpF : TEST_Op<"illegal_op_f">, Results<(outs I32:$res)>; def LegalOpA : TEST_Op<"legal_op_a">, Arguments<(ins Test_LegalizerEnum:$status)>, Results<(outs I32:$res)>; diff --git a/mlir/test/lib/TestDialect/TestPatterns.cpp b/mlir/test/lib/TestDialect/TestPatterns.cpp index ed5aea8..8b89a3a 100644 --- a/mlir/test/lib/TestDialect/TestPatterns.cpp +++ b/mlir/test/lib/TestDialect/TestPatterns.cpp @@ -158,6 +158,7 @@ struct TestConversionTarget : public ConversionTarget { TestConversionTarget(MLIRContext &ctx) : ConversionTarget(ctx) { addLegalOp(); addDynamicallyLegalOp(); + addIllegalOp(); } bool isDynamicallyLegal(Operation *op) const final { // Don't allow F32 operands. -- 2.7.4