[mlir] clean up transform dialect definitions, NFC
authorAlex Zinenko <zinenko@google.com>
Tue, 4 Oct 2022 15:49:30 +0000 (15:49 +0000)
committerAlex Zinenko <zinenko@google.com>
Tue, 11 Oct 2022 09:55:10 +0000 (09:55 +0000)
Refactor the definition of the Transform dialect to move non-trivial
method implementations out of the .td file, and detemplatize functions
when possible while moving their implementations to a .cpp.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D135165

mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
mlir/lib/Dialect/Transform/IR/TransformDialect.cpp

index 055cf71..80ff80e 100644 (file)
@@ -9,7 +9,6 @@
 #ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H
 #define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H
 
-#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/LLVM.h"
@@ -23,21 +22,7 @@ namespace detail {
 /// Asserts that the operations provided as template arguments implement the
 /// TransformOpInterface and MemoryEffectsOpInterface. This must be a dynamic
 /// assertion since interface implementations may be registered at runtime.
-template <typename OpTy>
-static inline void checkImplementsTransformInterface(MLIRContext *context) {
-  // Since the operation is being inserted into the Transform dialect and the
-  // dialect does not implement the interface fallback, only check for the op
-  // itself having the interface implementation.
-  RegisteredOperationName opName =
-      *RegisteredOperationName::lookup(OpTy::getOperationName(), context);
-  assert((opName.hasInterface<TransformOpInterface>() ||
-          opName.hasTrait<OpTrait::IsTerminator>()) &&
-         "non-terminator ops injected into the transform dialect must "
-         "implement TransformOpInterface");
-  assert(opName.hasInterface<MemoryEffectOpInterface>() &&
-         "ops injected into the transform dialect must implement "
-         "MemoryEffectsOpInterface");
-}
+void checkImplementsTransformOpInterface(StringRef name, MLIRContext *context);
 
 /// Asserts that the type provided as template argument implements the
 /// TransformTypeInterface. This must be a dynamic assertion since interface
@@ -200,6 +185,25 @@ private:
   bool buildOnly;
 };
 
