[mlir] Remove TypedAttr and ElementsAttr from DenseArrayAttr
authorJeff Niu <jeff@modular.com>
Tue, 8 Nov 2022 04:20:59 +0000 (20:20 -0800)
committerJeff Niu <jeff@modular.com>
Mon, 5 Dec 2022 21:27:55 +0000 (13:27 -0800)
This patch removes the implementation of TypedAttr and ElementsAttr
from DenseArrayAttr and, in doing so, removes the need store a shaped
type. The attribute now stores a size (number of elements), an MLIR type
as a discriminator, and a raw byte array.

The intent of DenseArrayAttr was not to be a drop-in replacement for DenseElementsAttr. It was meant to be a simple container of integers or floats that map to C++ types. The ElementsAttr implementation on DenseArrayAttr had many holes in it, and fixing those holes would require evolving DenseArrayAttr in a way that is incompatible with its original purpose.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D137606

mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
mlir/include/mlir/IR/BuiltinAttributes.td
mlir/lib/AsmParser/AttributeParser.cpp
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/BuiltinAttributes.cpp
mlir/lib/IR/BuiltinDialectBytecode.cpp
mlir/test/IR/attribute.mlir
mlir/test/IR/elements-attr-interface.mlir
mlir/test/IR/invalid-builtin-attributes.mlir
mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp

index 64e5e0a..2494773 100644 (file)
@@ -165,9 +165,8 @@ public:
       if (*rawConstantIter == GEPOp::kDynamicIndex)
         return *valuesIter;
 
-      return IntegerAttr::get(
-          ElementsAttr::getElementType(base->rawConstantIndices),
-          *rawConstantIter);
+      return IntegerAttr::get(base->rawConstantIndices.getElementType(),
+                              *rawConstantIter);
     }
 
     iterator &operator++() {
index 1a06c92..adb19db 100644 (file)
@@ -155,8 +155,7 @@ def Builtin_DenseArrayRawDataParameter : ArrayRefParameter<
   }];
 }
 
