/// a registered Attribute.
class AbstractAttribute {
public:
+ using HasTraitFn = llvm::unique_function<bool(TypeID) const>;
+
/// Look up the specified abstract attribute in the MLIRContext and return a
/// reference to it.
static const AbstractAttribute &lookup(TypeID typeID, MLIRContext *context);
/// This method is used by Dialect objects when they register the list of
/// attributes they contain.
- template <typename T> static AbstractAttribute get(Dialect &dialect) {
- return AbstractAttribute(dialect, T::getInterfaceMap(), T::getTypeID());
+ template <typename T>
+ static AbstractAttribute get(Dialect &dialect) {
+ return AbstractAttribute(dialect, T::getInterfaceMap(), T::getHasTraitFn(),
+ T::getTypeID());
}
/// Return the dialect this attribute was registered to.
/// Returns an instance of the concept object for the given interface if it
/// was registered to this attribute, null otherwise. This should not be used
/// directly.
- template <typename T> typename T::Concept *getInterface() const {
+ template <typename T>
+ typename T::Concept *getInterface() const {
return interfaceMap.lookup<T>();
}
return interfaceMap.contains(interfaceID);
}
+ /// Returns true if the attribute has a particular trait.
+ template <template <typename T> class Trait>
+ bool hasTrait() const {
+ return hasTraitFn(TypeID::get<Trait>());
+ }
+
+ /// Returns true if the attribute has a particular trait.
+ bool hasTrait(TypeID traitID) const { return hasTraitFn(traitID); }
+
/// Return the unique identifier representing the concrete attribute class.
TypeID getTypeID() const { return typeID; }
private:
AbstractAttribute(Dialect &dialect, detail::InterfaceMap &&interfaceMap,
- TypeID typeID)
+ HasTraitFn &&hasTrait, TypeID typeID)
: dialect(dialect), interfaceMap(std::move(interfaceMap)),
- typeID(typeID) {}
+ hasTraitFn(std::move(hasTrait)), typeID(typeID) {}
/// Give StorageUserBase access to the mutable lookup.
template <typename ConcreteT, typename BaseT, typename StorageT,
/// This is a collection of the interfaces registered to this attribute.
detail::InterfaceMap interfaceMap;
+ /// Function to check if the attribute has a particular trait.
+ HasTraitFn hasTraitFn;
+
/// The unique identifier of the derived Attribute class.
const TypeID typeID;
};
friend ::llvm::hash_code hash_value(Attribute arg);
+ /// Returns true if the type was registered with a particular trait.
+ template <template <typename T> class Trait>
+ bool hasTrait() {
+ return getAbstractAttribute().hasTrait<Trait>();
+ }
+
/// Return the abstract descriptor for this attribute.
const AbstractTy &getAbstractAttribute() const {
return impl->getAbstractAttribute();
}
//===----------------------------------------------------------------------===//
+// TypeTrait definitions
+//===----------------------------------------------------------------------===//
+
+// TypeTrait represents a trait regarding a type.
+// TODO: Remove this class in favor of using Trait.
+class TypeTrait;
+
+// These classes are used to define type specific traits.
+class NativeTypeTrait<string name> : NativeTrait<name, "Type">, TypeTrait;
+class ParamNativeTypeTrait<string prop, string params>
+ : ParamNativeTrait<prop, params, "Type">, TypeTrait;
+class GenInternalTypeTrait<string prop>
+ : GenInternalTrait<prop, "Type">, TypeTrait;
+class PredTypeTrait<string descr, Pred pred>
+ : PredTrait<descr, pred>, TypeTrait;
+
+//===----------------------------------------------------------------------===//
+// AttrTrait definitions
+//===----------------------------------------------------------------------===//
+
+// AttrTrait represents a trait regarding an attribute.
+// TODO: Remove this class in favor of using Trait.
+class AttrTrait;
+
+// These classes are used to define attribute specific traits.
+class NativeAttrTrait<string name> : NativeTrait<name, "Attribute">, AttrTrait;
+class ParamNativeAttrTrait<string prop, string params>
+ : ParamNativeTrait<prop, params, "Attribute">, AttrTrait;
+class GenInternalAttrTrait<string prop>
+ : GenInternalTrait<prop, "Attribute">, AttrTrait;
+class PredAttrTrait<string descr, Pred pred>
+ : PredTrait<descr, pred>, AttrTrait;
+
+//===----------------------------------------------------------------------===//
// OpTrait definitions
//===----------------------------------------------------------------------===//
// StorageUserBase
//===----------------------------------------------------------------------===//
+namespace storage_user_base_impl {
+/// Returns true if this given Trait ID matches the IDs of any of the provided
+/// trait types `Traits`.
+template <template <typename T> class... Traits>
+bool hasTrait(TypeID traitID) {
+ TypeID traitIDs[] = {TypeID::get<Traits>()...};
+ for (unsigned i = 0, e = sizeof...(Traits); i != e; ++i)
+ if (traitIDs[i] == traitID)
+ return true;
+ return false;
+}
+
+// We specialize for the empty case to not define an empty array.
+template <>
+inline bool hasTrait(TypeID traitID) {
+ return false;
+}
+} // namespace storage_user_base_impl
+
/// Utility class for implementing users of storage classes uniqued by a
/// StorageUniquer. Clients are not expected to interact with this class
/// directly.
/// Utility declarations for the concrete attribute class.
using Base = StorageUserBase<ConcreteT, BaseT, StorageT, UniquerT, Traits...>;
using ImplType = StorageT;
+ using HasTraitFn = bool (*)(TypeID);
/// Return a unique identifier for the concrete type.
static TypeID getTypeID() { return TypeID::get<ConcreteT>(); }
return detail::InterfaceMap::template get<Traits<ConcreteT>...>();
}
+ /// Returns the function that returns true if the given Trait ID matches the
+ /// IDs of any of the traits defined by the storage user.
+ static HasTraitFn getHasTraitFn() {
+ return [](TypeID id) {
+ return storage_user_base_impl::hasTrait<Traits...>(id);
+ };
+ }
+
/// Attach the given models as implementations of the corresponding interfaces
/// for the concrete storage user class. The type must be registered with the
/// context, i.e. the dialect to which the type belongs must be loaded. The
/// a registered Type.
class AbstractType {
public:
+ using HasTraitFn = llvm::unique_function<bool(TypeID) const>;
+
/// Look up the specified abstract type in the MLIRContext and return a
/// reference to it.
static const AbstractType &lookup(TypeID typeID, MLIRContext *context);
/// This method is used by Dialect objects when they register the list of
/// types they contain.
- template <typename T> static AbstractType get(Dialect &dialect) {
- return AbstractType(dialect, T::getInterfaceMap(), T::getTypeID());
+ template <typename T>
+ static AbstractType get(Dialect &dialect) {
+ return AbstractType(dialect, T::getInterfaceMap(), T::getHasTraitFn(),
+ T::getTypeID());
}
/// This method is used by Dialect objects to register types with
/// The use of this method is in general discouraged in favor of
/// 'get<CustomType>(dialect)';
static AbstractType get(Dialect &dialect, detail::InterfaceMap &&interfaceMap,
- TypeID typeID) {
- return AbstractType(dialect, std::move(interfaceMap), typeID);
+ HasTraitFn &&hasTrait, TypeID typeID) {
+ return AbstractType(dialect, std::move(interfaceMap), std::move(hasTrait),
+ typeID);
}
/// Return the dialect this type was registered to.
return interfaceMap.contains(interfaceID);
}
+ /// Returns true if the type has a particular trait.
+ template <template <typename T> class Trait>
+ bool hasTrait() const {
+ return hasTraitFn(TypeID::get<Trait>());
+ }
+
+ /// Returns true if the type has a particular trait.
+ bool hasTrait(TypeID traitID) const { return hasTraitFn(traitID); }
+
/// Return the unique identifier representing the concrete type class.
TypeID getTypeID() const { return typeID; }
private:
AbstractType(Dialect &dialect, detail::InterfaceMap &&interfaceMap,
- TypeID typeID)
+ HasTraitFn &&hasTrait, TypeID typeID)
: dialect(dialect), interfaceMap(std::move(interfaceMap)),
- typeID(typeID) {}
+ hasTraitFn(std::move(hasTrait)), typeID(typeID) {}
/// Give StorageUserBase access to the mutable lookup.
template <typename ConcreteT, typename BaseT, typename StorageT,
/// This is a collection of the interfaces registered to this type.
detail::InterfaceMap interfaceMap;
+ /// Function to check if the type has a particular trait.
+ HasTraitFn hasTraitFn;
+
/// The unique identifier of the derived Type class.
const TypeID typeID;
};
return Type(reinterpret_cast<ImplType *>(const_cast<void *>(pointer)));
}
+ /// Returns true if the type was registered with a particular trait.
+ template <template <typename T> class Trait>
+ bool hasTrait() {
+ return getAbstractType().hasTrait<Trait>();
+ }
+
/// Return the abstract type descriptor for this type.
const AbstractTy &getAbstractType() { return impl->getAbstractType(); }
"terminator"() : () -> ()
}
}
+
+// -----
+
+// Check that we can query traits in types
+func @succeeded_type_traits() {
+ // CHECK: "test.result_type_with_trait"() : () -> !test.test_type_with_trait
+ "test.result_type_with_trait"() : () -> !test.test_type_with_trait
+ return
+}
+
+// -----
+
+// Check that we can query traits in types
+func @failed_type_traits() {
+ // expected-error@+1 {{result type should have trait 'TestTypeTrait'}}
+ "test.result_type_with_trait"() : () -> i32
+ return
+}
+
+// -----
+
+// Check that we can query traits in attributes
+func @succeeded_attr_traits() {
+ // CHECK: "test.attr_with_trait"() {attr = #test.attr_with_trait} : () -> ()
+ "test.attr_with_trait"() {attr = #test.attr_with_trait} : () -> ()
+ return
+}
+
+// -----
+
+// Check that we can query traits in attributes
+func @failed_attr_traits() {
+ // expected-error@+1 {{'attr' attribute should have trait 'TestAttrTrait'}}
+ "test.attr_with_trait"() {attr = 42 : i32} : () -> ()
+ return
+}
\ No newline at end of file
include "TestOps.td"
// All of the attributes will extend this class.
-class Test_Attr<string name> : AttrDef<Test_Dialect, name>;
+class Test_Attr<string name, list<Trait> traits = []>
+ : AttrDef<Test_Dialect, name, traits>;
def SimpleAttrA : Test_Attr<"SimpleA"> {
let mnemonic = "smpla";
let typeBuilder = "$_attr.getType()";
}
+def TestAttrTrait : NativeAttrTrait<"TestAttrTrait">;
+
+// The definition of a singleton attribute that has a trait.
+def AttrWithTrait : Test_Attr<"AttrWithTrait", [TestAttrTrait]> {
+ let mnemonic = "attr_with_trait";
+ let parameters = (ins );
+}
+
#endif // TEST_ATTRDEFS
#include <tuple>
+#include "TestTraits.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
let assemblyFormat = "regions attr-dict-with-keyword";
}
+// This operation requires its return type to have the trait 'TestTypeTrait'.
+def ResultTypeWithTraitOp : TEST_Op<"result_type_with_trait", []> {
+ let results = (outs AnyType);
+
+ let verifier = [{
+ if((*this)->getResultTypes()[0].hasTrait<TypeTrait::TestTypeTrait>())
+ return success();
+ return this->emitError("result type should have trait 'TestTypeTrait'");
+ }];
+}
+
+// This operation requires its "attr" attribute to have the
+// trait 'TestAttrTrait'.
+def AttrWithTraitOp : TEST_Op<"attr_with_trait", []> {
+ let arguments = (ins AnyAttr:$attr);
+
+ let verifier = [{
+ if (this->attr().hasTrait<AttributeTrait::TestAttrTrait>())
+ return success();
+ return this->emitError("'attr' attribute should have trait 'TestAttrTrait'");
+ }];
+}
+
+
//===----------------------------------------------------------------------===//
// Test Locations
//===----------------------------------------------------------------------===//
--- /dev/null
+//===- TestTraits.h - MLIR Test Traits --------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains traits defined by the TestDialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TESTTRAITS_H
+#define MLIR_TESTTRAITS_H
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Types.h"
+
+namespace mlir {
+namespace TypeTrait {
+
+/// A trait defined on types for testing purposes.
+template <typename ConcreteType>
+class TestTypeTrait : public TypeTrait::TraitBase<ConcreteType, TestTypeTrait> {
+};
+
+} // namespace TypeTrait
+
+namespace AttributeTrait {
+
+/// A trait defined on attributes for testing purposes.
+template <typename ConcreteType>
+class TestAttrTrait
+ : public AttributeTrait::TraitBase<ConcreteType, TestAttrTrait> {};
+
+} // namespace AttributeTrait
+} // namespace mlir
+
+#endif // MLIR_TESTTRAITS_H
let mnemonic = "memref_element";
}
+def TestTypeTrait : NativeTypeTrait<"TestTypeTrait">;
+
+// The definition of a singleton type that has a trait.
+def TestTypeWithTrait : Test_Type<"TestTypeWithTrait", [TestTypeTrait]> {
+ let mnemonic = "test_type_with_trait";
+}
+
#endif // TEST_TYPEDEFS
#include <tuple>
+#include "TestTraits.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"