[mlir] Add isa/dyn_cast support for dialect interfaces
authorRiver Riddle <riddleriver@gmail.com>
Fri, 21 Jan 2022 08:38:30 +0000 (00:38 -0800)
committerRiver Riddle <riddleriver@gmail.com>
Tue, 1 Feb 2022 03:24:34 +0000 (19:24 -0800)
This matches the same API usage as attributes/ops/types. For example:

```c++
Dialect *dialect = ...;

// Instead of this:
if (auto *interface = dialect->getRegisteredInterface<DialectInlinerInterface>())

// You can do this:
if (auto *interface = dyn_cast<DialectInlinerInterface>(dialect))
```

Differential Revision: https://reviews.llvm.org/D117859

mlir/docs/Interfaces.md
mlir/include/mlir/IR/Dialect.h
mlir/lib/Dialect/DLTI/DLTI.cpp
mlir/lib/IR/BuiltinAttributes.cpp
mlir/lib/IR/Operation.cpp
mlir/lib/Interfaces/DataLayoutInterfaces.cpp
mlir/unittests/IR/DialectTest.cpp

index c181a60..b51aec9 100644 (file)
@@ -77,8 +77,7 @@ or transformation without the need to determine the specific dialect subclass:
 
 ```c++
 Dialect *dialect = ...;
-if (DialectInlinerInterface *interface
-      = dialect->getRegisteredInterface<DialectInlinerInterface>()) {
+if (DialectInlinerInterface *interface = dyn_cast<DialectInlinerInterface>(dialect)) {
   // The dialect has provided an implementation of this interface.
   ...
 }
index 7fb298c..798d66f 100644 (file)
@@ -440,11 +440,58 @@ private:
 
 namespace llvm {
 /// Provide isa functionality for Dialects.
-template <typename T> struct isa_impl<T, ::mlir::Dialect> {
+template <typename T>
+struct isa_impl<T, ::mlir::Dialect,
+                std::enable_if_t<std::is_base_of<::mlir::Dialect, T>::value>> {
   static inline bool doit(const ::mlir::Dialect &dialect) {
     return mlir::TypeID::get<T>() == dialect.getTypeID();
   }
 };
+template <typename T>
+struct isa_impl<
+    T, ::mlir::Dialect,
+    std::enable_if_t<std::is_base_of<::mlir::DialectInterface, T>::value>> {
+  static inline bool doit(const ::mlir::Dialect &dialect) {
+    return const_cast<::mlir::Dialect &>(dialect).getRegisteredInterface<T>();
+  }
+};
+template <typename T>
+struct cast_retty_impl<T, ::mlir::Dialect *> {
+  using ret_type =
+      std::conditional_t<std::is_base_of<::mlir::Dialect, T>::value, T *,
+                         const T *>;
+};
+template <typename T>
+struct cast_retty_impl<T, ::mlir::Dialect> {
+  using ret_type =
+      std::conditional_t<std::is_base_of<::mlir::Dialect, T>::value, T &,
+                         const T &>;
+};
+
+template <typename T>
+struct cast_convert_val<T, ::mlir::Dialect, ::mlir::Dialect> {
+  template <typename To>
+  static std::enable_if_t<std::is_base_of<::mlir::Dialect, To>::value, To &>
+  doitImpl(::mlir::Dialect &dialect) {
+    return static_cast<To &>(dialect);
+  }
+  template <typename To>
+  static std::enable_if_t<std::is_base_of<::mlir::DialectInterface, To>::value,
+                          const To &>
+  doitImpl(::mlir::Dialect &dialect) {
+    return *dialect.getRegisteredInterface<To>();
+  }
+
+  static auto &doit(::mlir::Dialect &dialect) { return doitImpl<T>(dialect); }
+};
+template <class T>
+struct cast_convert_val<T, ::mlir::Dialect *, ::mlir::Dialect *> {
+  static auto doit(::mlir::Dialect *dialect) {
+    return &cast_convert_val<T, ::mlir::Dialect, ::mlir::Dialect>::doit(
+        *dialect);
+  }
+};
+
 } // namespace llvm
 
 #endif
index cf1573d..7382fba 100644 (file)
@@ -231,8 +231,8 @@ combineOneSpec(DataLayoutSpecInterface spec,
     // dialect is not loaded for some reason, use the default combinator
     // that conservatively accepts identical entries only.
     entriesForID[id] =
-        dialect ? dialect->getRegisteredInterface<DataLayoutDialectInterface>()
-                      ->combine(entriesForID[id], kvp.second)
+        dialect ? cast<DataLayoutDialectInterface>(dialect)->combine(
+                      entriesForID[id], kvp.second)
                 : DataLayoutDialectInterface::defaultCombine(entriesForID[id],
                                                              kvp.second);
     if (!entriesForID[id])
index 802df2d..79e80f7 100644 (file)
@@ -1236,8 +1236,7 @@ bool OpaqueElementsAttr::decode(ElementsAttr &result) {
   Dialect *dialect = getContext()->getLoadedDialect(getDialect());
   if (!dialect)
     return true;
-  auto *interface =
-      dialect->getRegisteredInterface<DialectDecodeAttributesInterface>();
+  auto *interface = llvm::dyn_cast<DialectDecodeAttributesInterface>(dialect);
   if (!interface)
     return true;
   return failed(interface->decode(*this, result));
index 1ca4d68..e679337 100644 (file)
@@ -506,7 +506,7 @@ LogicalResult Operation::fold(ArrayRef<Attribute> operands,
   if (!dialect)
     return failure();
 
-  auto *interface = dialect->getRegisteredInterface<DialectFoldInterface>();
+  auto *interface = dyn_cast<DialectFoldInterface>(dialect);
   if (!interface)
     return failure();
 
index 2b7ff5e..ac6397c 100644 (file)
@@ -438,8 +438,7 @@ LogicalResult mlir::detail::verifyDataLayoutSpec(DataLayoutSpecInterface spec,
     if (!dialect)
       continue;
 
-    const auto *iface =
-        dialect->getRegisteredInterface<DataLayoutDialectInterface>();
+    const auto *iface = dyn_cast<DataLayoutDialectInterface>(dialect);
     if (!iface) {
       return emitError(loc)
              << "the '" << dialect->getNamespace()
index ca89e82..b4fd697 100644 (file)
@@ -68,18 +68,17 @@ TEST(Dialect, DelayedInterfaceRegistration) {
   MLIRContext context(registry);
 
   // Load the TestDialect and check that the interface got registered for it.
-  auto *testDialect = context.getOrLoadDialect<TestDialect>();
+  Dialect *testDialect = context.getOrLoadDialect<TestDialect>();
   ASSERT_TRUE(testDialect != nullptr);
-  auto *testDialectInterface =
-      testDialect->getRegisteredInterface<TestDialectInterfaceBase>();
+  auto *testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect);
   EXPECT_TRUE(testDialectInterface != nullptr);
 
   // Load the SecondTestDialect and check that the interface is not registered
   // for it.
-  auto *secondTestDialect = context.getOrLoadDialect<SecondTestDialect>();
+  Dialect *secondTestDialect = context.getOrLoadDialect<SecondTestDialect>();
   ASSERT_TRUE(secondTestDialect != nullptr);
   auto *secondTestDialectInterface =
-      secondTestDialect->getRegisteredInterface<SecondTestDialectInterface>();
+      dyn_cast<SecondTestDialectInterface>(secondTestDialect);
   EXPECT_TRUE(secondTestDialectInterface == nullptr);
 
   // Use the same mechanism as for delayed registration but for an already
@@ -90,7 +89,7 @@ TEST(Dialect, DelayedInterfaceRegistration) {
       .addDialectInterface<SecondTestDialect, SecondTestDialectInterface>();
   context.appendDialectRegistry(secondRegistry);
   secondTestDialectInterface =
-      secondTestDialect->getRegisteredInterface<SecondTestDialectInterface>();
+      dyn_cast<SecondTestDialectInterface>(secondTestDialect);
   EXPECT_TRUE(secondTestDialectInterface != nullptr);
 }
 
@@ -102,10 +101,9 @@ TEST(Dialect, RepeatedDelayedRegistration) {
   MLIRContext context(registry);
 
   // Load the TestDialect and check that the interface got registered for it.
-  auto *testDialect = context.getOrLoadDialect<TestDialect>();
+  Dialect *testDialect = context.getOrLoadDialect<TestDialect>();
   ASSERT_TRUE(testDialect != nullptr);
-  auto *testDialectInterface =
-      testDialect->getRegisteredInterface<TestDialectInterfaceBase>();
+  auto *testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect);
   EXPECT_TRUE(testDialectInterface != nullptr);
 
   // Try adding the same dialect interface again and check that we don't crash
@@ -114,8 +112,7 @@ TEST(Dialect, RepeatedDelayedRegistration) {
   secondRegistry.insert<TestDialect>();
   secondRegistry.addDialectInterface<TestDialect, TestDialectInterface>();
   context.appendDialectRegistry(secondRegistry);
-  testDialectInterface =
-      testDialect->getRegisteredInterface<TestDialectInterfaceBase>();
+  testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect);
   EXPECT_TRUE(testDialectInterface != nullptr);
 }