struct is_complex_t<std::complex<T>> : public std::true_type {};
} // namespace detail
-/// An attribute that represents a reference to a dense vector or tensor object.
-///
+/// An attribute that represents a reference to a dense vector or tensor
+/// object.
class DenseElementsAttr : public Attribute {
public:
using Attribute::Attribute;
//===----------------------------------------------------------------------===//
namespace mlir {
+namespace detail {
+/// Base class for DenseArrayAttr that is instantiated and specialized for each
+/// supported element type below.
+template <typename T>
+class DenseArrayAttr : public DenseArrayBaseAttr {
+public:
+ using DenseArrayBaseAttr::DenseArrayBaseAttr;
+
+ /// Implicit conversion to ArrayRef<T>.
+ operator ArrayRef<T>() const;
+ ArrayRef<T> asArrayRef() { return ArrayRef<T>{*this}; }
+
+ /// Builder from ArrayRef<T>.
+ static DenseArrayAttr get(MLIRContext *context, ArrayRef<T> content);
+
+ /// Print the short form `[42, 100, -1]` without any type prefix.
+ void print(AsmPrinter &printer) const;
+ void print(raw_ostream &os) const;
+ /// Print the short form `42, 100, -1` without any braces or type prefix.
+ void printWithoutBraces(raw_ostream &os) const;
+
+ /// Parse the short form `[42, 100, -1]` without any type prefix.
+ static Attribute parse(AsmParser &parser, Type odsType);
+
+ /// Parse the short form `42, 100, -1` without any type prefix or braces.
+ static Attribute parseWithoutBraces(AsmParser &parser, Type odsType);
+
+ /// Support for isa<>/cast<>.
+ static bool classof(Attribute attr);
+};
+template <>
+void DenseArrayAttr<int8_t>::printWithoutBraces(raw_ostream &os) const;
+
+extern template class DenseArrayAttr<int8_t>;
+extern template class DenseArrayAttr<int16_t>;
+extern template class DenseArrayAttr<int32_t>;
+extern template class DenseArrayAttr<int64_t>;
+extern template class DenseArrayAttr<float>;
+extern template class DenseArrayAttr<double>;
+} // namespace detail
+
+// Public name for all the supported DenseArrayAttr
+using DenseI8ArrayAttr = detail::DenseArrayAttr<int8_t>;
+using DenseI16ArrayAttr = detail::DenseArrayAttr<int16_t>;
+using DenseI32ArrayAttr = detail::DenseArrayAttr<int32_t>;
+using DenseI64ArrayAttr = detail::DenseArrayAttr<int64_t>;
+using DenseF32ArrayAttr = detail::DenseArrayAttr<float>;
+using DenseF64ArrayAttr = detail::DenseArrayAttr<double>;
+
//===----------------------------------------------------------------------===//
// BoolAttr
//===----------------------------------------------------------------------===//
// DenseIntOrFPElementsAttr
//===----------------------------------------------------------------------===//
+def Builtin_DenseArrayBase : Builtin_Attr<
+ "DenseArrayBase", [ElementsAttrInterface]> {
+ let summary = "A dense array of i8, i16, i32, i64, f32, or f64.";
+ let description = [{
+ A dense array attribute is an attribute that represents a dense array of
+ primitive element types. Contrary to DenseIntOrFPElementsAttr this is a
+ flat unidimensional array which does not have a storage optimization for
+ splat. This allows to expose the raw array through a C++ API as
+ `ArrayRef<T>`. This is the base class attribute, the actual access is
+ intended to be managed through the subclasses `DenseI8ArrayAttr`,
+ `DenseI16ArrayAttr`, `DenseI32ArrayAttr`, `DenseI64ArrayAttr`,
+ `DenseF32ArrayAttr`, and `DenseF64ArrayAttr`.
+
+ Syntax:
+
+ ```
+ dense-array-attribute ::= `[` `:` (integer-type | float-type) tensor-literal `]`
+ ```
+ Examples:
+
+ ```mlir
+ [:i8]
+ [:i32 10, 42]
+ [:f64 42., 12.]
+ ```
+
+ when a specific subclass is used as argument of an operation, the declarative
+ assembly will omit the type and print directly:
+ ```
+ [1, 2, 3]
+ ```
+ }];
+ let parameters = (ins AttributeSelfTypeParameter<"", "ShapedType">:$type,
+ "DenseArrayBaseAttr::EltType":$eltType,
+ ArrayRefParameter<"char">:$elements);
+ let extraClassDeclaration = [{
+ // All possible supported element type.
+ enum class EltType { I8, I16, I32, I64, F32, F64 };
+
+ /// Allow implicit conversion to ElementsAttr.
+ operator ElementsAttr() const {
+ return *this ? cast<ElementsAttr>() : nullptr;
+ }
+
+ /// ElementsAttr implementation.
+ using ContiguousIterableTypesT =
+ std::tuple<int8_t, int16_t, int32_t, int64_t, float, double>;
+ const int8_t *value_begin_impl(OverloadToken<int8_t>) const;
+ const int16_t *value_begin_impl(OverloadToken<int16_t>) const;
+ const int32_t *value_begin_impl(OverloadToken<int32_t>) const;
+ const int64_t *value_begin_impl(OverloadToken<int64_t>) const;
+ const float *value_begin_impl(OverloadToken<float>) const;
+ const double *value_begin_impl(OverloadToken<double>) const;
+
+ /// Methods to support type inquiry through isa, cast, and dyn_cast.
+ EltType getElementType() const;
+ /// Printer for the short form: will dispatch to the appropriate subclass.
+ void print(AsmPrinter &printer) const;
+ void print(raw_ostream &os) const;
+ /// Print the short form `42, 100, -1` without any braces or prefix.
+ void printWithoutBraces(raw_ostream &os) const;
+ }];
+ // Do not generate the storage class, we need to handle custom storage alignment.
+ let genStorageClass = 0;
+ let genAccessors = 0;
+ let skipDefaultBuilders = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// DenseIntOrFPElementsAttr
+//===----------------------------------------------------------------------===//
+
def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr<
"DenseIntOrFPElements", [ElementsAttrInterface], "DenseElementsAttr"
> {
let convertFromStorage = "$_self";
}
+class DenseArrayAttrBase<string denseAttrName, string cppType, string summaryName> :
+ ElementsAttrBase<CPred<"$_self.isa<::mlir::" # denseAttrName # ">()">,
+ summaryName # " dense array attribute"> {
+ let storageType = "::mlir::" # denseAttrName;
+ let returnType = "::llvm::ArrayRef<" # cppType # ">";
+}
+def DenseI8ArrayAttr : DenseArrayAttrBase<"DenseI8ArrayAttr", "int8_t", "i8">;
+def DenseI16ArrayAttr : DenseArrayAttrBase<"DenseI16ArrayAttr", "int16_t", "i16">;
+def DenseI32ArrayAttr : DenseArrayAttrBase<"DenseI32ArrayAttr", "int32_t", "i32">;
+def DenseI64ArrayAttr : DenseArrayAttrBase<"DenseI64ArrayAttr", "int64_t", "i64">;
+def DenseF32ArrayAttr : DenseArrayAttrBase<"DenseF32ArrayAttr", "float", "f32">;
+def DenseF64ArrayAttr : DenseArrayAttrBase<"DenseF64ArrayAttr", "double", "f64">;
+
def IndexElementsAttr
: IntElementsAttrBase<CPred<[{$_self.cast<::mlir::DenseIntElementsAttr>()
.getType()
}
os << '>';
}
-
+ } else if (auto denseArrayAttr = attr.dyn_cast<DenseArrayBaseAttr>()) {
+ typeElision = AttrTypeElision::Must;
+ switch (denseArrayAttr.getElementType()) {
+ case DenseArrayBaseAttr::EltType::I8:
+ os << "[:i8 ";
+ break;
+ case DenseArrayBaseAttr::EltType::I16:
+ os << "[:i16 ";
+ break;
+ case DenseArrayBaseAttr::EltType::I32:
+ os << "[:i32 ";
+ break;
+ case DenseArrayBaseAttr::EltType::I64:
+ os << "[:i64 ";
+ break;
+ case DenseArrayBaseAttr::EltType::F32:
+ os << "[:f32 ";
+ break;
+ case DenseArrayBaseAttr::EltType::F64:
+ os << "[:f64 ";
+ break;
+ }
+ denseArrayAttr.printWithoutBraces(os);
+ os << "]";
} else if (auto locAttr = attr.dyn_cast<LocationAttr>()) {
printLocation(locAttr);
+ } else {
+ llvm::report_fatal_error("Unknown builtin attribute");
}
// Don't print the type if we must elide it, or if it is a None type.
if (typeElision != AttrTypeElision::Must && !attrType.isa<NoneType>()) {
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Types.h"
//===----------------------------------------------------------------------===//
void BuiltinDialect::registerAttributes() {
- addAttributes<AffineMapAttr, ArrayAttr, DenseIntOrFPElementsAttr,
- DenseStringElementsAttr, DictionaryAttr, FloatAttr,
- SymbolRefAttr, IntegerAttr, IntegerSetAttr, OpaqueAttr,
- OpaqueElementsAttr, SparseElementsAttr, StringAttr, TypeAttr,
- UnitAttr>();
+ addAttributes<AffineMapAttr, ArrayAttr, DenseArrayBaseAttr,
+ DenseIntOrFPElementsAttr, DenseStringElementsAttr,
+ DictionaryAttr, FloatAttr, SymbolRefAttr, IntegerAttr,
+ IntegerSetAttr, OpaqueAttr, OpaqueElementsAttr,
+ SparseElementsAttr, StringAttr, TypeAttr, UnitAttr>();
}
//===----------------------------------------------------------------------===//
readBits(getData(), offset + storageWidth, bitWidth)};
}
+//===----------------------------------------------------------------------===//
+// DenseArrayAttr
+//===----------------------------------------------------------------------===//
+
+/// Custom storage to ensure proper memory alignment for the allocation of
+/// DenseArray of any element type.
+struct ::mlir::detail::DenseArrayBaseAttrStorage : public AttributeStorage {
+ using KeyTy = std::tuple<ShapedType, DenseArrayBaseAttr::EltType,
+ ::llvm::ArrayRef<char>>;
+ DenseArrayBaseAttrStorage(ShapedType type,
+ DenseArrayBaseAttr::EltType eltType,
+ ::llvm::ArrayRef<char> elements)
+ : AttributeStorage(type), eltType(eltType), elements(elements) {}
+
+ bool operator==(const KeyTy &tblgenKey) const {
+ return (getType() == std::get<0>(tblgenKey)) &&
+ (eltType == std::get<1>(tblgenKey)) &&
+ (elements == std::get<2>(tblgenKey));
+ }
+
+ static ::llvm::hash_code hashKey(const KeyTy &tblgenKey) {
+ return ::llvm::hash_combine(std::get<0>(tblgenKey), std::get<1>(tblgenKey),
+ std::get<2>(tblgenKey));
+ }
+
+ static DenseArrayBaseAttrStorage *
+ construct(AttributeStorageAllocator &allocator, const KeyTy &tblgenKey) {
+ auto type = std::get<0>(tblgenKey);
+ auto eltType = std::get<1>(tblgenKey);
+ auto elements = std::get<2>(tblgenKey);
+ if (!elements.empty()) {
+ char *alloc = static_cast<char *>(
+ allocator.allocate(elements.size(), alignof(uint64_t)));
+ std::uninitialized_copy(elements.begin(), elements.end(), alloc);
+ elements = ArrayRef<char>(alloc, elements.size());
+ }
+ return new (allocator.allocate<DenseArrayBaseAttrStorage>())
+ DenseArrayBaseAttrStorage(type, eltType, elements);
+ }
+
+ DenseArrayBaseAttr::EltType eltType;
+ ::llvm::ArrayRef<char> elements;
+};
+
+DenseArrayBaseAttr::EltType DenseArrayBaseAttr::getElementType() const {
+ return getImpl()->eltType;
+}
+
+const int8_t *
+DenseArrayBaseAttr::value_begin_impl(OverloadToken<int8_t>) const {
+ return cast<DenseI8ArrayAttr>().asArrayRef().begin();
+}
+const int16_t *
+DenseArrayBaseAttr::value_begin_impl(OverloadToken<int16_t>) const {
+ return cast<DenseI16ArrayAttr>().asArrayRef().begin();
+}
+const int32_t *
+DenseArrayBaseAttr::value_begin_impl(OverloadToken<int32_t>) const {
+ return cast<DenseI32ArrayAttr>().asArrayRef().begin();
+}
+const int64_t *
+DenseArrayBaseAttr::value_begin_impl(OverloadToken<int64_t>) const {
+ return cast<DenseI64ArrayAttr>().asArrayRef().begin();
+}
+const float *DenseArrayBaseAttr::value_begin_impl(OverloadToken<float>) const {
+ return cast<DenseF32ArrayAttr>().asArrayRef().begin();
+}
+const double *
+DenseArrayBaseAttr::value_begin_impl(OverloadToken<double>) const {
+ return cast<DenseF64ArrayAttr>().asArrayRef().begin();
+}
+
+void DenseArrayBaseAttr::print(AsmPrinter &printer) const {
+ print(printer.getStream());
+}
+
+void DenseArrayBaseAttr::printWithoutBraces(raw_ostream &os) const {
+ switch (getElementType()) {
+ case DenseArrayBaseAttr::EltType::I8:
+ this->cast<DenseI8ArrayAttr>().printWithoutBraces(os);
+ return;
+ case DenseArrayBaseAttr::EltType::I16:
+ this->cast<DenseI16ArrayAttr>().printWithoutBraces(os);
+ return;
+ case DenseArrayBaseAttr::EltType::I32:
+ this->cast<DenseI32ArrayAttr>().printWithoutBraces(os);
+ return;
+ case DenseArrayBaseAttr::EltType::I64:
+ this->cast<DenseI64ArrayAttr>().printWithoutBraces(os);
+ return;
+ case DenseArrayBaseAttr::EltType::F32:
+ this->cast<DenseF32ArrayAttr>().printWithoutBraces(os);
+ return;
+ case DenseArrayBaseAttr::EltType::F64:
+ this->cast<DenseF64ArrayAttr>().printWithoutBraces(os);
+ return;
+ }
+ llvm_unreachable("<unknown DenseArrayBaseAttr>");
+}
+
+void DenseArrayBaseAttr::print(raw_ostream &os) const {
+ os << "[";
+ printWithoutBraces(os);
+ os << "]";
+}
+
+template <typename T>
+void DenseArrayAttr<T>::print(AsmPrinter &printer) const {
+ print(printer.getStream());
+}
+
+template <typename T>
+void DenseArrayAttr<T>::printWithoutBraces(raw_ostream &os) const {
+ ArrayRef<T> values{*this};
+ llvm::interleaveComma(values, os);
+}
+
+/// Specialization for int8_t for forcing printing as number instead of chars.
+template <>
+void DenseArrayAttr<int8_t>::printWithoutBraces(raw_ostream &os) const {
+ ArrayRef<int8_t> values{*this};
+ llvm::interleaveComma(values, os, [&](int64_t v) { os << v; });
+}
+
+template <typename T>
+void DenseArrayAttr<T>::print(raw_ostream &os) const {
+ os << "[";
+ printWithoutBraces(os);
+ os << "]";
+}
+
+/// Parse a single element: generic template for int types, specialized for
+/// floating points below.
+template <typename T>
+static ParseResult parseDenseArrayAttrElt(AsmParser &parser, T &value) {
+ return parser.parseInteger(value);
+}
+
+template <>
+ParseResult parseDenseArrayAttrElt<float>(AsmParser &parser, float &value) {
+ double doubleVal;
+ if (parser.parseFloat(doubleVal))
+ return failure();
+ value = doubleVal;
+ return success();
+}
+
+template <>
+ParseResult parseDenseArrayAttrElt<double>(AsmParser &parser, double &value) {
+ return parser.parseFloat(value);
+}
+
+/// Parse a DenseArrayAttr without the braces: `1, 2, 3`
+template <typename T>
+Attribute DenseArrayAttr<T>::parseWithoutBraces(AsmParser &parser,
+ Type odsType) {
+ SmallVector<T> data;
+ if (failed(parser.parseCommaSeparatedList([&]() {
+ T value;
+ if (parseDenseArrayAttrElt(parser, value))
+ return failure();
+ data.push_back(value);
+ return success();
+ })))
+ return {};
+ return get(parser.getContext(), data);
+}
+
+/// Parse a DenseArrayAttr: `[ 1, 2, 3 ]`
+template <typename T>
+Attribute DenseArrayAttr<T>::parse(AsmParser &parser, Type odsType) {
+ if (parser.parseLSquare())
+ return {};
+ Attribute result = parseWithoutBraces(parser, odsType);
+ if (parser.parseRSquare())
+ return {};
+ return result;
+}
+
+/// Conversion from DenseArrayAttr<T> to ArrayRef<T>.
+template <typename T>
+DenseArrayAttr<T>::operator ArrayRef<T>() const {
+ ArrayRef<char> raw = getImpl()->elements;
+ assert((raw.size() % sizeof(T)) == 0);
+ return ArrayRef<T>(reinterpret_cast<const T *>(raw.data()),
+ raw.size() / sizeof(T));
+}
+
+namespace {
+/// Mapping from C++ element type to MLIR DenseArrayAttr internals.
+template <typename T>
+struct denseArrayAttrEltTypeBuilder;
+template <>
+struct denseArrayAttrEltTypeBuilder<int8_t> {
+ constexpr static auto eltType = DenseArrayBaseAttr::EltType::I8;
+ static ShapedType getShapedType(MLIRContext *context, int64_t shape) {
+ return VectorType::get(shape, IntegerType::get(context, 8));
+ }
+};
+template <>
+struct denseArrayAttrEltTypeBuilder<int16_t> {
+ constexpr static auto eltType = DenseArrayBaseAttr::EltType::I16;
+ static ShapedType getShapedType(MLIRContext *context, int64_t shape) {
+ return VectorType::get(shape, IntegerType::get(context, 16));
+ }
+};
+template <>
+struct denseArrayAttrEltTypeBuilder<int32_t> {
+ constexpr static auto eltType = DenseArrayBaseAttr::EltType::I32;
+ static ShapedType getShapedType(MLIRContext *context, int64_t shape) {
+ return VectorType::get(shape, IntegerType::get(context, 32));
+ }
+};
+template <>
+struct denseArrayAttrEltTypeBuilder<int64_t> {
+ constexpr static auto eltType = DenseArrayBaseAttr::EltType::I64;
+ static ShapedType getShapedType(MLIRContext *context, int64_t shape) {
+ return VectorType::get(shape, IntegerType::get(context, 64));
+ }
+};
+template <>
+struct denseArrayAttrEltTypeBuilder<float> {
+ constexpr static auto eltType = DenseArrayBaseAttr::EltType::F32;
+ static ShapedType getShapedType(MLIRContext *context, int64_t shape) {
+ return VectorType::get(shape, Float32Type::get(context));
+ }
+};
+template <>
+struct denseArrayAttrEltTypeBuilder<double> {
+ constexpr static auto eltType = DenseArrayBaseAttr::EltType::F64;
+ static ShapedType getShapedType(MLIRContext *context, int64_t shape) {
+ return VectorType::get(shape, Float64Type::get(context));
+ }
+};
+} // namespace
+
+/// Builds a DenseArrayAttr<T> from an ArrayRef<T>.
+template <typename T>
+DenseArrayAttr<T> DenseArrayAttr<T>::get(MLIRContext *context,
+ ArrayRef<T> content) {
+ auto shapedType =
+ denseArrayAttrEltTypeBuilder<T>::getShapedType(context, content.size());
+ auto eltType = denseArrayAttrEltTypeBuilder<T>::eltType;
+ auto rawArray = ArrayRef<char>(reinterpret_cast<const char *>(content.data()),
+ content.size() * sizeof(T));
+ return Base::get(context, shapedType, eltType, rawArray)
+ .template cast<DenseArrayAttr<T>>();
+}
+
+template <typename T>
+bool DenseArrayAttr<T>::classof(Attribute attr) {
+ return attr.isa<DenseArrayBaseAttr>() &&
+ attr.cast<DenseArrayBaseAttr>().getElementType() ==
+ denseArrayAttrEltTypeBuilder<T>::eltType;
+}
+
+namespace mlir {
+namespace detail {
+// Explicit instantiation for all the supported DenseArrayAttr.
+template class DenseArrayAttr<int8_t>;
+template class DenseArrayAttr<int16_t>;
+template class DenseArrayAttr<int32_t>;
+template class DenseArrayAttr<int64_t>;
+template class DenseArrayAttr<float>;
+template class DenseArrayAttr<double>;
+} // namespace detail
+} // namespace mlir
+
//===----------------------------------------------------------------------===//
// DenseElementsAttr
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
#include "Parser.h"
+
+#include "AsmParserImpl.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/Parser/AsmParserState.h"
#include "llvm/ADT/StringExtras.h"
/// | float-literal (`:` float-type)?
/// | string-literal (`:` type)?
/// | type
+/// | `[` `:` (integer-type | float-type) tensor-literal `]`
/// | `[` (attribute-value (`,` attribute-value)*)? `]`
/// | `{` (attribute-entry (`,` attribute-entry)*)? `}`
/// | symbol-ref-id (`::` symbol-ref-id)*
// Parse an array attribute.
case Token::l_square: {
+ consumeToken(Token::l_square);
+ if (consumeIf(Token::colon))
+ return parseDenseArrayAttr();
SmallVector<Attribute, 4> elements;
auto parseElt = [&]() -> ParseResult {
elements.push_back(parseAttribute());
return elements.back() ? success() : failure();
};
- if (parseCommaSeparatedList(Delimiter::Square, parseElt))
+ if (parseCommaSeparatedListUntil(Token::r_square, parseElt))
return nullptr;
return builder.getArrayAttr(elements);
}
// ElementsAttr Parser
//===----------------------------------------------------------------------===//
+namespace {
+/// This class provides an implementation of AsmParser, allowing to call back
+/// into the libMLIRIR-provided APIs for invoking attribute parsing code defined
+/// in libMLIRIR.
+class CustomAsmParser : public AsmParserImpl<AsmParser> {
+public:
+ CustomAsmParser(Parser &parser)
+ : AsmParserImpl<AsmParser>(parser.getToken().getLoc(), parser) {}
+};
+} // namespace
+
+/// Parse a dense array attribute.
+Attribute Parser::parseDenseArrayAttr() {
+ auto typeLoc = getToken().getLoc();
+ auto type = parseType();
+ if (!type)
+ return {};
+ CustomAsmParser parser(*this);
+ Attribute result;
+ if (auto intType = type.dyn_cast<IntegerType>()) {
+ switch (type.getIntOrFloatBitWidth()) {
+ case 8:
+ result = DenseI8ArrayAttr::parseWithoutBraces(parser, Type{});
+ break;
+ case 16:
+ result = DenseI16ArrayAttr::parseWithoutBraces(parser, Type{});
+ break;
+ case 32:
+ result = DenseI32ArrayAttr::parseWithoutBraces(parser, Type{});
+ break;
+ case 64:
+ result = DenseI64ArrayAttr::parseWithoutBraces(parser, Type{});
+ break;
+ default:
+ emitError(typeLoc, "expected i8, i16, i32, or i64 but got: ") << type;
+ return {};
+ }
+ } else if (auto floatType = type.dyn_cast<FloatType>()) {
+ switch (type.getIntOrFloatBitWidth()) {
+ case 32:
+ result = DenseF32ArrayAttr::parseWithoutBraces(parser, Type{});
+ break;
+ case 64:
+ result = DenseF64ArrayAttr::parseWithoutBraces(parser, Type{});
+ break;
+ default:
+ emitError(typeLoc, "expected f32 or f64 but got: ") << type;
+ return {};
+ }
+ } else {
+ emitError(typeLoc, "expected integer or float type, got: ") << type;
+ return {};
+ }
+ if (!consumeIf(Token::r_square)) {
+ emitError("expected ']' to close an array attribute");
+ return {};
+ }
+ return result;
+}
+
/// Parse a dense elements attribute.
Attribute Parser::parseDenseElementsAttr(Type attrType) {
auto attribLoc = getToken().getLoc();
Attribute parseDenseElementsAttr(Type attrType);
ShapedType parseElementsLiteralType(Type type);
+ /// Parse a DenseArrayAttr.
+ Attribute parseDenseArrayAttr();
+
/// Parse a sparse elements attribute.
Attribute parseSparseElementsAttr(Type attrType);
return
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Test DenseArrayAttr
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @dense_array_attr
+func.func @dense_array_attr() attributes{
+// CHECK-SAME: f32attr = [:f32 1.024000e+03, 4.530000e+02, -6.435000e+03],
+ f32attr = [:f32 1024., 453., -6435.],
+// CHECK-SAME: f64attr = [:f64 -1.420000e+02],
+ f64attr = [:f64 -142.],
+// CHECK-SAME: i16attr = [:i16 3, 5, -4, 10],
+ i16attr = [:i16 3, 5, -4, 10],
+// CHECK-SAME: i32attr = [:i32 1024, 453, -6435],
+ i32attr = [:i32 1024, 453, -6435],
+// CHECK-SAME: i64attr = [:i64 -142],
+ i64attr = [:i64 -142],
+// CHECK-SAME: i8attr = [:i8 1, -2, 3]
+ i8attr = [:i8 1, -2, 3]
+ } {
+// CHECK: test.dense_array_attr
+ test.dense_array_attr
+// CHECK-SAME: i8attr = [1, -2, 3]
+ i8attr = [1, -2, 3]
+// CHECK-SAME: i16attr = [3, 5, -4, 10]
+ i16attr = [3, 5, -4, 10]
+// CHECK-SAME: i32attr = [1024, 453, -6435]
+ i32attr = [1024, 453, -6435]
+// CHECK-SAME: i64attr = [-142]
+ i64attr = [-142]
+// CHECK-SAME: f32attr = [1.024000e+03, 4.530000e+02, -6.435000e+03]
+ f32attr = [1024., 453., -6435.]
+// CHECK-SAME: f64attr = [-1.420000e+02]
+ f64attr = [-142.]
+ return
+}
+
// -----
//===----------------------------------------------------------------------===//
// This tests that the abstract iteration of ElementsAttr works properly, and
// is properly failable when necessary.
+// expected-error@below {{Test iterating `int64_t`: unable to iterate type}}
// expected-error@below {{Test iterating `uint64_t`: 10, 11, 12, 13, 14}}
// expected-error@below {{Test iterating `APInt`: 10, 11, 12, 13, 14}}
// expected-error@below {{Test iterating `IntegerAttr`: 10 : i64, 11 : i64, 12 : i64, 13 : i64, 14 : i64}}
arith.constant #test.i64_elements<[10, 11, 12, 13, 14]> : tensor<5xi64>
+// expected-error@below {{Test iterating `int64_t`: 10, 11, 12, 13, 14}}
// expected-error@below {{Test iterating `uint64_t`: 10, 11, 12, 13, 14}}
// expected-error@below {{Test iterating `APInt`: 10, 11, 12, 13, 14}}
// expected-error@below {{Test iterating `IntegerAttr`: 10 : i64, 11 : i64, 12 : i64, 13 : i64, 14 : i64}}
arith.constant dense<[10, 11, 12, 13, 14]> : tensor<5xi64>
+// expected-error@below {{Test iterating `int64_t`: unable to iterate type}}
// expected-error@below {{Test iterating `uint64_t`: unable to iterate type}}
// expected-error@below {{Test iterating `APInt`: unable to iterate type}}
// expected-error@below {{Test iterating `IntegerAttr`: unable to iterate type}}
arith.constant opaque<"_", "0xDEADBEEF"> : tensor<5xi64>
// Check that we don't crash on empty element attributes.
+// expected-error@below {{Test iterating `int64_t`: }}
// expected-error@below {{Test iterating `uint64_t`: }}
// expected-error@below {{Test iterating `APInt`: }}
// expected-error@below {{Test iterating `IntegerAttr`: }}
arith.constant dense<> : tensor<0xi64>
+
+// expected-error@below {{Test iterating `int8_t`: 10, 11, -12, 13, 14}}
+arith.constant [:i8 10, 11, -12, 13, 14]
+// expected-error@below {{Test iterating `int16_t`: 10, 11, -12, 13, 14}}
+arith.constant [:i16 10, 11, -12, 13, 14]
+// expected-error@below {{Test iterating `int32_t`: 10, 11, -12, 13, 14}}
+arith.constant [:i32 10, 11, -12, 13, 14]
+// expected-error@below {{Test iterating `int64_t`: 10, 11, -12, 13, 14}}
+arith.constant [:i64 10, 11, -12, 13, 14]
+// expected-error@below {{Test iterating `float`: 10.00, 11.00, -12.00, 13.00, 14.00}}
+arith.constant [:f32 10., 11., -12., 13., 14.]
+// expected-error@below {{Test iterating `double`: 10.00, 11.00, -12.00, 13.00, 14.00}}
+arith.constant [:f64 10., 11., -12., 13., 14.]
// -----
-// expected-error@+1 {{expected ']'}}
+// expected-error@+1 {{expected ',' or ']'}}
"f"() { b = [@m:
// -----
);
}
+def DenseArrayAttrOp : TEST_Op<"dense_array_attr"> {
+ let arguments = (ins
+ DenseI8ArrayAttr:$i8attr,
+ DenseI16ArrayAttr:$i16attr,
+ DenseI32ArrayAttr:$i32attr,
+ DenseI64ArrayAttr:$i64attr,
+ DenseF32ArrayAttr:$f32attr,
+ DenseF64ArrayAttr:$f64attr
+ );
+ let assemblyFormat = [{
+ `i8attr` `=` $i8attr `i16attr` `=` $i16attr `i32attr` `=` $i32attr
+ `i64attr` `=` $i64attr `f32attr` `=` $f32attr `f64attr` `=` $f64attr
+ attr-dict
+ }];
+}
+
//===----------------------------------------------------------------------===//
// Test Enum Attributes
//===----------------------------------------------------------------------===//
using namespace mlir;
using namespace test;
+// Helper to print one scalar value, force int8_t to print as integer instead of
+// char.
+template <typename T>
+static void printOneElement(InFlightDiagnostic &os, T value) {
+ os << llvm::formatv("{0}", value).str();
+}
+template <>
+void printOneElement<int8_t>(InFlightDiagnostic &os, int8_t value) {
+ os << llvm::formatv("{0}", static_cast<int64_t>(value)).str();
+}
+
namespace {
struct TestElementsAttrInterface
: public PassWrapper<TestElementsAttrInterface, OperationPass<ModuleOp>> {
auto elementsAttr = attr.getValue().dyn_cast<ElementsAttr>();
if (!elementsAttr)
continue;
+ if (auto concreteAttr =
+ attr.getValue().dyn_cast<DenseArrayBaseAttr>()) {
+ switch (concreteAttr.getElementType()) {
+ case DenseArrayBaseAttr::EltType::I8:
+ testElementsAttrIteration<int8_t>(op, elementsAttr, "int8_t");
+ break;
+ case DenseArrayBaseAttr::EltType::I16:
+ testElementsAttrIteration<int16_t>(op, elementsAttr, "int16_t");
+ break;
+ case DenseArrayBaseAttr::EltType::I32:
+ testElementsAttrIteration<int32_t>(op, elementsAttr, "int32_t");
+ break;
+ case DenseArrayBaseAttr::EltType::I64:
+ testElementsAttrIteration<int64_t>(op, elementsAttr, "int64_t");
+ break;
+ case DenseArrayBaseAttr::EltType::F32:
+ testElementsAttrIteration<float>(op, elementsAttr, "float");
+ break;
+ case DenseArrayBaseAttr::EltType::F64:
+ testElementsAttrIteration<double>(op, elementsAttr, "double");
+ break;
+ }
+ continue;
+ }
+ testElementsAttrIteration<int64_t>(op, elementsAttr, "int64_t");
testElementsAttrIteration<uint64_t>(op, elementsAttr, "uint64_t");
testElementsAttrIteration<APInt>(op, elementsAttr, "APInt");
testElementsAttrIteration<IntegerAttr>(op, elementsAttr, "IntegerAttr");
return;
}
- llvm::interleaveComma(*values, diag, [&](T value) {
- diag << llvm::formatv("{0}", value).str();
- });
+ llvm::interleaveComma(*values, diag,
+ [&](T value) { printOneElement(diag, value); });
}
};
} // namespace