[mlir] Add `test-convergence` option to Canonicalizer tests
authorMatthias Springer <springerm@google.com>
Wed, 4 Jan 2023 10:39:41 +0000 (11:39 +0100)
committerMatthias Springer <springerm@google.com>
Wed, 4 Jan 2023 11:02:21 +0000 (12:02 +0100)
This new option is set to `false` by default. It should  be set only in Canonicalizer tests to detect faulty canonicalization patterns. I.e., patterns that prevent the canonicalizer from converging. The canonicalizer should always convergence on such small unit tests that we have in `canonicalize.mlir`.

Two faulty canonicalization patterns were detected and fixed with this change.

Differential Revision: https://reviews.llvm.org/D140873

28 files changed:
mlir/include/mlir/Transforms/Passes.td
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
mlir/lib/Transforms/Canonicalizer.cpp
mlir/test/Dialect/AMDGPU/canonicalize.mlir
mlir/test/Dialect/Affine/canonicalize.mlir
mlir/test/Dialect/Arith/canonicalize.mlir
mlir/test/Dialect/Bufferization/canonicalize.mlir
mlir/test/Dialect/Builtin/canonicalize.mlir
mlir/test/Dialect/Complex/canonicalize.mlir
mlir/test/Dialect/ControlFlow/canonicalize.mlir
mlir/test/Dialect/GPU/canonicalize.mlir
mlir/test/Dialect/LLVMIR/canonicalize.mlir
mlir/test/Dialect/Linalg/canonicalize.mlir
mlir/test/Dialect/Math/canonicalize.mlir
mlir/test/Dialect/MemRef/canonicalize.mlir
mlir/test/Dialect/OpenACC/canonicalize.mlir
mlir/test/Dialect/OpenMP/canonicalize.mlir
mlir/test/Dialect/PDL/canonicalize.mlir
mlir/test/Dialect/Quant/canonicalize.mlir
mlir/test/Dialect/SCF/canonicalize.mlir
mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
mlir/test/Dialect/Shape/canonicalize.mlir
mlir/test/Dialect/Tensor/canonicalize.mlir
mlir/test/Dialect/Tosa/canonicalize.mlir
mlir/test/Dialect/Vector/canonicalize.mlir
mlir/test/Pass/run-reproducer.mlir
mlir/test/Transforms/canonicalize.mlir

index fd6351d..e2a1593 100644 (file)
@@ -39,7 +39,9 @@ def Canonicalizer : Pass<"canonicalize"> {
            /*default=*/"10",
            "Max. iterations between applying patterns / simplifying regions">,
     Option<"maxNumRewrites", "max-num-rewrites", "int64_t", /*default=*/"-1",
-           "Max. number of pattern rewrites within an iteration">
+           "Max. number of pattern rewrites within an iteration">,
+    Option<"testConvergence", "test-convergence", "bool", /*default=*/"false",
+           "Test only: Fail pass on non-convergence to detect cyclic pattern">
   ] # RewritePassUtils.options;
 }
 
index d687043..8ef5483 100644 (file)
@@ -687,6 +687,8 @@ struct FoldLaunchArguments : public OpRewritePattern<LaunchOp> {
       // Check if size is trivially one.
       if (!matchPattern(size, m_One()))
         return;
+      if (id.getUses().empty())
+        return;
       if (!simplified) {
         // Create a zero value the first time.
         OpBuilder::InsertionGuard guard(rewriter);
@@ -694,7 +696,7 @@ struct FoldLaunchArguments : public OpRewritePattern<LaunchOp> {
         zero =
             rewriter.create<arith::ConstantIndexOp>(op.getLoc(), /*value=*/0);
       }
-      id.replaceAllUsesWith(zero);
+      rewriter.replaceAllUsesWith(id, zero);
       simplified = true;
     };
     constPropIdUses(op.getBlockIds().x, op.getGridSizeX());
index 7a5de2f..fb9835a 100644 (file)
@@ -178,16 +178,15 @@ struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> {
     // Early return if there is no condition.
     Value ifCond = op.getIfCond();
     if (!ifCond)
-      return success();
+      return failure();
 
     IntegerAttr constAttr;
-    if (matchPattern(ifCond, m_Constant(&constAttr))) {
-      if (constAttr.getInt())
-        rewriter.updateRootInPlace(op,
-                                   [&]() { op.getIfCondMutable().erase(0); });
-      else
-        rewriter.eraseOp(op);
-    }
+    if (!matchPattern(ifCond, m_Constant(&constAttr)))
+      return failure();
+    if (constAttr.getInt())
+      rewriter.updateRootInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
+    else
+      rewriter.eraseOp(op);
 
     return success();
   }
index a6fa775..b4ad85c 100644 (file)
@@ -57,8 +57,11 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
     config.enableRegionSimplification = enableRegionSimplification;
     config.maxIterations = maxIterations;
     config.maxNumRewrites = maxNumRewrites;
+    LogicalResult converged =
+        applyPatternsAndFoldGreedily(getOperation(), patterns, config);
     // Canonicalization is best-effort. Non-convergence is not a pass failure.
-    (void)applyPatternsAndFoldGreedily(getOperation(), patterns, config);
+    if (testConvergence && failed(converged))
+      signalPassFailure();
   }
 
   FrozenRewritePatternSet patterns;