-def Builtin_DenseArray : Builtin_Attr<
-    "DenseArray", [ElementsAttrInterface, TypedAttrInterface]> {
+def Builtin_DenseArray : Builtin_Attr<"DenseArray"> {
   let summary = "A dense array of integer or floating point elements.";
   let description = [{
     A dense array attribute is an attribute that represents a dense array of
@@ -195,43 +194,26 @@ def Builtin_DenseArray : Builtin_Attr<
   }];
 
   let parameters = (ins
-    AttributeSelfTypeParameter<"", "RankedTensorType">:$type,
+    "Type":$elementType,
+    "int64_t":$size,
     Builtin_DenseArrayRawDataParameter:$rawData
   );
 
   let builders = [
-    AttrBuilderWithInferredContext<(ins "RankedTensorType":$type,
+    AttrBuilderWithInferredContext<(ins "Type":$elementType, "unsigned":$size,
                                         "ArrayRef<char>":$rawData), [{
-      return $_get(type.getContext(), type, rawData);
+      return $_get(elementType.getContext(), elementType, size, rawData);
     }]>,
   ];
 
-  let extraClassDeclaration = [{
-    /// Allow implicit conversion to ElementsAttr.
-    operator ElementsAttr() const {
-      return *this ? cast<ElementsAttr>() : nullptr;
-    }
+  let genVerifyDecl = 1;
 
-    /// ElementsAttr implementation.
-    using ContiguousIterableTypesT =
-        std::tuple<bool, int8_t, int16_t, int32_t, int64_t, float, double>;
-    FailureOr<const bool *>
-    try_value_begin_impl(OverloadToken<bool>) const;
-    FailureOr<const int8_t *>
-    try_value_begin_impl(OverloadToken<int8_t>) const;
-    FailureOr<const int16_t *>
-    try_value_begin_impl(OverloadToken<int16_t>) const;
-    FailureOr<const int32_t *>
-    try_value_begin_impl(OverloadToken<int32_t>) const;
-    FailureOr<const int64_t *>
-    try_value_begin_impl(OverloadToken<int64_t>) const;
-    FailureOr<const float *>
-    try_value_begin_impl(OverloadToken<float>) const;
-    FailureOr<const double *>
-    try_value_begin_impl(OverloadToken<double>) const;
+  let extraClassDeclaration = [{
+    /// Get the number of elements in the array.
+    int64_t size() const { return getSize(); }
+    /// Return true if there are no elements in the dense array.
+    bool empty() const { return !size(); }
   }];
-
-  let genVerifyDecl = 1;
 }
 
 //===----------------------------------------------------------------------===//
index e10dd5e..b768be0 100644 (file)
@@ -844,9 +844,7 @@ public:
   ParseResult parseFloatElement(Parser &p);
 
   /// Convert the current contents to a dense array.
-  DenseArrayAttr getAttr() {
-    return DenseArrayAttr::get(RankedTensorType::get(size, type), rawData);
-  }
+  DenseArrayAttr getAttr() { return DenseArrayAttr::get(type, size, rawData); }
 
 private:
   /// Append the raw data of an APInt to the result.
@@ -934,18 +932,9 @@ Attribute Parser::parseDenseArrayAttr(Type attrType) {
     return {};
 
   SMLoc typeLoc = getToken().getLoc();
-  Type eltType;
-  // If an attribute type was provided, use its element type.
-  if (attrType) {
-    auto tensorType = attrType.dyn_cast<RankedTensorType>();
-    if (!tensorType) {
-      emitError(typeLoc, "dense array attribute expected ranked tensor type");
-      return {};
-    }
-    eltType = tensorType.getElementType();
-
-    // Otherwise, parse a type.
-  } else if (!(eltType = parseType())) {
+  Type eltType = parseType();
+  if (!eltType) {
+    emitError(typeLoc, "expected an integer or floating point type");
     return {};
   }
 
@@ -960,23 +949,11 @@ Attribute Parser::parseDenseArrayAttr(Type attrType) {
     return {};
   }
 
-  // If a type was provided, check that it matches the parsed type.
-  auto checkProvidedType = [&](DenseArrayAttr result) -> Attribute {
-    if (attrType && result.getType() != attrType) {
-      emitError(typeLoc, "expected attribute type ")
-          << attrType << " does not match parsed type " << result.getType();
-      return {};
-    }
-    return result;
-  };
-
   // Check for empty list.
-  if (consumeIf(Token::greater)) {
-    return checkProvidedType(
-        DenseArrayAttr::get(RankedTensorType::get(0, eltType), {}));
-  }
-  if (!attrType &&
-      parseToken(Token::colon, "expected ':' after dense array type"))
+  if (consumeIf(Token::greater))
+    return DenseArrayAttr::get(eltType, 0, {});
+
+  if (parseToken(Token::colon, "expected ':' after dense array type"))
     return {};
 
   DenseArrayElementParser eltParser(eltType);
@@ -991,7 +968,7 @@ Attribute Parser::parseDenseArrayAttr(Type attrType) {
   }
   if (parseToken(Token::greater, "expected '>' to close an array attribute"))
     return {};
-  return checkProvidedType(eltParser.getAttr());
+  return eltParser.getAttr();
 }
 
 /// Parse a dense elements attribute.
index 8fa7b3e..f691655 100644 (file)
@@ -2197,11 +2197,9 @@ void AsmPrinter::Impl::printAttributeImpl(Attribute attr,
     stridedLayoutAttr.print(os);
   } else if (auto denseArrayAttr = attr.dyn_cast<DenseArrayAttr>()) {
     os << "array<";
-    if (typeElision != AttrTypeElision::Must)
-      printType(denseArrayAttr.getType().getElementType());
+    printType(denseArrayAttr.getElementType());
     if (!denseArrayAttr.empty()) {
-      if (typeElision != AttrTypeElision::Must)
-        os << ": ";
+      os << ": ";
       printDenseArrayAttr(denseArrayAttr);
     }
     os << ">";
index 99f7380..e73ca99 100644 (file)
@@ -690,69 +690,21 @@ DenseElementsAttr::ComplexIntElementIterator::operator*() const {
 
 LogicalResult
 DenseArrayAttr::verify(function_ref<InFlightDiagnostic()> emitError,
-                       RankedTensorType type, ArrayRef<char> rawData) {
-  if (type.getRank() != 1)
-    return emitError() << "expected rank 1 tensor type";
-  if (!type.getElementType().isIntOrIndexOrFloat())
+                       Type elementType, int64_t size, ArrayRef<char> rawData) {
+  if (!elementType.isIntOrIndexOrFloat())
     return emitError() << "expected integer or floating point element type";
   int64_t dataSize = rawData.size();
-  int64_t size = type.getShape().front();
-  if (type.getElementType().isInteger(1)) {
-    if (size != dataSize)
-      return emitError() << "expected " << size
-                         << " bytes for i1 array but got " << dataSize;
-  } else if (size * type.getElementTypeBitWidth() != dataSize * 8) {
+  int64_t elementSize =
+      llvm::divideCeil(elementType.getIntOrFloatBitWidth(), CHAR_BIT);
+  if (size * elementSize != dataSize) {
     return emitError() << "expected data size (" << size << " elements, "
-                       << type.getElementTypeBitWidth()
-                       << " bits each) does not match: " << dataSize
+                       << elementSize
+                       << " bytes each) does not match: " << dataSize
                        << " bytes";
   }
   return success();
 }
 
-FailureOr<const bool *>
-DenseArrayAttr::try_value_begin_impl(OverloadToken<bool>) const {
-  if (auto attr = dyn_cast<DenseBoolArrayAttr>())
-    return attr.asArrayRef().begin();
-  return failure();
-}
-FailureOr<const int8_t *>
-DenseArrayAttr::try_value_begin_impl(OverloadToken<int8_t>) const {
-  if (auto attr = dyn_cast<DenseI8ArrayAttr>())
-    return attr.asArrayRef().begin();
-  return failure();
-}
-FailureOr<const int16_t *>
-DenseArrayAttr::try_value_begin_impl(OverloadToken<int16_t>) const {
-  if (auto attr = dyn_cast<DenseI16ArrayAttr>())
-    return attr.asArrayRef().begin();
-  return failure();
-}
-FailureOr<const int32_t *>
-DenseArrayAttr::try_value_begin_impl(OverloadToken<int32_t>) const {
-  if (auto attr = dyn_cast<DenseI32ArrayAttr>())
-    return attr.asArrayRef().begin();
-  return failure();
-}
-FailureOr<const int64_t *>
-DenseArrayAttr::try_value_begin_impl(OverloadToken<int64_t>) const {
-  if (auto attr = dyn_cast<DenseI64ArrayAttr>())
-    return attr.asArrayRef().begin();
-  return failure();
-}
-FailureOr<const float *>
-DenseArrayAttr::try_value_begin_impl(OverloadToken<float>) const {
-  if (auto attr = dyn_cast<DenseF32ArrayAttr>())
-    return attr.asArrayRef().begin();
-  return failure();
-}
-FailureOr<const double *>
-DenseArrayAttr::try_value_begin_impl(OverloadToken<double>) const {
-  if (auto attr = dyn_cast<DenseF64ArrayAttr>())
-    return attr.asArrayRef().begin();
-  return failure();
-}
-
 namespace {
 /// Instantiations of this class provide utilities for interacting with native
 /// data types in the context of DenseArrayAttr.
@@ -898,12 +850,11 @@ DenseArrayAttrImpl<T>::operator ArrayRef<T>() const {
 template <typename T>
 DenseArrayAttrImpl<T> DenseArrayAttrImpl<T>::get(MLIRContext *context,
                                                  ArrayRef<T> content) {
-  auto shapedType = RankedTensorType::get(
-      content.size(), DenseArrayAttrUtil<T>::getElementType(context));
+  Type elementType = DenseArrayAttrUtil<T>::getElementType(context);
   auto rawArray = ArrayRef<char>(reinterpret_cast<const char *>(content.data()),
                                  content.size() * sizeof(T));
-  return Base::get(context, shapedType, rawArray)
-      .template cast<DenseArrayAttrImpl<T>>();
+  return llvm::cast<DenseArrayAttrImpl<T>>(
+      Base::get(context, elementType, content.size(), rawArray));
 }
 
 template <typename T>
index 18f3ea8..22a563d 100644 (file)
@@ -494,17 +494,20 @@ void BuiltinDialectBytecodeInterface::write(
 
 DenseArrayAttr BuiltinDialectBytecodeInterface::readDenseArrayAttr(
     DialectBytecodeReader &reader) const {
-  RankedTensorType type;
+  Type elementType;
+  uint64_t size;
   ArrayRef<char> blob;
-  if (failed(reader.readType(type)) || failed(reader.readBlob(blob)))
+  if (failed(reader.readType(elementType)) || failed(reader.readVarInt(size)) ||
+      failed(reader.readBlob(blob)))
     return DenseArrayAttr();
-  return DenseArrayAttr::get(type, blob);
+  return DenseArrayAttr::get(elementType, size, blob);
 }
 
 void BuiltinDialectBytecodeInterface::write(
     DenseArrayAttr attr, DialectBytecodeWriter &writer) const {
   writer.writeVarInt(builtin_encoding::kDenseArrayAttr);
-  writer.writeType(attr.getType());
+  writer.writeType(attr.getElementType());
+  writer.writeVarInt(attr.getSize());
   writer.writeOwnedBlob(attr.getRawData());
 }
 
index 3081702..d494824 100644 (file)
@@ -630,8 +630,6 @@ func.func @dense_array_attr() attributes {
     x7_f16 = array<f16: 1., 3.>
   }: () -> ()
 
-  // CHECK: test.typed_attr tensor<4xi32> = array<1, 2, 3, 4>
-  test.typed_attr tensor<4xi32> = array<1, 2, 3, 4>
   return
 }
 
index 38b2b8a..aac87f2 100644 (file)
@@ -27,27 +27,6 @@ arith.constant dense<[10, 11, 12, 13, 14]> : tensor<5xi64>
 // expected-error@below {{Test iterating `IntegerAttr`: }}
 arith.constant dense<> : tensor<0xi64>
 
-// expected-error@below {{Test iterating `bool`: true, false, true, false, true, false}}
-// expected-error@below {{Test iterating `int64_t`: unable to iterate type}}
-arith.constant array<i1: true, false, true, false, true, false>
-// expected-error@below {{Test iterating `int8_t`: 10, 11, -12, 13, 14}}
-// expected-error@below {{Test iterating `int64_t`: unable to iterate type}}
-arith.constant array<i8: 10, 11, -12, 13, 14>
-// expected-error@below {{Test iterating `int16_t`: 10, 11, -12, 13, 14}}
-// expected-error@below {{Test iterating `int64_t`: unable to iterate type}}
-arith.constant array<i16: 10, 11, -12, 13, 14>
-// expected-error@below {{Test iterating `int32_t`: 10, 11, -12, 13, 14}}
-// expected-error@below {{Test iterating `int64_t`: unable to iterate type}}
-arith.constant array<i32: 10, 11, -12, 13, 14>
-// expected-error@below {{Test iterating `int64_t`: 10, 11, -12, 13, 14}}
-arith.constant array<i64: 10, 11, -12, 13, 14>
-// expected-error@below {{Test iterating `float`: 10.00, 11.00, -12.00, 13.00, 14.00}}
-// expected-error@below {{Test iterating `int64_t`: unable to iterate type}}
-arith.constant array<f32: 10., 11., -12., 13., 14.>
-// expected-error@below {{Test iterating `double`: 10.00, 11.00, -12.00, 13.00, 14.00}}
-// expected-error@below {{Test iterating `int64_t`: unable to iterate type}}
-arith.constant array<f64: 10., 11., -12., 13., 14.>
-
 // Check that we handle an external constant parsed from the config.
 // expected-error@below {{Test iterating `int64_t`: unable to iterate type}}
 // expected-error@below {{Test iterating `uint64_t`: 1, 2, 3}}
index 8e57afa..5343f97 100644 (file)
@@ -546,18 +546,3 @@ func.func @duplicate_dictionary_attr_key() {
 
 // expected-error@below {{expected '>' to close an array attribute}}
 #attr = array<i8: 1)
-
-// -----
-
-// expected-error@below {{dense array attribute expected ranked tensor type}}
-test.typed_attr i32 = array<1>
-
-// -----
-
-// expected-error@below {{does not match parsed type}}
-test.typed_attr tensor<1xi32> = array<>
-
-// -----
-
-// expected-error@below {{does not match parsed type}}
-test.typed_attr tensor<0xi32> = array<1>
index cbef0bc..9313f40 100644 (file)
@@ -21,10 +21,6 @@ template <typename T>
 static void printOneElement(InFlightDiagnostic &os, T value) {
   os << llvm::formatv("{0}", value).str();
 }
-template <>
-void printOneElement<int8_t>(InFlightDiagnostic &os, int8_t value) {
-  os << llvm::formatv("{0}", static_cast<int64_t>(value)).str();
-}
 
 namespace {
 struct TestElementsAttrInterface
@@ -41,32 +37,6 @@ struct TestElementsAttrInterface
         auto elementsAttr = attr.getValue().dyn_cast<ElementsAttr>();
         if (!elementsAttr)
           continue;
-        if (auto concreteAttr = attr.getValue().dyn_cast<DenseArrayAttr>()) {
-          llvm::TypeSwitch<DenseArrayAttr>(concreteAttr)
-              .Case([&](DenseBoolArrayAttr attr) {
-                testElementsAttrIteration<bool>(op, attr, "bool");
-              })
-              .Case([&](DenseI8ArrayAttr attr) {
-                testElementsAttrIteration<int8_t>(op, attr, "int8_t");
-              })
-              .Case([&](DenseI16ArrayAttr attr) {
-                testElementsAttrIteration<int16_t>(op, attr, "int16_t");
-              })
-              .Case([&](DenseI32ArrayAttr attr) {
-                testElementsAttrIteration<int32_t>(op, attr, "int32_t");
-              })
-              .Case([&](DenseI64ArrayAttr attr) {
-                testElementsAttrIteration<int64_t>(op, attr, "int64_t");
-              })
-              .Case([&](DenseF32ArrayAttr attr) {
-                testElementsAttrIteration<float>(op, attr, "float");
-              })
-              .Case([&](DenseF64ArrayAttr attr) {
-                testElementsAttrIteration<double>(op, attr, "double");
-              });
-          testElementsAttrIteration<int64_t>(op, elementsAttr, "int64_t");
-          continue;
-        }
         testElementsAttrIteration<int64_t>(op, elementsAttr, "int64_t");
         testElementsAttrIteration<uint64_t>(op, elementsAttr, "uint64_t");
         testElementsAttrIteration<APInt>(op, elementsAttr, "APInt");