[mlir] Allow dense array to be parsed with type elision
authorJeff Niu <jeff@modular.com>
Tue, 30 Aug 2022 19:13:15 +0000 (12:13 -0700)
committerJeff Niu <jeff@modular.com>
Tue, 30 Aug 2022 20:29:25 +0000 (13:29 -0700)
This patch makes parsing dense arrays with type elision work properly.
If a ranked tensor type is supplied to `parseAttribute` on a dense
array, the element type is skipped. Moreover, if type elision is set to
`AttrTypeElision::Must`, the element type is elided.

For example, this allows

```
memref.global @z : memref<3xi32> = array<1, 2, 3>
```

Fixes #57433

Depends on D132758

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D132964

mlir/lib/AsmParser/AttributeParser.cpp
mlir/lib/IR/AsmPrinter.cpp
mlir/test/IR/attribute.mlir
mlir/test/IR/invalid-builtin-attributes.mlir
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestOps.td

index 0bc7323..f4077a5 100644 (file)
@@ -925,33 +925,59 @@ ParseResult DenseArrayElementParser::parseFloatElement(Parser &p) {
 }
 
 /// Parse a dense array attribute.
-Attribute Parser::parseDenseArrayAttr(Type type) {
+Attribute Parser::parseDenseArrayAttr(Type attrType) {
   consumeToken(Token::kw_array);
   if (parseToken(Token::less, "expected '<' after 'array'"))
     return {};
 
-  // Only bool or integer and floating point elements divisible by bytes are
-  // supported.
   SMLoc typeLoc = getToken().getLoc();
-  if (!type && !(type = parseType()))
+  Type eltType;
+  // If an attribute type was provided, use its element type.
+  if (attrType) {
+    auto tensorType = attrType.dyn_cast<RankedTensorType>();
+    if (!tensorType) {
+      emitError(typeLoc, "dense array attribute expected ranked tensor type");
+      return {};
+    }
+    eltType = tensorType.getElementType();
+
+    // Otherwise, parse a type.
+  } else if (!(eltType = parseType())) {
     return {};
-  if (!type.isIntOrIndexOrFloat()) {
-    emitError(typeLoc, "expected integer or float type, got: ") << type;
+  }
+
+  // Only bool or integer and floating point elements divisible by bytes are
+  // supported.
+  if (!eltType.isIntOrIndexOrFloat()) {
+    emitError(typeLoc, "expected integer or float type, got: ") << eltType;
     return {};
   }
-  if (!type.isInteger(1) && type.getIntOrFloatBitWidth() % 8 != 0) {
+  if (!eltType.isInteger(1) && eltType.getIntOrFloatBitWidth() % 8 != 0) {
     emitError(typeLoc, "element type bitwidth must be a multiple of 8");
     return {};
   }
 
+  // If a type was provided, check that it matches the parsed type.
+  auto checkProvidedType = [&](DenseArrayAttr result) -> Attribute {
+    if (attrType && result.getType() != attrType) {
+      emitError(typeLoc, "expected attribute type ")
+          << attrType << " does not match parsed type " << result.getType();
+      return {};
+    }
+    return result;
+  };
+
   // Check for empty list.
-  if (consumeIf(Token::greater))
-    return DenseArrayAttr::get(RankedTensorType::get(0, type), {});
-  if (parseToken(Token::colon, "expected ':' after dense array type"))
+  if (consumeIf(Token::greater)) {
+    return checkProvidedType(
+        DenseArrayAttr::get(RankedTensorType::get(0, eltType), {}));
+  }
+  if (!attrType &&
+      parseToken(Token::colon, "expected ':' after dense array type"))
     return {};
 
-  DenseArrayElementParser eltParser(type);
-  if (type.isIntOrIndex()) {
+  DenseArrayElementParser eltParser(eltType);
+  if (eltType.isIntOrIndex()) {
     if (parseCommaSeparatedList(
             [&] { return eltParser.parseIntegerElement(*this); }))
       return {};
@@ -962,7 +988,7 @@ Attribute Parser::parseDenseArrayAttr(Type type) {
   }
   if (parseToken(Token::greater, "expected '>' to close an array attribute"))
     return {};
-  return eltParser.getAttr();
+  return checkProvidedType(eltParser.getAttr());
 }
 
 /// Parse a dense elements attribute.
index 3f20a6d..fedb452 100644 (file)
@@ -1864,13 +1864,16 @@ void AsmPrinter::Impl::printAttribute(Attribute attr,
   } else if (auto stridedLayoutAttr = attr.dyn_cast<StridedLayoutAttr>()) {
     stridedLayoutAttr.print(os);
   } else if (auto denseArrayAttr = attr.dyn_cast<DenseArrayAttr>()) {
-    typeElision = AttrTypeElision::Must;
-    os << "array<" << denseArrayAttr.getType().getElementType();
+    os << "array<";
+    if (typeElision != AttrTypeElision::Must)
+      printType(denseArrayAttr.getType().getElementType());
     if (!denseArrayAttr.empty()) {
-      os << ": ";
+      if (typeElision != AttrTypeElision::Must)
+        os << ": ";
       printDenseArrayAttr(denseArrayAttr);
     }
     os << ">";
+    return;
   } else if (auto resourceAttr = attr.dyn_cast<DenseResourceElementsAttr>()) {
     os << "dense_resource<";
     printResourceHandle(resourceAttr.getRawHandle());
index 6405155..7870606 100644 (file)
@@ -589,6 +589,9 @@ func.func @dense_array_attr() attributes {
     x6_bf16 = array<bf16: 1.2, 3.4>,
     x7_f16 = array<f16: 1., 3.>
   }: () -> ()
+
+  // CHECK: test.typed_attr tensor<4xi32> = array<1, 2, 3, 4>
+  test.typed_attr tensor<4xi32> = array<1, 2, 3, 4>
   return
 }
 
index 0444bef..49acce2 100644 (file)
@@ -546,3 +546,18 @@ func.func @duplicate_dictionary_attr_key() {
 
 // expected-error@below {{expected '>' to close an array attribute}}
 #attr = array<i8: 1)
+
+// -----
+
+// expected-error@below {{dense array attribute expected ranked tensor type}}
+test.typed_attr i32 = array<1>
+
+// -----
+
+// expected-error@below {{does not match parsed type}}
+test.typed_attr tensor<1xi32> = array<>
+
+// -----
+
+// expected-error@below {{does not match parsed type}}
+test.typed_attr tensor<0xi32> = array<1>
index e75c7ea..ad15ebf 100644 (file)
@@ -461,6 +461,22 @@ TestDialect::getOperationPrinter(Operation *op) const {
 }
 
 //===----------------------------------------------------------------------===//
+// TypedAttrOp
+//===----------------------------------------------------------------------===//
+
+/// Parse an attribute with a given type.
+static ParseResult parseAttrElideType(AsmParser &parser, TypeAttr type,
+                                      Attribute &attr) {
+  return parser.parseAttribute(attr, type.getValue());
+}
+
+/// Print an attribute without its type.
+static void printAttrElideType(AsmPrinter &printer, Operation *op,
+                               TypeAttr type, Attribute attr) {
+  printer.printAttributeWithoutType(attr);
+}
+
+//===----------------------------------------------------------------------===//
 // TestBranchOp
 //===----------------------------------------------------------------------===//
 
index 18775b7..f0dd4e2 100644 (file)
@@ -270,6 +270,13 @@ def StringElementsAttrOp : TEST_Op<"string_elements_attr"> {
   );
 }
 
+def TypedAttrOp : TEST_Op<"typed_attr"> {
+  let arguments = (ins TypeAttr:$type, AnyAttr:$attr);
+  let assemblyFormat = [{
+    attr-dict $type `=` custom<AttrElideType>(ref($type), $attr)
+  }];
+}
+
 def DenseArrayAttrOp : TEST_Op<"dense_array_attr"> {
   let arguments = (ins
     DenseBoolArrayAttr:$i1attr,