index d984f8b..4559e39 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -split-input-file -canonicalize  | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -canonicalize="test-convergence" | FileCheck %s
 
 // CHECK-LABEL: func @known_oob_load
 func.func @known_oob_load(%arg0: memref<4xf32>) -> f32 {
index e47cdde..1dac401 100644 (file)
@@ -1,5 +1,5 @@
-// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -canonicalize | FileCheck %s
-// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -canonicalize="top-down=0" | FileCheck %s --check-prefix=CHECK-BOTTOM-UP
+// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -canonicalize="test-convergence" | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -canonicalize="test-convergence top-down=0" | FileCheck %s --check-prefix=CHECK-BOTTOM-UP
 
 // -----
 
index 02cbaa2..5806c9c 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -canonicalize --split-input-file | FileCheck %s
+// RUN: mlir-opt %s -canonicalize="test-convergence" --split-input-file | FileCheck %s
 
 // CHECK-LABEL: @select_same_val
 //       CHECK:   return %arg1
index df34039..fae4351 100644 (file)
@@ -1,5 +1,6 @@
-// RUN: mlir-opt %s -canonicalize --split-input-file \
-// RUN:   -allow-unregistered-dialect |\
+// RUN: mlir-opt %s \
+// RUN:   -canonicalize="test-convergence" \
+// RUN:   --split-input-file -allow-unregistered-dialect | \
 // RUN: FileCheck %s
 
 // Basic folding of to_tensor(to_memref(t)) -> t
index 6e29429..2e36b7e 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -canonicalize | FileCheck %s
+// RUN: mlir-opt %s -canonicalize="test-convergence" | FileCheck %s
 
 //===----------------------------------------------------------------------===//
 // UnrealizedConversionCastOp
index 1b85837..f0d287f 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -canonicalize | FileCheck %s
+// RUN: mlir-opt %s -canonicalize="test-convergence" | FileCheck %s
 
 // CHECK-LABEL: func @create_of_real_and_imag
 // CHECK-SAME: (%[[CPLX:.*]]: complex<f32>)
@@ -154,4 +154,4 @@ func.func @complex_sub_zero() -> complex<f32> {
   // CHECK-NEXT: return %[[CPLX:.*]] : complex<f32>
   %sub = complex.sub %complex1, %complex2 : complex<f32>
   return %sub : complex<f32>
-}
\ No newline at end of file
+}
index 8cef845..0ad6898 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline='builtin.module(func.func(canonicalize))' -split-input-file | FileCheck --dump-input-context 20 %s
+// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline='builtin.module(func.func(canonicalize{test-convergence}))' -split-input-file | FileCheck --dump-input-context 20 %s
 
 /// Test the folding of BranchOp.
 
index eedc238..99633ff 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -canonicalize --split-input-file -allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt %s -canonicalize="test-convergence" --split-input-file -allow-unregistered-dialect | FileCheck %s
 
 // Fold all the gpu.wait ops as they are redundant.
 // CHECK-LABEL: func @fold_wait_op_test1
