Allow attaching a type to StringAttr.
authorRiver Riddle <riverriddle@google.com>
Thu, 27 Jun 2019 16:12:19 +0000 (09:12 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 27 Jun 2019 16:13:44 +0000 (09:13 -0700)
Some dialects allow for string types, and this allows for reusing StringAttr for constants of these types.

PiperOrigin-RevId: 255413948

mlir/g3doc/LangRef.md
mlir/include/mlir/IR/Attributes.h
mlir/include/mlir/IR/Builders.h
mlir/lib/IR/AttributeDetail.h
mlir/lib/IR/Attributes.cpp
mlir/lib/IR/Builders.cpp
mlir/lib/Parser/Parser.cpp
mlir/lib/SPIRV/SPIRVOps.cpp
mlir/test/IR/parser.mlir
mlir/test/SPIRV/ops.mlir

index 7e7c025..5e5d00c 100644 (file)
@@ -871,7 +871,7 @@ the given function.
 Syntax:
 
 ``` {.ebnf}
-string-attribute ::= string-literal
+string-attribute ::= string-literal (`:` type)?
 ```
 
 A string attribute is an attribute that represents a string literal value.
index 215dff1..5b9bfca 100644 (file)
@@ -405,8 +405,12 @@ public:
   using Base::Base;
   using ValueType = StringRef;
 
+  /// Get an instance of a StringAttr with the given string.
   static StringAttr get(StringRef bytes, MLIRContext *context);
 
+  /// Get an instance of a StringAttr with the given string and Type.
+  static StringAttr get(StringRef bytes, Type type);
+
   StringRef getValue() const;
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
index b99f091..27d0c28 100644 (file)
@@ -107,6 +107,7 @@ public:
   FloatAttr getFloatAttr(Type type, double value);
   FloatAttr getFloatAttr(Type type, const APFloat &value);
   StringAttr getStringAttr(StringRef bytes);
+  StringAttr getStringAttr(StringRef bytes, Type type);
   ArrayAttr getArrayAttr(ArrayRef<Attribute> value);
   AffineMapAttr getAffineMapAttr(AffineMap map);
   IntegerSetAttr getIntegerSetAttr(IntegerSet set);
index 0fe07a9..8e75736 100644 (file)
@@ -297,18 +297,21 @@ struct OpaqueAttributeStorage : public AttributeStorage {
 
 /// An attribute representing a string value.
 struct StringAttributeStorage : public AttributeStorage {
-  using KeyTy = StringRef;
+  using KeyTy = std::pair<StringRef, Type>;
 
-  StringAttributeStorage(StringRef value) : value(value) {}
+  StringAttributeStorage(StringRef value, Type type)
+      : AttributeStorage(type), value(value) {}
 
   /// Key equality function.
-  bool operator==(const KeyTy &key) const { return key == value; }
+  bool operator==(const KeyTy &key) const {
+    return key == KeyTy(value, getType());
+  }
 
   /// Construct a new storage instance.
   static StringAttributeStorage *construct(AttributeStorageAllocator &allocator,
                                            const KeyTy &key) {
     return new (allocator.allocate<StringAttributeStorage>())
-        StringAttributeStorage(allocator.copyInto(key));
+        StringAttributeStorage(allocator.copyInto(key.first), key.second);
   }
 
   StringRef value;
index 37ed96b..01f9a06 100644 (file)
@@ -255,7 +255,8 @@ FunctionAttr FunctionAttr::get(Function *value) {
 }
 
 FunctionAttr FunctionAttr::get(StringRef value, MLIRContext *ctx) {
-  return Base::get(ctx, StandardAttributes::Function, value);
+  return Base::get(ctx, StandardAttributes::Function, value,
+                   NoneType::get(ctx));
 }
 
 StringRef FunctionAttr::getValue() const { return getImpl()->value; }
@@ -332,7 +333,12 @@ LogicalResult OpaqueAttr::verifyConstructionInvariants(
 //===----------------------------------------------------------------------===//
 
 StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) {
-  return Base::get(context, StandardAttributes::String, bytes);
+  return get(bytes, NoneType::get(context));
+}
+
+/// Get an instance of a StringAttr with the given string and Type.
+StringAttr StringAttr::get(StringRef bytes, Type type) {
+  return Base::get(type.getContext(), StandardAttributes::String, bytes, type);
 }
 
 StringRef StringAttr::getValue() const { return getImpl()->value; }
index 43e6a44..9b30205 100644 (file)
@@ -159,6 +159,10 @@ StringAttr Builder::getStringAttr(StringRef bytes) {
   return StringAttr::get(bytes, context);
 }
 
+StringAttr Builder::getStringAttr(StringRef bytes, Type type) {
+  return StringAttr::get(bytes, type);
+}
+
 ArrayAttr Builder::getArrayAttr(ArrayRef<Attribute> value) {
   return ArrayAttr::get(value, context);
 }
index d244308..3e0f4e8 100644 (file)
@@ -926,7 +926,7 @@ ParseResult Parser::parseXInDimensionList() {
 ///                    | bool-literal
 ///                    | integer-literal (`:` (index-type | integer-type))?
 ///                    | float-literal (`:` float-type)?
-///                    | string-literal
+///                    | string-literal (`:` type)?
 ///                    | type
 ///                    | `[` (attribute-value (`,` attribute-value)*)? `]`
 ///                    | `{` (attribute-entry (`,` attribute-entry)*)? `}`
@@ -1034,6 +1034,13 @@ Attribute Parser::parseAttribute(Type type) {
   case Token::string: {
     auto val = getToken().getStringValue();
     consumeToken(Token::string);
+
+    // Parse the optional trailing colon type.
+    if (!type && consumeIf(Token::colon)) {
+      Type stringType = parseType();
+      return stringType ? StringAttr::get(val, stringType) : Attribute();
+    }
+
     return builder.getStringAttr(val);
   }
 
index a32af84..1fd27c4 100644 (file)
@@ -43,7 +43,9 @@ static ParseResult parseStorageClassAttribute(spirv::StorageClass &storageClass,
   Attribute storageClassAttr;
   SmallVector<NamedAttribute, 1> storageAttr;
   auto loc = parser->getCurrentLocation();
-  if (parser->parseAttribute(storageClassAttr, "storage_class", storageAttr)) {
+  if (parser->parseAttribute(storageClassAttr,
+                             parser->getBuilder().getNoneType(),
+                             "storage_class", storageAttr)) {
     return failure();
   }
   if (!storageClassAttr.isa<StringAttr>()) {
index 8a81be2..c3ed9da 100644 (file)
@@ -512,6 +512,9 @@ func @stringquote() -> () {
 ^bb0:
   // CHECK: "foo"() {bar = "a\22quoted\22string"} : () -> ()
   "foo"(){bar = "a\"quoted\"string"} : () -> ()
+
+  // CHECK-NEXT: "typed_string" : !foo.string
+  "foo"(){bar = "typed_string" : !foo.string} : () -> ()
   return
 }
 
index 947622d..407ce4f 100644 (file)
@@ -111,7 +111,7 @@ func @volatile_load_missing_lbrace() -> () {
 func @volatile_load_missing_rbrace() -> () {
   %0 = spv.Variable : !spv.ptr<f32, Function>
   // expected-error @+1 {{expected ']'}}
-  %1 = spv.Load "Function" %0 ["Volatile" : f32
+  %1 = spv.Load "Function" %0 ["Volatile"} : f32
   return
 }
 
@@ -247,7 +247,7 @@ func @volatile_store_missing_lbrace(%arg0 : f32) -> () {
 func @volatile_store_missing_rbrace(%arg0 : f32) -> () {
   %0 = spv.Variable : !spv.ptr<f32, Function>
   // expected-error @+1 {{expected ']'}}
-  spv.Store  "Function" %0, %arg0 ["Volatile" : f32
+  spv.Store "Function" %0, %arg0 ["Volatile"} : f32
   return
 }