From ce674b131b66105ecd3918a11de4eb0205b50f99 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Mon, 27 Jan 2020 19:04:55 -0800 Subject: [PATCH] [mlir] Add support for marking 'unknown' operations as dynamically legal. Summary: This allows for providing a default "catchall" legality check that is not dependent on specific operations or dialects. For example, this can be useful to check legality based on the specific types of operation operands or results. Differential Revision: https://reviews.llvm.org/D73379 --- mlir/docs/DialectConversion.md | 5 +++ mlir/include/mlir/Transforms/DialectConversion.h | 26 ++++++++++++--- mlir/lib/Transforms/DialectConversion.cpp | 41 +++++++++++++----------- mlir/test/Transforms/test-legalizer-full.mlir | 11 +++++++ mlir/test/lib/TestDialect/TestPatterns.cpp | 5 +++ 5 files changed, 66 insertions(+), 22 deletions(-) diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md index e6b652f..d064394 100644 --- a/mlir/docs/DialectConversion.md +++ b/mlir/docs/DialectConversion.md @@ -100,6 +100,11 @@ struct MyTarget : public ConversionTarget { /// callback. addDynamicallyLegalOp([](ReturnOp op) { ... }); + /// Treat unknown operations, i.e. those without a legalization action + /// directly set, as dynamically legal. + markUnknownOpDynamicallyLegal(); + markUnknownOpDynamicallyLegal([](Operation *op) { ... }); + //-------------------------------------------------------------------------- // Marking an operation as illegal. diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index aadb592..cd148f2 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -416,7 +416,8 @@ public: /// dynamically legal on the target. using DynamicLegalityCallbackFn = std::function; - ConversionTarget(MLIRContext &ctx) : ctx(ctx) {} + ConversionTarget(MLIRContext &ctx) + : unknownOpsDynamicallyLegal(false), ctx(ctx) {} virtual ~ConversionTarget() = default; //===--------------------------------------------------------------------===// @@ -532,6 +533,16 @@ public: setLegalityCallback(dialectNames, *callback); } + /// Register unknown operations as dynamically legal. For operations(and + /// dialects) that do not have a set legalization action, treat them as + /// dynamically legal and invoke the given callback if valid or + /// 'isDynamicallyLegal'. + void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn) { + unknownOpsDynamicallyLegal = true; + unknownLegalityFn = fn; + } + void markUnknownOpDynamicallyLegal() { unknownOpsDynamicallyLegal = true; } + /// Register the operations of the given dialects as illegal, i.e. /// operations of this dialect are not supported by the target. template @@ -585,6 +596,9 @@ private: /// If some legal instances of this operation may also be recursively legal. bool isRecursivelyLegal; + + /// The legality callback if this operation is dynamically legal. + Optional legalityFn; }; /// Get the legalization information for the given operation. @@ -594,9 +608,6 @@ private: /// information. llvm::MapVector legalOperations; - /// A set of dynamic legality callbacks for given operation names. - DenseMap opLegalityFns; - /// A set of legality callbacks for given operation names that are used to /// check if an operation instance is recursively legal. DenseMap opRecursiveLegalityFns; @@ -608,6 +619,13 @@ private: /// A set of dynamic legality callbacks for given dialect names. llvm::StringMap dialectLegalityFns; + /// An optional legality callback for unknown operations. + Optional unknownLegalityFn; + + /// Flag indicating if unknown operations should be treated as dynamically + /// legal. + bool unknownOpsDynamicallyLegal; + /// The current context this target applies to. MLIRContext &ctx; }; diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index e2cd12e..c6e7f9b 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -1704,19 +1704,11 @@ auto ConversionTarget::isLegal(Operation *op) const // Returns true if this operation instance is known to be legal. auto isOpLegal = [&] { - // Handle dynamic legality. - if (info->action == LegalizationAction::Dynamic) { - // Check for callbacks on the operation or dialect. - auto opFn = opLegalityFns.find(op->getName()); - if (opFn != opLegalityFns.end()) - return opFn->second(op); - auto dialectFn = dialectLegalityFns.find(op->getName().getDialect()); - if (dialectFn != dialectLegalityFns.end()) - return dialectFn->second(op); - - // Otherwise, invoke the hook on the derived instance. - return isDynamicallyLegal(op); - } + // Handle dynamic legality either with the provided legality function, or + // the default hook on the derived instance. + if (info->action == LegalizationAction::Dynamic) + return info->legalityFn ? (*info->legalityFn)(op) + : isDynamicallyLegal(op); // Otherwise, the operation is only legal if it was marked 'Legal'. return info->action == LegalizationAction::Legal; @@ -1726,7 +1718,6 @@ auto ConversionTarget::isLegal(Operation *op) const // This operation is legal, compute any additional legality information. LegalOpDetails legalityDetails; - if (info->isRecursivelyLegal) { auto legalityFnIt = opRecursiveLegalityFns.find(op->getName()); if (legalityFnIt != opRecursiveLegalityFns.end()) @@ -1741,7 +1732,11 @@ auto ConversionTarget::isLegal(Operation *op) const void ConversionTarget::setLegalityCallback( OperationName name, const DynamicLegalityCallbackFn &callback) { assert(callback && "expected valid legality callback"); - opLegalityFns[name] = callback; + auto infoIt = legalOperations.find(name); + assert(infoIt != legalOperations.end() && + infoIt->second.action == LegalizationAction::Dynamic && + "expected operation to already be marked as dynamically legal"); + infoIt->second.legalityFn = callback; } /// Set the recursive legality callback for the given operation and mark the @@ -1774,10 +1769,20 @@ auto ConversionTarget::getOpInfo(OperationName op) const auto it = legalOperations.find(op); if (it != legalOperations.end()) return it->second; - // Otherwise, default to checking on the parent dialect. + // Check for info for the parent dialect. auto dialectIt = legalDialects.find(op.getDialect()); - if (dialectIt != legalDialects.end()) - return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false}; + if (dialectIt != legalDialects.end()) { + Optional callback; + auto dialectFn = dialectLegalityFns.find(op.getDialect()); + if (dialectFn != dialectLegalityFns.end()) + callback = dialectFn->second; + return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false, + callback}; + } + // Otherwise, check if we mark unknown operations as dynamic. + if (unknownOpsDynamicallyLegal) + return LegalizationInfo{LegalizationAction::Dynamic, + /*isRecursivelyLegal=*/false, unknownLegalityFn}; return llvm::None; } diff --git a/mlir/test/Transforms/test-legalizer-full.mlir b/mlir/test/Transforms/test-legalizer-full.mlir index d0fc4c9..6bbda4a 100644 --- a/mlir/test/Transforms/test-legalizer-full.mlir +++ b/mlir/test/Transforms/test-legalizer-full.mlir @@ -58,3 +58,14 @@ func @test_undo_region_clone() { %ignored = "test.illegal_op_f"() : () -> (i32) "test.return"() : () -> () } + +// ----- + +// Test that unknown operations can be dynamically legal. +func @test_unknown_dynamically_legal() { + "foo.unknown_op"() {test.dynamically_legal} : () -> () + + // expected-error@+1 {{failed to legalize operation 'foo.unknown_op'}} + "foo.unknown_op"() {} : () -> () + "test.return"() : () -> () +} diff --git a/mlir/test/lib/TestDialect/TestPatterns.cpp b/mlir/test/lib/TestDialect/TestPatterns.cpp index d977748..d34181c 100644 --- a/mlir/test/lib/TestDialect/TestPatterns.cpp +++ b/mlir/test/lib/TestDialect/TestPatterns.cpp @@ -399,6 +399,11 @@ struct TestLegalizePatternDriver // Handle a full conversion. if (mode == ConversionMode::Full) { + // Check support for marking unknown operations as dynamically legal. + target.markUnknownOpDynamicallyLegal([](Operation *op) { + return (bool)op->getAttrOfType("test.dynamically_legal"); + }); + (void)applyFullConversion(getModule(), target, patterns, &converter); return; } -- 2.7.4