Introduce a new Dense Array attribute
authorMehdi Amini <joker.eph@gmail.com>
Tue, 28 Jun 2022 11:29:27 +0000 (11:29 +0000)
committerMehdi Amini <joker.eph@gmail.com>
Tue, 28 Jun 2022 13:28:06 +0000 (13:28 +0000)
This attribute is similar to DenseElementsAttr but does not support
splat. As such it has a much simpler API and does not need any smart
iterator: it exposes direct ArrayRef access.

A new syntax is introduced so that the generic printing/parsing looks
like:

  [:i64 1, -2, 3]

This attribute beings like an ArrayAttr but has a `:` token after the
opening square brace to introduce the element type (supported are I8,
I16, I32, I64, F32, F64) and the comma separated list for the data.

This is particularly convenient for attributes intended to be small,
like those referring to shapes.
For example a `transpose` operation with a `dims` attribute could be
defined as such:

  let arguments = (ins AnyTensor:$input, DenseI64ArrayAttr:$dims);
  let assemblyFormat = "$input `dims` `=` $dims attr-dict : type($input)";

And printed this way (the element type is elided in this case):

  transpose %input dims = [0, 2, 1] : tensor<2x3x4xf32>

The C++ API for dims would just directly return an ArrayRef<int64>

RFC: https://discourse.llvm.org/t/rfc-introduce-a-new-dense-array-attribute/63279

Recommit with a custom DenseArrayBaseAttrStorage class to ensure
over-alignment of the storage to the largest type.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D123774

12 files changed:
mlir/include/mlir/IR/BuiltinAttributes.h
mlir/include/mlir/IR/BuiltinAttributes.td
mlir/include/mlir/IR/OpBase.td
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/BuiltinAttributes.cpp
mlir/lib/Parser/AttributeParser.cpp
mlir/lib/Parser/Parser.h
mlir/test/IR/attribute.mlir
mlir/test/IR/elements-attr-interface.mlir
mlir/test/IR/invalid.mlir
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp

index 85f6d3f4e638edbe2685338ee6a17ce9eb378471..f22f66fd6ac2caaab1b5d5beb968e7bf669a1c77 100644 (file)
@@ -66,8 +66,8 @@ template <typename T>
 struct is_complex_t<std::complex<T>> : public std::true_type {};
 } // namespace detail
 
