[mlir:PDLL] Allow complex constraints on Rewrite arguments/results
authorRiver Riddle <riddleriver@gmail.com>
Thu, 15 Sep 2022 23:18:32 +0000 (16:18 -0700)
committerRiver Riddle <riddleriver@gmail.com>
Tue, 8 Nov 2022 09:57:58 +0000 (01:57 -0800)
The documentation already has examples of this, and it allows for
using nicer C++ types when defining native Rewrites.

Differential Revision: https://reviews.llvm.org/D133989

mlir/include/mlir/Tools/PDLL/AST/Nodes.h
mlir/include/mlir/Tools/PDLL/Parser/CodeComplete.h
mlir/lib/Tools/PDLL/Parser/CodeComplete.cpp
mlir/lib/Tools/PDLL/Parser/Parser.cpp
mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp
mlir/test/mlir-pdll/CodeGen/CPP/general.pdll
mlir/test/mlir-pdll/CodeGen/CPP/include/ods.td [new file with mode: 0644]
mlir/test/mlir-pdll/Parser/rewrite-failure.pdll

index 5f282b0..2281115 100644 (file)
@@ -943,7 +943,7 @@ private:
                      Type resultType)
       : Base(name.getLoc(), &name), numInputs(numInputs),
         numResults(numResults), codeBlock(codeBlock), constraintBody(body),
-        resultType(resultType) {}
+        resultType(resultType), hasNativeInputTypes(hasNativeInputTypes) {}
 
   /// The number of inputs to this constraint.
   unsigned numInputs;
index d0ccbe9..90ceda9 100644 (file)
@@ -48,12 +48,9 @@ public:
 
   /// Signal code completion for a constraint name with an optional decl scope.
   /// `currentType` is the current type of the variable that will use the
-  /// constraint, or nullptr if a type is unknown. `allowNonCoreConstraints`
-  /// indicates if user defined constraints are allowed in the completion
-  /// results. `allowInlineTypeConstraints` enables inline type constraints for
-  /// Attr/Value/ValueRange.
+  /// constraint, or nullptr if a type is unknown. `allowInlineTypeConstraints`
+  /// enables inline type constraints for Attr/Value/ValueRange.
   virtual void codeCompleteConstraintName(ast::Type currentType,
-                                          bool allowNonCoreConstraints,
                                           bool allowInlineTypeConstraints,
                                           const ast::DeclScope *scope);
 
