[mlir] Make DenseArrayAttr generic
authorJeff Niu <jeff@modular.com>
Thu, 25 Aug 2022 23:21:28 +0000 (16:21 -0700)
committerJeff Niu <jeff@modular.com>
Tue, 30 Aug 2022 20:29:24 +0000 (13:29 -0700)
This patch turns `DenseArrayBaseAttr` into a fully-functional attribute by
adding a generic parser and printer, supporting bool or integer and floating
point element types with bitwidths divisible by 8. It has been renamed
to `DenseArrayAttr`. The patch maintains the specialized subclasses,
e.g. `DenseI32ArrayAttr`, which remain the preferred API for accessing
elements in C++.

This allows `DenseArrayAttr` to hold signed and unsigned integer elements:

```
array<si8: -128, 127>
array<ui8: 255>
```

"Exotic" floating point elements:

```
array<bf16: 1.2, 3.4>
```

And integers of other bitwidths:

```
array<i24: 8388607>
```

Reviewed By: rriddle, lattner

Differential Revision: https://reviews.llvm.org/D132758

mlir/include/mlir/IR/BuiltinAttributes.h
mlir/include/mlir/IR/BuiltinAttributes.td
mlir/lib/AsmParser/AttributeParser.cpp
mlir/lib/CAPI/IR/BuiltinAttributes.cpp
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/BuiltinAttributes.cpp
mlir/test/IR/attribute.mlir
mlir/test/IR/invalid-builtin-attributes.mlir
mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp

index 4f2b415..a102311 100644 (file)
@@ -761,9 +761,9 @@ namespace detail {
 /// Base class for DenseArrayAttr that is instantiated and specialized for each
 /// supported element type below.
 template <typename T>
-class DenseArrayAttr : public DenseArrayBaseAttr {
+class DenseArrayAttrImpl : public DenseArrayAttr {
 public:
-  using DenseArrayBaseAttr::DenseArrayBaseAttr;
+  using DenseArrayAttr::DenseArrayAttr;
 
   /// Implicit conversion to ArrayRef<T>.
   operator ArrayRef<T>() const;
@@ -773,7 +773,7 @@ public:
   T operator[](std::size_t index) const { return asArrayRef()[index]; }
 
   /// Builder from ArrayRef<T>.
-  static DenseArrayAttr get(MLIRContext *context, ArrayRef<T> content);
+  static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef<T> content);
 
   /// Print the short form `[42, 100, -1]` without any type prefix.
   void print(AsmPrinter &printer) const;
@@ -791,23 +791,23 @@ public:
   static bool classof(Attribute attr);
 };
 
-extern template class DenseArrayAttr<bool>;
-extern template class DenseArrayAttr<int8_t>;
-extern template class DenseArrayAttr<int16_t>;
-extern template class DenseArrayAttr<int32_t>;
-extern template class DenseArrayAttr<int64_t>;
-extern template class DenseArrayAttr<float>;
-extern template class DenseArrayAttr<double>;
+extern template class DenseArrayAttrImpl<bool>;
+extern template class DenseArrayAttrImpl<int8_t>;
+extern template class DenseArrayAttrImpl<int16_t>;
+extern template class DenseArrayAttrImpl<int32_t>;
+extern template class DenseArrayAttrImpl<int64_t>;
+extern template class DenseArrayAttrImpl<float>;
+extern template class DenseArrayAttrImpl<double>;
 } // namespace detail
 
 // Public name for all the supported DenseArrayAttr
-using DenseBoolArrayAttr = detail::DenseArrayAttr<bool>;
-using DenseI8ArrayAttr = detail::DenseArrayAttr<int8_t>;
-using DenseI16ArrayAttr = detail::DenseArrayAttr<int16_t>;
-using DenseI32ArrayAttr = detail::DenseArrayAttr<int32_t>;
-using DenseI64ArrayAttr = detail::DenseArrayAttr<int64_t>;
-using DenseF32ArrayAttr = detail::DenseArrayAttr<float>;
-using DenseF64ArrayAttr = detail::DenseArrayAttr<double>;
+using DenseBoolArrayAttr = detail::DenseArrayAttrImpl<bool>;
+using DenseI8ArrayAttr = detail::DenseArrayAttrImpl<int8_t>;
+using DenseI16ArrayAttr = detail::DenseArrayAttrImpl<int16_t>;
+using DenseI32ArrayAttr = detail::DenseArrayAttrImpl<int32_t>;
+using DenseI64ArrayAttr = detail::DenseArrayAttrImpl<int64_t>;
+using DenseF32ArrayAttr = detail::DenseArrayAttrImpl<float>;
+using DenseF64ArrayAttr = detail::DenseArrayAttrImpl<double>;
 
 //===----------------------------------------------------------------------===//
 // DenseResourceElementsAttr
