Add support for parsing/printing the trailing type of a dialect attribute.
authorRiver Riddle <riverriddle@google.com>
Wed, 17 Jul 2019 23:05:32 +0000 (16:05 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Fri, 19 Jul 2019 18:39:04 +0000 (11:39 -0700)
This cl standardizes the printing of the type of dialect attributes to work the same as other attribute kinds. The type of dialect attributes will trail the dialect specific portion:

`#` dialect-namespace `<` attr-data `>` `:` type

The attribute parsing hooks on Dialect have been updated to take an optionally null expected type for the attribute. This matches the respective parseAttribute hooks in the OpAsmParser.

PiperOrigin-RevId: 258661298

mlir/include/mlir/IR/Attributes.h
mlir/include/mlir/IR/Dialect.h
mlir/lib/IR/AttributeDetail.h
mlir/lib/IR/Attributes.cpp
mlir/lib/IR/Dialect.cpp
mlir/lib/Parser/Parser.cpp
mlir/test/IR/parser.mlir

index dfe3f99..323473f 100644 (file)
@@ -352,14 +352,14 @@ public:
   using Base::Base;
 
   /// Get or create a new OpaqueAttr with the provided dialect and string data.
-  static OpaqueAttr get(Identifier dialect, StringRef attrData,
+  static OpaqueAttr get(Identifier dialect, StringRef attrData, Type type,
                         MLIRContext *context);
 
   /// Get or create a new OpaqueAttr with the provided dialect and string data.
   /// If the given identifier is not a valid namespace for a dialect, then a
   /// null attribute is returned.
   static OpaqueAttr getChecked(Identifier dialect, StringRef attrData,
-                               MLIRContext *context, Location location);
+                               Type type, Location location);
 
   /// Returns the dialect namespace of the opaque attribute.
   Identifier getDialectNamespace() const;
@@ -371,7 +371,7 @@ public:
   static LogicalResult
   verifyConstructionInvariants(llvm::Optional<Location> loc,
                                MLIRContext *context, Identifier dialect,
-                               StringRef attrData);
+                               StringRef attrData, Type type);
 
   static bool kindof(unsigned kind) {
     return kind == StandardAttributes::Opaque;
index d6efcde..84a0331 100644 (file)
@@ -107,10 +107,14 @@ public:
   // Parsing Hooks
   //===--------------------------------------------------------------------===//
 
-  /// Parse an attribute registered to this dialect.
-  virtual Attribute parseAttribute(StringRef attrData, Location loc) const;
-
-  /// Print an attribute registered to this dialect.
+  /// Parse an attribute registered to this dialect. If 'type' is nonnull, it
+  /// refers to the expected type of the attribute.
+  virtual Attribute parseAttribute(StringRef attrData, Type type,
+                                   Location loc) const;
+
+  /// Print an attribute registered to this dialect. Note: The type of the
+  /// attribute need not be printed by this method as it is always printed by
+  /// the caller.
   virtual void printAttribute(Attribute, raw_ostream &) const {
     llvm_unreachable("dialect has no registered attribute printing hook");
   }
index 8e75736..a226a4c 100644 (file)
@@ -273,19 +273,23 @@ struct IntegerSetAttributeStorage : public AttributeStorage {
 
 /// Opaque Attribute Storage and Uniquing.
 struct OpaqueAttributeStorage : public AttributeStorage {
-  OpaqueAttributeStorage(Identifier dialectNamespace, StringRef attrData)
-      : dialectNamespace(dialectNamespace), attrData(attrData) {}
+  OpaqueAttributeStorage(Identifier dialectNamespace, StringRef attrData,
+                         Type type)
+      : AttributeStorage(type), dialectNamespace(dialectNamespace),
+        attrData(attrData) {}
 
   /// The hash key used for uniquing.
-  using KeyTy = std::pair<Identifier, StringRef>;
+  using KeyTy = std::tuple<Identifier, StringRef, Type>;
   bool operator==(const KeyTy &key) const {
-    return key == KeyTy(dialectNamespace, attrData);
+    return key == KeyTy(dialectNamespace, attrData, getType());
   }
 
   static OpaqueAttributeStorage *construct(AttributeStorageAllocator &allocator,
                                            const KeyTy &key) {
     return new (allocator.allocate<OpaqueAttributeStorage>())
-        OpaqueAttributeStorage(key.first, allocator.copyInto(key.second));
+        OpaqueAttributeStorage(std::get<0>(key),
+                               allocator.copyInto(std::get<1>(key)),
+                               std::get<2>(key));
   }
 
   // The dialect namespace.
index 047d71c..b6f4d13 100644 (file)
@@ -292,15 +292,16 @@ IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; }
 // OpaqueAttr
 //===----------------------------------------------------------------------===//
 
-OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData,
+OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type,
                            MLIRContext *context) {
-  return Base::get(context, StandardAttributes::Opaque, dialect, attrData);
+  return Base::get(context, StandardAttributes::Opaque, dialect, attrData,
+                   type);
 }
 
 OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData,
-                                  MLIRContext *context, Location location) {
-  return Base::getChecked(location, context, StandardAttributes::Opaque,
-                          dialect, attrData);
+                                  Type type, Location location) {
+  return Base::getChecked(location, type.getContext(),
+                          StandardAttributes::Opaque, dialect, attrData, type);
 }
 
 /// Returns the dialect namespace of the opaque attribute.