index acc2ca8..9d241ea 100644 (file)
@@ -24,5 +24,5 @@ void CodeCompleteContext::codeCompleteOperationMemberAccess(
     ast::OperationType opType) {}
 
 void CodeCompleteContext::codeCompleteConstraintName(
-    ast::Type currentType, bool allowNonCoreConstraints,
-    bool allowInlineTypeConstraints, const ast::DeclScope *scope) {}
+    ast::Type currentType, bool allowInlineTypeConstraints,
+    const ast::DeclScope *scope) {}
index de19f57..3af285e 100644 (file)
@@ -297,13 +297,10 @@ private:
   /// existing constraints that have already been parsed for the same entity
   /// that will be constrained by this constraint. `allowInlineTypeConstraints`
   /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`.
-  /// If `allowNonCoreConstraints` is true, then complex (e.g. user defined
-  /// constraints) may be used with the variable.
   FailureOr<ast::ConstraintRef>
   parseConstraint(Optional<SMRange> &typeConstraint,
                   ArrayRef<ast::ConstraintRef> existingConstraints,
-                  bool allowInlineTypeConstraints,
-                  bool allowNonCoreConstraints);
+                  bool allowInlineTypeConstraints);
 
   /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl
   /// argument or result variable. The constraints for these variables do not
@@ -389,20 +386,16 @@ private:
   /// `inferredType` is the type of the variable inferred by the constraints
   /// within the list, and is updated to the most refined type as determined by
   /// the constraints. Returns success if the constraint list is valid, failure
-  /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user
-  /// defined constraints) may be used with the variable.
+  /// otherwise.
   LogicalResult
   validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
-                              ast::Type &inferredType,
-                              bool allowNonCoreConstraints = true);
+                              ast::Type &inferredType);
   /// Validate a single reference to a constraint. `inferredType` contains the
   /// currently inferred variabled type and is refined within the type defined
   /// by the constraint. Returns success if the constraint is valid, failure
-  /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user
-  /// defined constraints) may be used with the variable.
+  /// otherwise.
   LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref,
-                                           ast::Type &inferredType,
-                                           bool allowNonCoreConstraints = true);
+                                           ast::Type &inferredType);
   LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr);
   LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr);
 
@@ -469,7 +462,6 @@ private:
   LogicalResult codeCompleteMemberAccess(ast::Expr *parentExpr);
   LogicalResult codeCompleteAttributeName(Optional<StringRef> opName);
   LogicalResult codeCompleteConstraintName(ast::Type inferredType,
-                                           bool allowNonCoreConstraints,
                                            bool allowInlineTypeConstraints);
   LogicalResult codeCompleteDialectName();
   LogicalResult codeCompleteOperationName(StringRef dialectName);
@@ -1129,18 +1121,7 @@ FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) {
   // Check to see if this result is named.
   if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
     // Check to see if this name actually refers to a Constraint.
-    ast::Decl *existingDecl = curDeclScope->lookup(curToken.getSpelling());
-    if (isa_and_nonnull<ast::ConstraintDecl>(existingDecl)) {
-      // If yes, and this is a Rewrite, give a nice error message as non-Core
-      // constraints are not supported on Rewrite results.
-      if (parserContext == ParserContext::Rewrite) {
-        return emitError(
-            "`Rewrite` results are only permitted to use core constraints, "
-            "such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange`");
-      }
-
-      // Otherwise, parse this as an unnamed result variable.
-    } else {
+    if (!curDeclScope->lookup<ast::ConstraintDecl>(curToken.getSpelling())) {
       // If it wasn't a constraint, parse the result similarly to a variable. If
       // there is already an existing decl, we will emit an error when defining
       // this variable later.
@@ -1662,8 +1643,7 @@ LogicalResult Parser::parseVariableDeclConstraintList(
   Optional<SMRange> typeConstraint;
   auto parseSingleConstraint = [&] {
     FailureOr<ast::ConstraintRef> constraint = parseConstraint(
-        typeConstraint, constraints, /*allowInlineTypeConstraints=*/true,
-        /*allowNonCoreConstraints=*/true);
+        typeConstraint, constraints, /*allowInlineTypeConstraints=*/true);
     if (failed(constraint))
       return failure();
     constraints.push_back(*constraint);
@@ -1684,8 +1664,7 @@ LogicalResult Parser::parseVariableDeclConstraintList(
 FailureOr<ast::ConstraintRef>
 Parser::parseConstraint(Optional<SMRange> &typeConstraint,
                         ArrayRef<ast::ConstraintRef> existingConstraints,
-                        bool allowInlineTypeConstraints,
-                        bool allowNonCoreConstraints) {
+                        bool allowInlineTypeConstraints) {
   auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult {
     if (!allowInlineTypeConstraints) {
       return emitError(
@@ -1791,12 +1770,10 @@ Parser::parseConstraint(Optional<SMRange> &typeConstraint,
   case Token::code_complete: {
     // Try to infer the current type for use by code completion.
     ast::Type inferredType;
-    if (failed(validateVariableConstraints(existingConstraints, inferredType,
-                                           allowNonCoreConstraints)))
+    if (failed(validateVariableConstraints(existingConstraints, inferredType)))
       return failure();
 
-    return codeCompleteConstraintName(inferredType, allowNonCoreConstraints,
-                                      allowInlineTypeConstraints);
+    return codeCompleteConstraintName(inferredType, allowInlineTypeConstraints);
   }
   default:
     break;
@@ -1805,13 +1782,9 @@ Parser::parseConstraint(Optional<SMRange> &typeConstraint,
 }
 
 FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() {
-  // Constraint arguments may apply more complex constraints via the arguments.
-  bool allowNonCoreConstraints = parserContext == ParserContext::Constraint;
-
   Optional<SMRange> typeConstraint;
   return parseConstraint(typeConstraint, /*existingConstraints=*/llvm::None,
-                         /*allowInlineTypeConstraints=*/false,
-                         allowNonCoreConstraints);
+                         /*allowInlineTypeConstraints=*/false);
 }
 
 //===----------------------------------------------------------------------===//
@@ -2598,29 +2571,23 @@ Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
 FailureOr<ast::VariableDecl *>
 Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc,
                                       const ast::ConstraintRef &constraint) {
-  // Constraint arguments may apply more complex constraints via the arguments.
-  bool allowNonCoreConstraints = parserContext == ParserContext::Constraint;
   ast::Type argType;
-  if (failed(validateVariableConstraint(constraint, argType,
-                                        allowNonCoreConstraints)))
+  if (failed(validateVariableConstraint(constraint, argType)))
     return failure();
   return defineVariableDecl(name, loc, argType, constraint);
 }
 
 LogicalResult
 Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
-                                    ast::Type &inferredType,
-                                    bool allowNonCoreConstraints) {
+                                    ast::Type &inferredType) {
   for (const ast::ConstraintRef &ref : constraints)
-    if (failed(validateVariableConstraint(ref, inferredType,
-                                          allowNonCoreConstraints)))
+    if (failed(validateVariableConstraint(ref, inferredType)))
       return failure();
   return success();
 }
 
 LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref,
-                                                 ast::Type &inferredType,
-                                                 bool allowNonCoreConstraints) {
+                                                 ast::Type &inferredType) {
   ast::Type constraintType;
   if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) {
     if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
@@ -2652,13 +2619,6 @@ LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref,
     constraintType = valueRangeTy;
   } else if (const auto *cst =
                  dyn_cast<ast::UserConstraintDecl>(ref.constraint)) {
-    if (!allowNonCoreConstraints) {
-      return emitError(ref.referenceLoc,
-                       "`Rewrite` arguments and results are only permitted to "
-                       "use core constraints, such as `Attr`, `Op`, `Type`, "
-                       "`TypeRange`, `Value`, `ValueRange`");
-    }
-
     ArrayRef<ast::VariableDecl *> inputs = cst->getInputs();
     if (inputs.size() != 1) {
       return emitErrorAndNote(ref.referenceLoc,
@@ -3160,11 +3120,9 @@ LogicalResult Parser::codeCompleteAttributeName(Optional<StringRef> opName) {
 
 LogicalResult
 Parser::codeCompleteConstraintName(ast::Type inferredType,
-                                   bool allowNonCoreConstraints,
                                    bool allowInlineTypeConstraints) {
   codeCompleteContext->codeCompleteConstraintName(
-      inferredType, allowNonCoreConstraints, allowInlineTypeConstraints,
-      curDeclScope);
+      inferredType, allowInlineTypeConstraints, curDeclScope);
   return failure();
 }
 
index 476bf16..7846103 100644 (file)
@@ -760,7 +760,6 @@ public:
   }
 
   void codeCompleteConstraintName(ast::Type currentType,
-                                  bool allowNonCoreConstraints,
                                   bool allowInlineTypeConstraints,
                                   const ast::DeclScope *scope) final {
     auto addCoreConstraint = [&](StringRef constraint, StringRef mlirType,
@@ -808,9 +807,6 @@ public:
     while (scope) {
       for (const ast::Decl *decl : scope->getDecls()) {
         if (const auto *cst = dyn_cast<ast::UserConstraintDecl>(decl)) {
-          if (!allowNonCoreConstraints)
-            continue;
-
           lsp::CompletionItem item;
           item.label = cst->getName().getName().str();
           item.kind = lsp::CompletionItemKind::Interface;
index f975307..21a8966 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-pdll %s -I %S -split-input-file -x cpp | FileCheck %s
+// RUN: mlir-pdll %s -I %S -I %S/../../../../include -split-input-file -x cpp | FileCheck %s
 
 // Check that we generate a wrapper pattern for each PDL pattern. Also
 // add in a pattern awkwardly named the same as our generated patterns to
@@ -44,6 +44,8 @@ Pattern => erase op<test.op3>;
 
 // Check the generation of native constraints and rewrites.
 
+#include "include/ods.td"
+
 // CHECK:      static ::mlir::LogicalResult TestCstPDLFn(::mlir::PatternRewriter &rewriter,
 // CHECK-SAME:     ::mlir::Attribute attr, ::mlir::Operation * op, ::mlir::Type type,
 // CHECK-SAME:     ::mlir::Value value, ::mlir::TypeRange typeRange, ::mlir::ValueRange valueRange) {
@@ -58,6 +60,7 @@ Pattern => erase op<test.op3>;
 // CHECK: foo;
 // CHECK: }
 
+// CHECK: TestAttrInterface TestRewriteODSPDLFn(::mlir::PatternRewriter &rewriter, TestAttrInterface attr) {
 // CHECK: static ::mlir::Attribute TestRewriteSinglePDLFn(::mlir::PatternRewriter &rewriter) {
 // CHECK: std::tuple<::mlir::Attribute, ::mlir::Type> TestRewriteTuplePDLFn(::mlir::PatternRewriter &rewriter) {
 
@@ -73,6 +76,7 @@ Constraint TestCst(attr: Attr, op: Op, type: Type, value: Value, typeRange: Type
 Constraint TestUnusedCst() [{ return success(); }];
 
 Rewrite TestRewrite(attr: Attr, op: Op, type: Type, value: Value, typeRange: TypeRange, valueRange: ValueRange) [{ foo; }];
+Rewrite TestRewriteODS(attr: TestAttrInterface) -> TestAttrInterface [{}];
 Rewrite TestRewriteSingle() -> Attr [{}];
 Rewrite TestRewriteTuple() -> (Attr, Type) [{}];
 Rewrite TestUnusedRewrite(op: Op) [{}];
@@ -82,6 +86,7 @@ Pattern TestCstAndRewrite {
   TestCst(attr<"true">, root, type, operand, types, operands);
   rewrite root with {
     TestRewrite(attr<"true">, root, type, operand, types, operands);
+    TestRewriteODS(attr<"true">);
     TestRewriteSingle();
     TestRewriteTuple();
     erase root;
diff --git a/mlir/test/mlir-pdll/CodeGen/CPP/include/ods.td b/mlir/test/mlir-pdll/CodeGen/CPP/include/ods.td
new file mode 100644 (file)
index 0000000..3eb57a4
--- /dev/null
@@ -0,0 +1,3 @@
+include "mlir/IR/OpBase.td"
+
+def TestAttrInterface : AttrInterface<"TestAttrInterface">;
index 1cdb32b..dd8843d 100644 (file)
@@ -88,13 +88,6 @@ Rewrite Foo(arg: Value<type>){}
 
 // -----
 
-Constraint ValueConstraint(value: Value);
-
-// CHECK: arguments and results are only permitted to use core constraints, such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange`
-Rewrite Foo(arg: ValueConstraint);
-
-// -----
-
 // CHECK: expected `)` to end argument list
 Rewrite Foo(arg: Value{}
 
@@ -139,13 +132,6 @@ Rewrite Foo() -> Value<type>){}
 
 // -----
 
-Constraint ValueConstraint(value: Value);
-
-// CHECK: results are only permitted to use core constraints, such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange`
-Rewrite Foo() -> ValueConstraint;
-
-// -----
-
 //===----------------------------------------------------------------------===//
 // Native Rewrites
 //===----------------------------------------------------------------------===//