```c++
Dialect *dialect = ...;
-if (DialectInlinerInterface *interface
- = dialect->getRegisteredInterface<DialectInlinerInterface>()) {
+if (DialectInlinerInterface *interface = dyn_cast<DialectInlinerInterface>(dialect)) {
// The dialect has provided an implementation of this interface.
...
}
namespace llvm {
/// Provide isa functionality for Dialects.
-template <typename T> struct isa_impl<T, ::mlir::Dialect> {
+template <typename T>
+struct isa_impl<T, ::mlir::Dialect,
+ std::enable_if_t<std::is_base_of<::mlir::Dialect, T>::value>> {
static inline bool doit(const ::mlir::Dialect &dialect) {
return mlir::TypeID::get<T>() == dialect.getTypeID();
}
};
+template <typename T>
+struct isa_impl<
+ T, ::mlir::Dialect,
+ std::enable_if_t<std::is_base_of<::mlir::DialectInterface, T>::value>> {
+ static inline bool doit(const ::mlir::Dialect &dialect) {
+ return const_cast<::mlir::Dialect &>(dialect).getRegisteredInterface<T>();
+ }
+};
+template <typename T>
+struct cast_retty_impl<T, ::mlir::Dialect *> {
+ using ret_type =
+ std::conditional_t<std::is_base_of<::mlir::Dialect, T>::value, T *,
+ const T *>;
+};
+template <typename T>
+struct cast_retty_impl<T, ::mlir::Dialect> {
+ using ret_type =
+ std::conditional_t<std::is_base_of<::mlir::Dialect, T>::value, T &,
+ const T &>;
+};
+
+template <typename T>
+struct cast_convert_val<T, ::mlir::Dialect, ::mlir::Dialect> {
+ template <typename To>
+ static std::enable_if_t<std::is_base_of<::mlir::Dialect, To>::value, To &>
+ doitImpl(::mlir::Dialect &dialect) {
+ return static_cast<To &>(dialect);
+ }
+ template <typename To>
+ static std::enable_if_t<std::is_base_of<::mlir::DialectInterface, To>::value,
+ const To &>
+ doitImpl(::mlir::Dialect &dialect) {
+ return *dialect.getRegisteredInterface<To>();
+ }
+
+ static auto &doit(::mlir::Dialect &dialect) { return doitImpl<T>(dialect); }
+};
+template <class T>
+struct cast_convert_val<T, ::mlir::Dialect *, ::mlir::Dialect *> {
+ static auto doit(::mlir::Dialect *dialect) {
+ return &cast_convert_val<T, ::mlir::Dialect, ::mlir::Dialect>::doit(
+ *dialect);
+ }
+};
+
} // namespace llvm
#endif
// dialect is not loaded for some reason, use the default combinator
// that conservatively accepts identical entries only.
entriesForID[id] =
- dialect ? dialect->getRegisteredInterface<DataLayoutDialectInterface>()
- ->combine(entriesForID[id], kvp.second)
+ dialect ? cast<DataLayoutDialectInterface>(dialect)->combine(
+ entriesForID[id], kvp.second)
: DataLayoutDialectInterface::defaultCombine(entriesForID[id],
kvp.second);
if (!entriesForID[id])
Dialect *dialect = getContext()->getLoadedDialect(getDialect());
if (!dialect)
return true;
- auto *interface =
- dialect->getRegisteredInterface<DialectDecodeAttributesInterface>();
+ auto *interface = llvm::dyn_cast<DialectDecodeAttributesInterface>(dialect);
if (!interface)
return true;
return failed(interface->decode(*this, result));
if (!dialect)
return failure();
- auto *interface = dialect->getRegisteredInterface<DialectFoldInterface>();
+ auto *interface = dyn_cast<DialectFoldInterface>(dialect);
if (!interface)
return failure();
if (!dialect)
continue;
- const auto *iface =
- dialect->getRegisteredInterface<DataLayoutDialectInterface>();
+ const auto *iface = dyn_cast<DataLayoutDialectInterface>(dialect);
if (!iface) {
return emitError(loc)
<< "the '" << dialect->getNamespace()
MLIRContext context(registry);
// Load the TestDialect and check that the interface got registered for it.
- auto *testDialect = context.getOrLoadDialect<TestDialect>();
+ Dialect *testDialect = context.getOrLoadDialect<TestDialect>();
ASSERT_TRUE(testDialect != nullptr);
- auto *testDialectInterface =
- testDialect->getRegisteredInterface<TestDialectInterfaceBase>();
+ auto *testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect);
EXPECT_TRUE(testDialectInterface != nullptr);
// Load the SecondTestDialect and check that the interface is not registered
// for it.
- auto *secondTestDialect = context.getOrLoadDialect<SecondTestDialect>();
+ Dialect *secondTestDialect = context.getOrLoadDialect<SecondTestDialect>();
ASSERT_TRUE(secondTestDialect != nullptr);
auto *secondTestDialectInterface =
- secondTestDialect->getRegisteredInterface<SecondTestDialectInterface>();
+ dyn_cast<SecondTestDialectInterface>(secondTestDialect);
EXPECT_TRUE(secondTestDialectInterface == nullptr);
// Use the same mechanism as for delayed registration but for an already
.addDialectInterface<SecondTestDialect, SecondTestDialectInterface>();
context.appendDialectRegistry(secondRegistry);
secondTestDialectInterface =
- secondTestDialect->getRegisteredInterface<SecondTestDialectInterface>();
+ dyn_cast<SecondTestDialectInterface>(secondTestDialect);
EXPECT_TRUE(secondTestDialectInterface != nullptr);
}
MLIRContext context(registry);
// Load the TestDialect and check that the interface got registered for it.
- auto *testDialect = context.getOrLoadDialect<TestDialect>();
+ Dialect *testDialect = context.getOrLoadDialect<TestDialect>();
ASSERT_TRUE(testDialect != nullptr);
- auto *testDialectInterface =
- testDialect->getRegisteredInterface<TestDialectInterfaceBase>();
+ auto *testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect);
EXPECT_TRUE(testDialectInterface != nullptr);
// Try adding the same dialect interface again and check that we don't crash
secondRegistry.insert<TestDialect>();
secondRegistry.addDialectInterface<TestDialect, TestDialectInterface>();
context.appendDialectRegistry(secondRegistry);
- testDialectInterface =
- testDialect->getRegisteredInterface<TestDialectInterfaceBase>();
+ testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect);
EXPECT_TRUE(testDialectInterface != nullptr);
}