From 53a3483cc808a709a1aa8e23f6c99d30c3404b94 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Fri, 10 May 2019 15:14:13 -0700 Subject: [PATCH] Ensure that all attributes are registered with a dialect. This is one of the final steps towards allowing dialects to define their own attributes, but there are still several things missing before this is fully supported(e.g. parsing/printing ). -- PiperOrigin-RevId: 247684322 --- mlir/include/mlir/IR/AttributeSupport.h | 18 +++++++++-- mlir/include/mlir/IR/Attributes.h | 3 ++ mlir/include/mlir/IR/Dialect.h | 23 ++++++++------ mlir/include/mlir/IR/TypeSupport.h | 2 +- mlir/lib/IR/Attributes.cpp | 25 +++++----------- mlir/lib/IR/MLIRContext.cpp | 53 +++++++++++++++++++++++---------- 6 files changed, 79 insertions(+), 45 deletions(-) diff --git a/mlir/include/mlir/IR/AttributeSupport.h b/mlir/include/mlir/IR/AttributeSupport.h index 88dbf8e..0779461 100644 --- a/mlir/include/mlir/IR/AttributeSupport.h +++ b/mlir/include/mlir/IR/AttributeSupport.h @@ -54,6 +54,12 @@ public: /// Get the type of this attribute. Type getType() const; + /// Get the dialect of this attribute. + const Dialect &getDialect() const { + assert(dialect && "Malformed attribute storage object."); + return *dialect; + } + protected: /// Construct a new attribute storage instance with the given type and a /// boolean that signals if the derived attribute is or contains a function @@ -68,7 +74,14 @@ protected: /// Set the type of this attribute. void setType(Type type); + // Set the dialect for this storage instance. This is used by the + // AttributeUniquer when initializing a newly constructed storage object. + void initializeDialect(const Dialect &newDialect) { dialect = &newDialect; } + private: + /// The dialect for this attribute. + const Dialect *dialect; + /// This field is a pair of: /// - The type of the attribute value. /// - A boolean that is true if this is, or contains, a function attribute. @@ -99,7 +112,7 @@ public: template static T get(MLIRContext *ctx, Kind kind, Args &&... args) { return ctx->getAttributeUniquer().get( - getInitFn(ctx), static_cast(kind), + getInitFn(ctx, T::getClassID()), static_cast(kind), std::forward(args)...); } @@ -112,7 +125,8 @@ public: private: /// Returns a functor used to initialize new attribute storage instances. - static std::function getInitFn(MLIRContext *ctx); + static std::function + getInitFn(MLIRContext *ctx, const ClassID *const attrID); }; } // namespace detail diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h index 4f4d2c6..14d1569 100644 --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -136,6 +136,9 @@ public: /// Return the context this attribute belongs to. MLIRContext *getContext() const; + /// Get the dialect this attribute is registered to. + const Dialect &getDialect() const; + /// Return true if this field is, or contains, a function attribute. bool isOrContainsFunction() const; diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h index c279ffd..689c2ab 100644 --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -173,33 +173,38 @@ protected: /// This method is used by derived classes to add their types to the set. template void addTypes() { - VariadicTypeAdder::addToSet(*this); + VariadicSymbolAdder::addToSet(*this); + } + + /// This method is used by derived classes to add their attributes to the set. + template void addAttributes() { + VariadicSymbolAdder::addToSet(*this); } // It would be nice to define this as variadic functions instead of a nested // variadic type, but we can't do that: function template partial // specialization is not allowed, and we can't define an overload set // because we don't have any arguments of the types we are pushing around. - template struct VariadicTypeAdder { + template struct VariadicSymbolAdder { static void addToSet(Dialect &dialect) { - VariadicTypeAdder::addToSet(dialect); - VariadicTypeAdder::addToSet(dialect); + VariadicSymbolAdder::addToSet(dialect); + VariadicSymbolAdder::addToSet(dialect); } }; - template struct VariadicTypeAdder { + template struct VariadicSymbolAdder { static void addToSet(Dialect &dialect) { - dialect.addType(First::getClassID()); + dialect.addSymbol(First::getClassID()); } }; - // Register a type with its given unqiue type identifer. - void addType(const ClassID *const typeID); - // Enable support for unregistered operations. void allowUnknownOperations(bool allow = true) { allowUnknownOps = allow; } private: + // Register a symbol(e.g. type) with its given unique class identifier. + void addSymbol(const ClassID *const classID); + Dialect(const Dialect &) = delete; void operator=(Dialect &) = delete; diff --git a/mlir/include/mlir/IR/TypeSupport.h b/mlir/include/mlir/IR/TypeSupport.h index f5b8b18..684517c 100644 --- a/mlir/include/mlir/IR/TypeSupport.h +++ b/mlir/include/mlir/IR/TypeSupport.h @@ -69,7 +69,7 @@ private: // when initializing a newly constructed type storage object. void initializeDialect(const Dialect &newDialect) { dialect = &newDialect; } - /// The registered information for the current type. + /// The dialect for this type. const Dialect *dialect; /// Space for subclasses to store data. diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp index 0504c49..bf1972c 100644 --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -46,16 +46,6 @@ void AttributeStorage::setType(Type type) { typeAndContainsFunctionAttrPair.setPointer(type.getAsOpaquePointer()); } -/// Returns a functor used to initialize new attribute storage instances. -std::function -AttributeUniquer::getInitFn(MLIRContext *ctx) { - return [ctx](AttributeStorage *storage) { - // If the attribute did not provide a type, then default to NoneType. - if (!storage->getType()) - storage->setType(NoneType::get(ctx)); - }; -} - //===----------------------------------------------------------------------===// // Attribute //===----------------------------------------------------------------------===// @@ -70,6 +60,9 @@ Type Attribute::getType() const { return impl->getType(); } /// Return the context this attribute belongs to. MLIRContext *Attribute::getContext() const { return getType().getContext(); } +/// Get the dialect this attribute is registered to. +const Dialect &Attribute::getDialect() const { return impl->getDialect(); } + bool Attribute::isOrContainsFunction() const { return impl->isOrContainsFunctionCache(); } @@ -359,23 +352,19 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type, ArrayRef data) { assert((type.getSizeInBits() <= data.size() * APInt::APINT_WORD_SIZE) && "Input data bit size should be larger than that type requires"); - - Attribute::Kind kind; switch (type.getElementType().getKind()) { case StandardTypes::BF16: case StandardTypes::F16: case StandardTypes::F32: case StandardTypes::F64: - kind = Attribute::Kind::DenseFPElements; - break; + return AttributeUniquer::get( + type.getContext(), Attribute::Kind::DenseFPElements, type, data); case StandardTypes::Integer: - kind = Attribute::Kind::DenseIntElements; - break; + return AttributeUniquer::get( + type.getContext(), Attribute::Kind::DenseIntElements, type, data); default: llvm_unreachable("unexpected element type"); } - return AttributeUniquer::get(type.getContext(), kind, type, - data); } DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type, diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index 8c0e98c..a4cc076 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -132,13 +132,17 @@ safeGetOrCreate(ContainerTy &container, KeyT &&key, } namespace { -/// A builtin dialect to define types/etc that are necessary for the -/// validity of the IR. +/// A builtin dialect to define types/etc that are necessary for the validity of +/// the IR. struct BuiltinDialect : public Dialect { BuiltinDialect(MLIRContext *context) : Dialect(/*name=*/"", context) { - addTypes(); + addAttributes(); + addTypes(); } }; @@ -323,8 +327,9 @@ public: /// operations. llvm::StringMap registeredOperations; - /// This is a mapping from type identifier to Dialect for registered types. - DenseMap registeredTypes; + /// This is a mapping from class identifier to Dialect for registered + /// attributes and types. + DenseMap registeredDialectSymbols; /// These are identifiers uniqued into this MLIRContext. llvm::StringMap identifiers; @@ -552,14 +557,14 @@ void Dialect::addOperation(AbstractOperation opInfo) { } } -/// Register a dialect-specific type with the current context. -void Dialect::addType(const ClassID *const typeID) { +/// Register a dialect-specific symbol(e.g. type) with the current context. +void Dialect::addSymbol(const ClassID *const classID) { auto &impl = context->getImpl(); // Lock access to the context registry. llvm::sys::SmartScopedWriter registryLock(impl.contextMutex); - if (!impl.registeredTypes.insert({typeID, this}).second) { - llvm::errs() << "error: type already registered.\n"; + if (!impl.registeredDialectSymbols.insert({classID, this}).second) { + llvm::errs() << "error: dialect symbol already registered.\n"; abort(); } } @@ -816,6 +821,15 @@ SDBMNegExpr SDBMNegExpr::get(SDBMPositiveExpr var) { // Type uniquing //===----------------------------------------------------------------------===// +static Dialect &lookupDialectForSymbol(MLIRContext *ctx, + const ClassID *const classID) { + auto &impl = ctx->getImpl(); + auto it = impl.registeredDialectSymbols.find(classID); + assert(it != impl.registeredDialectSymbols.end() && + "symbol is not registered."); + return *it->second; +} + /// Returns the storage unqiuer used for constructing type storage instances. /// This should not be used directly. StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; } @@ -823,10 +837,7 @@ StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; } /// Get the dialect that registered the type with the provided typeid. const Dialect &TypeUniquer::lookupDialectForType(MLIRContext *ctx, const ClassID *const typeID) { - auto &impl = ctx->getImpl(); - auto it = impl.registeredTypes.find(typeID); - assert(it != impl.registeredTypes.end() && "typeID is not registered."); - return *it->second; + return lookupDialectForSymbol(ctx, typeID); } //===----------------------------------------------------------------------===// @@ -839,6 +850,18 @@ StorageUniquer &MLIRContext::getAttributeUniquer() { return getImpl().attributeUniquer; } +/// Returns a functor used to initialize new attribute storage instances. +std::function +AttributeUniquer::getInitFn(MLIRContext *ctx, const ClassID *const attrID) { + return [ctx, attrID](AttributeStorage *storage) { + storage->initializeDialect(lookupDialectForSymbol(ctx, attrID)); + + // If the attribute did not provide a type, then default to NoneType. + if (!storage->getType()) + storage->setType(NoneType::get(ctx)); + }; +} + /// Perform a three-way comparison between the names of the specified /// NamedAttributes. static int compareNamedAttributes(const NamedAttribute *lhs, -- 2.7.4