Ensure that all attributes are registered with a dialect. This is one of the...
authorRiver Riddle <riverriddle@google.com>
Fri, 10 May 2019 22:14:13 +0000 (15:14 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sat, 11 May 2019 02:30:23 +0000 (19:30 -0700)
--

PiperOrigin-RevId: 247684322

mlir/include/mlir/IR/AttributeSupport.h
mlir/include/mlir/IR/Attributes.h
mlir/include/mlir/IR/Dialect.h
mlir/include/mlir/IR/TypeSupport.h
mlir/lib/IR/Attributes.cpp
mlir/lib/IR/MLIRContext.cpp

index 88dbf8e..0779461 100644 (file)
@@ -54,6 +54,12 @@ public:
   /// Get the type of this attribute.
   Type getType() const;
 
+  /// Get the dialect of this attribute.
+  const Dialect &getDialect() const {
+    assert(dialect && "Malformed attribute storage object.");
+    return *dialect;
+  }
+
 protected:
   /// Construct a new attribute storage instance with the given type and a
   /// boolean that signals if the derived attribute is or contains a function
@@ -68,7 +74,14 @@ protected:
   /// Set the type of this attribute.
   void setType(Type type);
 
+  // Set the dialect for this storage instance. This is used by the
+  // AttributeUniquer when initializing a newly constructed storage object.
+  void initializeDialect(const Dialect &newDialect) { dialect = &newDialect; }
+
 private:
+  /// The dialect for this attribute.
+  const Dialect *dialect;
+
   /// This field is a pair of:
   ///  - The type of the attribute value.
   ///  - A boolean that is true if this is, or contains, a function attribute.
@@ -99,7 +112,7 @@ public:
   template <typename T, typename Kind, typename... Args>
   static T get(MLIRContext *ctx, Kind kind, Args &&... args) {
     return ctx->getAttributeUniquer().get<typename T::ImplType>(
-        getInitFn(ctx), static_cast<unsigned>(kind),
+        getInitFn(ctx, T::getClassID()), static_cast<unsigned>(kind),
         std::forward<Args>(args)...);
   }
 
@@ -112,7 +125,8 @@ public:
 
 private:
   /// Returns a functor used to initialize new attribute storage instances.
-  static std::function<void(AttributeStorage *)> getInitFn(MLIRContext *ctx);
+  static std::function<void(AttributeStorage *)>
+  getInitFn(MLIRContext *ctx, const ClassID *const attrID);
 };
 } // namespace detail
 
index 4f4d2c6..14d1569 100644 (file)
@@ -136,6 +136,9 @@ public:
   /// Return the context this attribute belongs to.
   MLIRContext *getContext() const;
 
+  /// Get the dialect this attribute is registered to.
+  const Dialect &getDialect() const;
+
   /// Return true if this field is, or contains, a function attribute.
   bool isOrContainsFunction() const;
 
index c279ffd..689c2ab 100644 (file)
@@ -173,33 +173,38 @@ protected:
 
   /// This method is used by derived classes to add their types to the set.
   template <typename... Args> void addTypes() {
-    VariadicTypeAdder<Args...>::addToSet(*this);
+    VariadicSymbolAdder<Args...>::addToSet(*this);
+  }
+
+  /// This method is used by derived classes to add their attributes to the set.
+  template <typename... Args> void addAttributes() {
+    VariadicSymbolAdder<Args...>::addToSet(*this);
   }
 
   // It would be nice to define this as variadic functions instead of a nested
   // variadic type, but we can't do that: function template partial
   // specialization is not allowed, and we can't define an overload set
   // because we don't have any arguments of the types we are pushing around.
