};
//===----------------------------------------------------------------------===//
+// ReturnStmt
+//===----------------------------------------------------------------------===//
+
+/// This statement represents a return from a "callable" like decl, e.g. a
+/// Constraint or a Rewrite.
+class ReturnStmt final : public Node::NodeBase<ReturnStmt, Stmt> {
+public:
+ static ReturnStmt *create(Context &ctx, SMRange loc, Expr *resultExpr);
+
+ /// Return the result expression of this statement.
+ Expr *getResultExpr() { return resultExpr; }
+ const Expr *getResultExpr() const { return resultExpr; }
+
+ /// Set the result expression of this statement.
+ void setResultExpr(Expr *expr) { resultExpr = expr; }
+
+private:
+ ReturnStmt(SMRange loc, Expr *resultExpr)
+ : Base(loc), resultExpr(resultExpr) {}
+
+ // The result expression of this statement.
+ Expr *resultExpr;
+};
+
+//===----------------------------------------------------------------------===//
// Expr
//===----------------------------------------------------------------------===//
};
//===----------------------------------------------------------------------===//
+// CallExpr
+//===----------------------------------------------------------------------===//
+
+/// This expression represents a call to a decl, such as a
+/// UserConstraintDecl/UserRewriteDecl.
+class CallExpr final : public Node::NodeBase<CallExpr, Expr>,
+ private llvm::TrailingObjects<CallExpr, Expr *> {
+public:
+ static CallExpr *create(Context &ctx, SMRange loc, Expr *callable,
+ ArrayRef<Expr *> arguments, Type resultType);
+
+ /// Return the callable of this call.
+ Expr *getCallableExpr() const { return callable; }
+
+ /// Return the arguments of this call.
+ MutableArrayRef<Expr *> getArguments() {
+ return {getTrailingObjects<Expr *>(), numArgs};
+ }
+ ArrayRef<Expr *> getArguments() const {
+ return const_cast<CallExpr *>(this)->getArguments();
+ }
+
+private:
+ CallExpr(SMRange loc, Type type, Expr *callable, unsigned numArgs)
+ : Base(loc, type), callable(callable), numArgs(numArgs) {}
+
+ /// The callable of this call.
+ Expr *callable;
+
+ /// The number of arguments of the call.
+ unsigned numArgs;
+
+ /// TrailingObject utilities.
+ friend llvm::TrailingObjects<CallExpr, Expr *>;
+};
+
+//===----------------------------------------------------------------------===//
// DeclRefExpr
//===----------------------------------------------------------------------===//
};
//===----------------------------------------------------------------------===//
+// UserConstraintDecl
+//===----------------------------------------------------------------------===//
+
+/// This decl represents a user defined constraint. This is either:
+/// * an imported native constraint
+/// - Similar to an external function declaration. This is a native
+/// constraint defined externally, and imported into PDLL via a
+/// declaration.
+/// * a native constraint defined in PDLL
+/// - This is a native constraint, i.e. a constraint whose implementation is
+/// defined in C++(or potentially some other non-PDLL language). The
+/// implementation of this constraint is specified as a string code block
+/// in PDLL.
+/// * a PDLL constraint
+/// - This is a constraint which is defined using only PDLL constructs.
+class UserConstraintDecl final
+ : public Node::NodeBase<UserConstraintDecl, ConstraintDecl>,
+ llvm::TrailingObjects<UserConstraintDecl, VariableDecl *> {
+public:
+ /// Create a native constraint with the given optional code block.
+ static UserConstraintDecl *createNative(Context &ctx, const Name &name,
+ ArrayRef<VariableDecl *> inputs,
+ ArrayRef<VariableDecl *> results,
+ Optional<StringRef> codeBlock,
+ Type resultType) {
+ return createImpl(ctx, name, inputs, results, codeBlock, /*body=*/nullptr,
+ resultType);
+ }
+
+ /// Create a PDLL constraint with the given body.
+ static UserConstraintDecl *createPDLL(Context &ctx, const Name &name,
+ ArrayRef<VariableDecl *> inputs,
+ ArrayRef<VariableDecl *> results,
+ const CompoundStmt *body,
+ Type resultType) {
+ return createImpl(ctx, name, inputs, results, /*codeBlock=*/llvm::None,
+ body, resultType);
+ }
+
+ /// Return the name of the constraint.
+ const Name &getName() const { return *Decl::getName(); }
+
+ /// Return the input arguments of this constraint.
+ MutableArrayRef<VariableDecl *> getInputs() {
+ return {getTrailingObjects<VariableDecl *>(), numInputs};
+ }
+ ArrayRef<VariableDecl *> getInputs() const {
+ return const_cast<UserConstraintDecl *>(this)->getInputs();
+ }
+
+ /// Return the explicit results of the constraint declaration. May be empty,
+ /// even if the constraint has results (e.g. in the case of inferred results).
+ MutableArrayRef<VariableDecl *> getResults() {
+ return {getTrailingObjects<VariableDecl *>() + numInputs, numResults};
+ }
+ ArrayRef<VariableDecl *> getResults() const {
+ return const_cast<UserConstraintDecl *>(this)->getResults();
+ }
+
+ /// Return the optional code block of this constraint, if this is a native
+ /// constraint with a provided implementation.
+ Optional<StringRef> getCodeBlock() const { return codeBlock; }
+
+ /// Return the body of this constraint if this constraint is a PDLL
+ /// constraint, otherwise returns nullptr.
+ const CompoundStmt *getBody() const { return constraintBody; }
+
+ /// Return the result type of this constraint.
+ Type getResultType() const { return resultType; }
+
+ /// Returns true if this constraint is external.
+ bool isExternal() const { return !constraintBody && !codeBlock; }
+
+private:
+ /// Create either a PDLL constraint or a native constraint with the given
+ /// components.
+ static UserConstraintDecl *
+ createImpl(Context &ctx, const Name &name, ArrayRef<VariableDecl *> inputs,
+ ArrayRef<VariableDecl *> results, Optional<StringRef> codeBlock,
+ const CompoundStmt *body, Type resultType);
+
+ UserConstraintDecl(const Name &name, unsigned numInputs, unsigned numResults,
+ Optional<StringRef> codeBlock, const CompoundStmt *body,
+ Type resultType)
+ : Base(name.getLoc(), &name), numInputs(numInputs),
+ numResults(numResults), codeBlock(codeBlock), constraintBody(body),
+ resultType(resultType) {}
+
+ /// The number of inputs to this constraint.
+ unsigned numInputs;
+
+ /// The number of explicit results to this constraint.
+ unsigned numResults;
+
+ /// The optional code block of this constraint.
+ Optional<StringRef> codeBlock;
+
+ /// The optional body of this constraint.
+ const CompoundStmt *constraintBody;
+
+ /// The result type of the constraint.
+ Type resultType;
+
+ /// Allow access to various internals.
+ friend llvm::TrailingObjects<UserConstraintDecl, VariableDecl *>;
+};
+
+//===----------------------------------------------------------------------===//
// NamedAttributeDecl
//===----------------------------------------------------------------------===//
};
//===----------------------------------------------------------------------===//
+// UserRewriteDecl
+//===----------------------------------------------------------------------===//
+
+/// This decl represents a user defined rewrite. This is either:
+/// * an imported native rewrite
+/// - Similar to an external function declaration. This is a native
+/// rewrite defined externally, and imported into PDLL via a declaration.
+/// * a native rewrite defined in PDLL
+/// - This is a native rewrite, i.e. a rewrite whose implementation is
+/// defined in C++(or potentially some other non-PDLL language). The
+/// implementation of this rewrite is specified as a string code block
+/// in PDLL.
+/// * a PDLL rewrite
+/// - This is a rewrite which is defined using only PDLL constructs.
+class UserRewriteDecl final
+ : public Node::NodeBase<UserRewriteDecl, Decl>,
+ llvm::TrailingObjects<UserRewriteDecl, VariableDecl *> {
+public:
+ /// Create a native rewrite with the given optional code block.
+ static UserRewriteDecl *createNative(Context &ctx, const Name &name,
+ ArrayRef<VariableDecl *> inputs,
+ ArrayRef<VariableDecl *> results,
+ Optional<StringRef> codeBlock,
+ Type resultType) {
+ return createImpl(ctx, name, inputs, results, codeBlock, /*body=*/nullptr,
+ resultType);
+ }
+
+ /// Create a PDLL rewrite with the given body.
+ static UserRewriteDecl *createPDLL(Context &ctx, const Name &name,
+ ArrayRef<VariableDecl *> inputs,
+ ArrayRef<VariableDecl *> results,
+ const CompoundStmt *body,
+ Type resultType) {
+ return createImpl(ctx, name, inputs, results, /*codeBlock=*/llvm::None,
+ body, resultType);
+ }
+
+ /// Return the name of the rewrite.
+ const Name &getName() const { return *Decl::getName(); }
+
+ /// Return the input arguments of this rewrite.
+ MutableArrayRef<VariableDecl *> getInputs() {
+ return {getTrailingObjects<VariableDecl *>(), numInputs};
+ }
+ ArrayRef<VariableDecl *> getInputs() const {
+ return const_cast<UserRewriteDecl *>(this)->getInputs();
+ }
+
+ /// Return the explicit results of the rewrite declaration. May be empty,
+ /// even if the rewrite has results (e.g. in the case of inferred results).
+ MutableArrayRef<VariableDecl *> getResults() {
+ return {getTrailingObjects<VariableDecl *>() + numInputs, numResults};
+ }
+ ArrayRef<VariableDecl *> getResults() const {
+ return const_cast<UserRewriteDecl *>(this)->getResults();
+ }
+
+ /// Return the optional code block of this rewrite, if this is a native
+ /// rewrite with a provided implementation.
+ Optional<StringRef> getCodeBlock() const { return codeBlock; }
+
+ /// Return the body of this rewrite if this rewrite is a PDLL rewrite,
+ /// otherwise returns nullptr.
+ const CompoundStmt *getBody() const { return rewriteBody; }
+
+ /// Return the result type of this rewrite.
+ Type getResultType() const { return resultType; }
+
+ /// Returns true if this rewrite is external.
+ bool isExternal() const { return !rewriteBody && !codeBlock; }
+
+private:
+ /// Create either a PDLL rewrite or a native rewrite with the given
+ /// components.
+ static UserRewriteDecl *createImpl(Context &ctx, const Name &name,
+ ArrayRef<VariableDecl *> inputs,
+ ArrayRef<VariableDecl *> results,
+ Optional<StringRef> codeBlock,
+ const CompoundStmt *body, Type resultType);
+
+ UserRewriteDecl(const Name &name, unsigned numInputs, unsigned numResults,
+ Optional<StringRef> codeBlock, const CompoundStmt *body,
+ Type resultType)
+ : Base(name.getLoc(), &name), numInputs(numInputs),
+ numResults(numResults), codeBlock(codeBlock), rewriteBody(body),
+ resultType(resultType) {}
+
+ /// The number of inputs to this rewrite.
+ unsigned numInputs;
+
+ /// The number of explicit results to this rewrite.
+ unsigned numResults;
+
+ /// The optional code block of this rewrite.
+ Optional<StringRef> codeBlock;
+
+ /// The optional body of this rewrite.
+ const CompoundStmt *rewriteBody;
+
+ /// The result type of the rewrite.
+ Type resultType;
+
+ /// Allow access to various internals.
+ friend llvm::TrailingObjects<UserRewriteDecl, VariableDecl *>;
+};
+
+//===----------------------------------------------------------------------===//
+// CallableDecl
+//===----------------------------------------------------------------------===//
+
+/// This decl represents a shared interface for all callable decls.
+class CallableDecl : public Decl {
+public:
+ /// Return the callable type of this decl.
+ StringRef getCallableType() const {
+ if (isa<UserConstraintDecl>(this))
+ return "constraint";
+ assert(isa<UserRewriteDecl>(this) && "unknown callable type");
+ return "rewrite";
+ }
+
+ /// Return the inputs of this decl.
+ ArrayRef<VariableDecl *> getInputs() const {
+ if (const auto *cst = dyn_cast<UserConstraintDecl>(this))
+ return cst->getInputs();
+ return cast<UserRewriteDecl>(this)->getInputs();
+ }
+
+ /// Return the result type of this decl.
+ Type getResultType() const {
+ if (const auto *cst = dyn_cast<UserConstraintDecl>(this))
+ return cst->getResultType();
+ return cast<UserRewriteDecl>(this)->getResultType();
+ }
+
+ /// Support LLVM type casting facilities.
+ static bool classof(const Node *decl) {
+ return isa<UserConstraintDecl, UserRewriteDecl>(decl);
+ }
+};
+
+//===----------------------------------------------------------------------===//
// VariableDecl
//===----------------------------------------------------------------------===//
inline bool Decl::classof(const Node *node) {
return isa<ConstraintDecl, NamedAttributeDecl, OpNameDecl, PatternDecl,
- VariableDecl>(node);
+ UserRewriteDecl, VariableDecl>(node);
}
inline bool ConstraintDecl::classof(const Node *node) {
- return isa<CoreConstraintDecl>(node);
+ return isa<CoreConstraintDecl, UserConstraintDecl>(node);
}
inline bool CoreConstraintDecl::classof(const Node *node) {
enum class ParserContext {
/// The parser is in the global context.
Global,
+ /// The parser is currently within a Constraint, which disallows all types
+ /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.).
+ Constraint,
/// The parser is currently within the matcher portion of a Pattern, which
/// is allows a terminal operation rewrite statement but no other rewrite
/// transformations.
FailureOr<ast::Decl *> parseTopLevelDecl();
FailureOr<ast::NamedAttributeDecl *> parseNamedAttributeDecl();
+
+ /// Parse an argument variable as part of the signature of a
+ /// UserConstraintDecl or UserRewriteDecl.
+ FailureOr<ast::VariableDecl *> parseArgumentDecl();
+
+ /// Parse a result variable as part of the signature of a UserConstraintDecl
+ /// or UserRewriteDecl.
+ FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum);
+
+ /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being
+ /// defined in a non-global context.
+ FailureOr<ast::UserConstraintDecl *>
+ parseUserConstraintDecl(bool isInline = false);
+
+ /// Parse an inline UserConstraintDecl. An inline decl is one defined in a
+ /// non-global context, such as within a Pattern/Constraint/etc.
+ FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl();
+
+ /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
+ /// PDLL constructs.
+ FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl(
+ const ast::Name &name, bool isInline,
+ ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
+ ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
+
+ /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being
+ /// defined in a non-global context.
+ FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false);
+
+ /// Parse an inline UserRewriteDecl. An inline decl is one defined in a
+ /// non-global context, such as within a Pattern/Rewrite/etc.
+ FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl();
+
+ /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
+ /// PDLL constructs.
+ FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl(
+ const ast::Name &name, bool isInline,
+ ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
+ ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
+
+ /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have
+ /// effectively the same syntax, and only differ on slight semantics (given
+ /// the different parsing contexts).
+ template <typename T, typename ParseUserPDLLDeclFnT>
+ FailureOr<T *> parseUserConstraintOrRewriteDecl(
+ ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
+ StringRef anonymousNamePrefix, bool isInline);
+
+ /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl.
+ /// These decls have effectively the same syntax.
+ template <typename T>
+ FailureOr<T *> parseUserNativeConstraintOrRewriteDecl(
+ const ast::Name &name, bool isInline,
+ ArrayRef<ast::VariableDecl *> arguments,
+ ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
+
+ /// Parse the functional signature (i.e. the arguments and results) of a
+ /// UserConstraintDecl or UserRewriteDecl.
+ LogicalResult parseUserConstraintOrRewriteSignature(
+ SmallVectorImpl<ast::VariableDecl *> &arguments,
+ SmallVectorImpl<ast::VariableDecl *> &results,
+ ast::DeclScope *&argumentScope, ast::Type &resultType);
+
+ /// Validate the return (which if present is specified by bodyIt) of a
+ /// UserConstraintDecl or UserRewriteDecl.
+ LogicalResult validateUserConstraintOrRewriteReturn(
+ StringRef declType, ast::CompoundStmt *body,
+ ArrayRef<ast::Stmt *>::iterator bodyIt,
+ ArrayRef<ast::Stmt *>::iterator bodyE,
+ ArrayRef<ast::VariableDecl *> results, ast::Type &resultType);
+
FailureOr<ast::CompoundStmt *>
parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
bool expectTerminalSemicolon = true);
/// location of a previously parsed type constraint for the entity that will
/// be constrained by the parsed constraint. `existingConstraints` are any
/// existing constraints that have already been parsed for the same entity
- /// that will be constrained by this constraint.
+ /// that will be constrained by this constraint. `allowInlineTypeConstraints`
+ /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`.
FailureOr<ast::ConstraintRef>
parseConstraint(Optional<SMRange> &typeConstraint,
- ArrayRef<ast::ConstraintRef> existingConstraints);
+ ArrayRef<ast::ConstraintRef> existingConstraints,
+ bool allowInlineTypeConstraints);
+
+ /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl
+ /// argument or result variable. The constraints for these variables do not
+ /// allow inline type constraints, and only permit a single constraint.
+ FailureOr<ast::ConstraintRef> parseArgOrResultConstraint();
//===--------------------------------------------------------------------===//
// Exprs
/// Identifier expressions.
FailureOr<ast::Expr *> parseAttributeExpr();
+ FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr);
FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc);
FailureOr<ast::Expr *> parseIdentifierExpr();
+ FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr();
+ FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr();
FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr);
FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false);
FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName);
FailureOr<ast::EraseStmt *> parseEraseStmt();
FailureOr<ast::LetStmt *> parseLetStmt();
FailureOr<ast::ReplaceStmt *> parseReplaceStmt();
+ FailureOr<ast::ReturnStmt *> parseReturnStmt();
FailureOr<ast::RewriteStmt *> parseRewriteStmt();
//===--------------------------------------------------------------------===//
//===--------------------------------------------------------------------===//
// Decls
+ /// Try to extract a callable from the given AST node. Returns nullptr on
+ /// failure.
+ ast::CallableDecl *tryExtractCallableDecl(ast::Node *node);
+
/// Try to create a pattern decl with the given components, returning the
/// Pattern on success.
FailureOr<ast::PatternDecl *>
const ParsedPatternMetadata &metadata,
ast::CompoundStmt *body);
+ /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set
+ /// of results, defined as part of the signature.
+ ast::Type
+ createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results);
+
+ /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl.
+ template <typename T>
+ FailureOr<T *> createUserPDLLConstraintOrRewriteDecl(
+ const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
+ ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
+ ast::CompoundStmt *body);
+
/// Try to create a variable decl with the given components, returning the
/// Variable on success.
FailureOr<ast::VariableDecl *>
createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
ArrayRef<ast::ConstraintRef> constraints);
+ /// Create a variable for an argument or result defined as part of the
+ /// signature of a UserConstraintDecl/UserRewriteDecl.
+ FailureOr<ast::VariableDecl *>
+ createArgOrResultVariableDecl(StringRef name, SMRange loc,
+ const ast::ConstraintRef &constraint);
+
/// Validate the constraints used to constraint a variable decl.
/// `inferredType` is the type of the variable inferred by the constraints
/// within the list, and is updated to the most refined type as determined by
/// Validate a single reference to a constraint. `inferredType` contains the
/// currently inferred variabled type and is refined within the type defined
/// by the constraint. Returns success if the constraint is valid, failure
- /// otherwise.
+ /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user
+ /// defined constraints) may be used with the variable.
LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref,
- ast::Type &inferredType);
+ ast::Type &inferredType,
+ bool allowNonCoreConstraints = true);
LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr);
LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr);
//===--------------------------------------------------------------------===//
// Exprs
- FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc,
- ast::Decl *decl);
+ FailureOr<ast::CallExpr *>
+ createCallExpr(SMRange loc, ast::Expr *parentExpr,
+ MutableArrayRef<ast::Expr *> arguments);
+ FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl);
FailureOr<ast::DeclRefExpr *>
createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
ArrayRef<ast::ConstraintRef> constraints);
FailureOr<ast::MemberAccessExpr *>
- createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
- SMRange loc);
+ createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc);
/// Validate the member access `name` into the given parent expression. On
/// success, this also returns the type of the member accessed.
LogicalResult
validateOperationOperands(SMRange loc, Optional<StringRef> name,
MutableArrayRef<ast::Expr *> operands);
- LogicalResult validateOperationResults(SMRange loc,
- Optional<StringRef> name,
+ LogicalResult validateOperationResults(SMRange loc, Optional<StringRef> name,
MutableArrayRef<ast::Expr *> results);
LogicalResult
- validateOperationOperandsOrResults(SMRange loc,
- Optional<StringRef> name,
+ validateOperationOperandsOrResults(SMRange loc, Optional<StringRef> name,
MutableArrayRef<ast::Expr *> values,
ast::Type singleTy, ast::Type rangeTy);
FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc,
//===--------------------------------------------------------------------===//
// Stmts
- FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc,
- ast::Expr *rootOp);
+ FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp);
FailureOr<ast::ReplaceStmt *>
createReplaceStmt(SMRange loc, ast::Expr *rootOp,
MutableArrayRef<ast::Expr *> replValues);
LogicalResult emitError(const Twine &msg) {
return emitError(curToken.getLoc(), msg);
}
- LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg,
- SMRange noteLoc, const Twine ¬e) {
+ LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc,
+ const Twine ¬e) {
lexer.emitErrorAndNote(loc, msg, noteLoc, note);
return failure();
}
/// Cached types to simplify verification and expression creation.
ast::Type valueTy, valueRangeTy;
ast::Type typeTy, typeRangeTy;
+
+ /// A counter used when naming anonymous constraints and rewrites.
+ unsigned anonymousDeclNameCounter = 0;
};
} // namespace
FailureOr<ast::Decl *> Parser::parseTopLevelDecl() {
FailureOr<ast::Decl *> decl;
switch (curToken.getKind()) {
+ case Token::kw_Constraint:
+ decl = parseUserConstraintDecl();
+ break;
case Token::kw_Pattern:
decl = parsePatternDecl();
break;
+ case Token::kw_Rewrite:
+ decl = parseUserRewriteDecl();
+ break;
default:
return emitError("expected top-level declaration, such as a `Pattern`");
}
return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement);
}
+FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() {
+ // Ensure that the argument is named.
+ if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword())
+ return emitError("expected identifier argument name");
+
+ // Parse the argument similarly to a normal variable.
+ StringRef name = curToken.getSpelling();
+ SMRange nameLoc = curToken.getLoc();
+ consumeToken();
+
+ if (failed(
+ parseToken(Token::colon, "expected `:` before argument constraint")))
+ return failure();
+
+ FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
+ if (failed(cst))
+ return failure();
+
+ return createArgOrResultVariableDecl(name, nameLoc, *cst);
+}
+
+FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) {
+ // Check to see if this result is named.
+ if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
+ // Check to see if this name actually refers to a Constraint.
+ ast::Decl *existingDecl = curDeclScope->lookup(curToken.getSpelling());
+ if (isa_and_nonnull<ast::ConstraintDecl>(existingDecl)) {
+ // If yes, and this is a Rewrite, give a nice error message as non-Core
+ // constraints are not supported on Rewrite results.
+ if (parserContext == ParserContext::Rewrite) {
+ return emitError(
+ "`Rewrite` results are only permitted to use core constraints, "
+ "such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange`");
+ }
+
+ // Otherwise, parse this as an unnamed result variable.
+ } else {
+ // If it wasn't a constraint, parse the result similarly to a variable. If
+ // there is already an existing decl, we will emit an error when defining
+ // this variable later.
+ StringRef name = curToken.getSpelling();
+ SMRange nameLoc = curToken.getLoc();
+ consumeToken();
+
+ if (failed(parseToken(Token::colon,
+ "expected `:` before result constraint")))
+ return failure();
+
+ FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
+ if (failed(cst))
+ return failure();
+
+ return createArgOrResultVariableDecl(name, nameLoc, *cst);
+ }
+ }
+
+ // If it isn't named, we parse the constraint directly and create an unnamed
+ // result variable.
+ FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
+ if (failed(cst))
+ return failure();
+
+ return createArgOrResultVariableDecl("", cst->referenceLoc, *cst);
+}
+
+FailureOr<ast::UserConstraintDecl *>
+Parser::parseUserConstraintDecl(bool isInline) {
+ // Constraints and rewrites have very similar formats, dispatch to a shared
+ // interface for parsing.
+ return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>(
+ [&](auto &&...args) { return parseUserPDLLConstraintDecl(args...); },
+ ParserContext::Constraint, "constraint", isInline);
+}
+
+FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() {
+ FailureOr<ast::UserConstraintDecl *> decl =
+ parseUserConstraintDecl(/*isInline=*/true);
+ if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
+ return failure();
+
+ curDeclScope->add(*decl);
+ return decl;
+}
+
+FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl(
+ const ast::Name &name, bool isInline,
+ ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
+ ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
+ // Push the argument scope back onto the list, so that the body can
+ // reference arguments.
+ pushDeclScope(argumentScope);
+
+ // Parse the body of the constraint. The body is either defined as a compound
+ // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`.
+ ast::CompoundStmt *body;
+ if (curToken.is(Token::equal_arrow)) {
+ FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
+ [&](ast::Stmt *&stmt) -> LogicalResult {
+ ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt);
+ if (!stmtExpr) {
+ return emitError(stmt->getLoc(),
+ "expected `Constraint` lambda body to contain a "
+ "single expression");
+ }
+ stmt = ast::ReturnStmt::create(ctx, stmt->getLoc(), stmtExpr);
+ return success();
+ },
+ /*expectTerminalSemicolon=*/!isInline);
+ if (failed(bodyResult))
+ return failure();
+ body = *bodyResult;
+ } else {
+ FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
+ if (failed(bodyResult))
+ return failure();
+ body = *bodyResult;
+
+ // Verify the structure of the body.
+ auto bodyIt = body->begin(), bodyE = body->end();
+ for (; bodyIt != bodyE; ++bodyIt)
+ if (isa<ast::ReturnStmt>(*bodyIt))
+ break;
+ if (failed(validateUserConstraintOrRewriteReturn(
+ "Constraint", body, bodyIt, bodyE, results, resultType)))
+ return failure();
+ }
+ popDeclScope();
+
+ return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>(
+ name, arguments, results, resultType, body);
+}
+
+FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) {
+ // Constraints and rewrites have very similar formats, dispatch to a shared
+ // interface for parsing.
+ return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>(
+ [&](auto &&...args) { return parseUserPDLLRewriteDecl(args...); },
+ ParserContext::Rewrite, "rewrite", isInline);
+}
+
+FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() {
+ FailureOr<ast::UserRewriteDecl *> decl =
+ parseUserRewriteDecl(/*isInline=*/true);
+ if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
+ return failure();
+
+ curDeclScope->add(*decl);
+ return decl;
+}
+
+FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl(
+ const ast::Name &name, bool isInline,
+ ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
+ ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
+ // Push the argument scope back onto the list, so that the body can
+ // reference arguments.
+ curDeclScope = argumentScope;
+ ast::CompoundStmt *body;
+ if (curToken.is(Token::equal_arrow)) {
+ FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
+ [&](ast::Stmt *&statement) -> LogicalResult {
+ if (isa<ast::OpRewriteStmt>(statement))
+ return success();
+
+ ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement);
+ if (!statementExpr) {
+ return emitError(
+ statement->getLoc(),
+ "expected `Rewrite` lambda body to contain a single expression "
+ "or an operation rewrite statement; such as `erase`, "
+ "`replace`, or `rewrite`");
+ }
+ statement =
+ ast::ReturnStmt::create(ctx, statement->getLoc(), statementExpr);
+ return success();
+ },
+ /*expectTerminalSemicolon=*/!isInline);
+ if (failed(bodyResult))
+ return failure();
+ body = *bodyResult;
+ } else {
+ FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
+ if (failed(bodyResult))
+ return failure();
+ body = *bodyResult;
+ }
+ popDeclScope();
+
+ // Verify the structure of the body.
+ auto bodyIt = body->begin(), bodyE = body->end();
+ for (; bodyIt != bodyE; ++bodyIt)
+ if (isa<ast::ReturnStmt>(*bodyIt))
+ break;
+ if (failed(validateUserConstraintOrRewriteReturn("Rewrite", body, bodyIt,
+ bodyE, results, resultType)))
+ return failure();
+ return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>(
+ name, arguments, results, resultType, body);
+}
+
+template <typename T, typename ParseUserPDLLDeclFnT>
+FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl(
+ ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
+ StringRef anonymousNamePrefix, bool isInline) {
+ SMRange loc = curToken.getLoc();
+ consumeToken();
+ llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, declContext);
+
+ // Parse the name of the decl.
+ const ast::Name *name = nullptr;
+ if (curToken.isNot(Token::identifier)) {
+ // Only inline decls can be un-named. Inline decls are similar to "lambdas"
+ // in C++, so being unnamed is fine.
+ if (!isInline)
+ return emitError("expected identifier name");
+
+ // Create a unique anonymous name to use, as the name for this decl is not
+ // important.
+ std::string anonName =
+ llvm::formatv("<anonymous_{0}_{1}>", anonymousNamePrefix,
+ anonymousDeclNameCounter++)
+ .str();
+ name = &ast::Name::create(ctx, anonName, loc);
+ } else {
+ // If a name was provided, we can use it directly.
+ name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
+ consumeToken(Token::identifier);
+ }
+
+ // Parse the functional signature of the decl.
+ SmallVector<ast::VariableDecl *> arguments, results;
+ ast::DeclScope *argumentScope;
+ ast::Type resultType;
+ if (failed(parseUserConstraintOrRewriteSignature(arguments, results,
+ argumentScope, resultType)))
+ return failure();
+
+ // Check to see which type of constraint this is. If the constraint contains a
+ // compound body, this is a PDLL decl.
+ if (curToken.isAny(Token::l_brace, Token::equal_arrow))
+ return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results,
+ resultType);
+
+ // Otherwise, this is a native decl.
+ return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments,
+ results, resultType);
+}
+
+template <typename T>
+FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl(
+ const ast::Name &name, bool isInline,
+ ArrayRef<ast::VariableDecl *> arguments,
+ ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
+ // If followed by a string, the native code body has also been specified.
+ std::string codeStrStorage;
+ Optional<StringRef> optCodeStr;
+ if (curToken.isString()) {
+ codeStrStorage = curToken.getStringValue();
+ optCodeStr = codeStrStorage;
+ consumeToken();
+ } else if (isInline) {
+ return emitError(name.getLoc(),
+ "external declarations must be declared in global scope");
+ }
+ if (failed(parseToken(Token::semicolon,
+ "expected `;` after native declaration")))
+ return failure();
+ return T::createNative(ctx, name, arguments, results, optCodeStr, resultType);
+}
+
+LogicalResult Parser::parseUserConstraintOrRewriteSignature(
+ SmallVectorImpl<ast::VariableDecl *> &arguments,
+ SmallVectorImpl<ast::VariableDecl *> &results,
+ ast::DeclScope *&argumentScope, ast::Type &resultType) {
+ // Parse the argument list of the decl.
+ if (failed(parseToken(Token::l_paren, "expected `(` to start argument list")))
+ return failure();
+
+ argumentScope = pushDeclScope();
+ if (curToken.isNot(Token::r_paren)) {
+ do {
+ FailureOr<ast::VariableDecl *> argument = parseArgumentDecl();
+ if (failed(argument))
+ return failure();
+ arguments.emplace_back(*argument);
+ } while (consumeIf(Token::comma));
+ }
+ popDeclScope();
+ if (failed(parseToken(Token::r_paren, "expected `)` to end argument list")))
+ return failure();
+
+ // Parse the results of the decl.
+ pushDeclScope();
+ if (consumeIf(Token::arrow)) {
+ auto parseResultFn = [&]() -> LogicalResult {
+ FailureOr<ast::VariableDecl *> result = parseResultDecl(results.size());
+ if (failed(result))
+ return failure();
+ results.emplace_back(*result);
+ return success();
+ };
+
+ // Check for a list of results.
+ if (consumeIf(Token::l_paren)) {
+ do {
+ if (failed(parseResultFn()))
+ return failure();
+ } while (consumeIf(Token::comma));
+ if (failed(parseToken(Token::r_paren, "expected `)` to end result list")))
+ return failure();
+
+ // Otherwise, there is only one result.
+ } else if (failed(parseResultFn())) {
+ return failure();
+ }
+ }
+ popDeclScope();
+
+ // Compute the result type of the decl.
+ resultType = createUserConstraintRewriteResultType(results);
+
+ // Verify that results are only named if there are more than one.
+ if (results.size() == 1 && !results.front()->getName().getName().empty()) {
+ return emitError(
+ results.front()->getLoc(),
+ "cannot create a single-element tuple with an element label");
+ }
+ return success();
+}
+
+LogicalResult Parser::validateUserConstraintOrRewriteReturn(
+ StringRef declType, ast::CompoundStmt *body,
+ ArrayRef<ast::Stmt *>::iterator bodyIt,
+ ArrayRef<ast::Stmt *>::iterator bodyE,
+ ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) {
+ // Handle if a `return` was provided.
+ if (bodyIt != bodyE) {
+ // Emit an error if we have trailing statements after the return.
+ if (std::next(bodyIt) != bodyE) {
+ return emitError(
+ (*std::next(bodyIt))->getLoc(),
+ llvm::formatv("`return` terminated the `{0}` body, but found "
+ "trailing statements afterwards",
+ declType));
+ }
+
+ // Otherwise if a return wasn't provided, check that no results are
+ // expected.
+ } else if (!results.empty()) {
+ return emitError(
+ {body->getLoc().End, body->getLoc().End},
+ llvm::formatv("missing return in a `{0}` expected to return `{1}`",
+ declType, resultType));
+ }
+ return success();
+}
+
FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() {
return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult {
if (isa<ast::OpRewriteStmt>(statement))
// Verify the body of the pattern.
auto bodyIt = body->begin(), bodyE = body->end();
for (; bodyIt != bodyE; ++bodyIt) {
+ if (isa<ast::ReturnStmt>(*bodyIt)) {
+ return emitError((*bodyIt)->getLoc(),
+ "`return` statements are only permitted within a "
+ "`Constraint` or `Rewrite` body");
+ }
// Break when we've found the rewrite statement.
if (isa<ast::OpRewriteStmt>(*bodyIt))
break;
}
FailureOr<ast::VariableDecl *>
-Parser::defineVariableDecl(StringRef name, SMRange nameLoc,
- ast::Type type, ast::Expr *initExpr,
+Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
+ ast::Expr *initExpr,
ArrayRef<ast::ConstraintRef> constraints) {
assert(curDeclScope && "defining variable outside of decl scope");
const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc);
}
FailureOr<ast::VariableDecl *>
-Parser::defineVariableDecl(StringRef name, SMRange nameLoc,
- ast::Type type,
+Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
ArrayRef<ast::ConstraintRef> constraints) {
return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr,
constraints);
SmallVectorImpl<ast::ConstraintRef> &constraints) {
Optional<SMRange> typeConstraint;
auto parseSingleConstraint = [&] {
- FailureOr<ast::ConstraintRef> constraint =
- parseConstraint(typeConstraint, constraints);
+ FailureOr<ast::ConstraintRef> constraint = parseConstraint(
+ typeConstraint, constraints, /*allowInlineTypeConstraints=*/true);
if (failed(constraint))
return failure();
constraints.push_back(*constraint);
FailureOr<ast::ConstraintRef>
Parser::parseConstraint(Optional<SMRange> &typeConstraint,
- ArrayRef<ast::ConstraintRef> existingConstraints) {
+ ArrayRef<ast::ConstraintRef> existingConstraints,
+ bool allowInlineTypeConstraints) {
auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult {
+ if (!allowInlineTypeConstraints) {
+ return emitError(
+ curToken.getLoc(),
+ "inline `Attr`, `Value`, and `ValueRange` type constraints are not "
+ "permitted on arguments or results");
+ }
if (typeConstraint)
return emitErrorAndNote(
curToken.getLoc(),
return ast::ConstraintRef(
ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc);
}
+
+ case Token::kw_Constraint: {
+ // Handle an inline constraint.
+ FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
+ if (failed(decl))
+ return failure();
+ return ast::ConstraintRef(*decl, loc);
+ }
case Token::identifier: {
StringRef constraintName = curToken.getSpelling();
consumeToken(Token::identifier);
return emitError(loc, "expected identifier constraint");
}
+FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() {
+ Optional<SMRange> typeConstraint;
+ return parseConstraint(typeConstraint, /*existingConstraints=*/llvm::None,
+ /*allowInlineTypeConstraints=*/false);
+}
+
//===----------------------------------------------------------------------===//
// Exprs
case Token::kw_attr:
lhsExpr = parseAttributeExpr();
break;
+ case Token::kw_Constraint:
+ lhsExpr = parseInlineConstraintLambdaExpr();
+ break;
case Token::identifier:
lhsExpr = parseIdentifierExpr();
break;
case Token::kw_op:
lhsExpr = parseOperationExpr();
break;
+ case Token::kw_Rewrite:
+ lhsExpr = parseInlineRewriteLambdaExpr();
+ break;
case Token::kw_type:
lhsExpr = parseTypeExpr();
break;
case Token::dot:
lhsExpr = parseMemberAccessExpr(*lhsExpr);
break;
+ case Token::l_paren:
+ lhsExpr = parseCallExpr(*lhsExpr);
+ break;
default:
return lhsExpr;
}
return ast::AttributeExpr::create(ctx, loc, attrExpr);
}
-FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name,
- SMRange loc) {
+FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr) {
+ SMRange loc = curToken.getLoc();
+ consumeToken(Token::l_paren);
+
+ // Parse the arguments of the call.
+ SmallVector<ast::Expr *> arguments;
+ if (curToken.isNot(Token::r_paren)) {
+ do {
+ FailureOr<ast::Expr *> argument = parseExpr();
+ if (failed(argument))
+ return failure();
+ arguments.push_back(*argument);
+ } while (consumeIf(Token::comma));
+ }
+ loc.End = curToken.getEndLoc();
+ if (failed(parseToken(Token::r_paren, "expected `)` after argument list")))
+ return failure();
+
+ return createCallExpr(loc, parentExpr, arguments);
+}
+
+FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) {
ast::Decl *decl = curDeclScope->lookup(name);
if (!decl)
return emitError(loc, "undefined reference to `" + name + "`");
return parseDeclRefExpr(name, nameLoc);
}
+FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() {
+ FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
+ if (failed(decl))
+ return failure();
+
+ return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
+ ast::ConstraintType::get(ctx));
+}
+
+FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() {
+ FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl();
+ if (failed(decl))
+ return failure();
+
+ return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
+ ast::RewriteType::get(ctx));
+}
+
FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) {
SMRange loc = curToken.getLoc();
consumeToken(Token::dot);
case Token::kw_replace:
stmt = parseReplaceStmt();
break;
+ case Token::kw_return:
+ stmt = parseReturnStmt();
+ break;
case Token::kw_rewrite:
stmt = parseRewriteStmt();
break;
}
FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() {
+ if (parserContext == ParserContext::Constraint)
+ return emitError("`erase` cannot be used within a Constraint");
SMRange loc = curToken.getLoc();
consumeToken(Token::kw_erase);
}
FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() {
+ if (parserContext == ParserContext::Constraint)
+ return emitError("`replace` cannot be used within a Constraint");
SMRange loc = curToken.getLoc();
consumeToken(Token::kw_replace);
return createReplaceStmt(loc, *rootOp, replValues);
}
+FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() {
+ SMRange loc = curToken.getLoc();
+ consumeToken(Token::kw_return);
+
+ // Parse the result value.
+ FailureOr<ast::Expr *> resultExpr = parseExpr();
+ if (failed(resultExpr))
+ return failure();
+
+ return ast::ReturnStmt::create(ctx, loc, *resultExpr);
+}
+
FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() {
+ if (parserContext == ParserContext::Constraint)
+ return emitError("`rewrite` cannot be used within a Constraint");
SMRange loc = curToken.getLoc();
consumeToken(Token::kw_rewrite);
if (failed(rewriteBody))
return failure();
+ // Verify the rewrite body.
+ for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) {
+ if (isa<ast::ReturnStmt>(stmt)) {
+ return emitError(stmt->getLoc(),
+ "`return` statements are only permitted within a "
+ "`Constraint` or `Rewrite` body");
+ }
+ }
+
return createRewriteStmt(loc, *rootOp, *rewriteBody);
}
//===----------------------------------------------------------------------===//
// Decls
+ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) {
+ // Unwrap reference expressions.
+ if (auto *init = dyn_cast<ast::DeclRefExpr>(node))
+ node = init->getDecl();
+ return dyn_cast<ast::CallableDecl>(node);
+}
+
FailureOr<ast::PatternDecl *>
Parser::createPatternDecl(SMRange loc, const ast::Name *name,
const ParsedPatternMetadata &metadata,
metadata.hasBoundedRecursion, body);
}
+ast::Type Parser::createUserConstraintRewriteResultType(
+ ArrayRef<ast::VariableDecl *> results) {
+ // Single result decls use the type of the single result.
+ if (results.size() == 1)
+ return results[0]->getType();
+
+ // Multiple results use a tuple type, with the types and names grabbed from
+ // the result variable decls.
+ auto resultTypes = llvm::map_range(
+ results, [&](const auto *result) { return result->getType(); });
+ auto resultNames = llvm::map_range(
+ results, [&](const auto *result) { return result->getName().getName(); });
+ return ast::TupleType::get(ctx, llvm::to_vector(resultTypes),
+ llvm::to_vector(resultNames));
+}
+
+template <typename T>
+FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl(
+ const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
+ ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
+ ast::CompoundStmt *body) {
+ if (!body->getChildren().empty()) {
+ if (auto *retStmt = dyn_cast<ast::ReturnStmt>(body->getChildren().back())) {
+ ast::Expr *resultExpr = retStmt->getResultExpr();
+
+ // Process the result of the decl. If no explicit signature results
+ // were provided, check for return type inference. Otherwise, check that
+ // the return expression can be converted to the expected type.
+ if (results.empty())
+ resultType = resultExpr->getType();
+ else if (failed(convertExpressionTo(resultExpr, resultType)))
+ return failure();
+ else
+ retStmt->setResultExpr(resultExpr);
+ }
+ }
+ return T::createPDLL(ctx, name, arguments, results, body, resultType);
+}
+
FailureOr<ast::VariableDecl *>
-Parser::createVariableDecl(StringRef name, SMRange loc,
- ast::Expr *initializer,
+Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
ArrayRef<ast::ConstraintRef> constraints) {
// The type of the variable, which is expected to be inferred by either a
// constraint or an initializer expression.
"list or the initializer");
}
+ // Constraint types cannot be used when defining variables.
+ if (type.isa<ast::ConstraintType, ast::RewriteType>()) {
+ return emitError(
+ loc, llvm::formatv("unable to define variable of `{0}` type", type));
+ }
+
// Try to define a variable with the given name.
FailureOr<ast::VariableDecl *> varDecl =
defineVariableDecl(name, loc, type, initializer, constraints);
return *varDecl;
}
+FailureOr<ast::VariableDecl *>
+Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc,
+ const ast::ConstraintRef &constraint) {
+ // Constraint arguments may apply more complex constraints via the arguments.
+ bool allowNonCoreConstraints = parserContext == ParserContext::Constraint;
+ ast::Type argType;
+ if (failed(validateVariableConstraint(constraint, argType,
+ allowNonCoreConstraints)))
+ return failure();
+ return defineVariableDecl(name, loc, argType, constraint);
+}
+
LogicalResult
Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
ast::Type &inferredType) {
}
LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref,
- ast::Type &inferredType) {
+ ast::Type &inferredType,
+ bool allowNonCoreConstraints) {
ast::Type constraintType;
if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) {
if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
return failure();
}
constraintType = valueRangeTy;
+ } else if (const auto *cst =
+ dyn_cast<ast::UserConstraintDecl>(ref.constraint)) {
+ if (!allowNonCoreConstraints) {
+ return emitError(ref.referenceLoc,
+ "`Rewrite` arguments and results are only permitted to "
+ "use core constraints, such as `Attr`, `Op`, `Type`, "
+ "`TypeRange`, `Value`, `ValueRange`");
+ }
+
+ ArrayRef<ast::VariableDecl *> inputs = cst->getInputs();
+ if (inputs.size() != 1) {
+ return emitErrorAndNote(ref.referenceLoc,
+ "`Constraint`s applied via a variable constraint "
+ "list must take a single input, but got " +
+ Twine(inputs.size()),
+ cst->getLoc(),
+ "see definition of constraint here");
+ }
+ constraintType = inputs.front()->getType();
} else {
llvm_unreachable("unknown constraint type");
}
//===----------------------------------------------------------------------===//
// Exprs
+FailureOr<ast::CallExpr *>
+Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr,
+ MutableArrayRef<ast::Expr *> arguments) {
+ ast::Type parentType = parentExpr->getType();
+
+ ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr);
+ if (!callableDecl) {
+ return emitError(loc,
+ llvm::formatv("expected a reference to a callable "
+ "`Constraint` or `Rewrite`, but got: `{0}`",
+ parentType));
+ }
+ if (parserContext == ParserContext::Rewrite) {
+ if (isa<ast::UserConstraintDecl>(callableDecl))
+ return emitError(
+ loc, "unable to invoke `Constraint` within a rewrite section");
+ } else if (isa<ast::UserRewriteDecl>(callableDecl)) {
+ return emitError(loc, "unable to invoke `Rewrite` within a match section");
+ }
+
+ // Verify the arguments of the call.
+ /// Handle size mismatch.
+ ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs();
+ if (callArgs.size() != arguments.size()) {
+ return emitErrorAndNote(
+ loc,
+ llvm::formatv("invalid number of arguments for {0} call; expected "
+ "{1}, but got {2}",
+ callableDecl->getCallableType(), callArgs.size(),
+ arguments.size()),
+ callableDecl->getLoc(),
+ llvm::formatv("see the definition of {0} here",
+ callableDecl->getName()->getName()));
+ }
+
+ /// Handle argument type mismatch.
+ auto attachDiagFn = [&](ast::Diagnostic &diag) {
+ diag.attachNote(llvm::formatv("see the definition of `{0}` here",
+ callableDecl->getName()->getName()),
+ callableDecl->getLoc());
+ };
+ for (auto it : llvm::zip(callArgs, arguments)) {
+ if (failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->getType(),
+ attachDiagFn)))
+ return failure();
+ }
+
+ return ast::CallExpr::create(ctx, loc, parentExpr, arguments,
+ callableDecl->getResultType());
+}
+
FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc,
ast::Decl *decl) {
// Check the type of decl being referenced.
ast::Type declType;
- if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
+ if (isa<ast::ConstraintDecl>(decl))
+ declType = ast::ConstraintType::get(ctx);
+ else if (isa<ast::UserRewriteDecl>(decl))
+ declType = ast::RewriteType::get(ctx);
+ else if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
declType = varDecl->getType();
else
return emitError(loc, "invalid reference to `" +
}
FailureOr<ast::DeclRefExpr *>
-Parser::createInlineVariableExpr(ast::Type type, StringRef name,
- SMRange loc,
+Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
ArrayRef<ast::ConstraintRef> constraints) {
FailureOr<ast::VariableDecl *> decl =
defineVariableDecl(name, loc, type, constraints);
}
FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
- StringRef name,
- SMRange loc) {
+ StringRef name, SMRange loc) {
ast::Type parentType = parentExpr->getType();
if (parentType.isa<ast::OperationType>()) {
if (name == ast::AllResultsMemberAccessExpr::getMemberName())
}
LogicalResult Parser::validateOperationOperandsOrResults(
- SMRange loc, Optional<StringRef> name,
- MutableArrayRef<ast::Expr *> values, ast::Type singleTy,
- ast::Type rangeTy) {
+ SMRange loc, Optional<StringRef> name, MutableArrayRef<ast::Expr *> values,
+ ast::Type singleTy, ast::Type rangeTy) {
// All operation types accept a single range parameter.
if (values.size() == 1) {
if (failed(convertExpressionTo(values[0], rangeTy)))
ArrayRef<StringRef> elementNames) {
for (const ast::Expr *element : elements) {
ast::Type eleTy = element->getType();
- if (eleTy.isa<ast::ConstraintType, ast::TupleType>()) {
+ if (eleTy.isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>()) {
return emitError(
element->getLoc(),
llvm::formatv("unable to build a tuple with `{0}` element", eleTy));