index 96ab4f2..2c4b92b 100644 (file)
@@ -140,7 +140,7 @@ def Builtin_ArrayAttr : Builtin_Attr<"Array", [
 }
 
 //===----------------------------------------------------------------------===//
-// DenseArrayBaseAttr
+// DenseArrayAttr
 //===----------------------------------------------------------------------===//
 
 def Builtin_DenseArrayRawDataParameter : ArrayRefParameter<
@@ -155,23 +155,28 @@ def Builtin_DenseArrayRawDataParameter : ArrayRefParameter<
   }];
 }
 
-def Builtin_DenseArrayBase : Builtin_Attr<
-    "DenseArrayBase", [ElementsAttrInterface, TypedAttrInterface]> {
-  let summary = "A dense array of i8, i16, i32, i64, f32, or f64.";
+def Builtin_DenseArray : Builtin_Attr<
+    "DenseArray", [ElementsAttrInterface, TypedAttrInterface]> {
+  let summary = "A dense array of integer or floating point elements.";
   let description = [{
     A dense array attribute is an attribute that represents a dense array of
     primitive element types. Contrary to DenseIntOrFPElementsAttr this is a
     flat unidimensional array which does not have a storage optimization for
     splat. This allows to expose the raw array through a C++ API as
-    `ArrayRef<T>`. This is the base class attribute, the actual access is
-    intended to be managed through the subclasses `DenseI8ArrayAttr`,
-    `DenseI16ArrayAttr`, `DenseI32ArrayAttr`, `DenseI64ArrayAttr`,
-    `DenseF32ArrayAttr`, and `DenseF64ArrayAttr`.
+    `ArrayRef<T>` for compatible types. The element type must be bool or an
+    integer or float whose bitwidth is a multiple of 8. Bool elements are stored
+    as bytes.
+
+    This is the base class attribute. Access to C++ types is intended to be
+    managed through the subclasses `DenseI8ArrayAttr`, `DenseI16ArrayAttr`,
+    `DenseI32ArrayAttr`, `DenseI64ArrayAttr`, `DenseF32ArrayAttr`,
+    and `DenseF64ArrayAttr`.
 
     Syntax:
 
     ```
-    dense-array-attribute ::= `[` `:` (integer-type | float-type) tensor-literal `]`
+    dense-array-attribute ::= `array` `<` (integer-type | float-type)
+                                          (`:` tensor-literal)? `>`
     ```
     Examples:
 
@@ -181,16 +186,26 @@ def Builtin_DenseArrayBase : Builtin_Attr<
     array<f64: 42., 12.>
     ```
 
-    when a specific subclass is used as argument of an operation, the declarative
-    assembly will omit the type and print directly:
-    ```
+    When a specific subclass is used as argument of an operation, the
+    declarative assembly will omit the type and print directly:
+
+    ```mlir
     [1, 2, 3]
     ```
   }];
+
   let parameters = (ins
     AttributeSelfTypeParameter<"", "RankedTensorType">:$type,
     Builtin_DenseArrayRawDataParameter:$rawData
   );
+
+  let builders = [
+    AttrBuilderWithInferredContext<(ins "RankedTensorType":$type,
+                                        "ArrayRef<char>":$rawData), [{
+      return $_get(type.getContext(), type, rawData);
+    }]>,
+  ];
+
   let extraClassDeclaration = [{
     /// Allow implicit conversion to ElementsAttr.
     operator ElementsAttr() const {
@@ -207,13 +222,9 @@ def Builtin_DenseArrayBase : Builtin_Attr<
     const int64_t *value_begin_impl(OverloadToken<int64_t>) const;
     const float *value_begin_impl(OverloadToken<float>) const;
     const double *value_begin_impl(OverloadToken<double>) const;
-
-    /// Printer for the short form: will dispatch to the appropriate subclass.
-    void print(AsmPrinter &printer) const;
-    void print(raw_ostream &os) const;
-    /// Print the short form `42, 100, -1` without any braces or prefix.
-    void printWithoutBraces(raw_ostream &os) const;
   }];
+
+  let genVerifyDecl = 1;
 }
 
 //===----------------------------------------------------------------------===//
index 8fadc29..0bc7323 100644 (file)
@@ -827,96 +827,142 @@ ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) {
 }
 
 //===----------------------------------------------------------------------===//
-// ElementsAttr Parser
+// DenseArrayAttr Parser
 //===----------------------------------------------------------------------===//
 
 namespace {
-/// This class provides an implementation of AsmParser, allowing to call back
-/// into the libMLIRIR-provided APIs for invoking attribute parsing code defined
-/// in libMLIRIR.
-class CustomAsmParser : public AsmParserImpl<AsmParser> {
+/// A generic dense array element parser. It parsers integer and floating point
+/// elements.
+class DenseArrayElementParser {
 public:
-  CustomAsmParser(Parser &parser)
-      : AsmParserImpl<AsmParser>(parser.getToken().getLoc(), parser) {}
+  explicit DenseArrayElementParser(Type type) : type(type) {}
+
+  /// Parse an integer element.
+  ParseResult parseIntegerElement(Parser &p);
+
+  /// Parse a floating point element.
+  ParseResult parseFloatElement(Parser &p);
+
+  /// Convert the current contents to a dense array.
+  DenseArrayAttr getAttr() {
+    return DenseArrayAttr::get(RankedTensorType::get(size, type), rawData);
+  }
+
+private:
+  /// Append the raw data of an APInt to the result.
+  void append(const APInt &data);
+
+  /// The array element type.
+  Type type;
+  /// The resultant byte array representing the contents of the array.
+  std::vector<char> rawData;
+  /// The number of elements in the array.
+  int64_t size = 0;
 };
 } // namespace
 
+void DenseArrayElementParser::append(const APInt &data) {
+  unsigned byteSize = data.getBitWidth() / 8;
+  size_t offset = rawData.size();
+  rawData.insert(rawData.end(), byteSize, 0);
+  llvm::StoreIntToMemory(
+      data, reinterpret_cast<uint8_t *>(rawData.data() + offset), byteSize);
+  ++size;
+}
+
+ParseResult DenseArrayElementParser::parseIntegerElement(Parser &p) {
+  bool isNegative = p.consumeIf(Token::minus);
+
+  // Parse an integer literal as an APInt.
+  Optional<APInt> value;
+  StringRef spelling = p.getToken().getSpelling();
+  if (p.getToken().isAny(Token::kw_true, Token::kw_false)) {
+    if (!type.isInteger(1))
+      return p.emitError("expected i1 type for 'true' or 'false' values");
+    value = APInt(/*numBits=*/8, p.getToken().is(Token::kw_true),
+                  !type.isUnsignedInteger());
+    p.consumeToken();
+  } else if (p.consumeIf(Token::integer)) {
+    value = buildAttributeAPInt(type, isNegative, spelling);
+    if (!value)
+      return p.emitError("integer constant out of range");
+  } else {
+    return p.emitError("expected integer literal");
+  }
+  append(*value);
+  return success();
+}
+
+ParseResult DenseArrayElementParser::parseFloatElement(Parser &p) {
+  bool isNegative = p.consumeIf(Token::minus);
+
+  Token token = p.getToken();
+  Optional<APFloat> result;
+  auto floatType = type.cast<FloatType>();
+  if (p.consumeIf(Token::integer)) {
+    // Parse an integer literal as a float.
+    if (p.parseFloatFromIntegerLiteral(result, token, isNegative,
+                                       floatType.getFloatSemantics(),
+                                       floatType.getWidth()))
+      return failure();
+  } else if (p.consumeIf(Token::floatliteral)) {
+    // Parse a floating point literal.
+    Optional<double> val = token.getFloatingPointValue();
+    if (!val)
+      return failure();
+    result = APFloat(isNegative ? -*val : *val);
+    if (!type.isF64()) {
+      bool unused;
+      result->convert(floatType.getFloatSemantics(),
+                      APFloat::rmNearestTiesToEven, &unused);
+    }
+  } else {
+    return p.emitError("expected integer or floating point literal");
+  }
+
+  append(result->bitcastToAPInt());
+  return success();
+}
+
 /// Parse a dense array attribute.
 Attribute Parser::parseDenseArrayAttr(Type type) {
   consumeToken(Token::kw_array);
+  if (parseToken(Token::less, "expected '<' after 'array'"))
+    return {};
+
+  // Only bool or integer and floating point elements divisible by bytes are
+  // supported.
   SMLoc typeLoc = getToken().getLoc();
-  if (parseToken(Token::less, "expected '<' after 'array'") ||
-      (!type && !(type = parseType())))
+  if (!type && !(type = parseType()))
+    return {};
+  if (!type.isIntOrIndexOrFloat()) {
+    emitError(typeLoc, "expected integer or float type, got: ") << type;
+    return {};
+  }
+  if (!type.isInteger(1) && type.getIntOrFloatBitWidth() % 8 != 0) {
+    emitError(typeLoc, "element type bitwidth must be a multiple of 8");
     return {};
-  CustomAsmParser parser(*this);
-  Attribute result;
+  }
+
   // Check for empty list.
-  bool isEmptyList = getToken().is(Token::greater);
-  if (!isEmptyList &&
-      parseToken(Token::colon, "expected ':' after dense array type"))
+  if (consumeIf(Token::greater))
+    return DenseArrayAttr::get(RankedTensorType::get(0, type), {});
+  if (parseToken(Token::colon, "expected ':' after dense array type"))
     return {};
 
-  if (auto intType = type.dyn_cast<IntegerType>()) {
-    switch (type.getIntOrFloatBitWidth()) {
-    case 1:
-      if (isEmptyList)
-        result = DenseBoolArrayAttr::get(parser.getContext(), {});
-      else
-        result = DenseBoolArrayAttr::parseWithoutBraces(parser, Type{});
-      break;
-    case 8:
-      if (isEmptyList)
-        result = DenseI8ArrayAttr::get(parser.getContext(), {});
-      else
-        result = DenseI8ArrayAttr::parseWithoutBraces(parser, Type{});
-      break;
-    case 16:
-      if (isEmptyList)
-        result = DenseI16ArrayAttr::get(parser.getContext(), {});
-      else
-        result = DenseI16ArrayAttr::parseWithoutBraces(parser, Type{});
-      break;
-    case 32:
-      if (isEmptyList)
-        result = DenseI32ArrayAttr::get(parser.getContext(), {});
-      else
-        result = DenseI32ArrayAttr::parseWithoutBraces(parser, Type{});
-      break;
-    case 64:
-      if (isEmptyList)
-        result = DenseI64ArrayAttr::get(parser.getContext(), {});
-      else
-        result = DenseI64ArrayAttr::parseWithoutBraces(parser, Type{});
-      break;
-    default:
-      emitError(typeLoc, "expected i1, i8, i16, i32, or i64 but got: ") << type;
-      return {};
-    }
-  } else if (auto floatType = type.dyn_cast<FloatType>()) {
-    switch (type.getIntOrFloatBitWidth()) {
-    case 32:
-      if (isEmptyList)
-        result = DenseF32ArrayAttr::get(parser.getContext(), {});
-      else
-        result = DenseF32ArrayAttr::parseWithoutBraces(parser, Type{});
-      break;
-    case 64:
-      if (isEmptyList)
-        result = DenseF64ArrayAttr::get(parser.getContext(), {});
-      else
-        result = DenseF64ArrayAttr::parseWithoutBraces(parser, Type{});
-      break;
-    default:
-      emitError(typeLoc, "expected f32 or f64 but got: ") << type;
+  DenseArrayElementParser eltParser(type);
+  if (type.isIntOrIndex()) {
+    if (parseCommaSeparatedList(
+            [&] { return eltParser.parseIntegerElement(*this); }))
       return {};
-    }
   } else {
-    emitError(typeLoc, "expected integer or float type, got: ") << type;
-    return {};
+    if (parseCommaSeparatedList(
+            [&] { return eltParser.parseFloatElement(*this); }))
+      return {};
   }
   if (parseToken(Token::greater, "expected '>' to close an array attribute"))
     return {};
-  return result;
+  return eltParser.getAttr();
 }
 
 /// Parse a dense elements attribute.
index c50096b..b02484a 100644 (file)
@@ -383,7 +383,7 @@ MlirAttribute mlirDenseF64ArrayGet(MlirContext ctx, intptr_t size,
 // Accessors.
 
 intptr_t mlirDenseArrayGetNumElements(MlirAttribute attr) {
-  return unwrap(attr).cast<DenseArrayBaseAttr>().size();
+  return unwrap(attr).cast<DenseArrayAttr>().size();
 }
 
 //===----------------------------------------------------------------------===//
index 3bb67fc..3f20a6d 100644 (file)
@@ -1476,6 +1476,9 @@ protected:
   void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
                                      bool allowHex);
 
+  /// Print a dense array attribute.
+  void printDenseArrayAttr(DenseArrayAttr attr);
+
   void printDialectAttribute(Attribute attr);
   void printDialectType(Type type);
 
@@ -1860,12 +1863,13 @@ void AsmPrinter::Impl::printAttribute(Attribute attr,
     }
   } else if (auto stridedLayoutAttr = attr.dyn_cast<StridedLayoutAttr>()) {
     stridedLayoutAttr.print(os);
-  } else if (auto denseArrayAttr = attr.dyn_cast<DenseArrayBaseAttr>()) {
+  } else if (auto denseArrayAttr = attr.dyn_cast<DenseArrayAttr>()) {
     typeElision = AttrTypeElision::Must;
     os << "array<" << denseArrayAttr.getType().getElementType();
-    if (!denseArrayAttr.empty())
+    if (!denseArrayAttr.empty()) {
       os << ": ";
-    denseArrayAttr.printWithoutBraces(os);
+      printDenseArrayAttr(denseArrayAttr);
+    }
     os << ">";
   } else if (auto resourceAttr = attr.dyn_cast<DenseResourceElementsAttr>()) {
     os << "dense_resource<";
@@ -1890,11 +1894,11 @@ void AsmPrinter::Impl::printAttribute(Attribute attr,
 
 /// Print the integer element of a DenseElementsAttr.
 static void printDenseIntElement(const APInt &value, raw_ostream &os,
-                                 bool isSigned) {
-  if (value.getBitWidth() == 1)
+                                 Type type) {
+  if (type.isInteger(1))
     os << (value.getBoolValue() ? "true" : "false");
   else
-    value.print(os, isSigned);
+    value.print(os, !type.isUnsignedInteger());
 }
 
 static void
@@ -1988,14 +1992,13 @@ void AsmPrinter::Impl::printDenseIntOrFPElementsAttr(
     // printDenseElementsAttrImpl. This lambda was hitting a bug in gcc 9.1,9.2
     // and hence was replaced.
     if (complexElementType.isa<IntegerType>()) {
-      bool isSigned = !complexElementType.isUnsignedInteger();
       auto valueIt = attr.value_begin<std::complex<APInt>>();
       printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
         auto complexValue = *(valueIt + index);
         os << "(";
-        printDenseIntElement(complexValue.real(), os, isSigned);
+        printDenseIntElement(complexValue.real(), os, complexElementType);
         os << ",";
-        printDenseIntElement(complexValue.imag(), os, isSigned);
+        printDenseIntElement(complexValue.imag(), os, complexElementType);
         os << ")";
       });
     } else {
@@ -2010,10 +2013,9 @@ void AsmPrinter::Impl::printDenseIntOrFPElementsAttr(
       });
     }
   } else if (elementType.isIntOrIndex()) {
-    bool isSigned = !elementType.isUnsignedInteger();
     auto valueIt = attr.value_begin<APInt>();
     printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
-      printDenseIntElement(*(valueIt + index), os, isSigned);
+      printDenseIntElement(*(valueIt + index), os, elementType);
     });
   } else {
     assert(elementType.isa<FloatType>() && "unexpected element type");
@@ -2031,6 +2033,29 @@ void AsmPrinter::Impl::printDenseStringElementsAttr(
   printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn);
 }
 