-  template <typename First, typename... Rest> struct VariadicTypeAdder {
+  template <typename First, typename... Rest> struct VariadicSymbolAdder {
     static void addToSet(Dialect &dialect) {
-      VariadicTypeAdder<First>::addToSet(dialect);
-      VariadicTypeAdder<Rest...>::addToSet(dialect);
+      VariadicSymbolAdder<First>::addToSet(dialect);
+      VariadicSymbolAdder<Rest...>::addToSet(dialect);
     }
   };
 
-  template <typename First> struct VariadicTypeAdder<First> {
+  template <typename First> struct VariadicSymbolAdder<First> {
     static void addToSet(Dialect &dialect) {
-      dialect.addType(First::getClassID());
+      dialect.addSymbol(First::getClassID());
     }
   };
 
-  // Register a type with its given unqiue type identifer.
-  void addType(const ClassID *const typeID);
-
   // Enable support for unregistered operations.
   void allowUnknownOperations(bool allow = true) { allowUnknownOps = allow; }
 
 private:
+  // Register a symbol(e.g. type) with its given unique class identifier.
+  void addSymbol(const ClassID *const classID);
+
   Dialect(const Dialect &) = delete;
   void operator=(Dialect &) = delete;
 
index f5b8b18..684517c 100644 (file)
@@ -69,7 +69,7 @@ private:
   // when initializing a newly constructed type storage object.
   void initializeDialect(const Dialect &newDialect) { dialect = &newDialect; }
 
-  /// The registered information for the current type.
+  /// The dialect for this type.
   const Dialect *dialect;
 
   /// Space for subclasses to store data.
index 0504c49..bf1972c 100644 (file)
@@ -46,16 +46,6 @@ void AttributeStorage::setType(Type type) {
   typeAndContainsFunctionAttrPair.setPointer(type.getAsOpaquePointer());
 }
 
-/// Returns a functor used to initialize new attribute storage instances.
-std::function<void(AttributeStorage *)>
-AttributeUniquer::getInitFn(MLIRContext *ctx) {
-  return [ctx](AttributeStorage *storage) {
-    // If the attribute did not provide a type, then default to NoneType.
-    if (!storage->getType())
-      storage->setType(NoneType::get(ctx));
-  };
-}
-
 //===----------------------------------------------------------------------===//
 // Attribute
 //===----------------------------------------------------------------------===//
@@ -70,6 +60,9 @@ Type Attribute::getType() const { return impl->getType(); }
 /// Return the context this attribute belongs to.
 MLIRContext *Attribute::getContext() const { return getType().getContext(); }
 
+/// Get the dialect this attribute is registered to.
+const Dialect &Attribute::getDialect() const { return impl->getDialect(); }
+
 bool Attribute::isOrContainsFunction() const {
   return impl->isOrContainsFunctionCache();
 }
@@ -359,23 +352,19 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type,
                                          ArrayRef<char> data) {
   assert((type.getSizeInBits() <= data.size() * APInt::APINT_WORD_SIZE) &&
          "Input data bit size should be larger than that type requires");
-
-  Attribute::Kind kind;
   switch (type.getElementType().getKind()) {
   case StandardTypes::BF16:
   case StandardTypes::F16:
   case StandardTypes::F32:
   case StandardTypes::F64:
-    kind = Attribute::Kind::DenseFPElements;
-    break;
+    return AttributeUniquer::get<DenseFPElementsAttr>(
+        type.getContext(), Attribute::Kind::DenseFPElements, type, data);
   case StandardTypes::Integer:
-    kind = Attribute::Kind::DenseIntElements;
-    break;
+    return AttributeUniquer::get<DenseIntElementsAttr>(
+        type.getContext(), Attribute::Kind::DenseIntElements, type, data);
   default:
     llvm_unreachable("unexpected element type");
   }
-  return AttributeUniquer::get<DenseElementsAttr>(type.getContext(), kind, type,
-                                                  data);
 }
 
 DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type,
index 8c0e98c..a4cc076 100644 (file)
@@ -132,13 +132,17 @@ safeGetOrCreate(ContainerTy &container, KeyT &&key,
 }
 
 namespace {
-/// A builtin dialect to define types/etc that are necessary for the
-/// validity of the IR.
+/// A builtin dialect to define types/etc that are necessary for the validity of
+/// the IR.
 struct BuiltinDialect : public Dialect {
   BuiltinDialect(MLIRContext *context) : Dialect(/*name=*/"", context) {
-    addTypes<FunctionType, OpaqueType, FloatType, IndexType, IntegerType,
-             VectorType, RankedTensorType, UnrankedTensorType, MemRefType,
-             ComplexType, TupleType, NoneType>();
+    addAttributes<AffineMapAttr, ArrayAttr, BoolAttr, DenseIntElementsAttr,
+                  DenseFPElementsAttr, FloatAttr, FunctionAttr, IntegerAttr,
+                  IntegerSetAttr, OpaqueElementsAttr, SparseElementsAttr,
+                  SplatElementsAttr, StringAttr, TypeAttr, UnitAttr>();
+    addTypes<ComplexType, FloatType, FunctionType, IndexType, IntegerType,
+             MemRefType, NoneType, OpaqueType, RankedTensorType, TupleType,
+             UnrankedTensorType, VectorType>();
   }
 };
 
@@ -323,8 +327,9 @@ public:
   /// operations.
   llvm::StringMap<AbstractOperation> registeredOperations;
 
-  /// This is a mapping from type identifier to Dialect for registered types.
-  DenseMap<const ClassID *, Dialect *> registeredTypes;
+  /// This is a mapping from class identifier to Dialect for registered
+  /// attributes and types.
+  DenseMap<const ClassID *, Dialect *> registeredDialectSymbols;
 
   /// These are identifiers uniqued into this MLIRContext.
   llvm::StringMap<char, llvm::BumpPtrAllocator &> identifiers;
@@ -552,14 +557,14 @@ void Dialect::addOperation(AbstractOperation opInfo) {
   }
 }
 
-/// Register a dialect-specific type with the current context.
-void Dialect::addType(const ClassID *const typeID) {
+/// Register a dialect-specific symbol(e.g. type) with the current context.
+void Dialect::addSymbol(const ClassID *const classID) {
   auto &impl = context->getImpl();
 
   // Lock access to the context registry.
   llvm::sys::SmartScopedWriter<true> registryLock(impl.contextMutex);
-  if (!impl.registeredTypes.insert({typeID, this}).second) {
-    llvm::errs() << "error: type already registered.\n";
+  if (!impl.registeredDialectSymbols.insert({classID, this}).second) {
+    llvm::errs() << "error: dialect symbol already registered.\n";
     abort();
   }
 }
@@ -816,6 +821,15 @@ SDBMNegExpr SDBMNegExpr::get(SDBMPositiveExpr var) {
 // Type uniquing
 //===----------------------------------------------------------------------===//
 
+static Dialect &lookupDialectForSymbol(MLIRContext *ctx,
+                                       const ClassID *const classID) {
+  auto &impl = ctx->getImpl();
+  auto it = impl.registeredDialectSymbols.find(classID);
+  assert(it != impl.registeredDialectSymbols.end() &&
+         "symbol is not registered.");
+  return *it->second;
+}
+
 /// Returns the storage unqiuer used for constructing type storage instances.
 /// This should not be used directly.
 StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; }
@@ -823,10 +837,7 @@ StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; }
 /// Get the dialect that registered the type with the provided typeid.
 const Dialect &TypeUniquer::lookupDialectForType(MLIRContext *ctx,
                                                  const ClassID *const typeID) {
-  auto &impl = ctx->getImpl();
-  auto it = impl.registeredTypes.find(typeID);
-  assert(it != impl.registeredTypes.end() && "typeID is not registered.");
-  return *it->second;
+  return lookupDialectForSymbol(ctx, typeID);
 }
 
 //===----------------------------------------------------------------------===//
@@ -839,6 +850,18 @@ StorageUniquer &MLIRContext::getAttributeUniquer() {
   return getImpl().attributeUniquer;
 }
 
+/// Returns a functor used to initialize new attribute storage instances.
+std::function<void(AttributeStorage *)>
+AttributeUniquer::getInitFn(MLIRContext *ctx, const ClassID *const attrID) {
+  return [ctx, attrID](AttributeStorage *storage) {
+    storage->initializeDialect(lookupDialectForSymbol(ctx, attrID));
+
+    // If the attribute did not provide a type, then default to NoneType.
+    if (!storage->getType())
+      storage->setType(NoneType::get(ctx));
+  };
+}
+
 /// Perform a three-way comparison between the names of the specified
 /// NamedAttributes.
 static int compareNamedAttributes(const NamedAttribute *lhs,