[mlir][ODS] Fix copy ctor for generate Pass classes
authorVladislav Vinogradov <vlad.vinogradov@intel.com>
Tue, 15 Jun 2021 15:09:31 +0000 (18:09 +0300)
committerVladislav Vinogradov <vlad.vinogradov@intel.com>
Mon, 21 Jun 2021 11:07:31 +0000 (14:07 +0300)
Redirect the copy ctor to the actual class instead of
overwriting it with `TypeID` based ctor.

This allows the final Pass classes to have extra fields and logic for their copy.

Reviewed By: lattner

Differential Revision: https://reviews.llvm.org/D104302

mlir/include/mlir/Pass/Pass.h
mlir/tools/mlir-tblgen/PassGen.cpp
mlir/unittests/TableGen/CMakeLists.txt
mlir/unittests/TableGen/PassGenTest.cpp [new file with mode: 0644]
mlir/unittests/TableGen/passes.td [new file with mode: 0644]

index da91ef3..3a9865d 100644 (file)
@@ -339,6 +339,7 @@ private:
 template <typename OpT = void> class OperationPass : public Pass {
 protected:
   OperationPass(TypeID passID) : Pass(passID, OpT::getOperationName()) {}
+  OperationPass(const OperationPass &) = default;
 
   /// Support isa/dyn_cast functionality.
   static bool classof(const Pass *pass) {
@@ -371,6 +372,7 @@ protected:
 template <> class OperationPass<void> : public Pass {
 protected:
   OperationPass(TypeID passID) : Pass(passID) {}
+  OperationPass(const OperationPass &) = default;
 };
 
 /// A model for providing function pass specific utilities.
@@ -409,6 +411,7 @@ public:
 
 protected:
   PassWrapper() : BaseT(TypeID::get<PassT>()) {}
+  PassWrapper(const PassWrapper &) = default;
 
   /// Returns the derived pass name.
   StringRef getName() const override { return llvm::getTypeName<PassT>(); }
index 8f3a19d..e09746b 100644 (file)
@@ -48,7 +48,7 @@ public:
   using Base = {0}Base;
 
   {0}Base() : {1}(::mlir::TypeID::get<DerivedT>()) {{}
-  {0}Base(const {0}Base &) : {1}(::mlir::TypeID::get<DerivedT>()) {{}
+  {0}Base(const {0}Base &other) : {1}(other) {{}
 
   /// Returns the command-line argument attached to this pass.
   static constexpr ::llvm::StringLiteral getArgumentName() {
index 7cee691..421133b 100644 (file)
@@ -8,15 +8,21 @@ mlir_tablegen(StructAttrGenTest.h.inc -gen-struct-attr-decls)
 mlir_tablegen(StructAttrGenTest.cpp.inc -gen-struct-attr-defs)
 add_public_tablegen_target(MLIRTableGenStructAttrIncGen)
 
+set(LLVM_TARGET_DEFINITIONS passes.td)
+mlir_tablegen(PassGenTest.h.inc -gen-pass-decls -name TableGenTest)
+add_public_tablegen_target(MLIRTableGenTestPassIncGen)
+
 add_mlir_unittest(MLIRTableGenTests
   EnumsGenTest.cpp
   StructsGenTest.cpp
   FormatTest.cpp
   OpBuildGen.cpp
+  PassGenTest.cpp
 )
 
 add_dependencies(MLIRTableGenTests MLIRTableGenEnumsIncGen)
 add_dependencies(MLIRTableGenTests MLIRTableGenStructAttrIncGen)
+add_dependencies(MLIRTableGenTests MLIRTableGenTestPassIncGen)
 add_dependencies(MLIRTableGenTests MLIRTestDialect)
 
 include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../test/lib/Dialect/Test)
diff --git a/mlir/unittests/TableGen/PassGenTest.cpp b/mlir/unittests/TableGen/PassGenTest.cpp
new file mode 100644 (file)
index 0000000..33bd160
--- /dev/null
@@ -0,0 +1,48 @@
+//===- PassGenTest.cpp - TableGen PassGen Tests ---------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Pass/Pass.h"
+
+#include "gmock/gmock.h"
+
+std::unique_ptr<mlir::Pass> createTestPass(int v = 0);
+
+#define GEN_PASS_REGISTRATION
+#include "PassGenTest.h.inc"
+
+#define GEN_PASS_CLASSES
+#include "PassGenTest.h.inc"
+
+struct TestPass : public TestPassBase<TestPass> {
+  explicit TestPass(int v) : extraVal(v) {}
+
+  void runOnOperation() override {}
+
+  std::unique_ptr<mlir::Pass> clone() const {
+    return TestPassBase<TestPass>::clone();
+  }
+
+  int extraVal;
+};
+
+std::unique_ptr<mlir::Pass> createTestPass(int v) {
+  return std::make_unique<TestPass>(v);
+}
+
+TEST(PassGenTest, PassClone) {
+  mlir::MLIRContext context;
+
+  const auto unwrap = [](const std::unique_ptr<mlir::Pass> &pass) {
+    return static_cast<const TestPass *>(pass.get());
+  };
+
+  const auto origPass = createTestPass(10);
+  const auto clonePass = unwrap(origPass)->clone();
+
+  EXPECT_EQ(unwrap(origPass)->extraVal, unwrap(clonePass)->extraVal);
+}
diff --git a/mlir/unittests/TableGen/passes.td b/mlir/unittests/TableGen/passes.td
new file mode 100644 (file)
index 0000000..f730390
--- /dev/null
@@ -0,0 +1,19 @@
+//===-- passes.td - PassGen test definition file -----------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+include "mlir/Pass/PassBase.td"
+include "mlir/Pass/PassBase.td"
+include "mlir/Rewrite/PassUtil.td"
+
+def TestPass : Pass<"test"> {
+  let summary = "Test pass";
+
+  let constructor = "::createTestPass()";
+
+  let options = RewritePassUtils.options;
+}