From ec03bbe8a74ae593d0ea5d8bf55c337e395873d1 Mon Sep 17 00:00:00 2001 From: Vladislav Vinogradov Date: Fri, 20 Aug 2021 14:25:42 +0300 Subject: [PATCH] [mlir] Fix bug in partial dialect conversion The discussion on forum: https://llvm.discourse.group/t/bug-in-partial-dialect-conversion/4115 The `applyPartialConversion` didn't handle the operations, that were marked as illegal inside dynamic legality callback. Instead of reporting error, if such operation was not converted to legal set, the method just added it to `unconvertedSet` in the same way as unknown operations. This patch fixes that and handle dynamically illegal operations as well. The patch includes 2 fixes for existing passes: * `tensor-bufferize` - explicitly mark `std.return` as legal. * `convert-parallel-loops-to-gpu` - ugly fix with marking visited operations to avoid recursive legality checks. Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D108505 --- mlir/include/mlir/Conversion/SCFToGPU/SCFToGPU.h | 4 + mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp | 30 ++++++- mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp | 1 + mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp | 1 + mlir/lib/Transforms/Utils/DialectConversion.cpp | 8 +- mlir/test/Transforms/test-legalizer-full.mlir | 105 +++++++++++++++-------- mlir/test/Transforms/test-legalizer.mlir | 67 ++++++++++----- mlir/test/lib/Dialect/Test/TestOps.td | 3 + mlir/test/lib/Dialect/Test/TestPatterns.cpp | 40 +++++++-- 9 files changed, 194 insertions(+), 65 deletions(-) diff --git a/mlir/include/mlir/Conversion/SCFToGPU/SCFToGPU.h b/mlir/include/mlir/Conversion/SCFToGPU/SCFToGPU.h index ac1ba0e..4838679 100644 --- a/mlir/include/mlir/Conversion/SCFToGPU/SCFToGPU.h +++ b/mlir/include/mlir/Conversion/SCFToGPU/SCFToGPU.h @@ -16,6 +16,7 @@ class ConversionTarget; struct LogicalResult; class MLIRContext; class Value; +class Operation; class RewritePatternSet; using OwningRewritePatternList = RewritePatternSet; @@ -49,6 +50,9 @@ void populateParallelLoopToGPUPatterns(RewritePatternSet &patterns); /// are not rewritten by the provided patterns are legal. void configureParallelLoopToGPULegality(ConversionTarget &target); +/// Clean up after applyPartialConversion/applyFullConversion call. +void finalizeParallelLoopToGPUConversion(Operation *op); + } // namespace mlir #endif // MLIR_CONVERSION_SCFTOGPU_SCFTOGPU_H_ diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp index d13cebe..9770299 100644 --- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp +++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp @@ -37,6 +37,24 @@ using namespace mlir; using namespace mlir::scf; +// Name of internal attribute to mark visited operations during conversion. +// +// NOTE: The conversion originally used the following legality criteria: +// `!parallelOp->hasAttr(gpu::getMappingAttrName())` +// But the provided pattern might reject some cases based on more detailed +// analysis of the `mapping` attribute. +// To avoid dialect conversion failure due to non-converted illegal operation +// we use this extra Unit attribute as a marker, that the operation was checked +// by the pattern and is should be considered as legal in the following legality +// checks. The `finalizeParallelLoopToGPUConversion` function performs clean up +// of this extra attributes ans is supposed to be called after the dialect +// conversion. +// +// TODO: Implement a cleaner solution, factoring out the "matching" logic +// from the pattern and its callees into a separate function that can be called +// from both the pattern and the op legality check. +static constexpr StringLiteral kVisitedAttrName = "SCFToGPU_visited"; + // Extract an indexed value from KernelDim3. static Value getDim3Value(const gpu::KernelDim3 &dim3, unsigned pos) { switch (pos) { @@ -567,6 +585,9 @@ static LogicalResult processParallelLoop( LogicalResult ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp, PatternRewriter &rewriter) const { + // Mark the operation as visited for recursive legality check. + parallelOp->setAttr(kVisitedAttrName, rewriter.getUnitAttr()); + // We can only transform starting at the outer-most loop. Launches inside of // parallel loops are not supported. if (auto parentLoop = parallelOp->getParentOfType()) @@ -649,6 +670,13 @@ void mlir::populateParallelLoopToGPUPatterns(RewritePatternSet &patterns) { void mlir::configureParallelLoopToGPULegality(ConversionTarget &target) { target.addLegalDialect(); target.addDynamicallyLegalOp([](scf::ParallelOp parallelOp) { - return !parallelOp->getAttr(gpu::getMappingAttrName()); + return !parallelOp->hasAttr(gpu::getMappingAttrName()) || + parallelOp->hasAttr(kVisitedAttrName); + }); +} + +void mlir::finalizeParallelLoopToGPUConversion(Operation *op) { + op->walk([](scf::ParallelOp parallelOp) { + parallelOp->removeAttr(kVisitedAttrName); }); } diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp index 43c6798..e9a8df0 100644 --- a/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp +++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp @@ -55,6 +55,7 @@ struct ParallelLoopToGpuPass if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); + finalizeParallelLoopToGPUConversion(getOperation()); } }; diff --git a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp index f9faba0..f5f7b0f 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp @@ -175,6 +175,7 @@ struct TensorBufferizePass : public TensorBufferizeBase { target.addLegalDialect(); target.addDynamicallyLegalDialect( [&](Operation *op) { return typeConverter.isLegal(op); }); + target.addLegalOp(); target.addLegalDialect(); if (failed( diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 6aa42f6..4f1c8cf 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1650,7 +1650,13 @@ OperationLegalizer::OperationLegalizer(ConversionTarget &targetInfo, bool OperationLegalizer::isIllegal(Operation *op) const { // Check if the target explicitly marked this operation as illegal. - return target.getOpAction(op->getName()) == LegalizationAction::Illegal; + if (auto info = target.getOpAction(op->getName())) { + if (*info == LegalizationAction::Dynamic) + return !target.isLegal(op); + return *info == LegalizationAction::Illegal; + } + + return false; } LogicalResult diff --git a/mlir/test/Transforms/test-legalizer-full.mlir b/mlir/test/Transforms/test-legalizer-full.mlir index 3cbc173..5480d3d 100644 --- a/mlir/test/Transforms/test-legalizer-full.mlir +++ b/mlir/test/Transforms/test-legalizer-full.mlir @@ -47,55 +47,88 @@ func @recursively_legal_invalid_op() { // ----- -// Test that region cloning can be properly undone. -func @test_undo_region_clone() { - "test.region"() ({ - ^bb1(%i0: i64): - "test.invalid"(%i0) : (i64) -> () - }) {legalizer.should_clone} : () -> () - - // expected-error@+1 {{failed to legalize operation 'test.illegal_op_f'}} - %ignored = "test.illegal_op_f"() : () -> (i32) - "test.return"() : () -> () +// expected-remark@+1 {{applyFullConversion failed}} +builtin.module { + + // Test that region cloning can be properly undone. + func @test_undo_region_clone() { + "test.region"() ({ + ^bb1(%i0: i64): + "test.invalid"(%i0) : (i64) -> () + }) {legalizer.should_clone} : () -> () + + // expected-error@+1 {{failed to legalize operation 'test.illegal_op_f'}} + %ignored = "test.illegal_op_f"() : () -> (i32) + "test.return"() : () -> () + } + } // ----- -// Test that unknown operations can be dynamically legal. -func @test_unknown_dynamically_legal() { - "foo.unknown_op"() {test.dynamically_legal} : () -> () +// expected-remark@+1 {{applyFullConversion failed}} +builtin.module { + + // Test that unknown operations can be dynamically legal. + func @test_unknown_dynamically_legal() { + "foo.unknown_op"() {test.dynamically_legal} : () -> () + + // expected-error@+1 {{failed to legalize operation 'foo.unknown_op'}} + "foo.unknown_op"() {} : () -> () + "test.return"() : () -> () + } - // expected-error@+1 {{failed to legalize operation 'foo.unknown_op'}} - "foo.unknown_op"() {} : () -> () - "test.return"() : () -> () } // ----- -// Test that region inlining can be properly undone. -func @test_undo_region_inline() { - "test.region"() ({ - ^bb1(%i0: i64): - // expected-error@+1 {{failed to legalize operation 'std.br'}} - br ^bb2(%i0 : i64) - ^bb2(%i1: i64): - "test.invalid"(%i1) : (i64) -> () - }) {} : () -> () +// expected-remark@+1 {{applyFullConversion failed}} +builtin.module { + + // Test that region inlining can be properly undone. + func @test_undo_region_inline() { + "test.region"() ({ + ^bb1(%i0: i64): + // expected-error@+1 {{failed to legalize operation 'std.br'}} + br ^bb2(%i0 : i64) + ^bb2(%i1: i64): + "test.invalid"(%i1) : (i64) -> () + }) {} : () -> () + + "test.return"() : () -> () + } - "test.return"() : () -> () } // ----- -// Test that multiple block erases can be properly undone. -func @test_undo_block_erase() { - // expected-error@+1 {{failed to legalize operation 'test.region'}} - "test.region"() ({ - ^bb1(%i0: i64): - br ^bb2(%i0 : i64) - ^bb2(%i1: i64): - "test.invalid"(%i1) : (i64) -> () - }) {legalizer.should_clone, legalizer.erase_old_blocks} : () -> () +// expected-remark@+1 {{applyFullConversion failed}} +builtin.module { + + // Test that multiple block erases can be properly undone. + func @test_undo_block_erase() { + // expected-error@+1 {{failed to legalize operation 'test.region'}} + "test.region"() ({ + ^bb1(%i0: i64): + br ^bb2(%i0 : i64) + ^bb2(%i1: i64): + "test.invalid"(%i1) : (i64) -> () + }) {legalizer.should_clone, legalizer.erase_old_blocks} : () -> () + + "test.return"() : () -> () + } + +} + +// ----- + +// expected-remark@+1 {{applyFullConversion failed}} +builtin.module { + + func @create_unregistered_op_in_pattern() -> i32 { + // expected-error@+1 {{failed to legalize operation 'test.illegal_op_g'}} + %0 = "test.illegal_op_g"() : () -> (i32) + "test.return"(%0) : (i32) -> () + } - "test.return"() : () -> () } diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir index 0603883..25c3eb3 100644 --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -173,18 +173,28 @@ func @bounded_recursion() { // ----- -func @fail_to_convert_illegal_op() -> i32 { - // expected-error@+1 {{failed to legalize operation 'test.illegal_op_f'}} - %result = "test.illegal_op_f"() : () -> (i32) - return %result : i32 +// expected-remark@+1 {{applyPartialConversion failed}} +builtin.module { + + func @fail_to_convert_illegal_op() -> i32 { + // expected-error@+1 {{failed to legalize operation 'test.illegal_op_f'}} + %result = "test.illegal_op_f"() : () -> (i32) + return %result : i32 + } + } // ----- -func @fail_to_convert_illegal_op_in_region() { - // expected-error@+1 {{failed to legalize operation 'test.region_builder'}} - "test.region_builder"() : () -> () - return +// expected-remark@+1 {{applyPartialConversion failed}} +builtin.module { + + func @fail_to_convert_illegal_op_in_region() { + // expected-error@+1 {{failed to legalize operation 'test.region_builder'}} + "test.region_builder"() : () -> () + return + } + } // ----- @@ -192,17 +202,21 @@ func @fail_to_convert_illegal_op_in_region() { // Check that the entry block arguments of a region are untouched in the case // of failure. -// CHECK-LABEL: func @fail_to_convert_region -func @fail_to_convert_region() { - // CHECK-NEXT: "test.region" - // CHECK-NEXT: ^bb{{.*}}(%{{.*}}: i64): - "test.region"() ({ - ^bb1(%i0: i64): - // expected-error@+1 {{failed to legalize operation 'test.region_builder'}} - "test.region_builder"() : () -> () - "test.valid"() : () -> () - }) : () -> () - return +// expected-remark@+1 {{applyPartialConversion failed}} +builtin.module { + + func @fail_to_convert_region() { + // CHECK: "test.region" + // CHECK-NEXT: ^bb{{.*}}(%{{.*}}: i64): + "test.region"() ({ + ^bb1(%i0: i64): + // expected-error@+1 {{failed to legalize operation 'test.region_builder'}} + "test.region_builder"() : () -> () + "test.valid"() : () -> () + }) : () -> () + return + } + } // ----- @@ -271,10 +285,8 @@ func @undo_child_created_before_parent() { return } - // ----- - // Check that a conversion pattern on `test.blackhole` can mark the producer // for deletion. // CHECK-LABEL: @blackhole @@ -284,3 +296,16 @@ func @blackhole() { // expected-remark@+1 {{op 'std.return' is not legalizable}} return } + +// ----- + +// expected-remark@+1 {{applyPartialConversion failed}} +builtin.module { + + func @create_unregistered_op_in_pattern() -> i32 { + // expected-error@+1 {{failed to legalize operation 'test.illegal_op_g'}} + %0 = "test.illegal_op_g"() : () -> (i32) + "test.return"(%0) : (i32) -> () + } + +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index e17a76b..a887adb 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1415,9 +1415,12 @@ def ILLegalOpC : TEST_Op<"illegal_op_c">, Results<(outs I32)>; def ILLegalOpD : TEST_Op<"illegal_op_d">, Results<(outs I32)>; def ILLegalOpE : TEST_Op<"illegal_op_e">, Results<(outs I32)>; def ILLegalOpF : TEST_Op<"illegal_op_f">, Results<(outs I32)>; +def ILLegalOpG : TEST_Op<"illegal_op_g">, Results<(outs I32)>; def LegalOpA : TEST_Op<"legal_op_a">, Arguments<(ins Test_LegalizerEnum:$status)>, Results<(outs I32)>; def LegalOpB : TEST_Op<"legal_op_b">, Results<(outs I32)>; +def LegalOpC : TEST_Op<"legal_op_c">, + Arguments<(ins I32)>, Results<(outs I32)>; // Check that the conversion infrastructure can properly undo the creation of // operations where an operation was created before its parent, in this case, diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index de141dc..d51cf5e 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -562,6 +562,20 @@ struct TestReplaceEraseOp : public OpRewritePattern { return success(); }; }; + +// This pattern replaces explicitly illegal op with explicitly legal op, +// but in addition creates unregistered operation. +struct TestCreateUnregisteredOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ILLegalOpG op, + PatternRewriter &rewriter) const final { + IntegerAttr attr = rewriter.getI32IntegerAttr(0); + Value val = rewriter.create(op->getLoc(), attr); + rewriter.replaceOpWithNewOp(op, val); + return success(); + }; +}; } // namespace namespace { @@ -632,6 +646,10 @@ struct TestLegalizePatternDriver TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnOperation() override { TestTypeConverter converter; mlir::RewritePatternSet patterns(&getContext()); @@ -643,8 +661,8 @@ struct TestLegalizePatternDriver TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, TestNonRootReplacement, TestBoundedRecursiveRewrite, - TestNestedOpCreationUndoRewrite, TestReplaceEraseOp>( - &getContext()); + TestNestedOpCreationUndoRewrite, TestReplaceEraseOp, + TestCreateUnregisteredOp>(&getContext()); patterns.add(&getContext(), converter); mlir::populateFuncOpTypeConversionPattern(patterns, converter); mlir::populateCallOpTypeConversionPattern(patterns, converter); @@ -652,7 +670,7 @@ struct TestLegalizePatternDriver // Define the conversion target used for the test. ConversionTarget target(getContext()); target.addLegalOp(); - target.addLegalOp(); target .addIllegalOp(); @@ -666,6 +684,11 @@ struct TestLegalizePatternDriver converter.isLegal(&op.getBody()); }); + // TestCreateUnregisteredOp creates `std.constant` operation, + // which was not added to target intentionally to test + // correct error code from conversion driver. + target.addDynamicallyLegalOp([](ILLegalOpG) { return false; }); + // Expect the type_producer/type_consumer operations to only operate on f64. target.addDynamicallyLegalOp( [](TestTypeProducerOp op) { return op.getType().isF64(); }); @@ -686,8 +709,10 @@ struct TestLegalizePatternDriver // Handle a partial conversion. if (mode == ConversionMode::Partial) { DenseSet unlegalizedOps; - (void)applyPartialConversion(getOperation(), target, std::move(patterns), - &unlegalizedOps); + if (failed(applyPartialConversion( + getOperation(), target, std::move(patterns), &unlegalizedOps))) { + getOperation()->emitRemark() << "applyPartialConversion failed"; + } // Emit remarks for each legalizable operation. for (auto *op : unlegalizedOps) op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; @@ -701,7 +726,10 @@ struct TestLegalizePatternDriver return (bool)op->getAttrOfType("test.dynamically_legal"); }); - (void)applyFullConversion(getOperation(), target, std::move(patterns)); + if (failed(applyFullConversion(getOperation(), target, + std::move(patterns)))) { + getOperation()->emitRemark() << "applyFullConversion failed"; + } return; } -- 2.7.4