+template <typename OpTy>
+void TransformDialect::addOperationIfNotRegistered() {
+  StringRef name = OpTy::getOperationName();
+  Optional<RegisteredOperationName> opName =
+      RegisteredOperationName::lookup(name, getContext());
+  if (!opName) {
+    addOperations<OpTy>();
+#ifndef NDEBUG
+    detail::checkImplementsTransformOpInterface(name, getContext());
+#endif // NDEBUG
+    return;
+  }
+
+  if (opName->getTypeID() == TypeID::get<OpTy>())
+    return;
+
+  reportDuplicateOpRegistration(name);
+}
+
 template <typename Type>
 void TransformDialect::addTypeIfNotRegistered() {
   // Use the address of the parse method as a proxy for identifying whether we
@@ -210,6 +214,8 @@ void TransformDialect::addTypeIfNotRegistered() {
     const ExtensionTypeParsingHook &parsingHook = it->getValue();
     if (*parsingHook.target<mlir::Type (*)(AsmParser &)>() != &Type::parse)
       reportDuplicateTypeRegistration(mnemonic);
+    else
+      return;
   }
   typePrintingHooks.try_emplace(
       TypeID::get<Type>(), +[](mlir::Type type, AsmPrinter &printer) {
@@ -217,6 +223,11 @@ void TransformDialect::addTypeIfNotRegistered() {
         cast<Type>(type).print(printer);
       });
   addTypes<Type>();
+
+#ifndef NDEBUG
+  detail::checkImplementsTransformTypeInterface(TypeID::get<Type>(),
+                                                getContext());
+#endif // NDEBUG
 }
 
 /// A wrapper for transform dialect extensions that forces them to be
index dd2b55f..18fa686 100644 (file)
@@ -333,33 +333,17 @@ def Transform_Dialect : Dialect {
           std::function<void (::mlir::Type, ::mlir::AsmPrinter &)>;
 
     private:
-      template <typename OpTy>
-      void addOperationIfNotRegistered() {
-        Optional<RegisteredOperationName> opName =
-            RegisteredOperationName::lookup(OpTy::getOperationName(),
-                                            getContext());
-        if (!opName)
-          return addOperations<OpTy>();
-
-        if (opName->getTypeID() == TypeID::get<OpTy>())
-          return;
-
-        llvm::errs() << "error: extensible dialect operation '"
-                     << OpTy::getOperationName()
-                     << "' is already registered with a mismatching TypeID";
-        abort();
-      }
-
       /// Registers operations specified as template parameters with this
       /// dialect. Checks that they implement the required interfaces.
       template <typename... OpTys>
       void addOperationsChecked() {
-        (addOperationIfNotRegistered<OpTys>(),...);
-
-        #ifndef NDEBUG
-        (detail::checkImplementsTransformInterface<OpTys>(getContext()),...);
-        #endif // NDEBUG
+        (addOperationIfNotRegistered<OpTys>(), ...);
       }
+      template <typename OpTy>
+      void addOperationIfNotRegistered();
+
+      /// Reports a repeated registration error of an op with the given name.
+      [[noreturn]] void reportDuplicateOpRegistration(StringRef opName);
 
       /// Registers the types specified as template parameters with the
       /// Transform dialect. Checks that they meet the requirements for
@@ -367,15 +351,7 @@ def Transform_Dialect : Dialect {
       template <typename... TypeTys>
       void addTypesChecked() {
         (addTypeIfNotRegistered<TypeTys>(), ...);
-
-        #ifndef NDEBUG
-        (detail::checkImplementsTransformTypeInterface(
-            TypeID::get<TypeTys>(), getContext()), ...);
-        #endif // NDEBUG
       }
-
-      /// Implementation of the type registration for a single type, should
-      /// not be called directly, use addTypesChecked instead.
       template <typename Type>
       void addTypeIfNotRegistered();
 
index adf9f0f..49e3924 100644 (file)
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/PDL/IR/PDL.h"
 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/Dialect/Transform/IR/TransformOps.h"
 #include "mlir/Dialect/Transform/IR/TransformTypes.h"
 #include "mlir/IR/DialectImplementation.h"
@@ -18,6 +19,22 @@ using namespace mlir;
 #include "mlir/Dialect/Transform/IR/TransformDialect.cpp.inc"
 
 #ifndef NDEBUG
+void transform::detail::checkImplementsTransformOpInterface(
+    StringRef name, MLIRContext *context) {
+  // Since the operation is being inserted into the Transform dialect and the
+  // dialect does not implement the interface fallback, only check for the op
+  // itself having the interface implementation.
+  RegisteredOperationName opName =
+      *RegisteredOperationName::lookup(name, context);
+  assert((opName.hasInterface<TransformOpInterface>() ||
+          opName.hasTrait<OpTrait::IsTerminator>()) &&
+         "non-terminator ops injected into the transform dialect must "
+         "implement TransformOpInterface");
+  assert(opName.hasInterface<MemoryEffectOpInterface>() &&
+         "ops injected into the transform dialect must implement "
+         "MemoryEffectsOpInterface");
+}
+
 void transform::detail::checkImplementsTransformTypeInterface(
     TypeID typeID, MLIRContext *context) {
   const auto &abstractType = AbstractType::lookup(typeID, context);
@@ -76,10 +93,20 @@ void transform::TransformDialect::reportDuplicateTypeRegistration(
     StringRef mnemonic) {
   std::string buffer;
   llvm::raw_string_ostream msg(buffer);
-  msg << "error: extensible dialect type '" << mnemonic
+  msg << "extensible dialect type '" << mnemonic
       << "' is already registered with a different implementation";
   msg.flush();
   llvm::report_fatal_error(StringRef(buffer));
 }
 
+void transform::TransformDialect::reportDuplicateOpRegistration(
+    StringRef opName) {
+  std::string buffer;
+  llvm::raw_string_ostream msg(buffer);
+  msg << "extensible dialect operation '" << opName
+      << "' is already registered with a mismatching TypeID";
+  msg.flush();
+  llvm::report_fatal_error(StringRef(buffer));
+}
+
 #include "mlir/Dialect/Transform/IR/TransformDialectEnums.cpp.inc"