namespace {
class Parser;
+//===----------------------------------------------------------------------===//
+// ParserState
+//===----------------------------------------------------------------------===//
+
/// This class refers to all of the state maintained globally by the parser,
-/// such as the current lexer position etc. The Parser base class provides
+/// such as the current lexer position etc. The Parser base class provides
/// methods to access this.
class ParserState {
public:
// This is the next token that hasn't been consumed yet.
Token curToken;
};
-} // end anonymous namespace
-namespace {
+//===----------------------------------------------------------------------===//
+// Parser
+//===----------------------------------------------------------------------===//
/// This class implement support for parsing global entities like types and
/// shared entities like SSA names. It is intended to be subclassed by
Module *getModule() { return state.module; }
const llvm::SourceMgr &getSourceMgr() { return state.lex.getSourceMgr(); }
- /// Return the current token the parser is inspecting.
- const Token &getToken() const { return state.curToken; }
- StringRef getTokenSpelling() const { return state.curToken.getSpelling(); }
+ /// Parse a comma-separated list of elements up until the specified end token.
+ ParseResult
+ parseCommaSeparatedListUntil(Token::Kind rightToken,
+ const std::function<ParseResult()> &parseElement,
+ bool allowEmptyList = true);
+
+ /// Parse a comma separated list of elements that must have at least one entry
+ /// in it.
+ ParseResult
+ parseCommaSeparatedList(const std::function<ParseResult()> &parseElement);
+
+ ParseResult parsePrettyDialectSymbolName(StringRef &prettyName);
+
+ // We have two forms of parsing methods - those that return a non-null
+ // pointer on success, and those that return a ParseResult to indicate whether
+ // they returned a failure. The second class fills in by-reference arguments
+ // as the results of their action.
+
+ //===--------------------------------------------------------------------===//
+ // Error Handling
+ //===--------------------------------------------------------------------===//
+
+ /// Emit an error and return failure.
+ InFlightDiagnostic emitError(const Twine &message = {}) {
+ return emitError(state.curToken.getLoc(), message);
+ }
+ InFlightDiagnostic emitError(SMLoc loc, const Twine &message = {});
/// Encode the specified source location information into an attribute for
/// attachment to the IR.
return state.lex.getEncodedSourceLocation(loc);
}
- /// Emit an error and return failure.
- InFlightDiagnostic emitError(const Twine &message = {}) {
- return emitError(state.curToken.getLoc(), message);
+ //===--------------------------------------------------------------------===//
+ // Token Parsing
+ //===--------------------------------------------------------------------===//
+
+ /// Return the current token the parser is inspecting.
+ const Token &getToken() const { return state.curToken; }
+ StringRef getTokenSpelling() const { return state.curToken.getSpelling(); }
+
+ /// If the current token has the specified kind, consume it and return true.
+ /// If not, return false.
+ bool consumeIf(Token::Kind kind) {
+ if (state.curToken.isNot(kind))
+ return false;
+ consumeToken(kind);
+ return true;
}
- InFlightDiagnostic emitError(SMLoc loc, const Twine &message = {});
/// Advance the current lexer onto the next token.
void consumeToken() {
consumeToken();
}
- /// If the current token has the specified kind, consume it and return true.
- /// If not, return false.
- bool consumeIf(Token::Kind kind) {
- if (state.curToken.isNot(kind))
- return false;
- consumeToken(kind);
- return true;
- }
-
/// Consume the specified token if present and return success. On failure,
/// output a diagnostic and return failure.
ParseResult parseToken(Token::Kind expectedToken, const Twine &message);
- /// Parse a comma-separated list of elements up until the specified end token.
- ParseResult
- parseCommaSeparatedListUntil(Token::Kind rightToken,
- const std::function<ParseResult()> &parseElement,
- bool allowEmptyList = true);
+ //===--------------------------------------------------------------------===//
+ // Type Parsing
+ //===--------------------------------------------------------------------===//
- /// Parse a comma separated list of elements that must have at least one entry
- /// in it.
- ParseResult
- parseCommaSeparatedList(const std::function<ParseResult()> &parseElement);
+ ParseResult parseFunctionResultTypes(SmallVectorImpl<Type> &elements);
+ ParseResult parseTypeListNoParens(SmallVectorImpl<Type> &elements);
+ ParseResult parseTypeListParens(SmallVectorImpl<Type> &elements);
- // We have two forms of parsing methods - those that return a non-null
- // pointer on success, and those that return a ParseResult to indicate whether
- // they returned a failure. The second class fills in by-reference arguments
- // as the results of their action.
+ /// Parse an arbitrary type.
+ Type parseType();
- // Type parsing.
- VectorType parseVectorType();
- ParseResult parseXInDimensionList();
- ParseResult parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions,
- bool allowDynamic);
- Type parseExtendedType();
- ParseResult parsePrettyDialectSymbolName(StringRef &prettyName);
- Type parseTensorType();
+ /// Parse a complex type.
Type parseComplexType();
- Type parseTupleType();
- Type parseMemRefType();
+
+ /// Parse an extended type.
+ Type parseExtendedType();
+
+ /// Parse a function type.
Type parseFunctionType();
+
+ /// Parse a memref type.
+ Type parseMemRefType();
+
+ /// Parse a non function type.
Type parseNonFunctionType();
- Type parseType();
- ParseResult parseTypeListNoParens(SmallVectorImpl<Type> &elements);
- ParseResult parseTypeListParens(SmallVectorImpl<Type> &elements);
- ParseResult parseFunctionResultTypes(SmallVectorImpl<Type> &elements);
- // Attribute parsing.
- Attribute parseExtendedAttribute(Type type);
+ /// Parse a tensor type.
+ Type parseTensorType();
+
+ /// Parse a tuple type.
+ Type parseTupleType();
+
+ /// Parse a vector type.
+ VectorType parseVectorType();
+ ParseResult parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions,
+ bool allowDynamic = true);
+ ParseResult parseXInDimensionList();
+
+ //===--------------------------------------------------------------------===//
+ // Attribute Parsing
+ //===--------------------------------------------------------------------===//
+
+ /// Parse an arbitrary attribute with an optional type.
Attribute parseAttribute(Type type = {});
+ /// Parse an attribute dictionary.
ParseResult parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes);
- // Polyhedral structures.
- ParseResult parseAffineMapOrIntegerSetReference(AffineMap &map,
- IntegerSet &set);
+ /// Parse an extended attribute.
+ Attribute parseExtendedAttribute(Type type);
+
+ /// Parse a dense elements attribute.
DenseElementsAttr parseDenseElementsAttr(ShapedType type);
DenseElementsAttr parseDenseElementsAttrAsTensor(Type eltType);
ShapedType parseElementsLiteralType();
- // Location Parsing.
+ //===--------------------------------------------------------------------===//
+ // Location Parsing
+ //===--------------------------------------------------------------------===//
+
+ /// Parse an inline location.
+ ParseResult parseLocation(llvm::Optional<Location> *loc);
+
+ /// Parse a raw location instance.
+ ParseResult parseLocationInstance(llvm::Optional<Location> *loc);
- /// Trailing locations.
+ /// Parse an optional trailing location.
///
/// trailing-location ::= location?
///
return success();
}
- /// Parse an inline location.
- ParseResult parseLocation(llvm::Optional<Location> *loc);
+ //===--------------------------------------------------------------------===//
+ // Affine Parsing
+ //===--------------------------------------------------------------------===//
- /// Parse a raw location instance.
- ParseResult parseLocationInstance(llvm::Optional<Location> *loc);
+ ParseResult parseAffineMapOrIntegerSetReference(AffineMap &map,
+ IntegerSet &set);
private:
- // The Parser is subclassed and reinstantiated. Do not add additional
- // non-trivial state here, add it to the ParserState class.
+ /// The Parser is subclassed and reinstantiated. Do not add additional
+ /// non-trivial state here, add it to the ParserState class.
ParserState &state;
};
} // end anonymous namespace
// Helper methods.
//===----------------------------------------------------------------------===//
-InFlightDiagnostic Parser::emitError(SMLoc loc, const Twine &message) {
- auto diag = getContext()->emitError(getEncodedSourceLocation(loc), message);
-
- // If we hit a parse error in response to a lexer error, then the lexer
- // already reported the error.
- if (getToken().is(Token::error))
- diag.abandon();
- return diag;
-}
-
-/// Consume the specified token if present and return success. On failure,
-/// output a diagnostic and return failure.
-ParseResult Parser::parseToken(Token::Kind expectedToken,
- const Twine &message) {
- if (consumeIf(expectedToken))
- return success();
- return emitError(message);
-}
-
/// Parse a comma separated list of elements that must have at least one entry
/// in it.
ParseResult Parser::parseCommaSeparatedList(
return success();
}
-//===----------------------------------------------------------------------===//
-// Type Parsing
-//===----------------------------------------------------------------------===//
-
-/// Parse any type except the function type.
-///
-/// non-function-type ::= integer-type
-/// | index-type
-/// | float-type
-/// | extended-type
-/// | vector-type
-/// | tensor-type
-/// | memref-type
-/// | complex-type
-/// | tuple-type
-/// | none-type
+/// Parse the body of a pretty dialect symbol, which starts and ends with <>'s,
+/// and may be recursive. Return with the 'prettyName' StringRef encompasing
+/// the entire pretty name.
///
-/// index-type ::= `index`
-/// float-type ::= `f16` | `bf16` | `f32` | `f64`
-/// none-type ::= `none`
+/// pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>'
+/// pretty-dialect-sym-contents ::= pretty-dialect-sym-body
+/// | '(' pretty-dialect-sym-contents+ ')'
+/// | '[' pretty-dialect-sym-contents+ ']'
+/// | '{' pretty-dialect-sym-contents+ '}'
+/// | '[^[<({>\])}\0]+'
///
-Type Parser::parseNonFunctionType() {
- switch (getToken().getKind()) {
- default:
- return (emitError("expected non-function type"), nullptr);
- case Token::kw_memref:
- return parseMemRefType();
- case Token::kw_tensor:
- return parseTensorType();
- case Token::kw_complex:
- return parseComplexType();
- case Token::kw_tuple:
- return parseTupleType();
- case Token::kw_vector:
- return parseVectorType();
- // integer-type
- case Token::inttype: {
- auto width = getToken().getIntTypeBitwidth();
- if (!width.hasValue())
- return (emitError("invalid integer width"), nullptr);
- auto loc = getEncodedSourceLocation(getToken().getLoc());
- consumeToken(Token::inttype);
- return IntegerType::getChecked(width.getValue(), builder.getContext(), loc);
- }
-
- // float-type
- case Token::kw_bf16:
- consumeToken(Token::kw_bf16);
- return builder.getBF16Type();
- case Token::kw_f16:
- consumeToken(Token::kw_f16);
- return builder.getF16Type();
- case Token::kw_f32:
- consumeToken(Token::kw_f32);
- return builder.getF32Type();
- case Token::kw_f64:
- consumeToken(Token::kw_f64);
- return builder.getF64Type();
-
- // index-type
- case Token::kw_index:
- consumeToken(Token::kw_index);
- return builder.getIndexType();
+ParseResult Parser::parsePrettyDialectSymbolName(StringRef &prettyName) {
+ // Pretty symbol names are a relatively unstructured format that contains a
+ // series of properly nested punctuation, with anything else in the middle.
+ // Scan ahead to find it and consume it if successful, otherwise emit an
+ // error.
+ auto *curPtr = getTokenSpelling().data();
- // none-type
- case Token::kw_none:
- consumeToken(Token::kw_none);
- return builder.getNoneType();
+ SmallVector<char, 8> nestedPunctuation;
- // extended type
- case Token::exclamation_identifier:
- return parseExtendedType();
- }
-}
-
-/// Parse an arbitrary type.
-///
-/// type ::= function-type
-/// | non-function-type
-///
-Type Parser::parseType() {
- if (getToken().is(Token::l_paren))
- return parseFunctionType();
- return parseNonFunctionType();
-}
-
-/// Parse a vector type.
-///
-/// vector-type ::= `vector` `<` static-dimension-list primitive-type `>`
-/// static-dimension-list ::= (decimal-literal `x`)+
-///
-VectorType Parser::parseVectorType() {
- consumeToken(Token::kw_vector);
-
- if (parseToken(Token::less, "expected '<' in vector type"))
- return nullptr;
-
- SmallVector<int64_t, 4> dimensions;
- if (parseDimensionListRanked(dimensions, /*allowDynamic=*/false))
- return nullptr;
- if (dimensions.empty())
- return (emitError("expected dimension size in vector type"), nullptr);
-
- // Parse the element type.
- auto typeLoc = getToken().getLoc();
- auto elementType = parseType();
- if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
- return nullptr;
-
- return VectorType::getChecked(dimensions, elementType,
- getEncodedSourceLocation(typeLoc));
-}
-
-/// Parse an 'x' token in a dimension list, handling the case where the x is
-/// juxtaposed with an element type, as in "xf32", leaving the "f32" as the next
-/// token.
-ParseResult Parser::parseXInDimensionList() {
- if (getToken().isNot(Token::bare_identifier) || getTokenSpelling()[0] != 'x')
- return emitError("expected 'x' in dimension list");
-
- // If we had a prefix of 'x', lex the next token immediately after the 'x'.
- if (getTokenSpelling().size() != 1)
- state.lex.resetPointer(getTokenSpelling().data() + 1);
-
- // Consume the 'x'.
- consumeToken(Token::bare_identifier);
-
- return success();
-}
-
-/// Parse a dimension list of a tensor or memref type. This populates the
-/// dimension list, using -1 for the `?` dimensions if `allowDynamic` is set and
-/// errors out on `?` otherwise.
-///
-/// dimension-list-ranked ::= (dimension `x`)*
-/// dimension ::= `?` | decimal-literal
-///
-/// When `allowDynamic` is not set, this can be also used to parse
-///
-/// static-dimension-list ::= (decimal-literal `x`)*
-ParseResult
-Parser::parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions,
- bool allowDynamic = true) {
- while (getToken().isAny(Token::integer, Token::question)) {
- if (consumeIf(Token::question)) {
- if (!allowDynamic)
- return emitError("expected static shape");
- dimensions.push_back(-1);
- } else {
- // Hexadecimal integer literals (starting with `0x`) are not allowed in
- // aggregate type declarations. Therefore, `0xf32` should be processed as
- // a sequence of separate elements `0`, `x`, `f32`.
- if (getTokenSpelling().size() > 1 && getTokenSpelling()[1] == 'x') {
- // We can get here only if the token is an integer literal. Hexadecimal
- // integer literals can only start with `0x` (`1x` wouldn't lex as a
- // literal, just `1` would, at which point we don't get into this
- // branch).
- assert(getTokenSpelling()[0] == '0' && "invalid integer literal");
- dimensions.push_back(0);
- state.lex.resetPointer(getTokenSpelling().data() + 1);
- consumeToken();
- } else {
- // Make sure this integer value is in bound and valid.
- auto dimension = getToken().getUnsignedIntegerValue();
- if (!dimension.hasValue())
- return emitError("invalid dimension");
- dimensions.push_back((int64_t)dimension.getValue());
- consumeToken(Token::integer);
- }
- }
-
- // Make sure we have an 'x' or something like 'xbf32'.
- if (parseXInDimensionList())
- return failure();
- }
-
- return success();
-}
-
-/// Parse the body of a pretty dialect symbol, which starts and ends with <>'s,
-/// and may be recursive. Return with the 'prettyName' StringRef encompasing
-/// the entire pretty name.
-///
-/// pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>'
-/// pretty-dialect-sym-contents ::= pretty-dialect-sym-body
-/// | '(' pretty-dialect-sym-contents+ ')'
-/// | '[' pretty-dialect-sym-contents+ ']'
-/// | '{' pretty-dialect-sym-contents+ '}'
-/// | '[^[<({>\])}\0]+'
-///
-ParseResult Parser::parsePrettyDialectSymbolName(StringRef &prettyName) {
- // Pretty symbol names are a relatively unstructured format that contains a
- // series of properly nested punctuation, with anything else in the middle.
- // Scan ahead to find it and consume it if successful, otherwise emit an
- // error.
- auto *curPtr = getTokenSpelling().data();
-
- SmallVector<char, 8> nestedPunctuation;
-
- // Scan over the nested punctuation, bailing out on error and consuming until
- // we find the end. We know that we're currently looking at the '<', so we
- // can go until we find the matching '>' character.
- assert(*curPtr == '<');
- do {
- char c = *curPtr++;
- switch (c) {
- case '\0':
- // This also handles the EOF case.
- return emitError("unexpected nul or EOF in pretty dialect name");
- case '<':
- case '[':
- case '(':
- case '{':
- nestedPunctuation.push_back(c);
- continue;
+ // Scan over the nested punctuation, bailing out on error and consuming until
+ // we find the end. We know that we're currently looking at the '<', so we
+ // can go until we find the matching '>' character.
+ assert(*curPtr == '<');
+ do {
+ char c = *curPtr++;
+ switch (c) {
+ case '\0':
+ // This also handles the EOF case.
+ return emitError("unexpected nul or EOF in pretty dialect name");
+ case '<':
+ case '[':
+ case '(':
+ case '{':
+ nestedPunctuation.push_back(c);
+ continue;
case '>':
if (nestedPunctuation.pop_back_val() != '<')
return createSymbol(dialectName, symbolData, encodedLoc);
}
+//===----------------------------------------------------------------------===//
+// Error Handling
+//===----------------------------------------------------------------------===//
+
+InFlightDiagnostic Parser::emitError(SMLoc loc, const Twine &message) {
+ auto diag = getContext()->emitError(getEncodedSourceLocation(loc), message);
+
+ // If we hit a parse error in response to a lexer error, then the lexer
+ // already reported the error.
+ if (getToken().is(Token::error))
+ diag.abandon();
+ return diag;
+}
+
+//===----------------------------------------------------------------------===//
+// Token Parsing
+//===----------------------------------------------------------------------===//
+
+/// Consume the specified token if present and return success. On failure,
+/// output a diagnostic and return failure.
+ParseResult Parser::parseToken(Token::Kind expectedToken,
+ const Twine &message) {
+ if (consumeIf(expectedToken))
+ return success();
+ return emitError(message);
+}
+
+//===----------------------------------------------------------------------===//
+// Type Parsing
+//===----------------------------------------------------------------------===//
+
+/// Parse an arbitrary type.
+///
+/// type ::= function-type
+/// | non-function-type
+///
+Type Parser::parseType() {
+ if (getToken().is(Token::l_paren))
+ return parseFunctionType();
+ return parseNonFunctionType();
+}
+
+/// Parse a function result type.
+///
+/// function-result-type ::= type-list-parens
+/// | non-function-type
+///
+ParseResult Parser::parseFunctionResultTypes(SmallVectorImpl<Type> &elements) {
+ if (getToken().is(Token::l_paren))
+ return parseTypeListParens(elements);
+
+ Type t = parseNonFunctionType();
+ if (!t)
+ return failure();
+ elements.push_back(t);
+ return success();
+}
+
+/// Parse a list of types without an enclosing parenthesis. The list must have
+/// at least one member.
+///
+/// type-list-no-parens ::= type (`,` type)*
+///
+ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type> &elements) {
+ auto parseElt = [&]() -> ParseResult {
+ auto elt = parseType();
+ elements.push_back(elt);
+ return elt ? success() : failure();
+ };
+
+ return parseCommaSeparatedList(parseElt);
+}
+
+/// Parse a parenthesized list of types.
+///
+/// type-list-parens ::= `(` `)`
+/// | `(` type-list-no-parens `)`
+///
+ParseResult Parser::parseTypeListParens(SmallVectorImpl<Type> &elements) {
+ if (parseToken(Token::l_paren, "expected '('"))
+ return failure();
+
+ // Handle empty lists.
+ if (getToken().is(Token::r_paren))
+ return consumeToken(), success();
+
+ if (parseTypeListNoParens(elements) ||
+ parseToken(Token::r_paren, "expected ')'"))
+ return failure();
+ return success();
+}
+
+/// Parse a complex type.
+///
+/// complex-type ::= `complex` `<` type `>`
+///
+Type Parser::parseComplexType() {
+ consumeToken(Token::kw_complex);
+
+ // Parse the '<'.
+ if (parseToken(Token::less, "expected '<' in complex type"))
+ return nullptr;
+
+ auto typeLocation = getEncodedSourceLocation(getToken().getLoc());
+ auto elementType = parseType();
+ if (!elementType ||
+ parseToken(Token::greater, "expected '>' in complex type"))
+ return nullptr;
+
+ return ComplexType::getChecked(elementType, typeLocation);
+}
+
/// Parse an extended type.
///
/// extended-type ::= (dialect-type | type-alias)
});
}
-/// Parse a tensor type.
+/// Parse a function type.
///
-/// tensor-type ::= `tensor` `<` dimension-list element-type `>`
-/// dimension-list ::= dimension-list-ranked | `*x`
+/// function-type ::= type-list-parens `->` type-list
///
-Type Parser::parseTensorType() {
- consumeToken(Token::kw_tensor);
+Type Parser::parseFunctionType() {
+ assert(getToken().is(Token::l_paren));
- if (parseToken(Token::less, "expected '<' in tensor type"))
- return nullptr;
-
- bool isUnranked;
- SmallVector<int64_t, 4> dimensions;
-
- if (consumeIf(Token::star)) {
- // This is an unranked tensor type.
- isUnranked = true;
-
- if (parseXInDimensionList())
- return nullptr;
-
- } else {
- isUnranked = false;
- if (parseDimensionListRanked(dimensions))
- return nullptr;
- }
-
- // Parse the element type.
- auto typeLocation = getEncodedSourceLocation(getToken().getLoc());
- auto elementType = parseType();
- if (!elementType || parseToken(Token::greater, "expected '>' in tensor type"))
- return nullptr;
-
- if (isUnranked)
- return UnrankedTensorType::getChecked(elementType, typeLocation);
- return RankedTensorType::getChecked(dimensions, elementType, typeLocation);
-}
-
-/// Parse a complex type.
-///
-/// complex-type ::= `complex` `<` type `>`
-///
-Type Parser::parseComplexType() {
- consumeToken(Token::kw_complex);
-
- // Parse the '<'.
- if (parseToken(Token::less, "expected '<' in complex type"))
- return nullptr;
-
- auto typeLocation = getEncodedSourceLocation(getToken().getLoc());
- auto elementType = parseType();
- if (!elementType ||
- parseToken(Token::greater, "expected '>' in complex type"))
- return nullptr;
-
- return ComplexType::getChecked(elementType, typeLocation);
-}
-
-/// Parse a tuple type.
-///
-/// tuple-type ::= `tuple` `<` (type (`,` type)*)? `>`
-///
-Type Parser::parseTupleType() {
- consumeToken(Token::kw_tuple);
-
- // Parse the '<'.
- if (parseToken(Token::less, "expected '<' in tuple type"))
- return nullptr;
-
- // Check for an empty tuple by directly parsing '>'.
- if (consumeIf(Token::greater))
- return TupleType::get(getContext());
-
- // Parse the element types and the '>'.
- SmallVector<Type, 4> types;
- if (parseTypeListNoParens(types) ||
- parseToken(Token::greater, "expected '>' in tuple type"))
+ SmallVector<Type, 4> arguments, results;
+ if (parseTypeListParens(arguments) ||
+ parseToken(Token::arrow, "expected '->' in function type") ||
+ parseFunctionResultTypes(results))
return nullptr;
- return TupleType::get(types, getContext());
+ return builder.getFunctionType(arguments, results);
}
/// Parse a memref type.
memorySpace, getEncodedSourceLocation(typeLoc));
}
-/// Parse a function type.
+/// Parse any type except the function type.
///
-/// function-type ::= type-list-parens `->` type-list
+/// non-function-type ::= integer-type
+/// | index-type
+/// | float-type
+/// | extended-type
+/// | vector-type
+/// | tensor-type
+/// | memref-type
+/// | complex-type
+/// | tuple-type
+/// | none-type
///
-Type Parser::parseFunctionType() {
- assert(getToken().is(Token::l_paren));
+/// index-type ::= `index`
+/// float-type ::= `f16` | `bf16` | `f32` | `f64`
+/// none-type ::= `none`
+///
+Type Parser::parseNonFunctionType() {
+ switch (getToken().getKind()) {
+ default:
+ return (emitError("expected non-function type"), nullptr);
+ case Token::kw_memref:
+ return parseMemRefType();
+ case Token::kw_tensor:
+ return parseTensorType();
+ case Token::kw_complex:
+ return parseComplexType();
+ case Token::kw_tuple:
+ return parseTupleType();
+ case Token::kw_vector:
+ return parseVectorType();
+ // integer-type
+ case Token::inttype: {
+ auto width = getToken().getIntTypeBitwidth();
+ if (!width.hasValue())
+ return (emitError("invalid integer width"), nullptr);
+ auto loc = getEncodedSourceLocation(getToken().getLoc());
+ consumeToken(Token::inttype);
+ return IntegerType::getChecked(width.getValue(), builder.getContext(), loc);
+ }
- SmallVector<Type, 4> arguments, results;
- if (parseTypeListParens(arguments) ||
- parseToken(Token::arrow, "expected '->' in function type") ||
- parseFunctionResultTypes(results))
- return nullptr;
+ // float-type
+ case Token::kw_bf16:
+ consumeToken(Token::kw_bf16);
+ return builder.getBF16Type();
+ case Token::kw_f16:
+ consumeToken(Token::kw_f16);
+ return builder.getF16Type();
+ case Token::kw_f32:
+ consumeToken(Token::kw_f32);
+ return builder.getF32Type();
+ case Token::kw_f64:
+ consumeToken(Token::kw_f64);
+ return builder.getF64Type();
- return builder.getFunctionType(arguments, results);
-}
+ // index-type
+ case Token::kw_index:
+ consumeToken(Token::kw_index);
+ return builder.getIndexType();
-/// Parse a list of types without an enclosing parenthesis. The list must have
-/// at least one member.
-///
-/// type-list-no-parens ::= type (`,` type)*
-///
-ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type> &elements) {
- auto parseElt = [&]() -> ParseResult {
- auto elt = parseType();
- elements.push_back(elt);
- return elt ? success() : failure();
- };
+ // none-type
+ case Token::kw_none:
+ consumeToken(Token::kw_none);
+ return builder.getNoneType();
- return parseCommaSeparatedList(parseElt);
+ // extended type
+ case Token::exclamation_identifier:
+ return parseExtendedType();
+ }
}
-/// Parse a parenthesized list of types.
+/// Parse a tensor type.
///
-/// type-list-parens ::= `(` `)`
-/// | `(` type-list-no-parens `)`
+/// tensor-type ::= `tensor` `<` dimension-list element-type `>`
+/// dimension-list ::= dimension-list-ranked | `*x`
///
-ParseResult Parser::parseTypeListParens(SmallVectorImpl<Type> &elements) {
- if (parseToken(Token::l_paren, "expected '('"))
- return failure();
+Type Parser::parseTensorType() {
+ consumeToken(Token::kw_tensor);
- // Handle empty lists.
- if (getToken().is(Token::r_paren))
- return consumeToken(), success();
+ if (parseToken(Token::less, "expected '<' in tensor type"))
+ return nullptr;
- if (parseTypeListNoParens(elements) ||
- parseToken(Token::r_paren, "expected ')'"))
- return failure();
- return success();
-}
+ bool isUnranked;
+ SmallVector<int64_t, 4> dimensions;
-/// Parse a function result type.
-///
-/// function-result-type ::= type-list-parens
-/// | non-function-type
-///
-ParseResult Parser::parseFunctionResultTypes(SmallVectorImpl<Type> &elements) {
- if (getToken().is(Token::l_paren))
- return parseTypeListParens(elements);
+ if (consumeIf(Token::star)) {
+ // This is an unranked tensor type.
+ isUnranked = true;
- Type t = parseNonFunctionType();
- if (!t)
- return failure();
- elements.push_back(t);
- return success();
+ if (parseXInDimensionList())
+ return nullptr;
+
+ } else {
+ isUnranked = false;
+ if (parseDimensionListRanked(dimensions))
+ return nullptr;
+ }
+
+ // Parse the element type.
+ auto typeLocation = getEncodedSourceLocation(getToken().getLoc());
+ auto elementType = parseType();
+ if (!elementType || parseToken(Token::greater, "expected '>' in tensor type"))
+ return nullptr;
+
+ if (isUnranked)
+ return UnrankedTensorType::getChecked(elementType, typeLocation);
+ return RankedTensorType::getChecked(dimensions, elementType, typeLocation);
}
-//===----------------------------------------------------------------------===//
-// Attribute parsing.
-//===----------------------------------------------------------------------===//
+/// Parse a tuple type.
+///
+/// tuple-type ::= `tuple` `<` (type (`,` type)*)? `>`
+///
+Type Parser::parseTupleType() {
+ consumeToken(Token::kw_tuple);
-namespace {
-class TensorLiteralParser {
-public:
- TensorLiteralParser(Parser &p, Type eltTy) : p(p), eltTy(eltTy) {}
+ // Parse the '<'.
+ if (parseToken(Token::less, "expected '<' in tuple type"))
+ return nullptr;
- ParseResult parse() {
- if (p.getToken().is(Token::l_square)) {
- return parseList(shape);
- }
- return parseElement();
- }
+ // Check for an empty tuple by directly parsing '>'.
+ if (consumeIf(Token::greater))
+ return TupleType::get(getContext());
- ArrayRef<Attribute> getValues() const { return storage; }
+ // Parse the element types and the '>'.
+ SmallVector<Type, 4> types;
+ if (parseTypeListNoParens(types) ||
+ parseToken(Token::greater, "expected '>' in tuple type"))
+ return nullptr;
- ArrayRef<int64_t> getShape() const { return shape; }
+ return TupleType::get(types, getContext());
+}
-private:
- /// Parse a single element, returning failure if it isn't a valid element
- /// literal. For example:
- /// parseElement(1) -> Success, 1
- /// parseElement([1]) -> Failure
- ParseResult parseElement();
+/// Parse a vector type.
+///
+/// vector-type ::= `vector` `<` static-dimension-list primitive-type `>`
+/// static-dimension-list ::= (decimal-literal `x`)+
+///
+VectorType Parser::parseVectorType() {
+ consumeToken(Token::kw_vector);
- /// Parse a list of either lists or elements, returning the dimensions of the
- /// parsed sub-tensors in dims. For example:
- /// parseList([1, 2, 3]) -> Success, [3]
- /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2]
- /// parseList([[1, 2], 3]) -> Failure
- /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure
- ParseResult parseList(llvm::SmallVectorImpl<int64_t> &dims);
+ if (parseToken(Token::less, "expected '<' in vector type"))
+ return nullptr;
- Parser &p;
- Type eltTy;
- SmallVector<int64_t, 4> shape;
- std::vector<Attribute> storage;
-};
-} // namespace
+ SmallVector<int64_t, 4> dimensions;
+ if (parseDimensionListRanked(dimensions, /*allowDynamic=*/false))
+ return nullptr;
+ if (dimensions.empty())
+ return (emitError("expected dimension size in vector type"), nullptr);
-ParseResult TensorLiteralParser::parseElement() {
- switch (p.getToken().getKind()) {
- case Token::floatliteral:
- case Token::integer:
- case Token::minus: {
- auto result = p.parseAttribute(eltTy);
- if (!result)
- return failure();
- // check result matches the element type.
- switch (eltTy.getKind()) {
- case StandardTypes::BF16:
- case StandardTypes::F16:
- case StandardTypes::F32:
- case StandardTypes::F64: {
- // Bitcast the APFloat value to APInt and store the bit representation.
- auto fpAttrResult = result.dyn_cast<FloatAttr>();
- if (!fpAttrResult)
- return p.emitError(
- "expected tensor literal element with floating point type");
- auto apInt = fpAttrResult.getValue().bitcastToAPInt();
+ // Parse the element type.
+ auto typeLoc = getToken().getLoc();
+ auto elementType = parseType();
+ if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
+ return nullptr;
- // FIXME: using 64 bits and double semantics for BF16 because APFloat does
- // not support BF16 directly.
- size_t bitWidth = eltTy.isBF16() ? 64 : eltTy.getIntOrFloatBitWidth();
- assert(apInt.getBitWidth() == bitWidth);
- (void)bitWidth;
- (void)apInt;
- break;
- }
- case StandardTypes::Integer: {
- if (!result.isa<IntegerAttr>())
- return p.emitError("expected tensor literal element has integer type");
- auto value = result.cast<IntegerAttr>().getValue();
- if (value.getMinSignedBits() > eltTy.getIntOrFloatBitWidth())
- return p.emitError("tensor literal element has more bits than that "
- "specified in the type");
- break;
- }
- default:
- return p.emitError("expected integer or float tensor element");
- }
- storage.push_back(result);
- break;
- }
- default:
- return p.emitError("expected element literal of primitive type");
- }
- return success();
+ return VectorType::getChecked(dimensions, elementType,
+ getEncodedSourceLocation(typeLoc));
}
-/// Parse a list of either lists or elements, returning the dimensions of the
-/// parsed sub-tensors in dims. For example:
-/// parseList([1, 2, 3]) -> Success, [3]
-/// parseList([[1, 2], [3, 4]]) -> Success, [2, 2]
-/// parseList([[1, 2], 3]) -> Failure
-/// parseList([[1, [2, 3]], [4, [5]]]) -> Failure
+/// Parse a dimension list of a tensor or memref type. This populates the
+/// dimension list, using -1 for the `?` dimensions if `allowDynamic` is set and
+/// errors out on `?` otherwise.
+///
+/// dimension-list-ranked ::= (dimension `x`)*
+/// dimension ::= `?` | decimal-literal
+///
+/// When `allowDynamic` is not set, this can be also used to parse
+///
+/// static-dimension-list ::= (decimal-literal `x`)*
ParseResult
-TensorLiteralParser::parseList(llvm::SmallVectorImpl<int64_t> &dims) {
- p.consumeToken(Token::l_square);
-
- auto checkDims =
- [&](const llvm::SmallVectorImpl<int64_t> &prevDims,
- const llvm::SmallVectorImpl<int64_t> &newDims) -> ParseResult {
- if (prevDims == newDims)
- return success();
- return p.emitError("tensor literal is invalid; ranks are not consistent "
- "between elements");
- };
+Parser::parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions,
+ bool allowDynamic) {
+ while (getToken().isAny(Token::integer, Token::question)) {
+ if (consumeIf(Token::question)) {
+ if (!allowDynamic)
+ return emitError("expected static shape");
+ dimensions.push_back(-1);
+ } else {
+ // Hexadecimal integer literals (starting with `0x`) are not allowed in
+ // aggregate type declarations. Therefore, `0xf32` should be processed as
+ // a sequence of separate elements `0`, `x`, `f32`.
+ if (getTokenSpelling().size() > 1 && getTokenSpelling()[1] == 'x') {
+ // We can get here only if the token is an integer literal. Hexadecimal
+ // integer literals can only start with `0x` (`1x` wouldn't lex as a
+ // literal, just `1` would, at which point we don't get into this
+ // branch).
+ assert(getTokenSpelling()[0] == '0' && "invalid integer literal");
+ dimensions.push_back(0);
+ state.lex.resetPointer(getTokenSpelling().data() + 1);
+ consumeToken();
+ } else {
+ // Make sure this integer value is in bound and valid.
+ auto dimension = getToken().getUnsignedIntegerValue();
+ if (!dimension.hasValue())
+ return emitError("invalid dimension");
+ dimensions.push_back((int64_t)dimension.getValue());
+ consumeToken(Token::integer);
+ }
+ }
- bool first = true;
- llvm::SmallVector<int64_t, 4> newDims;
- unsigned size = 0;
- auto parseCommaSeparatedList = [&]() -> ParseResult {
- llvm::SmallVector<int64_t, 4> thisDims;
- if (p.getToken().getKind() == Token::l_square) {
- if (parseList(thisDims))
- return failure();
- } else if (parseElement()) {
+ // Make sure we have an 'x' or something like 'xbf32'.
+ if (parseXInDimensionList())
return failure();
- }
- ++size;
- if (!first)
- return checkDims(newDims, thisDims);
- newDims = thisDims;
- first = false;
- return success();
- };
- if (p.parseCommaSeparatedListUntil(Token::r_square, parseCommaSeparatedList))
- return failure();
+ }
- // Return the sublists' dimensions with 'size' prepended.
- dims.clear();
- dims.push_back(size);
- dims.append(newDims.begin(), newDims.end());
return success();
}
-/// Parse an extended attribute.
-///
-/// extended-attribute ::= (dialect-attribute | attribute-alias)
-/// dialect-attribute ::= `#` dialect-namespace `<` `"` attr-data `"` `>`
-/// dialect-attribute ::= `#` alias-name pretty-dialect-sym-body?
-/// attribute-alias ::= `#` alias-name
-///
-Attribute Parser::parseExtendedAttribute(Type type) {
- Attribute attr = parseExtendedSymbol<Attribute>(
- *this, Token::hash_identifier, state.attributeAliasDefinitions,
- [&](StringRef dialectName, StringRef symbolData,
- Location loc) -> Attribute {
- // If we found a registered dialect, then ask it to parse the attribute.
- if (auto *dialect = state.context->getRegisteredDialect(dialectName))
- return dialect->parseAttribute(symbolData, loc);
+/// Parse an 'x' token in a dimension list, handling the case where the x is
+/// juxtaposed with an element type, as in "xf32", leaving the "f32" as the next
+/// token.
+ParseResult Parser::parseXInDimensionList() {
+ if (getToken().isNot(Token::bare_identifier) || getTokenSpelling()[0] != 'x')
+ return emitError("expected 'x' in dimension list");
- // Otherwise, form a new opaque attribute.
- return OpaqueAttr::getChecked(
- Identifier::get(dialectName, state.context), symbolData,
- state.context, loc);
- });
+ // If we had a prefix of 'x', lex the next token immediately after the 'x'.
+ if (getTokenSpelling().size() != 1)
+ state.lex.resetPointer(getTokenSpelling().data() + 1);
- // Ensure that the attribute has the same type as requested.
- if (type && attr.getType() != type) {
- emitError("attribute type different than expected: expected ")
- << type << ", but got " << attr.getType();
- return nullptr;
- }
- return attr;
+ // Consume the 'x'.
+ consumeToken(Token::bare_identifier);
+
+ return success();
}
-/// Attribute parsing.
+//===----------------------------------------------------------------------===//
+// Attribute parsing.
+//===----------------------------------------------------------------------===//
+
+/// Parse an arbitrary attribute.
///
/// attribute-value ::= `unit`
/// | bool-literal
}
}
-/// Dense elements attribute.
+/// Attribute dictionary.
///
-/// dense-attr-list ::= `[` attribute-value `]`
-/// attribute-value ::= integer-literal
-/// | float-literal
-/// | `[` (attribute-value (`,` attribute-value)*)? `]`
+/// attribute-dict ::= `{` `}`
+/// | `{` attribute-entry (`,` attribute-entry)* `}`
+/// attribute-entry ::= bare-id `:` attribute-value
///
-/// This method returns a constructed dense elements attribute of tensor type
-/// with the shape from the parsing result.
-DenseElementsAttr Parser::parseDenseElementsAttrAsTensor(Type eltType) {
- TensorLiteralParser literalParser(*this, eltType);
- if (literalParser.parse())
+ParseResult
+Parser::parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes) {
+ if (!consumeIf(Token::l_brace))
+ return failure();
+
+ auto parseElt = [&]() -> ParseResult {
+ // We allow keywords as attribute names.
+ if (getToken().isNot(Token::bare_identifier, Token::inttype) &&
+ !getToken().isKeyword())
+ return emitError("expected attribute name");
+ Identifier nameId = builder.getIdentifier(getTokenSpelling());
+ consumeToken();
+
+ // Try to parse the ':' for the attribute value.
+ if (!consumeIf(Token::colon)) {
+ // If there is no ':', we treat this as a unit attribute.
+ attributes.push_back({nameId, builder.getUnitAttr()});
+ return success();
+ }
+
+ auto attr = parseAttribute();
+ if (!attr)
+ return failure();
+
+ attributes.push_back({nameId, attr});
+ return success();
+ };
+
+ if (parseCommaSeparatedListUntil(Token::r_brace, parseElt))
+ return failure();
+
+ return success();
+}
+
+/// Parse an extended attribute.
+///
+/// extended-attribute ::= (dialect-attribute | attribute-alias)
+/// dialect-attribute ::= `#` dialect-namespace `<` `"` attr-data `"` `>`
+/// dialect-attribute ::= `#` alias-name pretty-dialect-sym-body?
+/// attribute-alias ::= `#` alias-name
+///
+Attribute Parser::parseExtendedAttribute(Type type) {
+ Attribute attr = parseExtendedSymbol<Attribute>(
+ *this, Token::hash_identifier, state.attributeAliasDefinitions,
+ [&](StringRef dialectName, StringRef symbolData,
+ Location loc) -> Attribute {
+ // If we found a registered dialect, then ask it to parse the attribute.
+ if (auto *dialect = state.context->getRegisteredDialect(dialectName))
+ return dialect->parseAttribute(symbolData, loc);
+
+ // Otherwise, form a new opaque attribute.
+ return OpaqueAttr::getChecked(
+ Identifier::get(dialectName, state.context), symbolData,
+ state.context, loc);
+ });
+
+ // Ensure that the attribute has the same type as requested.
+ if (type && attr.getType() != type) {
+ emitError("attribute type different than expected: expected ")
+ << type << ", but got " << attr.getType();
return nullptr;
+ }
+ return attr;
+}
+
+namespace {
+class TensorLiteralParser {
+public:
+ TensorLiteralParser(Parser &p, Type eltTy) : p(p), eltTy(eltTy) {}
+
+ ParseResult parse() {
+ if (p.getToken().is(Token::l_square))
+ return parseList(shape);
+ return parseElement();
+ }
+
+ ArrayRef<Attribute> getValues() const { return storage; }
+
+ ArrayRef<int64_t> getShape() const { return shape; }
+
+private:
+ /// Parse a single element, returning failure if it isn't a valid element
+ /// literal. For example:
+ /// parseElement(1) -> Success, 1
+ /// parseElement([1]) -> Failure
+ ParseResult parseElement();
+
+ /// Parse a list of either lists or elements, returning the dimensions of the
+ /// parsed sub-tensors in dims. For example:
+ /// parseList([1, 2, 3]) -> Success, [3]
+ /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2]
+ /// parseList([[1, 2], 3]) -> Failure
+ /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure
+ ParseResult parseList(llvm::SmallVectorImpl<int64_t> &dims);
+
+ Parser &p;
+ Type eltTy;
+ SmallVector<int64_t, 4> shape;
+ std::vector<Attribute> storage;
+};
+} // namespace
+
+ParseResult TensorLiteralParser::parseElement() {
+ switch (p.getToken().getKind()) {
+ case Token::floatliteral:
+ case Token::integer:
+ case Token::minus: {
+ auto result = p.parseAttribute(eltTy);
+ if (!result)
+ return failure();
+ // check result matches the element type.
+ switch (eltTy.getKind()) {
+ case StandardTypes::BF16:
+ case StandardTypes::F16:
+ case StandardTypes::F32:
+ case StandardTypes::F64: {
+ // Bitcast the APFloat value to APInt and store the bit representation.
+ auto fpAttrResult = result.dyn_cast<FloatAttr>();
+ if (!fpAttrResult)
+ return p.emitError(
+ "expected tensor literal element with floating point type");
+ auto apInt = fpAttrResult.getValue().bitcastToAPInt();
+
+ // FIXME: using 64 bits and double semantics for BF16 because APFloat does
+ // not support BF16 directly.
+ size_t bitWidth = eltTy.isBF16() ? 64 : eltTy.getIntOrFloatBitWidth();
+ assert(apInt.getBitWidth() == bitWidth);
+ (void)bitWidth;
+ (void)apInt;
+ break;
+ }
+ case StandardTypes::Integer: {
+ if (!result.isa<IntegerAttr>())
+ return p.emitError("expected tensor literal element has integer type");
+ auto value = result.cast<IntegerAttr>().getValue();
+ if (value.getMinSignedBits() > eltTy.getIntOrFloatBitWidth())
+ return p.emitError("tensor literal element has more bits than that "
+ "specified in the type");
+ break;
+ }
+ default:
+ return p.emitError("expected integer or float tensor element");
+ }
+ storage.push_back(result);
+ break;
+ }
+ default:
+ return p.emitError("expected element literal of primitive type");
+ }
+ return success();
+}
+
+/// Parse a list of either lists or elements, returning the dimensions of the
+/// parsed sub-tensors in dims. For example:
+/// parseList([1, 2, 3]) -> Success, [3]
+/// parseList([[1, 2], [3, 4]]) -> Success, [2, 2]
+/// parseList([[1, 2], 3]) -> Failure
+/// parseList([[1, [2, 3]], [4, [5]]]) -> Failure
+ParseResult
+TensorLiteralParser::parseList(llvm::SmallVectorImpl<int64_t> &dims) {
+ p.consumeToken(Token::l_square);
+
+ auto checkDims =
+ [&](const llvm::SmallVectorImpl<int64_t> &prevDims,
+ const llvm::SmallVectorImpl<int64_t> &newDims) -> ParseResult {
+ if (prevDims == newDims)
+ return success();
+ return p.emitError("tensor literal is invalid; ranks are not consistent "
+ "between elements");
+ };
- auto type = builder.getTensorType(literalParser.getShape(), eltType);
- return builder.getDenseElementsAttr(type, literalParser.getValues())
- .cast<DenseElementsAttr>();
+ bool first = true;
+ llvm::SmallVector<int64_t, 4> newDims;
+ unsigned size = 0;
+ auto parseCommaSeparatedList = [&]() -> ParseResult {
+ llvm::SmallVector<int64_t, 4> thisDims;
+ if (p.getToken().getKind() == Token::l_square) {
+ if (parseList(thisDims))
+ return failure();
+ } else if (parseElement()) {
+ return failure();
+ }
+ ++size;
+ if (!first)
+ return checkDims(newDims, thisDims);
+ newDims = thisDims;
+ first = false;
+ return success();
+ };
+ if (p.parseCommaSeparatedListUntil(Token::r_square, parseCommaSeparatedList))
+ return failure();
+
+ // Return the sublists' dimensions with 'size' prepended.
+ dims.clear();
+ dims.push_back(size);
+ dims.append(newDims.begin(), newDims.end());
+ return success();
}
-/// Dense elements attribute.
+/// Parse a dense elements attribute.
///
/// dense-attr-list ::= `[` attribute-value `]`
/// attribute-value ::= integer-literal
.cast<DenseElementsAttr>();
}
+/// Parse a dense elements attribute.
+///
+/// dense-attr-list ::= `[` attribute-value `]`
+/// attribute-value ::= integer-literal
+/// | float-literal
+/// | `[` (attribute-value (`,` attribute-value)*)? `]`
+///
+/// This method returns a constructed dense elements attribute of tensor type
+/// with the shape from the parsing result.
+DenseElementsAttr Parser::parseDenseElementsAttrAsTensor(Type eltType) {
+ TensorLiteralParser literalParser(*this, eltType);
+ if (literalParser.parse())
+ return nullptr;
+
+ auto type = builder.getTensorType(literalParser.getShape(), eltType);
+ return builder.getDenseElementsAttr(type, literalParser.getValues())
+ .cast<DenseElementsAttr>();
+}
+
/// Shaped type for elements attribute.
///
/// elements-literal-type ::= vector-type | ranked-tensor-type
return sType;
}
-/// Debug Location.
+//===----------------------------------------------------------------------===//
+// Location parsing.
+//===----------------------------------------------------------------------===//
+
+/// Parse a location.
///
/// location ::= `loc` inline-location
/// inline-location ::= '(' location-inst ')'
return emitError("expected location instance");
}
-/// Attribute dictionary.
-///
-/// attribute-dict ::= `{` `}`
-/// | `{` attribute-entry (`,` attribute-entry)* `}`
-/// attribute-entry ::= bare-id `:` attribute-value
-///
-ParseResult
-Parser::parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes) {
- if (!consumeIf(Token::l_brace))
- return failure();
-
- auto parseElt = [&]() -> ParseResult {
- // We allow keywords as attribute names.
- if (getToken().isNot(Token::bare_identifier, Token::inttype) &&
- !getToken().isKeyword())
- return emitError("expected attribute name");
- Identifier nameId = builder.getIdentifier(getTokenSpelling());
- consumeToken();
-
- // Try to parse the ':' for the attribute value.
- if (!consumeIf(Token::colon)) {
- // If there is no ':', we treat this as a unit attribute.
- attributes.push_back({nameId, builder.getUnitAttr()});
- return success();
- }
-
- auto attr = parseAttribute();
- if (!attr)
- return failure();
-
- attributes.push_back({nameId, attr});
- return success();
- };
-
- if (parseCommaSeparatedListUntil(Token::r_brace, parseElt))
- return failure();
-
- return success();
-}
-
//===----------------------------------------------------------------------===//
-// Polyhedral structures.
+// Affine parsing.
//===----------------------------------------------------------------------===//
/// Lower precedence ops (all at the same precedence level). LNoOp is false in
return builder.getAffineMap(numDims, numSymbols, exprs);
}
+/// Parse an affine constraint.
+/// affine-constraint ::= affine-expr `>=` `0`
+/// | affine-expr `==` `0`
+///
+/// isEq is set to true if the parsed constraint is an equality, false if it
+/// is an inequality (greater than or equal).
+///
+AffineExpr AffineParser::parseAffineConstraint(bool *isEq) {
+ AffineExpr expr = parseAffineExpr();
+ if (!expr)
+ return nullptr;
+
+ if (consumeIf(Token::greater) && consumeIf(Token::equal) &&
+ getToken().is(Token::integer)) {
+ auto dim = getToken().getUnsignedIntegerValue();
+ if (dim.hasValue() && dim.getValue() == 0) {
+ consumeToken(Token::integer);
+ *isEq = false;
+ return expr;
+ }
+ return (emitError("expected '0' after '>='"), nullptr);
+ }
+
+ if (consumeIf(Token::equal) && consumeIf(Token::equal) &&
+ getToken().is(Token::integer)) {
+ auto dim = getToken().getUnsignedIntegerValue();
+ if (dim.hasValue() && dim.getValue() == 0) {
+ consumeToken(Token::integer);
+ *isEq = true;
+ return expr;
+ }
+ return (emitError("expected '0' after '=='"), nullptr);
+ }
+
+ return (emitError("expected '== 0' or '>= 0' at end of affine constraint"),
+ nullptr);
+}
+
+/// Parse the constraints that are part of an integer set definition.
+/// integer-set-inline
+/// ::= dim-and-symbol-id-lists `:`
+/// '(' affine-constraint-conjunction? ')'
+/// affine-constraint-conjunction ::= affine-constraint (`,`
+/// affine-constraint)*
+///
+IntegerSet AffineParser::parseIntegerSetConstraints(unsigned numDims,
+ unsigned numSymbols) {
+ if (parseToken(Token::l_paren,
+ "expected '(' at start of integer set constraint list"))
+ return IntegerSet();
+
+ SmallVector<AffineExpr, 4> constraints;
+ SmallVector<bool, 4> isEqs;
+ auto parseElt = [&]() -> ParseResult {
+ bool isEq;
+ auto elt = parseAffineConstraint(&isEq);
+ ParseResult res = elt ? success() : failure();
+ if (elt) {
+ constraints.push_back(elt);
+ isEqs.push_back(isEq);
+ }
+ return res;
+ };
+
+ // Parse a list of affine constraints (comma-separated).
+ if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true))
+ return IntegerSet();
+
+ // If no constraints were parsed, then treat this as a degenerate 'true' case.
+ if (constraints.empty()) {
+ /* 0 == 0 */
+ auto zero = getAffineConstantExpr(0, getContext());
+ return builder.getIntegerSet(numDims, numSymbols, zero, true);
+ }
+
+ // Parsed a valid integer set.
+ return builder.getIntegerSet(numDims, numSymbols, constraints, isEqs);
+}
+
/// Parse an ambiguous reference to either and affine map or an integer set.
ParseResult Parser::parseAffineMapOrIntegerSetReference(AffineMap &map,
IntegerSet &set) {
return builder.createOperation(opState);
}
-/// Parse an affine constraint.
-/// affine-constraint ::= affine-expr `>=` `0`
-/// | affine-expr `==` `0`
-///
-/// isEq is set to true if the parsed constraint is an equality, false if it
-/// is an inequality (greater than or equal).
-///
-AffineExpr AffineParser::parseAffineConstraint(bool *isEq) {
- AffineExpr expr = parseAffineExpr();
- if (!expr)
- return nullptr;
-
- if (consumeIf(Token::greater) && consumeIf(Token::equal) &&
- getToken().is(Token::integer)) {
- auto dim = getToken().getUnsignedIntegerValue();
- if (dim.hasValue() && dim.getValue() == 0) {
- consumeToken(Token::integer);
- *isEq = false;
- return expr;
- }
- return (emitError("expected '0' after '>='"), nullptr);
- }
-
- if (consumeIf(Token::equal) && consumeIf(Token::equal) &&
- getToken().is(Token::integer)) {
- auto dim = getToken().getUnsignedIntegerValue();
- if (dim.hasValue() && dim.getValue() == 0) {
- consumeToken(Token::integer);
- *isEq = true;
- return expr;
- }
- return (emitError("expected '0' after '=='"), nullptr);
- }
-
- return (emitError("expected '== 0' or '>= 0' at end of affine constraint"),
- nullptr);
-}
-
-/// Parse the constraints that are part of an integer set definition.
-/// integer-set-inline
-/// ::= dim-and-symbol-id-lists `:`
-/// '(' affine-constraint-conjunction? ')'
-/// affine-constraint-conjunction ::= affine-constraint (`,`
-/// affine-constraint)*
-///
-IntegerSet AffineParser::parseIntegerSetConstraints(unsigned numDims,
- unsigned numSymbols) {
- if (parseToken(Token::l_paren,
- "expected '(' at start of integer set constraint list"))
- return IntegerSet();
-
- SmallVector<AffineExpr, 4> constraints;
- SmallVector<bool, 4> isEqs;
- auto parseElt = [&]() -> ParseResult {
- bool isEq;
- auto elt = parseAffineConstraint(&isEq);
- ParseResult res = elt ? success() : failure();
- if (elt) {
- constraints.push_back(elt);
- isEqs.push_back(isEq);
- }
- return res;
- };
-
- // Parse a list of affine constraints (comma-separated).
- if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true))
- return IntegerSet();
-
- // If no constraints were parsed, then treat this as a degenerate 'true' case.
- if (constraints.empty()) {
- /* 0 == 0 */
- auto zero = getAffineConstantExpr(0, getContext());
- return builder.getIntegerSet(numDims, numSymbols, zero, true);
- }
-
- // Parsed a valid integer set.
- return builder.getIntegerSet(numDims, numSymbols, constraints, isEqs);
-}
-
//===----------------------------------------------------------------------===//
// Top-level entity parsing.
//===----------------------------------------------------------------------===//