-/// An attribute that represents a reference to a dense vector or tensor object.
-///
+/// An attribute that represents a reference to a dense vector or tensor
+/// object.
 class DenseElementsAttr : public Attribute {
 public:
   using Attribute::Attribute;
@@ -743,6 +743,55 @@ public:
 //===----------------------------------------------------------------------===//
 
 namespace mlir {
+namespace detail {
+/// Base class for DenseArrayAttr that is instantiated and specialized for each
+/// supported element type below.
+template <typename T>
+class DenseArrayAttr : public DenseArrayBaseAttr {
+public:
+  using DenseArrayBaseAttr::DenseArrayBaseAttr;
+
+  /// Implicit conversion to ArrayRef<T>.
+  operator ArrayRef<T>() const;
+  ArrayRef<T> asArrayRef() { return ArrayRef<T>{*this}; }
+
+  /// Builder from ArrayRef<T>.
+  static DenseArrayAttr get(MLIRContext *context, ArrayRef<T> content);
+
+  /// Print the short form `[42, 100, -1]` without any type prefix.
+  void print(AsmPrinter &printer) const;
+  void print(raw_ostream &os) const;
+  /// Print the short form `42, 100, -1` without any braces or type prefix.
+  void printWithoutBraces(raw_ostream &os) const;
+
+  /// Parse the short form `[42, 100, -1]` without any type prefix.
+  static Attribute parse(AsmParser &parser, Type odsType);
+
+  /// Parse the short form `42, 100, -1` without any type prefix or braces.
+  static Attribute parseWithoutBraces(AsmParser &parser, Type odsType);
+
+  /// Support for isa<>/cast<>.
+  static bool classof(Attribute attr);
+};
+template <>
+void DenseArrayAttr<int8_t>::printWithoutBraces(raw_ostream &os) const;
+
+extern template class DenseArrayAttr<int8_t>;
+extern template class DenseArrayAttr<int16_t>;
+extern template class DenseArrayAttr<int32_t>;
+extern template class DenseArrayAttr<int64_t>;
+extern template class DenseArrayAttr<float>;
+extern template class DenseArrayAttr<double>;
+} // namespace detail
+
+// Public name for all the supported DenseArrayAttr
+using DenseI8ArrayAttr = detail::DenseArrayAttr<int8_t>;
+using DenseI16ArrayAttr = detail::DenseArrayAttr<int16_t>;
+using DenseI32ArrayAttr = detail::DenseArrayAttr<int32_t>;
+using DenseI64ArrayAttr = detail::DenseArrayAttr<int64_t>;
+using DenseF32ArrayAttr = detail::DenseArrayAttr<float>;
+using DenseF64ArrayAttr = detail::DenseArrayAttr<double>;
+
 //===----------------------------------------------------------------------===//
 // BoolAttr
 //===----------------------------------------------------------------------===//
index 2fab3920883811e38f560a1d914461001c77cc96..503c0209eafa13f9b5450cb83eb6101f05090fc0 100644 (file)
@@ -144,6 +144,78 @@ def Builtin_ArrayAttr : Builtin_Attr<"Array", [
 // DenseIntOrFPElementsAttr
 //===----------------------------------------------------------------------===//
 
+def Builtin_DenseArrayBase : Builtin_Attr<
+    "DenseArrayBase", [ElementsAttrInterface]> {
+  let summary = "A dense array of i8, i16, i32, i64, f32, or f64.";
+  let description = [{
+    A dense array attribute is an attribute that represents a dense array of
+    primitive element types. Contrary to DenseIntOrFPElementsAttr this is a
+    flat unidimensional array which does not have a storage optimization for
+    splat. This allows to expose the raw array through a C++ API as
+    `ArrayRef<T>`. This is the base class attribute, the actual access is
+    intended to be managed through the subclasses `DenseI8ArrayAttr`,
+    `DenseI16ArrayAttr`, `DenseI32ArrayAttr`, `DenseI64ArrayAttr`,
+    `DenseF32ArrayAttr`, and `DenseF64ArrayAttr`.
+
+    Syntax:
+
+    ```
+    dense-array-attribute ::= `[` `:` (integer-type | float-type) tensor-literal `]`
+    ```
+    Examples:
+
+    ```mlir
+    [:i8]
+    [:i32 10, 42]
+    [:f64 42., 12.]
+    ```
+
+    when a specific subclass is used as argument of an operation, the declarative
+    assembly will omit the type and print directly:
+    ```
+    [1, 2, 3]
+    ```
+  }];
+  let parameters = (ins AttributeSelfTypeParameter<"", "ShapedType">:$type,
+                        "DenseArrayBaseAttr::EltType":$eltType,
+                        ArrayRefParameter<"char">:$elements);
+  let extraClassDeclaration = [{
+    // All possible supported element type.
+    enum class EltType { I8, I16, I32, I64, F32, F64 };
+
+    /// Allow implicit conversion to ElementsAttr.
+    operator ElementsAttr() const {
+      return *this ? cast<ElementsAttr>() : nullptr;
+    }
+
+    /// ElementsAttr implementation.
+    using ContiguousIterableTypesT =
+        std::tuple<int8_t, int16_t, int32_t, int64_t, float, double>;
+    const int8_t *value_begin_impl(OverloadToken<int8_t>) const;
+    const int16_t *value_begin_impl(OverloadToken<int16_t>) const;
+    const int32_t *value_begin_impl(OverloadToken<int32_t>) const;
+    const int64_t *value_begin_impl(OverloadToken<int64_t>) const;
+    const float *value_begin_impl(OverloadToken<float>) const;
+    const double *value_begin_impl(OverloadToken<double>) const;
+
+    /// Methods to support type inquiry through isa, cast, and dyn_cast.
+    EltType getElementType() const;
+    /// Printer for the short form: will dispatch to the appropriate subclass.
+    void print(AsmPrinter &printer) const;
+    void print(raw_ostream &os) const;
+    /// Print the short form `42, 100, -1` without any braces or prefix.
+    void printWithoutBraces(raw_ostream &os) const;
+  }];
+  // Do not generate the storage class, we need to handle custom storage alignment.
+  let genStorageClass = 0;
+  let genAccessors = 0;
+  let skipDefaultBuilders = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// DenseIntOrFPElementsAttr
+//===----------------------------------------------------------------------===//
+
 def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr<
     "DenseIntOrFPElements", [ElementsAttrInterface], "DenseElementsAttr"
   > {
index 4e092484a6512cc039fcdff14b9ead08f47b4cde..807216abd98132c6dd1976833042faa4290813c4 100644 (file)
@@ -1258,6 +1258,19 @@ class IntElementsAttrBase<Pred condition, string summary> :
   let convertFromStorage = "$_self";
 }
 
+class DenseArrayAttrBase<string denseAttrName, string cppType, string summaryName> :
+    ElementsAttrBase<CPred<"$_self.isa<::mlir::" # denseAttrName # ">()">,
+                     summaryName # " dense array attribute"> {
+  let storageType = "::mlir::" # denseAttrName;
+  let returnType = "::llvm::ArrayRef<" # cppType # ">";
+}
+def DenseI8ArrayAttr : DenseArrayAttrBase<"DenseI8ArrayAttr", "int8_t", "i8">;
+def DenseI16ArrayAttr : DenseArrayAttrBase<"DenseI16ArrayAttr", "int16_t", "i16">;
+def DenseI32ArrayAttr : DenseArrayAttrBase<"DenseI32ArrayAttr", "int32_t", "i32">;
+def DenseI64ArrayAttr : DenseArrayAttrBase<"DenseI64ArrayAttr", "int64_t", "i64">;
+def DenseF32ArrayAttr : DenseArrayAttrBase<"DenseF32ArrayAttr", "float", "f32">;
+def DenseF64ArrayAttr : DenseArrayAttrBase<"DenseF64ArrayAttr", "double", "f64">;
+
 def IndexElementsAttr
     : IntElementsAttrBase<CPred<[{$_self.cast<::mlir::DenseIntElementsAttr>()
                                       .getType()
index 2c8c9b0c640e15a6da6c7aeff3bf70567ebdffb9..981097e9101b6ebfddd4853b4054c5f6ea62d3e1 100644 (file)
@@ -1878,9 +1878,34 @@ void AsmPrinter::Impl::printAttribute(Attribute attr,
       }
       os << '>';
     }
-
+  } else if (auto denseArrayAttr = attr.dyn_cast<DenseArrayBaseAttr>()) {
+    typeElision = AttrTypeElision::Must;
+    switch (denseArrayAttr.getElementType()) {
+    case DenseArrayBaseAttr::EltType::I8:
+      os << "[:i8 ";
+      break;
+    case DenseArrayBaseAttr::EltType::I16:
+      os << "[:i16 ";
+      break;
+    case DenseArrayBaseAttr::EltType::I32:
+      os << "[:i32 ";
+      break;
+    case DenseArrayBaseAttr::EltType::I64:
+      os << "[:i64 ";
+      break;
+    case DenseArrayBaseAttr::EltType::F32:
+      os << "[:f32 ";
+      break;
+    case DenseArrayBaseAttr::EltType::F64:
+      os << "[:f64 ";
+      break;
+    }
+    denseArrayAttr.printWithoutBraces(os);
+    os << "]";
   } else if (auto locAttr = attr.dyn_cast<LocationAttr>()) {
     printLocation(locAttr);
+  } else {
+    llvm::report_fatal_error("Unknown builtin attribute");
   }
   // Don't print the type if we must elide it, or if it is a None type.
   if (typeElision != AttrTypeElision::Must && !attrType.isa<NoneType>()) {
index 4358460badf2c2ebe6f807563c163dead722a53d..bf72d463975491849866a0c0bdd32a4e5191b944 100644 (file)
@@ -12,6 +12,7 @@
 #include "mlir/IR/BuiltinDialect.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/IR/Types.h"
@@ -35,11 +36,11 @@ using namespace mlir::detail;
 //===----------------------------------------------------------------------===//
 
 void BuiltinDialect::registerAttributes() {
-  addAttributes<AffineMapAttr, ArrayAttr, DenseIntOrFPElementsAttr,
-                DenseStringElementsAttr, DictionaryAttr, FloatAttr,
-                SymbolRefAttr, IntegerAttr, IntegerSetAttr, OpaqueAttr,
-                OpaqueElementsAttr, SparseElementsAttr, StringAttr, TypeAttr,
-                UnitAttr>();
+  addAttributes<AffineMapAttr, ArrayAttr, DenseArrayBaseAttr,
+                DenseIntOrFPElementsAttr, DenseStringElementsAttr,
+                DictionaryAttr, FloatAttr, SymbolRefAttr, IntegerAttr,
+                IntegerSetAttr, OpaqueAttr, OpaqueElementsAttr,
+                SparseElementsAttr, StringAttr, TypeAttr, UnitAttr>();
 }
 
 //===----------------------------------------------------------------------===//
@@ -664,6 +665,274 @@ DenseElementsAttr::ComplexIntElementIterator::operator*() const {
           readBits(getData(), offset + storageWidth, bitWidth)};
 }
 
+//===----------------------------------------------------------------------===//
+// DenseArrayAttr
+//===----------------------------------------------------------------------===//
+
+/// Custom storage to ensure proper memory alignment for the allocation of
+/// DenseArray of any element type.
+struct ::mlir::detail::DenseArrayBaseAttrStorage : public AttributeStorage {
+  using KeyTy = std::tuple<ShapedType, DenseArrayBaseAttr::EltType,
+                           ::llvm::ArrayRef<char>>;
+  DenseArrayBaseAttrStorage(ShapedType type,
+                            DenseArrayBaseAttr::EltType eltType,
+                            ::llvm::ArrayRef<char> elements)
+      : AttributeStorage(type), eltType(eltType), elements(elements) {}
+
+  bool operator==(const KeyTy &tblgenKey) const {
+    return (getType() == std::get<0>(tblgenKey)) &&
+           (eltType == std::get<1>(tblgenKey)) &&
+           (elements == std::get<2>(tblgenKey));
+  }
+
+  static ::llvm::hash_code hashKey(const KeyTy &tblgenKey) {
+    return ::llvm::hash_combine(std::get<0>(tblgenKey), std::get<1>(tblgenKey),
+                                std::get<2>(tblgenKey));
+  }
+
+  static DenseArrayBaseAttrStorage *
+  construct(AttributeStorageAllocator &allocator, const KeyTy &tblgenKey) {
+    auto type = std::get<0>(tblgenKey);
+    auto eltType = std::get<1>(tblgenKey);
+    auto elements = std::get<2>(tblgenKey);
+    if (!elements.empty()) {
+      char *alloc = static_cast<char *>(
+          allocator.allocate(elements.size(), alignof(uint64_t)));
+      std::uninitialized_copy(elements.begin(), elements.end(), alloc);
+      elements = ArrayRef<char>(alloc, elements.size());
+    }
+    return new (allocator.allocate<DenseArrayBaseAttrStorage>())
+        DenseArrayBaseAttrStorage(type, eltType, elements);
+  }
+
+  DenseArrayBaseAttr::EltType eltType;
+  ::llvm::ArrayRef<char> elements;
+};
+
+DenseArrayBaseAttr::EltType DenseArrayBaseAttr::getElementType() const {
+  return getImpl()->eltType;
+}
+
+const int8_t *
+DenseArrayBaseAttr::value_begin_impl(OverloadToken<int8_t>) const {
+  return cast<DenseI8ArrayAttr>().asArrayRef().begin();
+}
+const int16_t *
+DenseArrayBaseAttr::value_begin_impl(OverloadToken<int16_t>) const {
+  return cast<DenseI16ArrayAttr>().asArrayRef().begin();
+}
+const int32_t *
+DenseArrayBaseAttr::value_begin_impl(OverloadToken<int32_t>) const {
+  return cast<DenseI32ArrayAttr>().asArrayRef().begin();
+}
+const int64_t *
+DenseArrayBaseAttr::value_begin_impl(OverloadToken<int64_t>) const {
+  return cast<DenseI64ArrayAttr>().asArrayRef().begin();
+}
+const float *DenseArrayBaseAttr::value_begin_impl(OverloadToken<float>) const {
+  return cast<DenseF32ArrayAttr>().asArrayRef().begin();
+}
+const double *
+DenseArrayBaseAttr::value_begin_impl(OverloadToken<double>) const {
+  return cast<DenseF64ArrayAttr>().asArrayRef().begin();
+}
+
+void DenseArrayBaseAttr::print(AsmPrinter &printer) const {
+  print(printer.getStream());
+}
+
+void DenseArrayBaseAttr::printWithoutBraces(raw_ostream &os) const {
+  switch (getElementType()) {
+  case DenseArrayBaseAttr::EltType::I8:
+    this->cast<DenseI8ArrayAttr>().printWithoutBraces(os);
+    return;
+  case DenseArrayBaseAttr::EltType::I16:
+    this->cast<DenseI16ArrayAttr>().printWithoutBraces(os);
+    return;
+  case DenseArrayBaseAttr::EltType::I32:
+    this->cast<DenseI32ArrayAttr>().printWithoutBraces(os);
+    return;
+  case DenseArrayBaseAttr::EltType::I64:
+    this->cast<DenseI64ArrayAttr>().printWithoutBraces(os);
+    return;
+  case DenseArrayBaseAttr::EltType::F32:
+    this->cast<DenseF32ArrayAttr>().printWithoutBraces(os);
+    return;
+  case DenseArrayBaseAttr::EltType::F64:
+    this->cast<DenseF64ArrayAttr>().printWithoutBraces(os);
+    return;
+  }
+  llvm_unreachable("<unknown DenseArrayBaseAttr>");
+}
+
+void DenseArrayBaseAttr::print(raw_ostream &os) const {
+  os << "[";
+  printWithoutBraces(os);
+  os << "]";
+}
+
+template <typename T>
+void DenseArrayAttr<T>::print(AsmPrinter &printer) const {
+  print(printer.getStream());
+}
+
+template <typename T>
+void DenseArrayAttr<T>::printWithoutBraces(raw_ostream &os) const {
+  ArrayRef<T> values{*this};
+  llvm::interleaveComma(values, os);
+}
+
+/// Specialization for int8_t for forcing printing as number instead of chars.
+template <>
+void DenseArrayAttr<int8_t>::printWithoutBraces(raw_ostream &os) const {
+  ArrayRef<int8_t> values{*this};
+  llvm::interleaveComma(values, os, [&](int64_t v) { os << v; });
+}
+
+template <typename T>
+void DenseArrayAttr<T>::print(raw_ostream &os) const {
+  os << "[";
+  printWithoutBraces(os);
+  os << "]";
+}
+
+/// Parse a single element: generic template for int types, specialized for
+/// floating points below.
+template <typename T>
+static ParseResult parseDenseArrayAttrElt(AsmParser &parser, T &value) {
+  return parser.parseInteger(value);
+}
+
+template <>
+ParseResult parseDenseArrayAttrElt<float>(AsmParser &parser, float &value) {
+  double doubleVal;
+  if (parser.parseFloat(doubleVal))
+    return failure();
+  value = doubleVal;
+  return success();
+}
+
+template <>
+ParseResult parseDenseArrayAttrElt<double>(AsmParser &parser, double &value) {
+  return parser.parseFloat(value);
+}
+
+/// Parse a DenseArrayAttr without the braces: `1, 2, 3`
+template <typename T>
+Attribute DenseArrayAttr<T>::parseWithoutBraces(AsmParser &parser,
+                                                Type odsType) {
+  SmallVector<T> data;
+  if (failed(parser.parseCommaSeparatedList([&]() {
+        T value;
+        if (parseDenseArrayAttrElt(parser, value))
+          return failure();
+        data.push_back(value);
+        return success();
+      })))
+    return {};
+  return get(parser.getContext(), data);
+}
+
+/// Parse a DenseArrayAttr: `[ 1, 2, 3 ]`
+template <typename T>
+Attribute DenseArrayAttr<T>::parse(AsmParser &parser, Type odsType) {
+  if (parser.parseLSquare())
+    return {};
+  Attribute result = parseWithoutBraces(parser, odsType);
+  if (parser.parseRSquare())
+    return {};
+  return result;
+}
+
+/// Conversion from DenseArrayAttr<T> to ArrayRef<T>.
+template <typename T>
+DenseArrayAttr<T>::operator ArrayRef<T>() const {
+  ArrayRef<char> raw = getImpl()->elements;
+  assert((raw.size() % sizeof(T)) == 0);
+  return ArrayRef<T>(reinterpret_cast<const T *>(raw.data()),
+                     raw.size() / sizeof(T));
+}
+
+namespace {
+/// Mapping from C++ element type to MLIR DenseArrayAttr internals.
+template <typename T>
+struct denseArrayAttrEltTypeBuilder;
+template <>
+struct denseArrayAttrEltTypeBuilder<int8_t> {
+  constexpr static auto eltType = DenseArrayBaseAttr::EltType::I8;
+  static ShapedType getShapedType(MLIRContext *context, int64_t shape) {
+    return VectorType::get(shape, IntegerType::get(context, 8));
+  }
+};
+template <>
+struct denseArrayAttrEltTypeBuilder<int16_t> {
+  constexpr static auto eltType = DenseArrayBaseAttr::EltType::I16;
+  static ShapedType getShapedType(MLIRContext *context, int64_t shape) {
+    return VectorType::get(shape, IntegerType::get(context, 16));
+  }
+};
+template <>
+struct denseArrayAttrEltTypeBuilder<int32_t> {
+  constexpr static auto eltType = DenseArrayBaseAttr::EltType::I32;
+  static ShapedType getShapedType(MLIRContext *context, int64_t shape) {
+    return VectorType::get(shape, IntegerType::get(context, 32));
+  }
+};
+template <>
+struct denseArrayAttrEltTypeBuilder<int64_t> {
+  constexpr static auto eltType = DenseArrayBaseAttr::EltType::I64;
+  static ShapedType getShapedType(MLIRContext *context, int64_t shape) {
+    return VectorType::get(shape, IntegerType::get(context, 64));
+  }
+};
+template <>
+struct denseArrayAttrEltTypeBuilder<float> {
+  constexpr static auto eltType = DenseArrayBaseAttr::EltType::F32;
+  static ShapedType getShapedType(MLIRContext *context, int64_t shape) {
+    return VectorType::get(shape, Float32Type::get(context));
+  }
+};
+template <>
+struct denseArrayAttrEltTypeBuilder<double> {
+  constexpr static auto eltType = DenseArrayBaseAttr::EltType::F64;
+  static ShapedType getShapedType(MLIRContext *context, int64_t shape) {
+    return VectorType::get(shape, Float64Type::get(context));
+  }
+};
+} // namespace
+
+/// Builds a DenseArrayAttr<T> from an ArrayRef<T>.
+template <typename T>
+DenseArrayAttr<T> DenseArrayAttr<T>::get(MLIRContext *context,
+                                         ArrayRef<T> content) {
+  auto shapedType =
+      denseArrayAttrEltTypeBuilder<T>::getShapedType(context, content.size());
+  auto eltType = denseArrayAttrEltTypeBuilder<T>::eltType;
+  auto rawArray = ArrayRef<char>(reinterpret_cast<const char *>(content.data()),
+                                 content.size() * sizeof(T));
+  return Base::get(context, shapedType, eltType, rawArray)
+      .template cast<DenseArrayAttr<T>>();
+}
+
+template <typename T>
+bool DenseArrayAttr<T>::classof(Attribute attr) {
+  return attr.isa<DenseArrayBaseAttr>() &&
+         attr.cast<DenseArrayBaseAttr>().getElementType() ==
+             denseArrayAttrEltTypeBuilder<T>::eltType;
+}
+
+namespace mlir {
+namespace detail {
+// Explicit instantiation for all the supported DenseArrayAttr.
+template class DenseArrayAttr<int8_t>;
+template class DenseArrayAttr<int16_t>;
+template class DenseArrayAttr<int32_t>;
+template class DenseArrayAttr<int64_t>;
+template class DenseArrayAttr<float>;
+template class DenseArrayAttr<double>;
+} // namespace detail
+} // namespace mlir
+
 //===----------------------------------------------------------------------===//
 // DenseElementsAttr
 //===----------------------------------------------------------------------===//
index efc1c226a2aa00d8147e95eaa8dda822804377c4..177420668a385790ed58786c49b31f12d8d0bdb9 100644 (file)
 //===----------------------------------------------------------------------===//
 
 #include "Parser.h"
+
+#include "AsmParserImpl.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/Parser/AsmParserState.h"
 #include "llvm/ADT/StringExtras.h"
@@ -30,6 +33,7 @@ using namespace mlir::detail;
 ///                    | float-literal (`:` float-type)?
 ///                    | string-literal (`:` type)?
 ///                    | type
+///                    | `[` `:` (integer-type | float-type) tensor-literal `]`
 ///                    | `[` (attribute-value (`,` attribute-value)*)? `]`
 ///                    | `{` (attribute-entry (`,` attribute-entry)*)? `}`
 ///                    | symbol-ref-id (`::` symbol-ref-id)*
@@ -67,13 +71,16 @@ Attribute Parser::parseAttribute(Type type) {
 
   // Parse an array attribute.
   case Token::l_square: {
+    consumeToken(Token::l_square);
+    if (consumeIf(Token::colon))
+      return parseDenseArrayAttr();
     SmallVector<Attribute, 4> elements;
     auto parseElt = [&]() -> ParseResult {
       elements.push_back(parseAttribute());
       return elements.back() ? success() : failure();
     };
 
-    if (parseCommaSeparatedList(Delimiter::Square, parseElt))
+    if (parseCommaSeparatedListUntil(Token::r_square, parseElt))
       return nullptr;
     return builder.getArrayAttr(elements);
   }
@@ -812,6 +819,66 @@ ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) {
 // ElementsAttr Parser
 //===----------------------------------------------------------------------===//
 
+namespace {
+/// This class provides an implementation of AsmParser, allowing to call back
+/// into the libMLIRIR-provided APIs for invoking attribute parsing code defined
+/// in libMLIRIR.
+class CustomAsmParser : public AsmParserImpl<AsmParser> {
+public:
+  CustomAsmParser(Parser &parser)
+      : AsmParserImpl<AsmParser>(parser.getToken().getLoc(), parser) {}
+};
+} // namespace
+
+/// Parse a dense array attribute.
+Attribute Parser::parseDenseArrayAttr() {
+  auto typeLoc = getToken().getLoc();
+  auto type = parseType();
+  if (!type)
+    return {};
+  CustomAsmParser parser(*this);
+  Attribute result;
+  if (auto intType = type.dyn_cast<IntegerType>()) {
+    switch (type.getIntOrFloatBitWidth()) {
+    case 8:
+      result = DenseI8ArrayAttr::parseWithoutBraces(parser, Type{});
+      break;
+    case 16:
+      result = DenseI16ArrayAttr::parseWithoutBraces(parser, Type{});
+      break;
+    case 32:
+      result = DenseI32ArrayAttr::parseWithoutBraces(parser, Type{});
+      break;
+    case 64:
+      result = DenseI64ArrayAttr::parseWithoutBraces(parser, Type{});
+      break;
+    default:
+      emitError(typeLoc, "expected i8, i16, i32, or i64 but got: ") << type;
+      return {};
+    }
+  } else if (auto floatType = type.dyn_cast<FloatType>()) {
+    switch (type.getIntOrFloatBitWidth()) {
+    case 32:
+      result = DenseF32ArrayAttr::parseWithoutBraces(parser, Type{});
+      break;
+    case 64:
+      result = DenseF64ArrayAttr::parseWithoutBraces(parser, Type{});
+      break;
+    default:
+      emitError(typeLoc, "expected f32 or f64 but got: ") << type;
+      return {};
+    }
+  } else {
+    emitError(typeLoc, "expected integer or float type, got: ") << type;
+    return {};
+  }
+  if (!consumeIf(Token::r_square)) {
+    emitError("expected ']' to close an array attribute");
+    return {};
+  }
+  return result;
+}
+
 /// Parse a dense elements attribute.
 Attribute Parser::parseDenseElementsAttr(Type attrType) {
   auto attribLoc = getToken().getLoc();
index 357de93d73a1a65c613b8f9fc30315ff8baf4583..e97c62cd91d8f8b3c1336dfa7641e81232d8641f 100644 (file)
@@ -264,6 +264,9 @@ public:
   Attribute parseDenseElementsAttr(Type attrType);
   ShapedType parseElementsLiteralType(Type type);
 
+  /// Parse a DenseArrayAttr.
+  Attribute parseDenseArrayAttr();
+
   /// Parse a sparse elements attribute.
   Attribute parseSparseElementsAttr(Type attrType);
 
index 19f87679560629456ebe99fef83478ad5a4e8552..f6b274015a3c42296357ec5e8247ddf600dd5cdd 100644 (file)
@@ -513,6 +513,45 @@ func.func @simple_scalar_example() {
   return
 }
 
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Test DenseArrayAttr
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @dense_array_attr
+func.func @dense_array_attr() attributes{ 
+// CHECK-SAME: f32attr = [:f32 1.024000e+03, 4.530000e+02, -6.435000e+03],
+               f32attr = [:f32 1024., 453., -6435.],
+// CHECK-SAME: f64attr = [:f64 -1.420000e+02],
+               f64attr = [:f64 -142.],
+// CHECK-SAME: i16attr = [:i16 3, 5, -4, 10],
+               i16attr = [:i16 3, 5, -4, 10],
+// CHECK-SAME: i32attr = [:i32 1024, 453, -6435],
+               i32attr = [:i32 1024, 453, -6435],
+// CHECK-SAME: i64attr = [:i64 -142],
+               i64attr = [:i64 -142],
+// CHECK-SAME: i8attr = [:i8 1, -2, 3]
+               i8attr = [:i8 1, -2, 3]
+ } {
+// CHECK:  test.dense_array_attr
+  test.dense_array_attr
+// CHECK-SAME: i8attr = [1, -2, 3]
+               i8attr = [1, -2, 3]
+// CHECK-SAME: i16attr = [3, 5, -4, 10]
+               i16attr = [3, 5, -4, 10]
+// CHECK-SAME: i32attr = [1024, 453, -6435]
+               i32attr = [1024, 453, -6435]
+// CHECK-SAME: i64attr = [-142]
+               i64attr = [-142]
+// CHECK-SAME: f32attr = [1.024000e+03, 4.530000e+02, -6.435000e+03]
+               f32attr = [1024., 453., -6435.]
+// CHECK-SAME: f64attr = [-1.420000e+02]
+               f64attr = [-142.]
+  return
+}
+
 // -----
 
 //===----------------------------------------------------------------------===//
index be1020d952271e48417e6d197bf88ceb6bbcc783..a476fb1ca3b3b8174e6c2048c2f7f41583fec367 100644 (file)
@@ -5,23 +5,40 @@
 // This tests that the abstract iteration of ElementsAttr works properly, and
 // is properly failable when necessary.
 
+// expected-error@below {{Test iterating `int64_t`: unable to iterate type}}
 // expected-error@below {{Test iterating `uint64_t`: 10, 11, 12, 13, 14}}
 // expected-error@below {{Test iterating `APInt`: 10, 11, 12, 13, 14}}
 // expected-error@below {{Test iterating `IntegerAttr`: 10 : i64, 11 : i64, 12 : i64, 13 : i64, 14 : i64}}
 arith.constant #test.i64_elements<[10, 11, 12, 13, 14]> : tensor<5xi64>
 
+// expected-error@below {{Test iterating `int64_t`: 10, 11, 12, 13, 14}}
 // expected-error@below {{Test iterating `uint64_t`: 10, 11, 12, 13, 14}}
 // expected-error@below {{Test iterating `APInt`: 10, 11, 12, 13, 14}}
 // expected-error@below {{Test iterating `IntegerAttr`: 10 : i64, 11 : i64, 12 : i64, 13 : i64, 14 : i64}}
 arith.constant dense<[10, 11, 12, 13, 14]> : tensor<5xi64>
 
+// expected-error@below {{Test iterating `int64_t`: unable to iterate type}}
 // expected-error@below {{Test iterating `uint64_t`: unable to iterate type}}
 // expected-error@below {{Test iterating `APInt`: unable to iterate type}}
 // expected-error@below {{Test iterating `IntegerAttr`: unable to iterate type}}
 arith.constant opaque<"_", "0xDEADBEEF"> : tensor<5xi64>
 
 // Check that we don't crash on empty element attributes.
+// expected-error@below {{Test iterating `int64_t`: }}
 // expected-error@below {{Test iterating `uint64_t`: }}
 // expected-error@below {{Test iterating `APInt`: }}
 // expected-error@below {{Test iterating `IntegerAttr`: }}
 arith.constant dense<> : tensor<0xi64>
+
+// expected-error@below {{Test iterating `int8_t`: 10, 11, -12, 13, 14}}
+arith.constant [:i8 10, 11, -12, 13, 14]
+// expected-error@below {{Test iterating `int16_t`: 10, 11, -12, 13, 14}}
+arith.constant [:i16 10, 11, -12, 13, 14]
+// expected-error@below {{Test iterating `int32_t`: 10, 11, -12, 13, 14}}
+arith.constant [:i32 10, 11, -12, 13, 14]
+// expected-error@below {{Test iterating `int64_t`: 10, 11, -12, 13, 14}}
+arith.constant [:i64 10, 11, -12, 13, 14]
+// expected-error@below {{Test iterating `float`: 10.00, 11.00, -12.00, 13.00, 14.00}}
+arith.constant [:f32 10., 11., -12., 13., 14.]
+// expected-error@below {{Test iterating `double`: 10.00, 11.00, -12.00, 13.00, 14.00}}
+arith.constant [:f64 10., 11., -12., 13., 14.]
index 2b98485bf8d2158e45e87703a3d98200a98adaba..3a8b7911638fee6c112dbe4ad6b0b5a75dc32782 100644 (file)
@@ -1654,7 +1654,7 @@ func.func @foo() {} // expected-error {{expected non-empty function body}}
 
 // -----
 
-// expected-error@+1 {{expected ']'}}
+// expected-error@+1 {{expected ',' or ']'}}
 "f"() { b = [@m:
 
 // -----
index c100aa2dbd67fe0668f08b93463bc9444bf342e3..325e5d91caa9bc51bceae1187b7b004486c5ebd5 100644 (file)
@@ -270,6 +270,22 @@ def StringElementsAttrOp : TEST_Op<"string_elements_attr"> {
   );
 }
 
+def DenseArrayAttrOp : TEST_Op<"dense_array_attr"> {
+  let arguments = (ins
+    DenseI8ArrayAttr:$i8attr,
+    DenseI16ArrayAttr:$i16attr,
+    DenseI32ArrayAttr:$i32attr,
+    DenseI64ArrayAttr:$i64attr,
+    DenseF32ArrayAttr:$f32attr,
+    DenseF64ArrayAttr:$f64attr
+  );
+  let assemblyFormat = [{
+   `i8attr` `=` $i8attr `i16attr` `=` $i16attr `i32attr` `=` $i32attr
+   `i64attr` `=` $i64attr  `f32attr` `=` $f32attr `f64attr` `=` $f64attr
+   attr-dict
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // Test Enum Attributes
 //===----------------------------------------------------------------------===//
index 783512d72aae2cc91ef45ab9e17abc68885387fa..f32a49bd5bedb998b4597e8cc51b2e20d8d20b39 100644 (file)
 using namespace mlir;
 using namespace test;
 
+// Helper to print one scalar value, force int8_t to print as integer instead of
+// char.
+template <typename T>
+static void printOneElement(InFlightDiagnostic &os, T value) {
+  os << llvm::formatv("{0}", value).str();
+}
+template <>
+void printOneElement<int8_t>(InFlightDiagnostic &os, int8_t value) {
+  os << llvm::formatv("{0}", static_cast<int64_t>(value)).str();
+}
+
 namespace {
 struct TestElementsAttrInterface
     : public PassWrapper<TestElementsAttrInterface, OperationPass<ModuleOp>> {
@@ -29,6 +40,31 @@ struct TestElementsAttrInterface
         auto elementsAttr = attr.getValue().dyn_cast<ElementsAttr>();
         if (!elementsAttr)
           continue;
+        if (auto concreteAttr =
+                attr.getValue().dyn_cast<DenseArrayBaseAttr>()) {
+          switch (concreteAttr.getElementType()) {
+          case DenseArrayBaseAttr::EltType::I8:
+            testElementsAttrIteration<int8_t>(op, elementsAttr, "int8_t");
+            break;
+          case DenseArrayBaseAttr::EltType::I16:
+            testElementsAttrIteration<int16_t>(op, elementsAttr, "int16_t");
+            break;
+          case DenseArrayBaseAttr::EltType::I32:
+            testElementsAttrIteration<int32_t>(op, elementsAttr, "int32_t");
+            break;
+          case DenseArrayBaseAttr::EltType::I64:
+            testElementsAttrIteration<int64_t>(op, elementsAttr, "int64_t");
+            break;
+          case DenseArrayBaseAttr::EltType::F32:
+            testElementsAttrIteration<float>(op, elementsAttr, "float");
+            break;
+          case DenseArrayBaseAttr::EltType::F64:
+            testElementsAttrIteration<double>(op, elementsAttr, "double");
+            break;
+          }
+          continue;
+        }
+        testElementsAttrIteration<int64_t>(op, elementsAttr, "int64_t");
         testElementsAttrIteration<uint64_t>(op, elementsAttr, "uint64_t");
         testElementsAttrIteration<APInt>(op, elementsAttr, "APInt");
         testElementsAttrIteration<IntegerAttr>(op, elementsAttr, "IntegerAttr");
@@ -48,9 +84,8 @@ struct TestElementsAttrInterface
       return;
     }
 
-    llvm::interleaveComma(*values, diag, [&](T value) {
-      diag << llvm::formatv("{0}", value).str();
-    });
+    llvm::interleaveComma(*values, diag,
+                          [&](T value) { printOneElement(diag, value); });
   }
 };
 } // namespace