ParseResult parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes);
/// Parse an extended attribute.
- Attribute parseExtendedAttribute(Type type);
+ Attribute parseExtendedAttr(Type type);
+
+ /// Parse a float attribute.
+ Attribute parseFloatAttr(Type type, bool isNegative);
+
+ /// Parse an integer attribute.
+ Attribute parseIntegerAttr(Type type, bool isSigned);
+
+ /// Parse an opaque elements attribute.
+ Attribute parseOpaqueElementsAttr();
+
+ /// Parse a sparse elements attribute.
+ Attribute parseSparseElementsAttr();
+
+ /// Parse a splat elements attribute.
+ Attribute parseSplatElementsAttr();
/// Parse a dense elements attribute.
- DenseElementsAttr parseDenseElementsAttr(ShapedType type);
+ Attribute parseDenseElementsAttr();
DenseElementsAttr parseDenseElementsAttrAsTensor(Type eltType);
ShapedType parseElementsLiteralType();
///
Attribute Parser::parseAttribute(Type type) {
switch (getToken().getKind()) {
- case Token::hash_identifier:
- return parseExtendedAttribute(type);
- case Token::kw_unit:
- consumeToken(Token::kw_unit);
- return builder.getUnitAttr();
-
- case Token::kw_true:
- consumeToken(Token::kw_true);
- return builder.getBoolAttr(true);
- case Token::kw_false:
- consumeToken(Token::kw_false);
- return builder.getBoolAttr(false);
-
- case Token::floatliteral: {
- auto val = getToken().getFloatingPointValue();
- if (!val.hasValue())
- return (emitError("floating point value too large for attribute"),
- nullptr);
- auto valTok = getToken().getLoc();
- consumeToken(Token::floatliteral);
- if (!type) {
- if (consumeIf(Token::colon)) {
- if (!(type = parseType()))
- return nullptr;
- } else {
- // Default to F64 when no type is specified.
- type = builder.getF64Type();
- }
- }
- if (!type.isa<FloatType>())
- return (emitError("floating point value not valid for specified type"),
- nullptr);
- return FloatAttr::getChecked(type, val.getValue(),
- getEncodedSourceLocation(valTok));
- }
- case Token::integer: {
- auto val = getToken().getUInt64IntegerValue();
- if (!val.hasValue() || (int64_t)val.getValue() < 0)
- return (emitError("integer constant out of range for attribute"),
- nullptr);
- consumeToken(Token::integer);
- if (!type) {
- if (consumeIf(Token::colon)) {
- if (!(type = parseType()))
- return nullptr;
- } else {
- // Default to i64 if not type is specified.
- type = builder.getIntegerType(64);
- }
- }
- if (!type.isIntOrIndex())
- return (emitError("integer value not valid for specified type"), nullptr);
- int width = type.isIndex() ? 64 : type.getIntOrFloatBitWidth();
- APInt apInt(width, val.getValue());
- if (apInt != *val)
- return emitError("integer constant out of range for attribute"), nullptr;
- return builder.getIntegerAttr(type, apInt);
- }
-
- case Token::minus: {
- consumeToken(Token::minus);
- if (getToken().is(Token::integer)) {
- auto val = getToken().getUInt64IntegerValue();
- if (!val.hasValue() || (int64_t)-val.getValue() >= 0)
- return (emitError("integer constant out of range for attribute"),
- nullptr);
- consumeToken(Token::integer);
- if (!type) {
- if (consumeIf(Token::colon)) {
- if (!(type = parseType()))
- return nullptr;
- } else {
- // Default to i64 if not type is specified.
- type = builder.getIntegerType(64);
- }
- }
- if (!type.isIntOrIndex())
- return (emitError("integer value not valid for type"), nullptr);
- int width = type.isIndex() ? 64 : type.getIntOrFloatBitWidth();
- APInt apInt(width, *val, /*isSigned=*/true);
- if (apInt != *val)
- return (emitError("integer constant out of range for attribute"),
- nullptr);
- return builder.getIntegerAttr(type, -apInt);
- }
- if (getToken().is(Token::floatliteral)) {
- auto val = getToken().getFloatingPointValue();
- if (!val.hasValue())
- return (emitError("floating point value too large for attribute"),
- nullptr);
- auto valTok = getToken().getLoc();
- consumeToken(Token::floatliteral);
- if (!type) {
- if (consumeIf(Token::colon)) {
- if (!(type = parseType()))
- return nullptr;
- } else {
- // Default to F64 when no type is specified.
- type = builder.getF64Type();
- }
- }
- if (!type.isa<FloatType>())
- return (emitError("floating point value not valid for type"), nullptr);
- return FloatAttr::getChecked(type, -val.getValue(),
- getEncodedSourceLocation(valTok));
- }
-
- return (emitError("expected constant integer or floating point value"),
- nullptr);
- }
-
- case Token::string: {
- auto val = getToken().getStringValue();
- consumeToken(Token::string);
- return builder.getStringAttr(val);
- }
-
- case Token::l_brace: {
- SmallVector<NamedAttribute, 4> elements;
- if (parseAttributeDict(elements))
- return nullptr;
- return builder.getDictionaryAttr(elements);
- }
- case Token::l_square: {
- consumeToken(Token::l_square);
- SmallVector<Attribute, 4> elements;
-
- auto parseElt = [&]() -> ParseResult {
- elements.push_back(parseAttribute());
- return elements.back() ? success() : failure();
- };
-
- if (parseCommaSeparatedListUntil(Token::r_square, parseElt))
- return nullptr;
- return builder.getArrayAttr(elements);
- }
+ // Parse an AffineMap or IntegerSet attribute.
case Token::l_paren: {
// Try to parse an affine map or an integer set reference.
AffineMap map;
return builder.getIntegerSetAttr(set);
}
- case Token::at_identifier: {
- auto nameStr = getTokenSpelling();
- consumeToken(Token::at_identifier);
- return builder.getFunctionAttr(nameStr.drop_front());
- }
- case Token::kw_opaque: {
- consumeToken(Token::kw_opaque);
- if (parseToken(Token::less, "expected '<' after 'opaque'"))
- return nullptr;
-
- if (getToken().getKind() != Token::string)
- return (emitError("expected dialect namespace"), nullptr);
- auto name = getToken().getStringValue();
- auto *dialect = builder.getContext()->getRegisteredDialect(name);
- // TODO(shpeisman): Allow for having an unknown dialect on an opaque
- // attribute. Otherwise, it can't be roundtripped without having the dialect
- // registered.
- if (!dialect)
- return (emitError("no registered dialect with namespace '" + name + "'"),
- nullptr);
-
- consumeToken(Token::string);
- if (parseToken(Token::comma, "expected ','"))
- return nullptr;
-
- auto type = parseElementsLiteralType();
- if (!type)
- return nullptr;
+ // Parse an array attribute.
+ case Token::l_square: {
+ consumeToken(Token::l_square);
- if (parseToken(Token::comma, "expected ',' after elements literal type"))
- return nullptr;
+ SmallVector<Attribute, 4> elements;
+ auto parseElt = [&]() -> ParseResult {
+ elements.push_back(parseAttribute());
+ return elements.back() ? success() : failure();
+ };
- if (getToken().getKind() != Token::string)
- return (emitError("opaque string should start with '0x'"), nullptr);
- auto val = getToken().getStringValue();
- if (val.size() < 2 || val[0] != '0' || val[1] != 'x')
- return (emitError("opaque string should start with '0x'"), nullptr);
- val = val.substr(2);
- if (!std::all_of(val.begin(), val.end(),
- [](char c) { return llvm::isHexDigit(c); })) {
- return (emitError("opaque string only contains hex digits"), nullptr);
- }
- consumeToken(Token::string);
- if (parseToken(Token::greater, "expected '>'"))
+ if (parseCommaSeparatedListUntil(Token::r_square, parseElt))
return nullptr;
- return builder.getOpaqueElementsAttr(dialect, type, llvm::fromHex(val));
+ return builder.getArrayAttr(elements);
}
- case Token::kw_splat: {
- consumeToken(Token::kw_splat);
- if (parseToken(Token::less, "expected '<' after 'splat'"))
- return nullptr;
- auto type = parseElementsLiteralType();
- if (!type)
- return nullptr;
- if (parseToken(Token::comma, "expected ',' after elements literal type"))
- return nullptr;
- switch (getToken().getKind()) {
- case Token::floatliteral:
- case Token::integer:
- case Token::kw_false:
- case Token::kw_true:
- case Token::minus: {
- auto scalar = parseAttribute(type.getElementType());
- if (!scalar)
- return nullptr;
- if (parseToken(Token::greater, "expected '>'"))
- return nullptr;
- return builder.getSplatElementsAttr(type, scalar);
- }
- default:
- return (emitError("expected scalar constant inside tensor literal"),
- nullptr);
- }
- }
- case Token::kw_dense: {
- consumeToken(Token::kw_dense);
- if (parseToken(Token::less, "expected '<' after 'dense'"))
- return nullptr;
+ // Parse a boolean attribute.
+ case Token::kw_false:
+ consumeToken(Token::kw_false);
+ return builder.getBoolAttr(false);
+ case Token::kw_true:
+ consumeToken(Token::kw_true);
+ return builder.getBoolAttr(true);
- auto type = parseElementsLiteralType();
- if (!type)
- return nullptr;
+ // Parse a dense elements attribute.
+ case Token::kw_dense:
+ return parseDenseElementsAttr();
- if (parseToken(Token::comma, "expected ',' after elements literal type"))
+ // Parse a dictionary attribute.
+ case Token::l_brace: {
+ SmallVector<NamedAttribute, 4> elements;
+ if (parseAttributeDict(elements))
return nullptr;
+ return builder.getDictionaryAttr(elements);
+ }
- auto attr = parseDenseElementsAttr(type);
- if (!attr)
- return nullptr;
+ // Parse an extended attribute, i.e. alias or dialect attribute.
+ case Token::hash_identifier:
+ return parseExtendedAttr(type);
- if (parseToken(Token::greater, "expected '>'"))
- return nullptr;
+ // Parse floating point and integer attributes.
+ case Token::floatliteral:
+ return parseFloatAttr(type, /*isNegative=*/false);
+ case Token::integer:
+ return parseIntegerAttr(type, /*isSigned=*/false);
+ case Token::minus: {
+ consumeToken(Token::minus);
+ if (getToken().is(Token::integer))
+ return parseIntegerAttr(type, /*isSigned=*/true);
+ if (getToken().is(Token::floatliteral))
+ return parseFloatAttr(type, /*isNegative=*/true);
- return attr;
+ return (emitError("expected constant integer or floating point value"),
+ nullptr);
}
- case Token::kw_sparse: {
- consumeToken(Token::kw_sparse);
- if (parseToken(Token::less, "Expected '<' after 'sparse'"))
- return nullptr;
- auto type = parseElementsLiteralType();
- if (!type)
- return nullptr;
+ // Parse a function attribute.
+ case Token::at_identifier: {
+ auto nameStr = getTokenSpelling();
+ consumeToken(Token::at_identifier);
+ return builder.getFunctionAttr(nameStr.drop_front());
+ }
- if (parseToken(Token::comma, "expected ',' after elements literal type"))
- return nullptr;
+ // Parse an opaque elements attribute.
+ case Token::kw_opaque:
+ return parseOpaqueElementsAttr();
- switch (getToken().getKind()) {
- case Token::l_square: {
- /// Parse indices
- auto indicesEltType = builder.getIntegerType(64);
- auto indices = parseDenseElementsAttrAsTensor(indicesEltType);
- if (!indices)
- return nullptr;
-
- if (parseToken(Token::comma, "expected ','"))
- return nullptr;
+ // Parse a sparse elements attribute.
+ case Token::kw_sparse:
+ return parseSparseElementsAttr();
- /// Parse values.
- auto valuesEltType = type.getElementType();
- auto values = parseDenseElementsAttrAsTensor(valuesEltType);
- if (!values)
- return nullptr;
+ // Parse a splat elements attribute.
+ case Token::kw_splat:
+ return parseSplatElementsAttr();
- /// Sanity check.
- auto valuesType = values.getType();
- if (valuesType.getRank() != 1) {
- return (emitError("expected 1-d tensor for values"), nullptr);
- }
- auto indicesType = indices.getType();
- auto sameShape = (indicesType.getRank() == 1) ||
- (type.getRank() == indicesType.getDimSize(1));
- auto sameElementNum =
- indicesType.getDimSize(0) == valuesType.getDimSize(0);
- if (!sameShape || !sameElementNum) {
- emitError() << "expected shape ([" << type.getShape()
- << "]); inferred shape of indices literal (["
- << indicesType.getShape()
- << "]); inferred shape of values literal (["
- << valuesType.getShape() << "])";
- return nullptr;
- }
+ // Parse a string attribute.
+ case Token::string: {
+ auto val = getToken().getStringValue();
+ consumeToken(Token::string);
+ return builder.getStringAttr(val);
+ }
- if (parseToken(Token::greater, "expected '>'"))
- return nullptr;
+ // Parse a 'unit' attribute.
+ case Token::kw_unit:
+ consumeToken(Token::kw_unit);
+ return builder.getUnitAttr();
- // Build the sparse elements attribute by the indices and values.
- return builder.getSparseElementsAttr(
- type, indices.cast<DenseIntElementsAttr>(), values);
- }
- default:
- return (emitError("expected '[' to start sparse tensor literal"),
- nullptr);
- }
- return (emitError("expected elements literal has a tensor or vector type"),
- nullptr);
- }
- default: {
+ default:
+ // Parse a type attribute.
if (Type type = parseType())
return builder.getTypeAttr(type);
return nullptr;
}
- }
}
/// Attribute dictionary.
/// dialect-attribute ::= `#` alias-name pretty-dialect-sym-body?
/// attribute-alias ::= `#` alias-name
///
-Attribute Parser::parseExtendedAttribute(Type type) {
+Attribute Parser::parseExtendedAttr(Type type) {
Attribute attr = parseExtendedSymbol<Attribute>(
*this, Token::hash_identifier, state.attributeAliasDefinitions,
[&](StringRef dialectName, StringRef symbolData,
return attr;
}
+/// Parse a float attribute.
+Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
+ auto val = getToken().getFloatingPointValue();
+ if (!val.hasValue())
+ return (emitError("floating point value too large for attribute"), nullptr);
+ auto valTok = getToken().getLoc();
+ consumeToken(Token::floatliteral);
+ if (!type) {
+ // Default to F64 when no type is specified.
+ if (!consumeIf(Token::colon))
+ type = builder.getF64Type();
+ else if (!(type = parseType()))
+ return nullptr;
+ }
+ if (!type.isa<FloatType>())
+ return (emitError("floating point value not valid for specified type"),
+ nullptr);
+ return FloatAttr::getChecked(type,
+ isNegative ? -val.getValue() : val.getValue(),
+ getEncodedSourceLocation(valTok));
+}
+
+/// Parse an integer attribute.
+Attribute Parser::parseIntegerAttr(Type type, bool isSigned) {
+ auto val = getToken().getUInt64IntegerValue();
+ if (!val.hasValue() ||
+ (isSigned ? (int64_t)-val.getValue() >= 0 : (int64_t)val.getValue() < 0))
+ return (emitError("integer constant out of range for attribute"), nullptr);
+ consumeToken(Token::integer);
+ if (!type) {
+ // Default to i64 if not type is specified.
+ if (!consumeIf(Token::colon))
+ type = builder.getIntegerType(64);
+ else if (!(type = parseType()))
+ return nullptr;
+ }
+ if (!type.isIntOrIndex())
+ return (emitError("integer value not valid for specified type"), nullptr);
+
+ int width = type.isIndex() ? 64 : type.getIntOrFloatBitWidth();
+ APInt apInt(width, *val, isSigned);
+ if (apInt != *val)
+ return (emitError("integer constant out of range for attribute"), nullptr);
+ return builder.getIntegerAttr(type, isSigned ? -apInt : apInt);
+}
+
+/// Parse an opaque elements attribute.
+Attribute Parser::parseOpaqueElementsAttr() {
+ consumeToken(Token::kw_opaque);
+ if (parseToken(Token::less, "expected '<' after 'opaque'"))
+ return nullptr;
+
+ if (getToken().isNot(Token::string))
+ return (emitError("expected dialect namespace"), nullptr);
+
+ auto name = getToken().getStringValue();
+ auto *dialect = builder.getContext()->getRegisteredDialect(name);
+ // TODO(shpeisman): Allow for having an unknown dialect on an opaque
+ // attribute. Otherwise, it can't be roundtripped without having the dialect
+ // registered.
+ if (!dialect)
+ return (emitError("no registered dialect with namespace '" + name + "'"),
+ nullptr);
+
+ consumeToken(Token::string);
+ if (parseToken(Token::comma, "expected ','"))
+ return nullptr;
+
+ auto type = parseElementsLiteralType();
+ if (!type)
+ return nullptr;
+
+ if (parseToken(Token::comma, "expected ',' after elements literal type"))
+ return nullptr;
+ if (getToken().getKind() != Token::string)
+ return (emitError("opaque string should start with '0x'"), nullptr);
+
+ auto val = getToken().getStringValue();
+ if (val.size() < 2 || val[0] != '0' || val[1] != 'x')
+ return (emitError("opaque string should start with '0x'"), nullptr);
+
+ val = val.substr(2);
+ if (!llvm::all_of(val, llvm::isHexDigit))
+ return (emitError("opaque string only contains hex digits"), nullptr);
+
+ consumeToken(Token::string);
+ if (parseToken(Token::greater, "expected '>'"))
+ return nullptr;
+ return builder.getOpaqueElementsAttr(dialect, type, llvm::fromHex(val));
+}
+
+/// Parse a sparse elements attribute.
+Attribute Parser::parseSparseElementsAttr() {
+ consumeToken(Token::kw_sparse);
+ if (parseToken(Token::less, "Expected '<' after 'sparse'"))
+ return nullptr;
+
+ auto type = parseElementsLiteralType();
+ if (!type)
+ return nullptr;
+
+ if (parseToken(Token::comma, "expected ',' after elements literal type"))
+ return nullptr;
+ if (getToken().isNot(Token::l_square))
+ return emitError("expected '[' to start sparse tensor literal"), nullptr;
+
+ /// Parse indices
+ auto indicesEltType = builder.getIntegerType(64);
+ auto indices = parseDenseElementsAttrAsTensor(indicesEltType);
+ if (!indices)
+ return nullptr;
+
+ if (parseToken(Token::comma, "expected ','"))
+ return nullptr;
+
+ /// Parse values.
+ auto valuesEltType = type.getElementType();
+ auto values = parseDenseElementsAttrAsTensor(valuesEltType);
+ if (!values)
+ return nullptr;
+
+ /// Sanity check.
+ auto valuesType = values.getType();
+ if (valuesType.getRank() != 1)
+ return (emitError("expected 1-d tensor for values"), nullptr);
+
+ auto indicesType = indices.getType();
+ auto sameShape = (indicesType.getRank() == 1) ||
+ (type.getRank() == indicesType.getDimSize(1));
+ auto sameElementNum = indicesType.getDimSize(0) == valuesType.getDimSize(0);
+ if (!sameShape || !sameElementNum) {
+ emitError() << "expected shape ([" << type.getShape()
+ << "]); inferred shape of indices literal (["
+ << indicesType.getShape()
+ << "]); inferred shape of values literal (["
+ << valuesType.getShape() << "])";
+ return nullptr;
+ }
+
+ if (parseToken(Token::greater, "expected '>'"))
+ return nullptr;
+
+ // Build the sparse elements attribute by the indices and values.
+ return builder.getSparseElementsAttr(
+ type, indices.cast<DenseIntElementsAttr>(), values);
+}
+
+/// Parse a splat elements attribute.
+Attribute Parser::parseSplatElementsAttr() {
+ consumeToken(Token::kw_splat);
+ if (parseToken(Token::less, "expected '<' after 'splat'"))
+ return nullptr;
+
+ auto type = parseElementsLiteralType();
+ if (!type)
+ return nullptr;
+ if (parseToken(Token::comma, "expected ',' after elements literal type"))
+ return nullptr;
+
+ switch (getToken().getKind()) {
+ case Token::floatliteral:
+ case Token::integer:
+ case Token::kw_false:
+ case Token::kw_true:
+ case Token::minus: {
+ auto scalar = parseAttribute(type.getElementType());
+ if (!scalar)
+ return nullptr;
+ if (parseToken(Token::greater, "expected '>'"))
+ return nullptr;
+ return builder.getSplatElementsAttr(type, scalar);
+ }
+ default:
+ return emitError("expected scalar constant inside tensor literal"), nullptr;
+ }
+}
+
namespace {
class TensorLiteralParser {
public:
/// This method compares the shapes from the parsing result and that from the
/// input argument. It returns a constructed dense elements attribute if both
/// match.
-DenseElementsAttr Parser::parseDenseElementsAttr(ShapedType type) {
- auto eltTy = type.getElementType();
- TensorLiteralParser literalParser(*this, eltTy);
+Attribute Parser::parseDenseElementsAttr() {
+ consumeToken(Token::kw_dense);
+ if (parseToken(Token::less, "expected '<' after 'dense'"))
+ return nullptr;
+
+ auto type = parseElementsLiteralType();
+ if (!type)
+ return nullptr;
+
+ if (parseToken(Token::comma, "expected ',' after elements literal type"))
+ return nullptr;
+
+ TensorLiteralParser literalParser(*this, type.getElementType());
if (literalParser.parse())
return nullptr;
return nullptr;
}
+ if (parseToken(Token::greater, "expected '>'"))
+ return nullptr;
+
return builder.getDenseElementsAttr(type, literalParser.getValues())
.cast<DenseElementsAttr>();
}
return nullptr;
if (!type.isa<RankedTensorType>() && !type.isa<VectorType>()) {
- return (
- emitError("elements literal must be a ranked tensor or vector type"),
- nullptr);
+ emitError("elements literal must be a ranked tensor or vector type");
+ return nullptr;
}
auto sType = type.cast<ShapedType>();