[mlir] Add unit test for disabling canonicalizer patterns (NFC)
authorMogball <jeffniu22@gmail.com>
Wed, 22 Dec 2021 21:06:01 +0000 (21:06 +0000)
committerMogball <jeffniu22@gmail.com>
Wed, 22 Dec 2021 21:07:06 +0000 (21:07 +0000)
mlir/unittests/Transforms/CMakeLists.txt
mlir/unittests/Transforms/Canonicalizer.cpp [new file with mode: 0644]
utils/bazel/llvm-project-overlay/mlir/unittests/BUILD.bazel

index 9636f93..b78f3cd 100644 (file)
@@ -1,4 +1,5 @@
 add_mlir_unittest(MLIRTransformsTests
+  Canonicalizer.cpp
   DialectConversion.cpp
 )
 target_link_libraries(MLIRTransformsTests
diff --git a/mlir/unittests/Transforms/Canonicalizer.cpp b/mlir/unittests/Transforms/Canonicalizer.cpp
new file mode 100644 (file)
index 0000000..de96832
--- /dev/null
@@ -0,0 +1,84 @@
+//===- DialectConversion.cpp - Dialect conversion unit tests --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Parser.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/Passes.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+
+namespace {
+
+struct DisabledPattern : public RewritePattern {
+  DisabledPattern(MLIRContext *context)
+      : RewritePattern("test.foo", /*benefit=*/0, context,
+                       /*generatedNamed=*/{}) {
+    setDebugName("DisabledPattern");
+  }
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    if (op->getNumResults() != 1)
+      return failure();
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
+struct EnabledPattern : public RewritePattern {
+  EnabledPattern(MLIRContext *context)
+      : RewritePattern("test.foo", /*benefit=*/0, context,
+                       /*generatedNamed=*/{}) {
+    setDebugName("EnabledPattern");
+  }
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    if (op->getNumResults() == 1)
+      return failure();
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
+struct TestDialect : public Dialect {
+  static StringRef getDialectNamespace() { return "test"; }
+
+  TestDialect(MLIRContext *context)
+      : Dialect(getDialectNamespace(), context, TypeID::get<TestDialect>()) {
+    allowUnknownOperations();
+  }
+
+  void getCanonicalizationPatterns(RewritePatternSet &results) const override {
+    results.insert<DisabledPattern, EnabledPattern>(results.getContext());
+  }
+};
+
+TEST(CanonicalizerTest, TestDisablePatterns) {
+  MLIRContext context;
+  context.getOrLoadDialect<TestDialect>();
+  PassManager mgr(&context);
+  mgr.addPass(
+      createCanonicalizerPass(GreedyRewriteConfig(), {"DisabledPattern"}));
+
+  const char *const code = R"mlir(
+    %0:2 = "test.foo"() {sym_name = "A"} : () -> (i32, i32)
+    %1 = "test.foo"() {sym_name = "B"} : () -> (f32)
+  )mlir";
+
+  OwningModuleRef module = mlir::parseSourceString(code, &context);
+  ASSERT_TRUE(succeeded(mgr.run(*module)));
+
+  EXPECT_TRUE(module->lookupSymbol("B"));
+  EXPECT_FALSE(module->lookupSymbol("A"));
+}
+
+} // end anonymous namespace
index 5ee06e7..73560d6 100644 (file)
@@ -266,7 +266,11 @@ cc_test(
     ]),
     deps = [
         "//llvm:gtest_main",
+        "//mlir:IR",
+        "//mlir:Parser",
+        "//mlir:Pass",
         "//mlir:TransformUtils",
+        "//mlir:Transforms",
     ],
 )