+void AsmPrinter::Impl::printDenseArrayAttr(DenseArrayAttr attr) {
+  Type type = attr.getElementType();
+  unsigned bitwidth = type.isInteger(1) ? 8 : type.getIntOrFloatBitWidth();
+  unsigned byteSize = bitwidth / 8;
+  ArrayRef<char> data = attr.getRawData();
+
+  auto printElementAt = [&](unsigned i) {
+    APInt value(bitwidth, 0);
+    llvm::LoadIntFromMemory(
+        value, reinterpret_cast<const uint8_t *>(data.begin() + byteSize * i),
+        byteSize);
+    // Print the data as-is or as a float.
+    if (type.isIntOrIndex()) {
+      printDenseIntElement(value, getStream(), type);
+    } else {
+      APFloat fltVal(type.cast<FloatType>().getFloatSemantics(), value);
+      printFloatValue(fltVal, getStream());
+    }
+  };
+  llvm::interleaveComma(llvm::seq<unsigned>(0, attr.size()), getStream(),
+                        printElementAt);
+}
+
 void AsmPrinter::Impl::printType(Type type) {
   if (!type) {
     os << "<<NULL TYPE>>";
index dab80f1..8d060ef 100644 (file)
@@ -741,50 +741,50 @@ DenseElementsAttr::ComplexIntElementIterator::operator*() const {
 // DenseArrayAttr
 //===----------------------------------------------------------------------===//
 
-const bool *DenseArrayBaseAttr::value_begin_impl(OverloadToken<bool>) const {
+LogicalResult
+DenseArrayAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+                       RankedTensorType type, ArrayRef<char> rawData) {
+  if (type.getRank() != 1)
+    return emitError() << "expected rank 1 tensor type";
+  if (!type.getElementType().isIntOrIndexOrFloat())
+    return emitError() << "expected integer or floating point element type";
+  int64_t dataSize = rawData.size();
+  int64_t size = type.getShape().front();
+  if (type.getElementType().isInteger(1)) {
+    if (size != dataSize)
+      return emitError() << "expected " << size
+                         << " bytes for i1 array but got " << dataSize;
+  } else if (size * type.getElementTypeBitWidth() != dataSize * 8) {
+    return emitError() << "expected data size (" << size << " elements, "
+                       << type.getElementTypeBitWidth()
+                       << " bits each) does not match: " << dataSize
+                       << " bytes";
+  }
+  return success();
+}
+
+const bool *DenseArrayAttr::value_begin_impl(OverloadToken<bool>) const {
   return cast<DenseBoolArrayAttr>().asArrayRef().begin();
 }
-const int8_t *
-DenseArrayBaseAttr::value_begin_impl(OverloadToken<int8_t>) const {
+const int8_t *DenseArrayAttr::value_begin_impl(OverloadToken<int8_t>) const {
   return cast<DenseI8ArrayAttr>().asArrayRef().begin();
 }
-const int16_t *
-DenseArrayBaseAttr::value_begin_impl(OverloadToken<int16_t>) const {
+const int16_t *DenseArrayAttr::value_begin_impl(OverloadToken<int16_t>) const {
   return cast<DenseI16ArrayAttr>().asArrayRef().begin();
 }
-const int32_t *
-DenseArrayBaseAttr::value_begin_impl(OverloadToken<int32_t>) const {
+const int32_t *DenseArrayAttr::value_begin_impl(OverloadToken<int32_t>) const {
   return cast<DenseI32ArrayAttr>().asArrayRef().begin();
 }
-const int64_t *
-DenseArrayBaseAttr::value_begin_impl(OverloadToken<int64_t>) const {
+const int64_t *DenseArrayAttr::value_begin_impl(OverloadToken<int64_t>) const {
   return cast<DenseI64ArrayAttr>().asArrayRef().begin();
 }
-const float *DenseArrayBaseAttr::value_begin_impl(OverloadToken<float>) const {
+const float *DenseArrayAttr::value_begin_impl(OverloadToken<float>) const {
   return cast<DenseF32ArrayAttr>().asArrayRef().begin();
 }
-const double *
-DenseArrayBaseAttr::value_begin_impl(OverloadToken<double>) const {
+const double *DenseArrayAttr::value_begin_impl(OverloadToken<double>) const {
   return cast<DenseF64ArrayAttr>().asArrayRef().begin();
 }
 
-void DenseArrayBaseAttr::print(AsmPrinter &printer) const {
-  print(printer.getStream());
-}
-
-void DenseArrayBaseAttr::printWithoutBraces(raw_ostream &os) const {
-  llvm::TypeSwitch<DenseArrayBaseAttr>(*this)
-      .Case<DenseBoolArrayAttr, DenseI8ArrayAttr, DenseI16ArrayAttr,
-            DenseI32ArrayAttr, DenseI64ArrayAttr, DenseF32ArrayAttr,
-            DenseF64ArrayAttr>([&](auto attr) { attr.printWithoutBraces(os); });
-}
-
-void DenseArrayBaseAttr::print(raw_ostream &os) const {
-  os << "[";
-  printWithoutBraces(os);
-  os << "]";
-}
-
 namespace {
 /// Instantiations of this class provide utilities for interacting with native
 /// data types in the context of DenseArrayAttr.
@@ -869,19 +869,19 @@ struct DenseArrayAttrUtil<double> {
 } // namespace
 
 template <typename T>
-void DenseArrayAttr<T>::print(AsmPrinter &printer) const {
+void DenseArrayAttrImpl<T>::print(AsmPrinter &printer) const {
   print(printer.getStream());
 }
 
 template <typename T>
-void DenseArrayAttr<T>::printWithoutBraces(raw_ostream &os) const {
+void DenseArrayAttrImpl<T>::printWithoutBraces(raw_ostream &os) const {
   llvm::interleaveComma(asArrayRef(), os, [&](T value) {
     DenseArrayAttrUtil<T>::printElement(os, value);
   });
 }
 
 template <typename T>
-void DenseArrayAttr<T>::print(raw_ostream &os) const {
+void DenseArrayAttrImpl<T>::print(raw_ostream &os) const {
   os << "[";
   printWithoutBraces(os);
   os << "]";
@@ -889,8 +889,8 @@ void DenseArrayAttr<T>::print(raw_ostream &os) const {
 
 /// Parse a DenseArrayAttr without the braces: `1, 2, 3`
 template <typename T>
-Attribute DenseArrayAttr<T>::parseWithoutBraces(AsmParser &parser,
-                                                Type odsType) {
+Attribute DenseArrayAttrImpl<T>::parseWithoutBraces(AsmParser &parser,
+                                                    Type odsType) {
   SmallVector<T> data;
   if (failed(parser.parseCommaSeparatedList([&]() {
         T value;
@@ -905,7 +905,7 @@ Attribute DenseArrayAttr<T>::parseWithoutBraces(AsmParser &parser,
 
 /// Parse a DenseArrayAttr: `[ 1, 2, 3 ]`
 template <typename T>
-Attribute DenseArrayAttr<T>::parse(AsmParser &parser, Type odsType) {
+Attribute DenseArrayAttrImpl<T>::parse(AsmParser &parser, Type odsType) {
   if (parser.parseLSquare())
     return {};
   // Handle empty list case.
@@ -919,7 +919,7 @@ Attribute DenseArrayAttr<T>::parse(AsmParser &parser, Type odsType) {
 
 /// Conversion from DenseArrayAttr<T> to ArrayRef<T>.
 template <typename T>
-DenseArrayAttr<T>::operator ArrayRef<T>() const {
+DenseArrayAttrImpl<T>::operator ArrayRef<T>() const {
   ArrayRef<char> raw = getRawData();
   assert((raw.size() % sizeof(T)) == 0);
   return ArrayRef<T>(reinterpret_cast<const T *>(raw.data()),
@@ -928,19 +928,19 @@ DenseArrayAttr<T>::operator ArrayRef<T>() const {
 
 /// Builds a DenseArrayAttr<T> from an ArrayRef<T>.
 template <typename T>
-DenseArrayAttr<T> DenseArrayAttr<T>::get(MLIRContext *context,
-                                         ArrayRef<T> content) {
+DenseArrayAttrImpl<T> DenseArrayAttrImpl<T>::get(MLIRContext *context,
+                                                 ArrayRef<T> content) {
   auto shapedType = RankedTensorType::get(
       content.size(), DenseArrayAttrUtil<T>::getElementType(context));
   auto rawArray = ArrayRef<char>(reinterpret_cast<const char *>(content.data()),
                                  content.size() * sizeof(T));
   return Base::get(context, shapedType, rawArray)
-      .template cast<DenseArrayAttr<T>>();
+      .template cast<DenseArrayAttrImpl<T>>();
 }
 
 template <typename T>
-bool DenseArrayAttr<T>::classof(Attribute attr) {
-  if (auto denseArray = attr.dyn_cast<DenseArrayBaseAttr>())
+bool DenseArrayAttrImpl<T>::classof(Attribute attr) {
+  if (auto denseArray = attr.dyn_cast<DenseArrayAttr>())
     return DenseArrayAttrUtil<T>::checkElementType(denseArray.getElementType());
   return false;
 }
@@ -948,13 +948,13 @@ bool DenseArrayAttr<T>::classof(Attribute attr) {
 namespace mlir {
 namespace detail {
 // Explicit instantiation for all the supported DenseArrayAttr.
-template class DenseArrayAttr<bool>;
-template class DenseArrayAttr<int8_t>;
-template class DenseArrayAttr<int16_t>;
-template class DenseArrayAttr<int32_t>;
-template class DenseArrayAttr<int64_t>;
-template class DenseArrayAttr<float>;
-template class DenseArrayAttr<double>;
+template class DenseArrayAttrImpl<bool>;
+template class DenseArrayAttrImpl<int8_t>;
+template class DenseArrayAttrImpl<int16_t>;
+template class DenseArrayAttrImpl<int32_t>;
+template class DenseArrayAttrImpl<int64_t>;
+template class DenseArrayAttrImpl<float>;
+template class DenseArrayAttrImpl<double>;
 } // namespace detail
 } // namespace mlir
 
index cf2d533..6405155 100644 (file)
@@ -569,6 +569,26 @@ func.func @dense_array_attr() attributes {
                f64attr = [-142.]
 // CHECK-SAME: emptyattr = []
                emptyattr = []
+
+  // CHECK: array.sizes
+  // CHECK-SAME: i0 = array<i0: 0, 0>
+  // CHECK-SAME: ui0 = array<ui0: 0, 0>
+  // CHECK-SAME: si0 = array<si0: 0, 0>
+  // CHECK-SAME: i24 = array<i24: -42, 42, 8388607>
+  // CHECK-SAME: ui24 = array<ui24: 16777215>
+  // CHECK-SAME: si24 = array<si24: -8388608>
+  // CHECK-SAME: bf16 = array<bf16: 1.2{{[0-9]+}}e+00, 3.4{{[0-9]+}}e+00>
+  // CHECK-SAME: f16 = array<f16: 1.{{[0-9]+}}e+00, 3.{{[0-9]+}}e+00>
+  "array.sizes"() {
+    x0_i0 = array<i0: 0, 0>,
+    x1_ui0 = array<ui0: 0, 0>,
+    x2_si0 = array<si0: 0, 0>,
+    x3_i24 = array<i24: -42, 42, 8388607>,
+    x4_ui24 = array<ui24: 16777215>,
+    x5_si24 = array<si24: -8388608>,
+    x6_bf16 = array<bf16: 1.2, 3.4>,
+    x7_f16 = array<f16: 1., 3.>
+  }: () -> ()
   return
 }
 
index 8095119..0444bef 100644 (file)
@@ -521,3 +521,28 @@ func.func @duplicate_dictionary_attr_key() {
 
 // expected-error@+1 {{`dense_resource` expected a shaped type}}
 #attr = dense_resource<resource> : i32
+
+// -----
+
+// expected-error@below {{expected '<' after 'array'}}
+#attr = array
+
+// -----
+
+// expected-error@below {{expected integer or float type}}
+#attr = array<vector<i32>>
+
+// -----
+
+// expected-error@below {{element type bitwidth must be a multiple of 8}}
+#attr = array<i7>
+
+// -----
+
+// expected-error@below {{expected ':' after dense array type}}
+#attr = array<i8)
+
+// -----
+
+// expected-error@below {{expected '>' to close an array attribute}}
+#attr = array<i8: 1)
index 6257799..23fde12 100644 (file)
@@ -41,9 +41,8 @@ struct TestElementsAttrInterface
         auto elementsAttr = attr.getValue().dyn_cast<ElementsAttr>();
         if (!elementsAttr)
           continue;
-        if (auto concreteAttr =
-                attr.getValue().dyn_cast<DenseArrayBaseAttr>()) {
-          llvm::TypeSwitch<DenseArrayBaseAttr>(concreteAttr)
+        if (auto concreteAttr = attr.getValue().dyn_cast<DenseArrayAttr>()) {
+          llvm::TypeSwitch<DenseArrayAttr>(concreteAttr)
               .Case([&](DenseBoolArrayAttr attr) {
                 testElementsAttrIteration<bool>(op, attr, "bool");
               })