index 9432eda..9a3309d 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -canonicalize %s -split-input-file | FileCheck %s
+// RUN: mlir-opt -canonicalize="test-convergence" %s -split-input-file | FileCheck %s
 
 // CHECK-LABEL: fold_extractvalue
 llvm.func @fold_extractvalue() -> i32 {
index 4510d20..9e4d886 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -canonicalize -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file | FileCheck %s
 
 // CHECK-LABEL: func @memref_cast(
 func.func @memref_cast(%a: index, %b: index) -> memref<?x?xf32> {
index f3825cd..7a5194b 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -canonicalize | FileCheck %s
+// RUN: mlir-opt %s -canonicalize="test-convergence" | FileCheck %s
 
 // CHECK-LABEL: @ceil_fold
 // CHECK: %[[cst:.+]] = arith.constant 1.000000e+00 : f32
index 88d9155..3d9f71e 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -canonicalize --split-input-file -allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt %s -canonicalize="test-convergence" --split-input-file -allow-unregistered-dialect | FileCheck %s
 
 // CHECK-LABEL: func @subview_of_size_memcast
 //  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8>
index 71c388c..10cb19f 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -canonicalize -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file | FileCheck %s
 
 func.func @testenterdataop(%a: memref<10xf32>) -> () {
   %ifCond = arith.constant true
index c5ab769..c5d18f3 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -canonicalize -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file | FileCheck %s
 
 func.func @update_no_op(%x : memref<i32>) {
   omp.atomic.update %x : memref<i32> {
index 94688a2..ee2a6f7 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -canonicalize %s | FileCheck %s
+// RUN: mlir-opt -canonicalize="test-convergence" %s | FileCheck %s
 
 pdl.pattern @operation_op : benefit(1) {
   %root = operation "foo.op"
index c67f129..36c3eaf 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize{test-convergence}))' | FileCheck %s
 
 // -----
 // CHECK-LABEL: redundant_scast
index e5e2afc..220adc5 100644 (file)
@@ -1,7 +1,4 @@
-// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize))' -split-input-file | FileCheck %s
-
-
-// -----
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize{test-convergence}))' -split-input-file | FileCheck %s
 
 func.func @single_iteration_some(%A: memref<?x?x?xi32>) {
   %c0 = arith.constant 0 : index
index e65f92e..518ad2e 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize{test-convergence}))' | FileCheck %s
 
 //===----------------------------------------------------------------------===//
 // spirv.AccessChain
index 1203182..aec5f32 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -canonicalize %s | FileCheck %s
+// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -canonicalize="test-convergence" %s | FileCheck %s
 
 // CHECK-LABEL: func @f
 func.func @f(%arg0: tensor<2x3x4xf32>) -> tensor<3xindex> {
index fed2ca7..2b11a33 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -split-input-file -canonicalize | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -canonicalize="test-convergence" | FileCheck %s
 
 // Checks that NOP casts are removed.
 // CHECK-LABEL: cast_values
index 7eea232..7464334 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --canonicalize %s | FileCheck %s
+// RUN: mlir-opt -canonicalize="test-convergence" %s | FileCheck %s
 
 // CHECK-LABEL: @argmax_nofold
 func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
index 2ebe2d7..1990b89 100644 (file)
@@ -1,6 +1,4 @@
-// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize))' -split-input-file -allow-unregistered-dialect | FileCheck %s
-
-// -----
+// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
 
 // CHECK-LABEL: create_vector_mask_to_constant_mask
 func.func @create_vector_mask_to_constant_mask() -> (vector<4x3xi1>) {
index 3a958f8..903fd69 100644 (file)
@@ -14,7 +14,7 @@ func.func @bar() {
   external_resources: {
     mlir_reproducer: {
       verify_each: true,
-      // CHECK:  builtin.module(func.func(cse,canonicalize{ max-iterations=1 max-num-rewrites=-1 region-simplify=false top-down=false}))
+      // CHECK:  builtin.module(func.func(cse,canonicalize{ max-iterations=1 max-num-rewrites=-1 region-simplify=false test-convergence=false top-down=false}))
       pipeline: "builtin.module(func.func(cse,canonicalize{max-iterations=1 max-num-rewrites=-1 region-simplify=false top-down=false}))",
       disable_threading: true
     }
index df1555d..5cc0eb5 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(canonicalize))' -split-input-file | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(canonicalize{test-convergence}))' -split-input-file | FileCheck %s
 
 // CHECK-LABEL: func @test_subi_zero
 func.func @test_subi_zero(%arg0: i32) -> i32 {