@@ -314,7 +315,7 @@ StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; }
 /// Verify the construction of an opaque attribute.
 LogicalResult OpaqueAttr::verifyConstructionInvariants(
     llvm::Optional<Location> loc, MLIRContext *context, Identifier dialect,
-    StringRef attrData) {
+    StringRef attrData, Type type) {
   if (!Dialect::isValidNamespace(dialect.strref())) {
     if (loc)
       emitError(*loc) << "invalid dialect namespace '" << dialect << "'";
index ce3eb82..17dea1f 100644 (file)
@@ -79,7 +79,8 @@ LogicalResult Dialect::verifyRegionArgAttribute(Operation *, unsigned, unsigned,
 }
 
 /// Parse an attribute registered to this dialect.
-Attribute Dialect::parseAttribute(StringRef attrData, Location loc) const {
+Attribute Dialect::parseAttribute(StringRef attrData, Type type,
+                                  Location loc) const {
   emitError(loc) << "dialect '" << getNamespace()
                  << "' provides no attribute parsing hook";
   return Attribute();
index 2045203..d227abd 100644 (file)
@@ -1108,18 +1108,23 @@ Attribute Parser::parseExtendedAttr(Type type) {
       *this, Token::hash_identifier, state.attributeAliasDefinitions,
       [&](StringRef dialectName, StringRef symbolData,
           Location loc) -> Attribute {
+        // Parse an optional trailing colon type.
+        Type attrType = type;
+        if (consumeIf(Token::colon) && !(attrType = parseType()))
+          return Attribute();
+
         // If we found a registered dialect, then ask it to parse the attribute.
         if (auto *dialect = state.context->getRegisteredDialect(dialectName))
-          return dialect->parseAttribute(symbolData, loc);
+          return dialect->parseAttribute(symbolData, attrType, loc);
 
         // Otherwise, form a new opaque attribute.
         return OpaqueAttr::getChecked(
             Identifier::get(dialectName, state.context), symbolData,
-            state.context, loc);
+            attrType ? attrType : NoneType::get(state.context), loc);
       });
 
   // Ensure that the attribute has the same type as requested.
-  if (type && attr.getType() != type) {
+  if (attr && type && attr.getType() != type) {
     emitError("attribute type different than expected: expected ")
         << type << ", but got " << attr.getType();
     return nullptr;
index 1dea2d4..5c13364 100644 (file)
@@ -939,3 +939,9 @@ func @scoped_names() {
 
 // CHECK-LABEL: func @loc_attr(i1 {foo.loc_attr = loc(callsite("foo" at "mysource.cc":10:8))})
 func @loc_attr(i1 {foo.loc_attr = loc(callsite("foo" at "mysource.cc":10:8))})
+
+// CHECK-LABEL: func @dialect_attribute_with_type
+func @dialect_attribute_with_type() {
+  // CHECK-NEXT: foo = #foo.attr : i32
+  "foo.unknown_op"() {foo = #foo.attr : i32} : () -> ()
+}