From: River Riddle Date: Fri, 31 May 2019 20:48:43 +0000 (-0700) Subject: NFC: Cleanup method definitions within Parser and add header blocks to improve... X-Git-Tag: llvmorg-11-init~1466^2~1534 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=05bb27fac2c658292d52c82721a35e8cafea300e;p=platform%2Fupstream%2Fllvm.git NFC: Cleanup method definitions within Parser and add header blocks to improve readability. -- PiperOrigin-RevId: 250949195 --- diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 10f80ed..891289d 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -48,8 +48,12 @@ using llvm::SourceMgr; namespace { class Parser; +//===----------------------------------------------------------------------===// +// ParserState +//===----------------------------------------------------------------------===// + /// This class refers to all of the state maintained globally by the parser, -/// such as the current lexer position etc. The Parser base class provides +/// such as the current lexer position etc. The Parser base class provides /// methods to access this. class ParserState { public: @@ -81,9 +85,10 @@ private: // This is the next token that hasn't been consumed yet. Token curToken; }; -} // end anonymous namespace -namespace { +//===----------------------------------------------------------------------===// +// Parser +//===----------------------------------------------------------------------===// /// This class implement support for parsing global entities like types and /// shared entities like SSA names. It is intended to be subclassed by @@ -100,9 +105,33 @@ public: Module *getModule() { return state.module; } const llvm::SourceMgr &getSourceMgr() { return state.lex.getSourceMgr(); } - /// Return the current token the parser is inspecting. - const Token &getToken() const { return state.curToken; } - StringRef getTokenSpelling() const { return state.curToken.getSpelling(); } + /// Parse a comma-separated list of elements up until the specified end token. + ParseResult + parseCommaSeparatedListUntil(Token::Kind rightToken, + const std::function &parseElement, + bool allowEmptyList = true); + + /// Parse a comma separated list of elements that must have at least one entry + /// in it. + ParseResult + parseCommaSeparatedList(const std::function &parseElement); + + ParseResult parsePrettyDialectSymbolName(StringRef &prettyName); + + // We have two forms of parsing methods - those that return a non-null + // pointer on success, and those that return a ParseResult to indicate whether + // they returned a failure. The second class fills in by-reference arguments + // as the results of their action. + + //===--------------------------------------------------------------------===// + // Error Handling + //===--------------------------------------------------------------------===// + + /// Emit an error and return failure. + InFlightDiagnostic emitError(const Twine &message = {}) { + return emitError(state.curToken.getLoc(), message); + } + InFlightDiagnostic emitError(SMLoc loc, const Twine &message = {}); /// Encode the specified source location information into an attribute for /// attachment to the IR. @@ -110,11 +139,22 @@ public: return state.lex.getEncodedSourceLocation(loc); } - /// Emit an error and return failure. - InFlightDiagnostic emitError(const Twine &message = {}) { - return emitError(state.curToken.getLoc(), message); + //===--------------------------------------------------------------------===// + // Token Parsing + //===--------------------------------------------------------------------===// + + /// Return the current token the parser is inspecting. + const Token &getToken() const { return state.curToken; } + StringRef getTokenSpelling() const { return state.curToken.getSpelling(); } + + /// If the current token has the specified kind, consume it and return true. + /// If not, return false. + bool consumeIf(Token::Kind kind) { + if (state.curToken.isNot(kind)) + return false; + consumeToken(kind); + return true; } - InFlightDiagnostic emitError(SMLoc loc, const Twine &message = {}); /// Advance the current lexer onto the next token. void consumeToken() { @@ -131,69 +171,77 @@ public: consumeToken(); } - /// If the current token has the specified kind, consume it and return true. - /// If not, return false. - bool consumeIf(Token::Kind kind) { - if (state.curToken.isNot(kind)) - return false; - consumeToken(kind); - return true; - } - /// Consume the specified token if present and return success. On failure, /// output a diagnostic and return failure. ParseResult parseToken(Token::Kind expectedToken, const Twine &message); - /// Parse a comma-separated list of elements up until the specified end token. - ParseResult - parseCommaSeparatedListUntil(Token::Kind rightToken, - const std::function &parseElement, - bool allowEmptyList = true); + //===--------------------------------------------------------------------===// + // Type Parsing + //===--------------------------------------------------------------------===// - /// Parse a comma separated list of elements that must have at least one entry - /// in it. - ParseResult - parseCommaSeparatedList(const std::function &parseElement); + ParseResult parseFunctionResultTypes(SmallVectorImpl &elements); + ParseResult parseTypeListNoParens(SmallVectorImpl &elements); + ParseResult parseTypeListParens(SmallVectorImpl &elements); - // We have two forms of parsing methods - those that return a non-null - // pointer on success, and those that return a ParseResult to indicate whether - // they returned a failure. The second class fills in by-reference arguments - // as the results of their action. + /// Parse an arbitrary type. + Type parseType(); - // Type parsing. - VectorType parseVectorType(); - ParseResult parseXInDimensionList(); - ParseResult parseDimensionListRanked(SmallVectorImpl &dimensions, - bool allowDynamic); - Type parseExtendedType(); - ParseResult parsePrettyDialectSymbolName(StringRef &prettyName); - Type parseTensorType(); + /// Parse a complex type. Type parseComplexType(); - Type parseTupleType(); - Type parseMemRefType(); + + /// Parse an extended type. + Type parseExtendedType(); + + /// Parse a function type. Type parseFunctionType(); + + /// Parse a memref type. + Type parseMemRefType(); + + /// Parse a non function type. Type parseNonFunctionType(); - Type parseType(); - ParseResult parseTypeListNoParens(SmallVectorImpl &elements); - ParseResult parseTypeListParens(SmallVectorImpl &elements); - ParseResult parseFunctionResultTypes(SmallVectorImpl &elements); - // Attribute parsing. - Attribute parseExtendedAttribute(Type type); + /// Parse a tensor type. + Type parseTensorType(); + + /// Parse a tuple type. + Type parseTupleType(); + + /// Parse a vector type. + VectorType parseVectorType(); + ParseResult parseDimensionListRanked(SmallVectorImpl &dimensions, + bool allowDynamic = true); + ParseResult parseXInDimensionList(); + + //===--------------------------------------------------------------------===// + // Attribute Parsing + //===--------------------------------------------------------------------===// + + /// Parse an arbitrary attribute with an optional type. Attribute parseAttribute(Type type = {}); + /// Parse an attribute dictionary. ParseResult parseAttributeDict(SmallVectorImpl &attributes); - // Polyhedral structures. - ParseResult parseAffineMapOrIntegerSetReference(AffineMap &map, - IntegerSet &set); + /// Parse an extended attribute. + Attribute parseExtendedAttribute(Type type); + + /// Parse a dense elements attribute. DenseElementsAttr parseDenseElementsAttr(ShapedType type); DenseElementsAttr parseDenseElementsAttrAsTensor(Type eltType); ShapedType parseElementsLiteralType(); - // Location Parsing. + //===--------------------------------------------------------------------===// + // Location Parsing + //===--------------------------------------------------------------------===// + + /// Parse an inline location. + ParseResult parseLocation(llvm::Optional *loc); + + /// Parse a raw location instance. + ParseResult parseLocationInstance(llvm::Optional *loc); - /// Trailing locations. + /// Parse an optional trailing location. /// /// trailing-location ::= location? /// @@ -211,15 +259,16 @@ public: return success(); } - /// Parse an inline location. - ParseResult parseLocation(llvm::Optional *loc); + //===--------------------------------------------------------------------===// + // Affine Parsing + //===--------------------------------------------------------------------===// - /// Parse a raw location instance. - ParseResult parseLocationInstance(llvm::Optional *loc); + ParseResult parseAffineMapOrIntegerSetReference(AffineMap &map, + IntegerSet &set); private: - // The Parser is subclassed and reinstantiated. Do not add additional - // non-trivial state here, add it to the ParserState class. + /// The Parser is subclassed and reinstantiated. Do not add additional + /// non-trivial state here, add it to the ParserState class. ParserState &state; }; } // end anonymous namespace @@ -228,25 +277,6 @@ private: // Helper methods. //===----------------------------------------------------------------------===// -InFlightDiagnostic Parser::emitError(SMLoc loc, const Twine &message) { - auto diag = getContext()->emitError(getEncodedSourceLocation(loc), message); - - // If we hit a parse error in response to a lexer error, then the lexer - // already reported the error. - if (getToken().is(Token::error)) - diag.abandon(); - return diag; -} - -/// Consume the specified token if present and return success. On failure, -/// output a diagnostic and return failure. -ParseResult Parser::parseToken(Token::Kind expectedToken, - const Twine &message) { - if (consumeIf(expectedToken)) - return success(); - return emitError(message); -} - /// Parse a comma separated list of elements that must have at least one entry /// in it. ParseResult Parser::parseCommaSeparatedList( @@ -288,221 +318,42 @@ ParseResult Parser::parseCommaSeparatedListUntil( return success(); } -//===----------------------------------------------------------------------===// -// Type Parsing -//===----------------------------------------------------------------------===// - -/// Parse any type except the function type. -/// -/// non-function-type ::= integer-type -/// | index-type -/// | float-type -/// | extended-type -/// | vector-type -/// | tensor-type -/// | memref-type -/// | complex-type -/// | tuple-type -/// | none-type +/// Parse the body of a pretty dialect symbol, which starts and ends with <>'s, +/// and may be recursive. Return with the 'prettyName' StringRef encompasing +/// the entire pretty name. /// -/// index-type ::= `index` -/// float-type ::= `f16` | `bf16` | `f32` | `f64` -/// none-type ::= `none` +/// pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>' +/// pretty-dialect-sym-contents ::= pretty-dialect-sym-body +/// | '(' pretty-dialect-sym-contents+ ')' +/// | '[' pretty-dialect-sym-contents+ ']' +/// | '{' pretty-dialect-sym-contents+ '}' +/// | '[^[<({>\])}\0]+' /// -Type Parser::parseNonFunctionType() { - switch (getToken().getKind()) { - default: - return (emitError("expected non-function type"), nullptr); - case Token::kw_memref: - return parseMemRefType(); - case Token::kw_tensor: - return parseTensorType(); - case Token::kw_complex: - return parseComplexType(); - case Token::kw_tuple: - return parseTupleType(); - case Token::kw_vector: - return parseVectorType(); - // integer-type - case Token::inttype: { - auto width = getToken().getIntTypeBitwidth(); - if (!width.hasValue()) - return (emitError("invalid integer width"), nullptr); - auto loc = getEncodedSourceLocation(getToken().getLoc()); - consumeToken(Token::inttype); - return IntegerType::getChecked(width.getValue(), builder.getContext(), loc); - } - - // float-type - case Token::kw_bf16: - consumeToken(Token::kw_bf16); - return builder.getBF16Type(); - case Token::kw_f16: - consumeToken(Token::kw_f16); - return builder.getF16Type(); - case Token::kw_f32: - consumeToken(Token::kw_f32); - return builder.getF32Type(); - case Token::kw_f64: - consumeToken(Token::kw_f64); - return builder.getF64Type(); - - // index-type - case Token::kw_index: - consumeToken(Token::kw_index); - return builder.getIndexType(); +ParseResult Parser::parsePrettyDialectSymbolName(StringRef &prettyName) { + // Pretty symbol names are a relatively unstructured format that contains a + // series of properly nested punctuation, with anything else in the middle. + // Scan ahead to find it and consume it if successful, otherwise emit an + // error. + auto *curPtr = getTokenSpelling().data(); - // none-type - case Token::kw_none: - consumeToken(Token::kw_none); - return builder.getNoneType(); + SmallVector nestedPunctuation; - // extended type - case Token::exclamation_identifier: - return parseExtendedType(); - } -} - -/// Parse an arbitrary type. -/// -/// type ::= function-type -/// | non-function-type -/// -Type Parser::parseType() { - if (getToken().is(Token::l_paren)) - return parseFunctionType(); - return parseNonFunctionType(); -} - -/// Parse a vector type. -/// -/// vector-type ::= `vector` `<` static-dimension-list primitive-type `>` -/// static-dimension-list ::= (decimal-literal `x`)+ -/// -VectorType Parser::parseVectorType() { - consumeToken(Token::kw_vector); - - if (parseToken(Token::less, "expected '<' in vector type")) - return nullptr; - - SmallVector dimensions; - if (parseDimensionListRanked(dimensions, /*allowDynamic=*/false)) - return nullptr; - if (dimensions.empty()) - return (emitError("expected dimension size in vector type"), nullptr); - - // Parse the element type. - auto typeLoc = getToken().getLoc(); - auto elementType = parseType(); - if (!elementType || parseToken(Token::greater, "expected '>' in vector type")) - return nullptr; - - return VectorType::getChecked(dimensions, elementType, - getEncodedSourceLocation(typeLoc)); -} - -/// Parse an 'x' token in a dimension list, handling the case where the x is -/// juxtaposed with an element type, as in "xf32", leaving the "f32" as the next -/// token. -ParseResult Parser::parseXInDimensionList() { - if (getToken().isNot(Token::bare_identifier) || getTokenSpelling()[0] != 'x') - return emitError("expected 'x' in dimension list"); - - // If we had a prefix of 'x', lex the next token immediately after the 'x'. - if (getTokenSpelling().size() != 1) - state.lex.resetPointer(getTokenSpelling().data() + 1); - - // Consume the 'x'. - consumeToken(Token::bare_identifier); - - return success(); -} - -/// Parse a dimension list of a tensor or memref type. This populates the -/// dimension list, using -1 for the `?` dimensions if `allowDynamic` is set and -/// errors out on `?` otherwise. -/// -/// dimension-list-ranked ::= (dimension `x`)* -/// dimension ::= `?` | decimal-literal -/// -/// When `allowDynamic` is not set, this can be also used to parse -/// -/// static-dimension-list ::= (decimal-literal `x`)* -ParseResult -Parser::parseDimensionListRanked(SmallVectorImpl &dimensions, - bool allowDynamic = true) { - while (getToken().isAny(Token::integer, Token::question)) { - if (consumeIf(Token::question)) { - if (!allowDynamic) - return emitError("expected static shape"); - dimensions.push_back(-1); - } else { - // Hexadecimal integer literals (starting with `0x`) are not allowed in - // aggregate type declarations. Therefore, `0xf32` should be processed as - // a sequence of separate elements `0`, `x`, `f32`. - if (getTokenSpelling().size() > 1 && getTokenSpelling()[1] == 'x') { - // We can get here only if the token is an integer literal. Hexadecimal - // integer literals can only start with `0x` (`1x` wouldn't lex as a - // literal, just `1` would, at which point we don't get into this - // branch). - assert(getTokenSpelling()[0] == '0' && "invalid integer literal"); - dimensions.push_back(0); - state.lex.resetPointer(getTokenSpelling().data() + 1); - consumeToken(); - } else { - // Make sure this integer value is in bound and valid. - auto dimension = getToken().getUnsignedIntegerValue(); - if (!dimension.hasValue()) - return emitError("invalid dimension"); - dimensions.push_back((int64_t)dimension.getValue()); - consumeToken(Token::integer); - } - } - - // Make sure we have an 'x' or something like 'xbf32'. - if (parseXInDimensionList()) - return failure(); - } - - return success(); -} - -/// Parse the body of a pretty dialect symbol, which starts and ends with <>'s, -/// and may be recursive. Return with the 'prettyName' StringRef encompasing -/// the entire pretty name. -/// -/// pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>' -/// pretty-dialect-sym-contents ::= pretty-dialect-sym-body -/// | '(' pretty-dialect-sym-contents+ ')' -/// | '[' pretty-dialect-sym-contents+ ']' -/// | '{' pretty-dialect-sym-contents+ '}' -/// | '[^[<({>\])}\0]+' -/// -ParseResult Parser::parsePrettyDialectSymbolName(StringRef &prettyName) { - // Pretty symbol names are a relatively unstructured format that contains a - // series of properly nested punctuation, with anything else in the middle. - // Scan ahead to find it and consume it if successful, otherwise emit an - // error. - auto *curPtr = getTokenSpelling().data(); - - SmallVector nestedPunctuation; - - // Scan over the nested punctuation, bailing out on error and consuming until - // we find the end. We know that we're currently looking at the '<', so we - // can go until we find the matching '>' character. - assert(*curPtr == '<'); - do { - char c = *curPtr++; - switch (c) { - case '\0': - // This also handles the EOF case. - return emitError("unexpected nul or EOF in pretty dialect name"); - case '<': - case '[': - case '(': - case '{': - nestedPunctuation.push_back(c); - continue; + // Scan over the nested punctuation, bailing out on error and consuming until + // we find the end. We know that we're currently looking at the '<', so we + // can go until we find the matching '>' character. + assert(*curPtr == '<'); + do { + char c = *curPtr++; + switch (c) { + case '\0': + // This also handles the EOF case. + return emitError("unexpected nul or EOF in pretty dialect name"); + case '<': + case '[': + case '(': + case '{': + nestedPunctuation.push_back(c); + continue; case '>': if (nestedPunctuation.pop_back_val() != '<') @@ -603,6 +454,118 @@ static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok, return createSymbol(dialectName, symbolData, encodedLoc); } +//===----------------------------------------------------------------------===// +// Error Handling +//===----------------------------------------------------------------------===// + +InFlightDiagnostic Parser::emitError(SMLoc loc, const Twine &message) { + auto diag = getContext()->emitError(getEncodedSourceLocation(loc), message); + + // If we hit a parse error in response to a lexer error, then the lexer + // already reported the error. + if (getToken().is(Token::error)) + diag.abandon(); + return diag; +} + +//===----------------------------------------------------------------------===// +// Token Parsing +//===----------------------------------------------------------------------===// + +/// Consume the specified token if present and return success. On failure, +/// output a diagnostic and return failure. +ParseResult Parser::parseToken(Token::Kind expectedToken, + const Twine &message) { + if (consumeIf(expectedToken)) + return success(); + return emitError(message); +} + +//===----------------------------------------------------------------------===// +// Type Parsing +//===----------------------------------------------------------------------===// + +/// Parse an arbitrary type. +/// +/// type ::= function-type +/// | non-function-type +/// +Type Parser::parseType() { + if (getToken().is(Token::l_paren)) + return parseFunctionType(); + return parseNonFunctionType(); +} + +/// Parse a function result type. +/// +/// function-result-type ::= type-list-parens +/// | non-function-type +/// +ParseResult Parser::parseFunctionResultTypes(SmallVectorImpl &elements) { + if (getToken().is(Token::l_paren)) + return parseTypeListParens(elements); + + Type t = parseNonFunctionType(); + if (!t) + return failure(); + elements.push_back(t); + return success(); +} + +/// Parse a list of types without an enclosing parenthesis. The list must have +/// at least one member. +/// +/// type-list-no-parens ::= type (`,` type)* +/// +ParseResult Parser::parseTypeListNoParens(SmallVectorImpl &elements) { + auto parseElt = [&]() -> ParseResult { + auto elt = parseType(); + elements.push_back(elt); + return elt ? success() : failure(); + }; + + return parseCommaSeparatedList(parseElt); +} + +/// Parse a parenthesized list of types. +/// +/// type-list-parens ::= `(` `)` +/// | `(` type-list-no-parens `)` +/// +ParseResult Parser::parseTypeListParens(SmallVectorImpl &elements) { + if (parseToken(Token::l_paren, "expected '('")) + return failure(); + + // Handle empty lists. + if (getToken().is(Token::r_paren)) + return consumeToken(), success(); + + if (parseTypeListNoParens(elements) || + parseToken(Token::r_paren, "expected ')'")) + return failure(); + return success(); +} + +/// Parse a complex type. +/// +/// complex-type ::= `complex` `<` type `>` +/// +Type Parser::parseComplexType() { + consumeToken(Token::kw_complex); + + // Parse the '<'. + if (parseToken(Token::less, "expected '<' in complex type")) + return nullptr; + + auto typeLocation = getEncodedSourceLocation(getToken().getLoc()); + auto elementType = parseType(); + if (!elementType || + parseToken(Token::greater, "expected '>' in complex type")) + return nullptr; + + return ComplexType::getChecked(elementType, typeLocation); +} + /// Parse an extended type. /// /// extended-type ::= (dialect-type | type-alias) @@ -625,86 +588,20 @@ Type Parser::parseExtendedType() { }); } -/// Parse a tensor type. +/// Parse a function type. /// -/// tensor-type ::= `tensor` `<` dimension-list element-type `>` -/// dimension-list ::= dimension-list-ranked | `*x` +/// function-type ::= type-list-parens `->` type-list /// -Type Parser::parseTensorType() { - consumeToken(Token::kw_tensor); +Type Parser::parseFunctionType() { + assert(getToken().is(Token::l_paren)); - if (parseToken(Token::less, "expected '<' in tensor type")) - return nullptr; - - bool isUnranked; - SmallVector dimensions; - - if (consumeIf(Token::star)) { - // This is an unranked tensor type. - isUnranked = true; - - if (parseXInDimensionList()) - return nullptr; - - } else { - isUnranked = false; - if (parseDimensionListRanked(dimensions)) - return nullptr; - } - - // Parse the element type. - auto typeLocation = getEncodedSourceLocation(getToken().getLoc()); - auto elementType = parseType(); - if (!elementType || parseToken(Token::greater, "expected '>' in tensor type")) - return nullptr; - - if (isUnranked) - return UnrankedTensorType::getChecked(elementType, typeLocation); - return RankedTensorType::getChecked(dimensions, elementType, typeLocation); -} - -/// Parse a complex type. -/// -/// complex-type ::= `complex` `<` type `>` -/// -Type Parser::parseComplexType() { - consumeToken(Token::kw_complex); - - // Parse the '<'. - if (parseToken(Token::less, "expected '<' in complex type")) - return nullptr; - - auto typeLocation = getEncodedSourceLocation(getToken().getLoc()); - auto elementType = parseType(); - if (!elementType || - parseToken(Token::greater, "expected '>' in complex type")) - return nullptr; - - return ComplexType::getChecked(elementType, typeLocation); -} - -/// Parse a tuple type. -/// -/// tuple-type ::= `tuple` `<` (type (`,` type)*)? `>` -/// -Type Parser::parseTupleType() { - consumeToken(Token::kw_tuple); - - // Parse the '<'. - if (parseToken(Token::less, "expected '<' in tuple type")) - return nullptr; - - // Check for an empty tuple by directly parsing '>'. - if (consumeIf(Token::greater)) - return TupleType::get(getContext()); - - // Parse the element types and the '>'. - SmallVector types; - if (parseTypeListNoParens(types) || - parseToken(Token::greater, "expected '>' in tuple type")) + SmallVector arguments, results; + if (parseTypeListParens(arguments) || + parseToken(Token::arrow, "expected '->' in function type") || + parseFunctionResultTypes(results)) return nullptr; - return TupleType::get(types, getContext()); + return builder.getFunctionType(arguments, results); } /// Parse a memref type. @@ -780,243 +677,237 @@ Type Parser::parseMemRefType() { memorySpace, getEncodedSourceLocation(typeLoc)); } -/// Parse a function type. +/// Parse any type except the function type. /// -/// function-type ::= type-list-parens `->` type-list +/// non-function-type ::= integer-type +/// | index-type +/// | float-type +/// | extended-type +/// | vector-type +/// | tensor-type +/// | memref-type +/// | complex-type +/// | tuple-type +/// | none-type /// -Type Parser::parseFunctionType() { - assert(getToken().is(Token::l_paren)); +/// index-type ::= `index` +/// float-type ::= `f16` | `bf16` | `f32` | `f64` +/// none-type ::= `none` +/// +Type Parser::parseNonFunctionType() { + switch (getToken().getKind()) { + default: + return (emitError("expected non-function type"), nullptr); + case Token::kw_memref: + return parseMemRefType(); + case Token::kw_tensor: + return parseTensorType(); + case Token::kw_complex: + return parseComplexType(); + case Token::kw_tuple: + return parseTupleType(); + case Token::kw_vector: + return parseVectorType(); + // integer-type + case Token::inttype: { + auto width = getToken().getIntTypeBitwidth(); + if (!width.hasValue()) + return (emitError("invalid integer width"), nullptr); + auto loc = getEncodedSourceLocation(getToken().getLoc()); + consumeToken(Token::inttype); + return IntegerType::getChecked(width.getValue(), builder.getContext(), loc); + } - SmallVector arguments, results; - if (parseTypeListParens(arguments) || - parseToken(Token::arrow, "expected '->' in function type") || - parseFunctionResultTypes(results)) - return nullptr; + // float-type + case Token::kw_bf16: + consumeToken(Token::kw_bf16); + return builder.getBF16Type(); + case Token::kw_f16: + consumeToken(Token::kw_f16); + return builder.getF16Type(); + case Token::kw_f32: + consumeToken(Token::kw_f32); + return builder.getF32Type(); + case Token::kw_f64: + consumeToken(Token::kw_f64); + return builder.getF64Type(); - return builder.getFunctionType(arguments, results); -} + // index-type + case Token::kw_index: + consumeToken(Token::kw_index); + return builder.getIndexType(); -/// Parse a list of types without an enclosing parenthesis. The list must have -/// at least one member. -/// -/// type-list-no-parens ::= type (`,` type)* -/// -ParseResult Parser::parseTypeListNoParens(SmallVectorImpl &elements) { - auto parseElt = [&]() -> ParseResult { - auto elt = parseType(); - elements.push_back(elt); - return elt ? success() : failure(); - }; + // none-type + case Token::kw_none: + consumeToken(Token::kw_none); + return builder.getNoneType(); - return parseCommaSeparatedList(parseElt); + // extended type + case Token::exclamation_identifier: + return parseExtendedType(); + } } -/// Parse a parenthesized list of types. +/// Parse a tensor type. /// -/// type-list-parens ::= `(` `)` -/// | `(` type-list-no-parens `)` +/// tensor-type ::= `tensor` `<` dimension-list element-type `>` +/// dimension-list ::= dimension-list-ranked | `*x` /// -ParseResult Parser::parseTypeListParens(SmallVectorImpl &elements) { - if (parseToken(Token::l_paren, "expected '('")) - return failure(); +Type Parser::parseTensorType() { + consumeToken(Token::kw_tensor); - // Handle empty lists. - if (getToken().is(Token::r_paren)) - return consumeToken(), success(); + if (parseToken(Token::less, "expected '<' in tensor type")) + return nullptr; - if (parseTypeListNoParens(elements) || - parseToken(Token::r_paren, "expected ')'")) - return failure(); - return success(); -} + bool isUnranked; + SmallVector dimensions; -/// Parse a function result type. -/// -/// function-result-type ::= type-list-parens -/// | non-function-type -/// -ParseResult Parser::parseFunctionResultTypes(SmallVectorImpl &elements) { - if (getToken().is(Token::l_paren)) - return parseTypeListParens(elements); + if (consumeIf(Token::star)) { + // This is an unranked tensor type. + isUnranked = true; - Type t = parseNonFunctionType(); - if (!t) - return failure(); - elements.push_back(t); - return success(); + if (parseXInDimensionList()) + return nullptr; + + } else { + isUnranked = false; + if (parseDimensionListRanked(dimensions)) + return nullptr; + } + + // Parse the element type. + auto typeLocation = getEncodedSourceLocation(getToken().getLoc()); + auto elementType = parseType(); + if (!elementType || parseToken(Token::greater, "expected '>' in tensor type")) + return nullptr; + + if (isUnranked) + return UnrankedTensorType::getChecked(elementType, typeLocation); + return RankedTensorType::getChecked(dimensions, elementType, typeLocation); } -//===----------------------------------------------------------------------===// -// Attribute parsing. -//===----------------------------------------------------------------------===// +/// Parse a tuple type. +/// +/// tuple-type ::= `tuple` `<` (type (`,` type)*)? `>` +/// +Type Parser::parseTupleType() { + consumeToken(Token::kw_tuple); -namespace { -class TensorLiteralParser { -public: - TensorLiteralParser(Parser &p, Type eltTy) : p(p), eltTy(eltTy) {} + // Parse the '<'. + if (parseToken(Token::less, "expected '<' in tuple type")) + return nullptr; - ParseResult parse() { - if (p.getToken().is(Token::l_square)) { - return parseList(shape); - } - return parseElement(); - } + // Check for an empty tuple by directly parsing '>'. + if (consumeIf(Token::greater)) + return TupleType::get(getContext()); - ArrayRef getValues() const { return storage; } + // Parse the element types and the '>'. + SmallVector types; + if (parseTypeListNoParens(types) || + parseToken(Token::greater, "expected '>' in tuple type")) + return nullptr; - ArrayRef getShape() const { return shape; } + return TupleType::get(types, getContext()); +} -private: - /// Parse a single element, returning failure if it isn't a valid element - /// literal. For example: - /// parseElement(1) -> Success, 1 - /// parseElement([1]) -> Failure - ParseResult parseElement(); +/// Parse a vector type. +/// +/// vector-type ::= `vector` `<` static-dimension-list primitive-type `>` +/// static-dimension-list ::= (decimal-literal `x`)+ +/// +VectorType Parser::parseVectorType() { + consumeToken(Token::kw_vector); - /// Parse a list of either lists or elements, returning the dimensions of the - /// parsed sub-tensors in dims. For example: - /// parseList([1, 2, 3]) -> Success, [3] - /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] - /// parseList([[1, 2], 3]) -> Failure - /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure - ParseResult parseList(llvm::SmallVectorImpl &dims); + if (parseToken(Token::less, "expected '<' in vector type")) + return nullptr; - Parser &p; - Type eltTy; - SmallVector shape; - std::vector storage; -}; -} // namespace + SmallVector dimensions; + if (parseDimensionListRanked(dimensions, /*allowDynamic=*/false)) + return nullptr; + if (dimensions.empty()) + return (emitError("expected dimension size in vector type"), nullptr); -ParseResult TensorLiteralParser::parseElement() { - switch (p.getToken().getKind()) { - case Token::floatliteral: - case Token::integer: - case Token::minus: { - auto result = p.parseAttribute(eltTy); - if (!result) - return failure(); - // check result matches the element type. - switch (eltTy.getKind()) { - case StandardTypes::BF16: - case StandardTypes::F16: - case StandardTypes::F32: - case StandardTypes::F64: { - // Bitcast the APFloat value to APInt and store the bit representation. - auto fpAttrResult = result.dyn_cast(); - if (!fpAttrResult) - return p.emitError( - "expected tensor literal element with floating point type"); - auto apInt = fpAttrResult.getValue().bitcastToAPInt(); + // Parse the element type. + auto typeLoc = getToken().getLoc(); + auto elementType = parseType(); + if (!elementType || parseToken(Token::greater, "expected '>' in vector type")) + return nullptr; - // FIXME: using 64 bits and double semantics for BF16 because APFloat does - // not support BF16 directly. - size_t bitWidth = eltTy.isBF16() ? 64 : eltTy.getIntOrFloatBitWidth(); - assert(apInt.getBitWidth() == bitWidth); - (void)bitWidth; - (void)apInt; - break; - } - case StandardTypes::Integer: { - if (!result.isa()) - return p.emitError("expected tensor literal element has integer type"); - auto value = result.cast().getValue(); - if (value.getMinSignedBits() > eltTy.getIntOrFloatBitWidth()) - return p.emitError("tensor literal element has more bits than that " - "specified in the type"); - break; - } - default: - return p.emitError("expected integer or float tensor element"); - } - storage.push_back(result); - break; - } - default: - return p.emitError("expected element literal of primitive type"); - } - return success(); + return VectorType::getChecked(dimensions, elementType, + getEncodedSourceLocation(typeLoc)); } -/// Parse a list of either lists or elements, returning the dimensions of the -/// parsed sub-tensors in dims. For example: -/// parseList([1, 2, 3]) -> Success, [3] -/// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] -/// parseList([[1, 2], 3]) -> Failure -/// parseList([[1, [2, 3]], [4, [5]]]) -> Failure +/// Parse a dimension list of a tensor or memref type. This populates the +/// dimension list, using -1 for the `?` dimensions if `allowDynamic` is set and +/// errors out on `?` otherwise. +/// +/// dimension-list-ranked ::= (dimension `x`)* +/// dimension ::= `?` | decimal-literal +/// +/// When `allowDynamic` is not set, this can be also used to parse +/// +/// static-dimension-list ::= (decimal-literal `x`)* ParseResult -TensorLiteralParser::parseList(llvm::SmallVectorImpl &dims) { - p.consumeToken(Token::l_square); - - auto checkDims = - [&](const llvm::SmallVectorImpl &prevDims, - const llvm::SmallVectorImpl &newDims) -> ParseResult { - if (prevDims == newDims) - return success(); - return p.emitError("tensor literal is invalid; ranks are not consistent " - "between elements"); - }; +Parser::parseDimensionListRanked(SmallVectorImpl &dimensions, + bool allowDynamic) { + while (getToken().isAny(Token::integer, Token::question)) { + if (consumeIf(Token::question)) { + if (!allowDynamic) + return emitError("expected static shape"); + dimensions.push_back(-1); + } else { + // Hexadecimal integer literals (starting with `0x`) are not allowed in + // aggregate type declarations. Therefore, `0xf32` should be processed as + // a sequence of separate elements `0`, `x`, `f32`. + if (getTokenSpelling().size() > 1 && getTokenSpelling()[1] == 'x') { + // We can get here only if the token is an integer literal. Hexadecimal + // integer literals can only start with `0x` (`1x` wouldn't lex as a + // literal, just `1` would, at which point we don't get into this + // branch). + assert(getTokenSpelling()[0] == '0' && "invalid integer literal"); + dimensions.push_back(0); + state.lex.resetPointer(getTokenSpelling().data() + 1); + consumeToken(); + } else { + // Make sure this integer value is in bound and valid. + auto dimension = getToken().getUnsignedIntegerValue(); + if (!dimension.hasValue()) + return emitError("invalid dimension"); + dimensions.push_back((int64_t)dimension.getValue()); + consumeToken(Token::integer); + } + } - bool first = true; - llvm::SmallVector newDims; - unsigned size = 0; - auto parseCommaSeparatedList = [&]() -> ParseResult { - llvm::SmallVector thisDims; - if (p.getToken().getKind() == Token::l_square) { - if (parseList(thisDims)) - return failure(); - } else if (parseElement()) { + // Make sure we have an 'x' or something like 'xbf32'. + if (parseXInDimensionList()) return failure(); - } - ++size; - if (!first) - return checkDims(newDims, thisDims); - newDims = thisDims; - first = false; - return success(); - }; - if (p.parseCommaSeparatedListUntil(Token::r_square, parseCommaSeparatedList)) - return failure(); + } - // Return the sublists' dimensions with 'size' prepended. - dims.clear(); - dims.push_back(size); - dims.append(newDims.begin(), newDims.end()); return success(); } -/// Parse an extended attribute. -/// -/// extended-attribute ::= (dialect-attribute | attribute-alias) -/// dialect-attribute ::= `#` dialect-namespace `<` `"` attr-data `"` `>` -/// dialect-attribute ::= `#` alias-name pretty-dialect-sym-body? -/// attribute-alias ::= `#` alias-name -/// -Attribute Parser::parseExtendedAttribute(Type type) { - Attribute attr = parseExtendedSymbol( - *this, Token::hash_identifier, state.attributeAliasDefinitions, - [&](StringRef dialectName, StringRef symbolData, - Location loc) -> Attribute { - // If we found a registered dialect, then ask it to parse the attribute. - if (auto *dialect = state.context->getRegisteredDialect(dialectName)) - return dialect->parseAttribute(symbolData, loc); +/// Parse an 'x' token in a dimension list, handling the case where the x is +/// juxtaposed with an element type, as in "xf32", leaving the "f32" as the next +/// token. +ParseResult Parser::parseXInDimensionList() { + if (getToken().isNot(Token::bare_identifier) || getTokenSpelling()[0] != 'x') + return emitError("expected 'x' in dimension list"); - // Otherwise, form a new opaque attribute. - return OpaqueAttr::getChecked( - Identifier::get(dialectName, state.context), symbolData, - state.context, loc); - }); + // If we had a prefix of 'x', lex the next token immediately after the 'x'. + if (getTokenSpelling().size() != 1) + state.lex.resetPointer(getTokenSpelling().data() + 1); - // Ensure that the attribute has the same type as requested. - if (type && attr.getType() != type) { - emitError("attribute type different than expected: expected ") - << type << ", but got " << attr.getType(); - return nullptr; - } - return attr; + // Consume the 'x'. + consumeToken(Token::bare_identifier); + + return success(); } -/// Attribute parsing. +//===----------------------------------------------------------------------===// +// Attribute parsing. +//===----------------------------------------------------------------------===// + +/// Parse an arbitrary attribute. /// /// attribute-value ::= `unit` /// | bool-literal @@ -1351,26 +1242,212 @@ Attribute Parser::parseAttribute(Type type) { } } -/// Dense elements attribute. +/// Attribute dictionary. /// -/// dense-attr-list ::= `[` attribute-value `]` -/// attribute-value ::= integer-literal -/// | float-literal -/// | `[` (attribute-value (`,` attribute-value)*)? `]` +/// attribute-dict ::= `{` `}` +/// | `{` attribute-entry (`,` attribute-entry)* `}` +/// attribute-entry ::= bare-id `:` attribute-value /// -/// This method returns a constructed dense elements attribute of tensor type -/// with the shape from the parsing result. -DenseElementsAttr Parser::parseDenseElementsAttrAsTensor(Type eltType) { - TensorLiteralParser literalParser(*this, eltType); - if (literalParser.parse()) +ParseResult +Parser::parseAttributeDict(SmallVectorImpl &attributes) { + if (!consumeIf(Token::l_brace)) + return failure(); + + auto parseElt = [&]() -> ParseResult { + // We allow keywords as attribute names. + if (getToken().isNot(Token::bare_identifier, Token::inttype) && + !getToken().isKeyword()) + return emitError("expected attribute name"); + Identifier nameId = builder.getIdentifier(getTokenSpelling()); + consumeToken(); + + // Try to parse the ':' for the attribute value. + if (!consumeIf(Token::colon)) { + // If there is no ':', we treat this as a unit attribute. + attributes.push_back({nameId, builder.getUnitAttr()}); + return success(); + } + + auto attr = parseAttribute(); + if (!attr) + return failure(); + + attributes.push_back({nameId, attr}); + return success(); + }; + + if (parseCommaSeparatedListUntil(Token::r_brace, parseElt)) + return failure(); + + return success(); +} + +/// Parse an extended attribute. +/// +/// extended-attribute ::= (dialect-attribute | attribute-alias) +/// dialect-attribute ::= `#` dialect-namespace `<` `"` attr-data `"` `>` +/// dialect-attribute ::= `#` alias-name pretty-dialect-sym-body? +/// attribute-alias ::= `#` alias-name +/// +Attribute Parser::parseExtendedAttribute(Type type) { + Attribute attr = parseExtendedSymbol( + *this, Token::hash_identifier, state.attributeAliasDefinitions, + [&](StringRef dialectName, StringRef symbolData, + Location loc) -> Attribute { + // If we found a registered dialect, then ask it to parse the attribute. + if (auto *dialect = state.context->getRegisteredDialect(dialectName)) + return dialect->parseAttribute(symbolData, loc); + + // Otherwise, form a new opaque attribute. + return OpaqueAttr::getChecked( + Identifier::get(dialectName, state.context), symbolData, + state.context, loc); + }); + + // Ensure that the attribute has the same type as requested. + if (type && attr.getType() != type) { + emitError("attribute type different than expected: expected ") + << type << ", but got " << attr.getType(); return nullptr; + } + return attr; +} + +namespace { +class TensorLiteralParser { +public: + TensorLiteralParser(Parser &p, Type eltTy) : p(p), eltTy(eltTy) {} + + ParseResult parse() { + if (p.getToken().is(Token::l_square)) + return parseList(shape); + return parseElement(); + } + + ArrayRef getValues() const { return storage; } + + ArrayRef getShape() const { return shape; } + +private: + /// Parse a single element, returning failure if it isn't a valid element + /// literal. For example: + /// parseElement(1) -> Success, 1 + /// parseElement([1]) -> Failure + ParseResult parseElement(); + + /// Parse a list of either lists or elements, returning the dimensions of the + /// parsed sub-tensors in dims. For example: + /// parseList([1, 2, 3]) -> Success, [3] + /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] + /// parseList([[1, 2], 3]) -> Failure + /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure + ParseResult parseList(llvm::SmallVectorImpl &dims); + + Parser &p; + Type eltTy; + SmallVector shape; + std::vector storage; +}; +} // namespace + +ParseResult TensorLiteralParser::parseElement() { + switch (p.getToken().getKind()) { + case Token::floatliteral: + case Token::integer: + case Token::minus: { + auto result = p.parseAttribute(eltTy); + if (!result) + return failure(); + // check result matches the element type. + switch (eltTy.getKind()) { + case StandardTypes::BF16: + case StandardTypes::F16: + case StandardTypes::F32: + case StandardTypes::F64: { + // Bitcast the APFloat value to APInt and store the bit representation. + auto fpAttrResult = result.dyn_cast(); + if (!fpAttrResult) + return p.emitError( + "expected tensor literal element with floating point type"); + auto apInt = fpAttrResult.getValue().bitcastToAPInt(); + + // FIXME: using 64 bits and double semantics for BF16 because APFloat does + // not support BF16 directly. + size_t bitWidth = eltTy.isBF16() ? 64 : eltTy.getIntOrFloatBitWidth(); + assert(apInt.getBitWidth() == bitWidth); + (void)bitWidth; + (void)apInt; + break; + } + case StandardTypes::Integer: { + if (!result.isa()) + return p.emitError("expected tensor literal element has integer type"); + auto value = result.cast().getValue(); + if (value.getMinSignedBits() > eltTy.getIntOrFloatBitWidth()) + return p.emitError("tensor literal element has more bits than that " + "specified in the type"); + break; + } + default: + return p.emitError("expected integer or float tensor element"); + } + storage.push_back(result); + break; + } + default: + return p.emitError("expected element literal of primitive type"); + } + return success(); +} + +/// Parse a list of either lists or elements, returning the dimensions of the +/// parsed sub-tensors in dims. For example: +/// parseList([1, 2, 3]) -> Success, [3] +/// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] +/// parseList([[1, 2], 3]) -> Failure +/// parseList([[1, [2, 3]], [4, [5]]]) -> Failure +ParseResult +TensorLiteralParser::parseList(llvm::SmallVectorImpl &dims) { + p.consumeToken(Token::l_square); + + auto checkDims = + [&](const llvm::SmallVectorImpl &prevDims, + const llvm::SmallVectorImpl &newDims) -> ParseResult { + if (prevDims == newDims) + return success(); + return p.emitError("tensor literal is invalid; ranks are not consistent " + "between elements"); + }; - auto type = builder.getTensorType(literalParser.getShape(), eltType); - return builder.getDenseElementsAttr(type, literalParser.getValues()) - .cast(); + bool first = true; + llvm::SmallVector newDims; + unsigned size = 0; + auto parseCommaSeparatedList = [&]() -> ParseResult { + llvm::SmallVector thisDims; + if (p.getToken().getKind() == Token::l_square) { + if (parseList(thisDims)) + return failure(); + } else if (parseElement()) { + return failure(); + } + ++size; + if (!first) + return checkDims(newDims, thisDims); + newDims = thisDims; + first = false; + return success(); + }; + if (p.parseCommaSeparatedListUntil(Token::r_square, parseCommaSeparatedList)) + return failure(); + + // Return the sublists' dimensions with 'size' prepended. + dims.clear(); + dims.push_back(size); + dims.append(newDims.begin(), newDims.end()); + return success(); } -/// Dense elements attribute. +/// Parse a dense elements attribute. /// /// dense-attr-list ::= `[` attribute-value `]` /// attribute-value ::= integer-literal @@ -1397,6 +1474,25 @@ DenseElementsAttr Parser::parseDenseElementsAttr(ShapedType type) { .cast(); } +/// Parse a dense elements attribute. +/// +/// dense-attr-list ::= `[` attribute-value `]` +/// attribute-value ::= integer-literal +/// | float-literal +/// | `[` (attribute-value (`,` attribute-value)*)? `]` +/// +/// This method returns a constructed dense elements attribute of tensor type +/// with the shape from the parsing result. +DenseElementsAttr Parser::parseDenseElementsAttrAsTensor(Type eltType) { + TensorLiteralParser literalParser(*this, eltType); + if (literalParser.parse()) + return nullptr; + + auto type = builder.getTensorType(literalParser.getShape(), eltType); + return builder.getDenseElementsAttr(type, literalParser.getValues()) + .cast(); +} + /// Shaped type for elements attribute. /// /// elements-literal-type ::= vector-type | ranked-tensor-type @@ -1420,7 +1516,11 @@ ShapedType Parser::parseElementsLiteralType() { return sType; } -/// Debug Location. +//===----------------------------------------------------------------------===// +// Location parsing. +//===----------------------------------------------------------------------===// + +/// Parse a location. /// /// location ::= `loc` inline-location /// inline-location ::= '(' location-inst ')' @@ -1608,48 +1708,8 @@ ParseResult Parser::parseLocationInstance(llvm::Optional *loc) { return emitError("expected location instance"); } -/// Attribute dictionary. -/// -/// attribute-dict ::= `{` `}` -/// | `{` attribute-entry (`,` attribute-entry)* `}` -/// attribute-entry ::= bare-id `:` attribute-value -/// -ParseResult -Parser::parseAttributeDict(SmallVectorImpl &attributes) { - if (!consumeIf(Token::l_brace)) - return failure(); - - auto parseElt = [&]() -> ParseResult { - // We allow keywords as attribute names. - if (getToken().isNot(Token::bare_identifier, Token::inttype) && - !getToken().isKeyword()) - return emitError("expected attribute name"); - Identifier nameId = builder.getIdentifier(getTokenSpelling()); - consumeToken(); - - // Try to parse the ':' for the attribute value. - if (!consumeIf(Token::colon)) { - // If there is no ':', we treat this as a unit attribute. - attributes.push_back({nameId, builder.getUnitAttr()}); - return success(); - } - - auto attr = parseAttribute(); - if (!attr) - return failure(); - - attributes.push_back({nameId, attr}); - return success(); - }; - - if (parseCommaSeparatedListUntil(Token::r_brace, parseElt)) - return failure(); - - return success(); -} - //===----------------------------------------------------------------------===// -// Polyhedral structures. +// Affine parsing. //===----------------------------------------------------------------------===// /// Lower precedence ops (all at the same precedence level). LNoOp is false in @@ -2146,6 +2206,85 @@ AffineMap AffineParser::parseAffineMapRange(unsigned numDims, return builder.getAffineMap(numDims, numSymbols, exprs); } +/// Parse an affine constraint. +/// affine-constraint ::= affine-expr `>=` `0` +/// | affine-expr `==` `0` +/// +/// isEq is set to true if the parsed constraint is an equality, false if it +/// is an inequality (greater than or equal). +/// +AffineExpr AffineParser::parseAffineConstraint(bool *isEq) { + AffineExpr expr = parseAffineExpr(); + if (!expr) + return nullptr; + + if (consumeIf(Token::greater) && consumeIf(Token::equal) && + getToken().is(Token::integer)) { + auto dim = getToken().getUnsignedIntegerValue(); + if (dim.hasValue() && dim.getValue() == 0) { + consumeToken(Token::integer); + *isEq = false; + return expr; + } + return (emitError("expected '0' after '>='"), nullptr); + } + + if (consumeIf(Token::equal) && consumeIf(Token::equal) && + getToken().is(Token::integer)) { + auto dim = getToken().getUnsignedIntegerValue(); + if (dim.hasValue() && dim.getValue() == 0) { + consumeToken(Token::integer); + *isEq = true; + return expr; + } + return (emitError("expected '0' after '=='"), nullptr); + } + + return (emitError("expected '== 0' or '>= 0' at end of affine constraint"), + nullptr); +} + +/// Parse the constraints that are part of an integer set definition. +/// integer-set-inline +/// ::= dim-and-symbol-id-lists `:` +/// '(' affine-constraint-conjunction? ')' +/// affine-constraint-conjunction ::= affine-constraint (`,` +/// affine-constraint)* +/// +IntegerSet AffineParser::parseIntegerSetConstraints(unsigned numDims, + unsigned numSymbols) { + if (parseToken(Token::l_paren, + "expected '(' at start of integer set constraint list")) + return IntegerSet(); + + SmallVector constraints; + SmallVector isEqs; + auto parseElt = [&]() -> ParseResult { + bool isEq; + auto elt = parseAffineConstraint(&isEq); + ParseResult res = elt ? success() : failure(); + if (elt) { + constraints.push_back(elt); + isEqs.push_back(isEq); + } + return res; + }; + + // Parse a list of affine constraints (comma-separated). + if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true)) + return IntegerSet(); + + // If no constraints were parsed, then treat this as a degenerate 'true' case. + if (constraints.empty()) { + /* 0 == 0 */ + auto zero = getAffineConstantExpr(0, getContext()); + return builder.getIntegerSet(numDims, numSymbols, zero, true); + } + + // Parsed a valid integer set. + return builder.getIntegerSet(numDims, numSymbols, constraints, isEqs); +} + /// Parse an ambiguous reference to either and affine map or an integer set. ParseResult Parser::parseAffineMapOrIntegerSetReference(AffineMap &map, IntegerSet &set) { @@ -3347,85 +3486,6 @@ Operation *FunctionParser::parseCustomOperation() { return builder.createOperation(opState); } -/// Parse an affine constraint. -/// affine-constraint ::= affine-expr `>=` `0` -/// | affine-expr `==` `0` -/// -/// isEq is set to true if the parsed constraint is an equality, false if it -/// is an inequality (greater than or equal). -/// -AffineExpr AffineParser::parseAffineConstraint(bool *isEq) { - AffineExpr expr = parseAffineExpr(); - if (!expr) - return nullptr; - - if (consumeIf(Token::greater) && consumeIf(Token::equal) && - getToken().is(Token::integer)) { - auto dim = getToken().getUnsignedIntegerValue(); - if (dim.hasValue() && dim.getValue() == 0) { - consumeToken(Token::integer); - *isEq = false; - return expr; - } - return (emitError("expected '0' after '>='"), nullptr); - } - - if (consumeIf(Token::equal) && consumeIf(Token::equal) && - getToken().is(Token::integer)) { - auto dim = getToken().getUnsignedIntegerValue(); - if (dim.hasValue() && dim.getValue() == 0) { - consumeToken(Token::integer); - *isEq = true; - return expr; - } - return (emitError("expected '0' after '=='"), nullptr); - } - - return (emitError("expected '== 0' or '>= 0' at end of affine constraint"), - nullptr); -} - -/// Parse the constraints that are part of an integer set definition. -/// integer-set-inline -/// ::= dim-and-symbol-id-lists `:` -/// '(' affine-constraint-conjunction? ')' -/// affine-constraint-conjunction ::= affine-constraint (`,` -/// affine-constraint)* -/// -IntegerSet AffineParser::parseIntegerSetConstraints(unsigned numDims, - unsigned numSymbols) { - if (parseToken(Token::l_paren, - "expected '(' at start of integer set constraint list")) - return IntegerSet(); - - SmallVector constraints; - SmallVector isEqs; - auto parseElt = [&]() -> ParseResult { - bool isEq; - auto elt = parseAffineConstraint(&isEq); - ParseResult res = elt ? success() : failure(); - if (elt) { - constraints.push_back(elt); - isEqs.push_back(isEq); - } - return res; - }; - - // Parse a list of affine constraints (comma-separated). - if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true)) - return IntegerSet(); - - // If no constraints were parsed, then treat this as a degenerate 'true' case. - if (constraints.empty()) { - /* 0 == 0 */ - auto zero = getAffineConstantExpr(0, getContext()); - return builder.getIntegerSet(numDims, numSymbols, zero, true); - } - - // Parsed a valid integer set. - return builder.getIntegerSet(numDims, numSymbols, constraints, isEqs); -} - //===----------------------------------------------------------------------===// // Top-level entity parsing. //===----------------------------------------------------------------------===//