/// The definition of a dynamic op. A dynamic op is an op that is defined at
/// runtime, and that can be registered at runtime by an extensible dialect (a
-/// dialect inheriting ExtensibleDialect). This class stores the functions that
-/// are in the OperationName class, and in addition defines the TypeID of the op
-/// that will be defined.
-/// Each dynamic operation definition refers to one instance of this class.
-class DynamicOpDefinition {
+/// dialect inheriting ExtensibleDialect). This class implements the method
+/// exposed by the OperationName class, and in addition defines the TypeID of
+/// the op that will be defined. Each dynamic operation definition refers to one
+/// instance of this class.
+class DynamicOpDefinition : public OperationName::Impl {
public:
+ using GetCanonicalizationPatternsFn =
+ llvm::unique_function<void(RewritePatternSet &, MLIRContext *) const>;
+
/// Create a new op at runtime. The op is registered only after passing it to
/// the dialect using registerDynamicOp.
static std::unique_ptr<DynamicOpDefinition>
OperationName::ParseAssemblyFn &&parseFn,
OperationName::PrintAssemblyFn &&printFn,
OperationName::FoldHookFn &&foldHookFn,
- OperationName::GetCanonicalizationPatternsFn
- &&getCanonicalizationPatternsFn,
+ GetCanonicalizationPatternsFn &&getCanonicalizationPatternsFn,
OperationName::PopulateDefaultAttrsFn &&populateDefaultAttrsFn);
/// Returns the op typeID.
/// Set the hook returning any canonicalization pattern rewrites that the op
/// supports, for use by the canonicalization pass.
- void
- setGetCanonicalizationPatternsFn(OperationName::GetCanonicalizationPatternsFn
- &&getCanonicalizationPatterns) {
+ void setGetCanonicalizationPatternsFn(
+ GetCanonicalizationPatternsFn &&getCanonicalizationPatterns) {
getCanonicalizationPatternsFn = std::move(getCanonicalizationPatterns);
}
populateDefaultAttrsFn = std::move(populateDefaultAttrs);
}
+ LogicalResult foldHook(Operation *op, ArrayRef<Attribute> attrs,
+ SmallVectorImpl<OpFoldResult> &results) final {
+ return foldHookFn(op, attrs, results);
+ }
+ void getCanonicalizationPatterns(RewritePatternSet &set,
+ MLIRContext *context) final {
+ getCanonicalizationPatternsFn(set, context);
+ }
+ bool hasTrait(TypeID id) final { return false; }
+ OperationName::ParseAssemblyFn getParseAssemblyFn() final { return parseFn; }
+ void populateDefaultAttrs(const OperationName &name,
+ NamedAttrList &attrs) final {
+ populateDefaultAttrsFn(name, attrs);
+ }
+ void printAssembly(Operation *op, OpAsmPrinter &printer,
+ StringRef name) final {
+ printFn(op, printer, name);
+ }
+ LogicalResult verifyInvariants(Operation *op) final { return verifyFn(op); }
+ LogicalResult verifyRegionInvariants(Operation *op) final {
+ return verifyRegionFn(op);
+ }
+
private:
DynamicOpDefinition(
StringRef name, ExtensibleDialect *dialect,
OperationName::ParseAssemblyFn &&parseFn,
OperationName::PrintAssemblyFn &&printFn,
OperationName::FoldHookFn &&foldHookFn,
- OperationName::GetCanonicalizationPatternsFn
- &&getCanonicalizationPatternsFn,
+ GetCanonicalizationPatternsFn &&getCanonicalizationPatternsFn,
OperationName::PopulateDefaultAttrsFn &&populateDefaultAttrsFn);
- /// Unique identifier for this operation.
- TypeID typeID;
-
- /// Name of the operation.
- /// The name is prefixed with the dialect name.
- std::string name;
-
/// Dialect defining this operation.
- ExtensibleDialect *dialect;
+ ExtensibleDialect *getdialect();
OperationName::VerifyInvariantsFn verifyFn;
OperationName::VerifyRegionInvariantsFn verifyRegionFn;
OperationName::ParseAssemblyFn parseFn;
OperationName::PrintAssemblyFn printFn;
OperationName::FoldHookFn foldHookFn;
- OperationName::GetCanonicalizationPatternsFn getCanonicalizationPatternsFn;
+ GetCanonicalizationPatternsFn getCanonicalizationPatternsFn;
OperationName::PopulateDefaultAttrsFn populateDefaultAttrsFn;
friend ExtensibleDialect;
#include "mlir/Support/InterfaceSupport.h"
#include "llvm/ADT/BitmaskEnum.h"
#include "llvm/ADT/PointerUnion.h"
+#include "llvm/ADT/STLFunctionalExtras.h"
#include "llvm/Support/PointerLikeTypeTraits.h"
#include "llvm/Support/TrailingObjects.h"
#include <memory>
class OperationName {
public:
- using GetCanonicalizationPatternsFn =
- llvm::unique_function<void(RewritePatternSet &, MLIRContext *) const>;
using FoldHookFn = llvm::unique_function<LogicalResult(
Operation *, ArrayRef<Attribute>, SmallVectorImpl<OpFoldResult> &) const>;
using HasTraitFn = llvm::unique_function<bool(TypeID) const>;
using ParseAssemblyFn =
- llvm::unique_function<ParseResult(OpAsmParser &, OperationState &) const>;
+ llvm::function_ref<ParseResult(OpAsmParser &, OperationState &)>;
// Note: RegisteredOperationName is passed as reference here as the derived
// class is defined below.
- using PopulateDefaultAttrsFn = llvm::unique_function<void(
- const RegisteredOperationName &, NamedAttrList &) const>;
+ using PopulateDefaultAttrsFn =
+ llvm::unique_function<void(const OperationName &, NamedAttrList &) const>;
using PrintAssemblyFn =
llvm::unique_function<void(Operation *, OpAsmPrinter &, StringRef) const>;
using VerifyInvariantsFn =
using VerifyRegionInvariantsFn =
llvm::unique_function<LogicalResult(Operation *) const>;
-protected:
/// This class represents a type erased version of an operation. It contains
/// all of the components necessary for opaquely interacting with an
/// operation. If the operation is not registered, some of these components
/// may not be populated.
- struct Impl {
- Impl(StringAttr name)
- : name(name), dialect(nullptr), interfaceMap(std::nullopt) {}
+ struct InterfaceConcept {
+ virtual ~InterfaceConcept() = default;
+ virtual LogicalResult foldHook(Operation *, ArrayRef<Attribute>,
+ SmallVectorImpl<OpFoldResult> &) = 0;
+ virtual void getCanonicalizationPatterns(RewritePatternSet &,
+ MLIRContext *) = 0;
+ virtual bool hasTrait(TypeID) = 0;
+ virtual OperationName::ParseAssemblyFn getParseAssemblyFn() = 0;
+ virtual void populateDefaultAttrs(const OperationName &,
+ NamedAttrList &) = 0;
+ virtual void printAssembly(Operation *, OpAsmPrinter &, StringRef) = 0;
+ virtual LogicalResult verifyInvariants(Operation *) = 0;
+ virtual LogicalResult verifyRegionInvariants(Operation *) = 0;
+ };
+
+public:
+ class Impl : public InterfaceConcept {
+ public:
+ Impl(StringRef, Dialect *dialect, TypeID typeID,
+ detail::InterfaceMap interfaceMap);
+ Impl(StringAttr name, Dialect *dialect, TypeID typeID,
+ detail::InterfaceMap interfaceMap)
+ : name(name), typeID(typeID), dialect(dialect),
+ interfaceMap(std::move(interfaceMap)) {}
+
+ /// Returns true if this is a registered operation.
+ bool isRegistered() const { return typeID != TypeID::get<void>(); }
+ detail::InterfaceMap &getInterfaceMap() { return interfaceMap; }
+ Dialect *getDialect() const { return dialect; }
+ StringAttr getName() const { return name; }
+ TypeID getTypeID() const { return typeID; }
+ ArrayRef<StringAttr> getAttributeNames() const { return attributeNames; }
+
+ protected:
+ //===------------------------------------------------------------------===//
+ // Registered Operation Info
/// The name of the operation.
StringAttr name;
- //===------------------------------------------------------------------===//
- // Registered Operation Info
+ /// The unique identifier of the derived Op class.
+ TypeID typeID;
/// The following fields are only populated when the operation is
/// registered.
- /// Returns true if the operation has been registered, i.e. if the
- /// registration info has been populated.
- bool isRegistered() const { return dialect; }
-
/// This is the dialect that this operation belongs to.
Dialect *dialect;
- /// The unique identifier of the derived Op class.
- TypeID typeID;
-
/// A map of interfaces that were registered to this operation.
detail::InterfaceMap interfaceMap;
- /// Internal callback hooks provided by the op implementation.
- FoldHookFn foldHookFn;
- GetCanonicalizationPatternsFn getCanonicalizationPatternsFn;
- HasTraitFn hasTraitFn;
- ParseAssemblyFn parseAssemblyFn;
- PopulateDefaultAttrsFn populateDefaultAttrsFn;
- PrintAssemblyFn printAssemblyFn;
- VerifyInvariantsFn verifyInvariantsFn;
- VerifyRegionInvariantsFn verifyRegionInvariantsFn;
-
/// A list of attribute names registered to this operation in StringAttr
/// form. This allows for operation classes to use StringAttr for attribute
/// lookup/creation/etc., as opposed to raw strings.
ArrayRef<StringAttr> attributeNames;
+
+ friend class RegisteredOperationName;
+ };
+
+protected:
+ /// Default implementation for unregistered operations.
+ struct UnregisteredOpModel : public Impl {
+ using Impl::Impl;
+ LogicalResult foldHook(Operation *, ArrayRef<Attribute>,
+ SmallVectorImpl<OpFoldResult> &) final;
+ void getCanonicalizationPatterns(RewritePatternSet &, MLIRContext *) final;
+ bool hasTrait(TypeID) final;
+ virtual OperationName::ParseAssemblyFn getParseAssemblyFn() final;
+ void populateDefaultAttrs(const OperationName &, NamedAttrList &) final;
+ void printAssembly(Operation *, OpAsmPrinter &, StringRef) final;
+ LogicalResult verifyInvariants(Operation *) final;
+ LogicalResult verifyRegionInvariants(Operation *) final;
};
public:
OperationName(StringRef name, MLIRContext *context);
/// Return if this operation is registered.
- bool isRegistered() const { return impl->isRegistered(); }
+ bool isRegistered() const { return getImpl()->isRegistered(); }
+
+ /// Return the unique identifier of the derived Op class, or null if not
+ /// registered.
+ TypeID getTypeID() const { return getImpl()->getTypeID(); }
/// If this operation is registered, returns the registered information,
/// std::nullopt otherwise.
Optional<RegisteredOperationName> getRegisteredInfo() const;
+ /// This hook implements a generalized folder for this operation. Operations
+ /// can implement this to provide simplifications rules that are applied by
+ /// the Builder::createOrFold API and the canonicalization pass.
+ ///
+ /// This is an intentionally limited interface - implementations of this
+ /// hook can only perform the following changes to the operation:
+ ///
+ /// 1. They can leave the operation alone and without changing the IR, and
+ /// return failure.
+ /// 2. They can mutate the operation in place, without changing anything
+ /// else
+ /// in the IR. In this case, return success.
+ /// 3. They can return a list of existing values that can be used instead
+ /// of
+ /// the operation. In this case, fill in the results list and return
+ /// success. The caller will remove the operation and use those results
+ /// instead.
+ ///
+ /// This allows expression of some simple in-place canonicalizations (e.g.
+ /// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), as well as
+ /// generalized constant folding.
+ LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
+ SmallVectorImpl<OpFoldResult> &results) const {
+ return getImpl()->foldHook(op, operands, results);
+ }
+
+ /// This hook returns any canonicalization pattern rewrites that the
+ /// operation supports, for use by the canonicalization pass.
+ void getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) const {
+ return getImpl()->getCanonicalizationPatterns(results, context);
+ }
+
/// Returns true if the operation was registered with a particular trait, e.g.
/// hasTrait<OperandsAreSignlessIntegerLike>(). Returns false if the operation
/// is unregistered.
bool hasTrait() const {
return hasTrait(TypeID::get<Trait>());
}
- bool hasTrait(TypeID traitID) const {
- return isRegistered() && impl->hasTraitFn(traitID);
- }
+ bool hasTrait(TypeID traitID) const { return getImpl()->hasTrait(traitID); }
/// Returns true if the operation *might* have the provided trait. This
/// means that either the operation is unregistered, or it was registered with
return mightHaveTrait(TypeID::get<Trait>());
}
bool mightHaveTrait(TypeID traitID) const {
- return !isRegistered() || impl->hasTraitFn(traitID);
+ return !isRegistered() || getImpl()->hasTrait(traitID);
+ }
+
+ /// Return the static hook for parsing this operation assembly.
+ ParseAssemblyFn getParseAssemblyFn() const {
+ return getImpl()->getParseAssemblyFn();
+ }
+
+ /// This hook implements the method to populate defaults attributes that are
+ /// unset.
+ void populateDefaultAttrs(NamedAttrList &attrs) const {
+ getImpl()->populateDefaultAttrs(*this, attrs);
+ }
+
+ /// This hook implements the AsmPrinter for this operation.
+ void printAssembly(Operation *op, OpAsmPrinter &p,
+ StringRef defaultDialect) const {
+ return getImpl()->printAssembly(op, p, defaultDialect);
+ }
+
+ /// These hooks implement the verifiers for this operation. It should emits
+ /// an error message and returns failure if a problem is detected, or
+ /// returns success if everything is ok.
+ LogicalResult verifyInvariants(Operation *op) const {
+ return getImpl()->verifyInvariants(op);
+ }
+ LogicalResult verifyRegionInvariants(Operation *op) const {
+ return getImpl()->verifyRegionInvariants(op);
+ }
+
+ /// Return the list of cached attribute names registered to this operation.
+ /// The order of attributes cached here is unique to each type of operation,
+ /// and the interpretation of this attribute list should generally be driven
+ /// by the respective operation. In many cases, this caching removes the
+ /// need to use the raw string name of a known attribute.
+ ///
+ /// For example the ODS generator, with an op defining the following
+ /// attributes:
+ ///
+ /// let arguments = (ins I32Attr:$attr1, I32Attr:$attr2);
+ ///
+ /// ... may produce an order here of ["attr1", "attr2"]. This allows for the
+ /// ODS generator to directly access the cached name for a known attribute,
+ /// greatly simplifying the cost and complexity of attribute usage produced
+ /// by the generator.
+ ///
+ ArrayRef<StringAttr> getAttributeNames() const {
+ return getImpl()->getAttributeNames();
}
/// Returns an instance of the concept object for the given interface if it
/// directly.
template <typename T>
typename T::Concept *getInterface() const {
- return impl->interfaceMap.lookup<T>();
+ return getImpl()->getInterfaceMap().lookup<T>();
+ }
+
+ /// Attach the given models as implementations of the corresponding
+ /// interfaces for the concrete operation.
+ template <typename... Models> void attachInterface() {
+ getImpl()->getInterfaceMap().insert<Models...>();
}
/// Returns true if this operation has the given interface registered to it.
return hasInterface(TypeID::get<T>());
}
bool hasInterface(TypeID interfaceID) const {
- return impl->interfaceMap.contains(interfaceID);
+ return getImpl()->getInterfaceMap().contains(interfaceID);
}
/// Returns true if the operation *might* have the provided interface. This
/// Return the dialect this operation is registered to if the dialect is
/// loaded in the context, or nullptr if the dialect isn't loaded.
Dialect *getDialect() const {
- return isRegistered() ? impl->dialect : impl->name.getReferencedDialect();
+ return isRegistered() ? getImpl()->getDialect()
+ : getImpl()->getName().getReferencedDialect();
}
/// Return the name of the dialect this operation is registered to.
StringRef getStringRef() const { return getIdentifier(); }
/// Return the name of this operation as a StringAttr.
- StringAttr getIdentifier() const { return impl->name; }
+ StringAttr getIdentifier() const { return getImpl()->getName(); }
void print(raw_ostream &os) const;
void dump() const;
protected:
OperationName(Impl *impl) : impl(impl) {}
+ Impl *getImpl() const { return impl; }
+ void setImpl(Impl *rhs) { impl = rhs; }
+private:
/// The internal implementation of the operation name.
- Impl *impl;
+ Impl *impl = nullptr;
/// Allow access to the Impl struct.
friend MLIRContextImpl;
+ friend DenseMapInfo<mlir::OperationName>;
+ friend DenseMapInfo<mlir::RegisteredOperationName>;
};
inline raw_ostream &operator<<(raw_ostream &os, OperationName info) {
/// the concrete operation types.
class RegisteredOperationName : public OperationName {
public:
+ /// Implementation of the InterfaceConcept for operation APIs that forwarded
+ /// to a concrete op implementation.
+ template <typename ConcreteOp> struct Model : public Impl {
+ Model(Dialect *dialect)
+ : Impl(ConcreteOp::getOperationName(), dialect,
+ TypeID::get<ConcreteOp>(), ConcreteOp::getInterfaceMap()) {}
+ LogicalResult foldHook(Operation *op, ArrayRef<Attribute> attrs,
+ SmallVectorImpl<OpFoldResult> &results) final {
+ return ConcreteOp::getFoldHookFn()(op, attrs, results);
+ }
+ void getCanonicalizationPatterns(RewritePatternSet &set,
+ MLIRContext *context) final {
+ ConcreteOp::getCanonicalizationPatterns(set, context);
+ }
+ bool hasTrait(TypeID id) final { return ConcreteOp::getHasTraitFn()(id); }
+ OperationName::ParseAssemblyFn getParseAssemblyFn() final {
+ return ConcreteOp::parse;
+ }
+ void populateDefaultAttrs(const OperationName &name,
+ NamedAttrList &attrs) final {
+ ConcreteOp::populateDefaultAttrs(name, attrs);
+ }
+ void printAssembly(Operation *op, OpAsmPrinter &printer,
+ StringRef name) final {
+ ConcreteOp::getPrintAssemblyFn()(op, printer, name);
+ }
+ LogicalResult verifyInvariants(Operation *op) final {
+ return ConcreteOp::getVerifyInvariantsFn()(op);
+ }
+ LogicalResult verifyRegionInvariants(Operation *op) final {
+ return ConcreteOp::getVerifyRegionInvariantsFn()(op);
+ }
+ };
+
/// Lookup the registered operation information for the given operation.
/// Returns std::nullopt if the operation isn't registered.
static Optional<RegisteredOperationName> lookup(StringRef name,
MLIRContext *ctx);
/// Register a new operation in a Dialect object.
- /// This constructor is used by Dialect objects when they register the list of
- /// operations they contain.
- template <typename T>
- static void insert(Dialect &dialect) {
- insert(T::getOperationName(), dialect, TypeID::get<T>(),
- T::getParseAssemblyFn(), T::getPrintAssemblyFn(),
- T::getVerifyInvariantsFn(), T::getVerifyRegionInvariantsFn(),
- T::getFoldHookFn(), T::getGetCanonicalizationPatternsFn(),
- T::getInterfaceMap(), T::getHasTraitFn(), T::getAttributeNames(),
- T::getPopulateDefaultAttrsFn());
+ /// This constructor is used by Dialect objects when they register the list
+ /// of operations they contain.
+ template <typename T> static void insert(Dialect &dialect) {
+ insert(std::make_unique<Model<T>>(&dialect), T::getAttributeNames());
}
/// The use of this method is in general discouraged in favor of
/// 'insert<CustomOp>(dialect)'.
- static void
- insert(StringRef name, Dialect &dialect, TypeID typeID,
- ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly,
- VerifyInvariantsFn &&verifyInvariants,
- VerifyRegionInvariantsFn &&verifyRegionInvariants,
- FoldHookFn &&foldHook,
- GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
- detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait,
- ArrayRef<StringRef> attrNames,
- PopulateDefaultAttrsFn &&populateDefaultAttrs);
+ static void insert(std::unique_ptr<OperationName::Impl> ownedImpl,
+ ArrayRef<StringRef> attrNames);
/// Return the dialect this operation is registered to.
- Dialect &getDialect() const { return *impl->dialect; }
-
- /// Return the unique identifier of the derived Op class.
- TypeID getTypeID() const { return impl->typeID; }
+ Dialect &getDialect() const { return *getImpl()->getDialect(); }
/// Use the specified object to parse this ops custom assembly format.
ParseResult parseAssembly(OpAsmParser &parser, OperationState &result) const;
- /// Return the static hook for parsing this operation assembly.
- const ParseAssemblyFn &getParseAssemblyFn() const {
- return impl->parseAssemblyFn;
- }
-
- /// This hook implements the AsmPrinter for this operation.
- void printAssembly(Operation *op, OpAsmPrinter &p,
- StringRef defaultDialect) const {
- return impl->printAssemblyFn(op, p, defaultDialect);
- }
-
- /// These hooks implement the verifiers for this operation. It should emits
- /// an error message and returns failure if a problem is detected, or returns
- /// success if everything is ok.
- LogicalResult verifyInvariants(Operation *op) const {
- return impl->verifyInvariantsFn(op);
- }
- LogicalResult verifyRegionInvariants(Operation *op) const {
- return impl->verifyRegionInvariantsFn(op);
- }
-
- /// This hook implements a generalized folder for this operation. Operations
- /// can implement this to provide simplifications rules that are applied by
- /// the Builder::createOrFold API and the canonicalization pass.
- ///
- /// This is an intentionally limited interface - implementations of this hook
- /// can only perform the following changes to the operation:
- ///
- /// 1. They can leave the operation alone and without changing the IR, and
- /// return failure.
- /// 2. They can mutate the operation in place, without changing anything else
- /// in the IR. In this case, return success.
- /// 3. They can return a list of existing values that can be used instead of
- /// the operation. In this case, fill in the results list and return
- /// success. The caller will remove the operation and use those results
- /// instead.
- ///
- /// This allows expression of some simple in-place canonicalizations (e.g.
- /// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), as well as
- /// generalized constant folding.
- LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
- SmallVectorImpl<OpFoldResult> &results) const {
- return impl->foldHookFn(op, operands, results);
- }
-
- /// This hook returns any canonicalization pattern rewrites that the operation
- /// supports, for use by the canonicalization pass.
- void getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) const {
- return impl->getCanonicalizationPatternsFn(results, context);
- }
-
- /// Attach the given models as implementations of the corresponding interfaces
- /// for the concrete operation.
- template <typename... Models>
- void attachInterface() {
- impl->interfaceMap.insert<Models...>();
- }
-
- /// Returns true if the operation has a particular trait.
- template <template <typename T> class Trait>
- bool hasTrait() const {
- return hasTrait(TypeID::get<Trait>());
- }
-
- /// Returns true if the operation has a particular trait.
- bool hasTrait(TypeID traitID) const { return impl->hasTraitFn(traitID); }
-
- /// Return the list of cached attribute names registered to this operation.
- /// The order of attributes cached here is unique to each type of operation,
- /// and the interpretation of this attribute list should generally be driven
- /// by the respective operation. In many cases, this caching removes the need
- /// to use the raw string name of a known attribute.
- ///
- /// For example the ODS generator, with an op defining the following
- /// attributes:
- ///
- /// let arguments = (ins I32Attr:$attr1, I32Attr:$attr2);
- ///
- /// ... may produce an order here of ["attr1", "attr2"]. This allows for the
- /// ODS generator to directly access the cached name for a known attribute,
- /// greatly simplifying the cost and complexity of attribute usage produced by
- /// the generator.
- ///
- ArrayRef<StringAttr> getAttributeNames() const {
- return impl->attributeNames;
- }
-
- /// This hook implements the method to populate defaults attributes that are
- /// unset.
- void populateDefaultAttrs(NamedAttrList &attrs) const;
-
/// Represent the operation name as an opaque pointer. (Used to support
/// PointerLikeTypeTraits).
static RegisteredOperationName getFromOpaquePointer(const void *pointer) {
llvm::BumpPtrAllocator abstractDialectSymbolAllocator;
/// This is a mapping from operation name to the operation info describing it.
- llvm::StringMap<OperationName::Impl> operations;
+ llvm::StringMap<std::unique_ptr<OperationName::Impl>> operations;
/// A vector of operation info specifically for registered operations.
llvm::StringMap<RegisteredOperationName> registeredOperations;
// OperationName
//===----------------------------------------------------------------------===//
+OperationName::Impl::Impl(StringRef name, Dialect *dialect, TypeID typeID,
+ detail::InterfaceMap interfaceMap)
+ : Impl(StringAttr::get(dialect->getContext(), name), dialect, typeID,
+ std::move(interfaceMap)) {}
+
OperationName::OperationName(StringRef name, MLIRContext *context) {
MLIRContextImpl &ctxImpl = context->getImpl();
llvm::sys::SmartScopedReader<true> contextLock(ctxImpl.operationInfoMutex);
auto it = ctxImpl.operations.find(name);
if (it != ctxImpl.operations.end()) {
- impl = &it->second;
+ impl = it->second.get();
return;
}
}
// Acquire a writer-lock so that we can safely create the new instance.
ScopedWriterLock lock(ctxImpl.operationInfoMutex, isMultithreadingEnabled);
- auto it = ctxImpl.operations.insert({name, OperationName::Impl(nullptr)});
- if (it.second)
- it.first->second.name = StringAttr::get(context, name);
- impl = &it.first->second;
+ auto it = ctxImpl.operations.insert({name, nullptr});
+ if (it.second) {
+ auto nameAttr = StringAttr::get(context, name);
+ it.first->second = std::make_unique<UnregisteredOpModel>(
+ nameAttr, nameAttr.getReferencedDialect(), TypeID::get<void>(),
+ detail::InterfaceMap(std::nullopt));
+ }
+ impl = it.first->second.get();
}
StringRef OperationName::getDialectNamespace() const {
return getStringRef().split('.').first;
}
+LogicalResult
+OperationName::UnregisteredOpModel::foldHook(Operation *, ArrayRef<Attribute>,
+ SmallVectorImpl<OpFoldResult> &) {
+ return failure();
+}
+void OperationName::UnregisteredOpModel::getCanonicalizationPatterns(
+ RewritePatternSet &, MLIRContext *) {}
+bool OperationName::UnregisteredOpModel::hasTrait(TypeID) { return false; }
+
+OperationName::ParseAssemblyFn
+OperationName::UnregisteredOpModel::getParseAssemblyFn() {
+ llvm::report_fatal_error("getParseAssemblyFn hook called on unregistered op");
+}
+void OperationName::UnregisteredOpModel::populateDefaultAttrs(
+ const OperationName &, NamedAttrList &) {}
+void OperationName::UnregisteredOpModel::printAssembly(
+ Operation *op, OpAsmPrinter &p, StringRef defaultDialect) {
+ p.printGenericOp(op);
+}
+LogicalResult
+OperationName::UnregisteredOpModel::verifyInvariants(Operation *) {
+ return success();
+}
+LogicalResult
+OperationName::UnregisteredOpModel::verifyRegionInvariants(Operation *) {
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// RegisteredOperationName
//===----------------------------------------------------------------------===//
return std::nullopt;
}
-ParseResult
-RegisteredOperationName::parseAssembly(OpAsmParser &parser,
- OperationState &result) const {
- return impl->parseAssemblyFn(parser, result);
-}
-
-void RegisteredOperationName::populateDefaultAttrs(NamedAttrList &attrs) const {
- impl->populateDefaultAttrsFn(*this, attrs);
-}
-
void RegisteredOperationName::insert(
- StringRef name, Dialect &dialect, TypeID typeID,
- ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly,
- VerifyInvariantsFn &&verifyInvariants,
- VerifyRegionInvariantsFn &&verifyRegionInvariants, FoldHookFn &&foldHook,
- GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
- detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait,
- ArrayRef<StringRef> attrNames,
- PopulateDefaultAttrsFn &&populateDefaultAttrs) {
- MLIRContext *ctx = dialect.getContext();
+ std::unique_ptr<RegisteredOperationName::Impl> ownedImpl,
+ ArrayRef<StringRef> attrNames) {
+ RegisteredOperationName::Impl *impl = ownedImpl.get();
+ MLIRContext *ctx = impl->getDialect()->getContext();
auto &ctxImpl = ctx->getImpl();
assert(ctxImpl.multiThreadedExecutionContext == 0 &&
"registering a new operation kind while in a multi-threaded execution "
attrNames.size());
for (unsigned i : llvm::seq<unsigned>(0, attrNames.size()))
new (&cachedAttrNames[i]) StringAttr(StringAttr::get(ctx, attrNames[i]));
+ impl->attributeNames = cachedAttrNames;
}
-
+ StringRef name = impl->getName().strref();
// Insert the operation info if it doesn't exist yet.
- auto it = ctxImpl.operations.insert({name, OperationName::Impl(nullptr)});
- if (it.second)
- it.first->second.name = StringAttr::get(ctx, name);
- OperationName::Impl &impl = it.first->second;
-
- if (impl.isRegistered()) {
- llvm::errs() << "error: operation named '" << name
- << "' is already registered.\n";
- abort();
- }
+ auto it = ctxImpl.operations.insert({name, nullptr});
+ it.first->second = std::move(ownedImpl);
+
+ // Update the registered info for this operation.
auto emplaced = ctxImpl.registeredOperations.try_emplace(
- name, RegisteredOperationName(&impl));
+ name, RegisteredOperationName(impl));
assert(emplaced.second && "operation name registration must be successful");
// Add emplaced operation name to the sorted operations container.
rhs.getIdentifier());
}),
value);
-
- // Update the registered info for this operation.
- impl.dialect = &dialect;
- impl.typeID = typeID;
- impl.interfaceMap = std::move(interfaceMap);
- impl.foldHookFn = std::move(foldHook);
- impl.getCanonicalizationPatternsFn = std::move(getCanonicalizationPatterns);
- impl.hasTraitFn = std::move(hasTrait);
- impl.parseAssemblyFn = std::move(parseAssembly);
- impl.printAssemblyFn = std::move(printAssembly);
- impl.verifyInvariantsFn = std::move(verifyInvariants);
- impl.verifyRegionInvariantsFn = std::move(verifyRegionInvariants);
- impl.attributeNames = cachedAttrNames;
- impl.populateDefaultAttrsFn = std::move(populateDefaultAttrs);
}
//===----------------------------------------------------------------------===//