[mlir:PDLL] Don't require users to provide operands/results when all are variadic
authorRiver Riddle <riddleriver@gmail.com>
Mon, 12 Sep 2022 18:43:44 +0000 (11:43 -0700)
committerRiver Riddle <riddleriver@gmail.com>
Tue, 8 Nov 2022 09:57:58 +0000 (01:57 -0800)
When all operands or results are variadic, zero values is a perfectly valid behavior
to expect, and we shouldn't force the user to provide values in this case. For example,
when creating a call or a return operation we often don't want/need to provide return
values.

Differential Revision: https://reviews.llvm.org/D133721

mlir/lib/Tools/PDLL/Parser/Parser.cpp
mlir/test/lib/Transforms/TestDialectConversion.pdll
mlir/test/mlir-pdll/Parser/expr.pdll
mlir/test/mlir-pdll/Parser/include/ops.td

index ffa7f0c..de19f57 100644 (file)
@@ -426,23 +426,23 @@ private:
   FailureOr<ast::OperationExpr *>
   createOperationExpr(SMRange loc, const ast::OpNameDecl *name,
                       OpResultTypeContext resultTypeContext,
-                      MutableArrayRef<ast::Expr *> operands,
+                      SmallVectorImpl<ast::Expr *> &operands,
                       MutableArrayRef<ast::NamedAttributeDecl *> attributes,
-                      MutableArrayRef<ast::Expr *> results);
+                      SmallVectorImpl<ast::Expr *> &results);
   LogicalResult
   validateOperationOperands(SMRange loc, Optional<StringRef> name,
                             const ods::Operation *odsOp,
-                            MutableArrayRef<ast::Expr *> operands);
+                            SmallVectorImpl<ast::Expr *> &operands);
   LogicalResult validateOperationResults(SMRange loc, Optional<StringRef> name,
                                          const ods::Operation *odsOp,
-                                         MutableArrayRef<ast::Expr *> results);
+                                         SmallVectorImpl<ast::Expr *> &results);
   void checkOperationResultTypeInferrence(SMRange loc, StringRef name,
                                           const ods::Operation *odsOp);
   LogicalResult validateOperationOperandsOrResults(
       StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc,
-      Optional<StringRef> name, MutableArrayRef<ast::Expr *> values,
+      Optional<StringRef> name, SmallVectorImpl<ast::Expr *> &values,
       ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
-      ast::Type rangeTy);
+      ast::RangeType rangeTy);
   FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc,
                                               ArrayRef<ast::Expr *> elements,
                                               ArrayRef<StringRef> elementNames);
@@ -2851,9 +2851,9 @@ FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
 FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
     SMRange loc, const ast::OpNameDecl *name,
     OpResultTypeContext resultTypeContext,
-    MutableArrayRef<ast::Expr *> operands,
+    SmallVectorImpl<ast::Expr *> &operands,
     MutableArrayRef<ast::NamedAttributeDecl *> attributes,
-    MutableArrayRef<ast::Expr *> results) {
+    SmallVectorImpl<ast::Expr *> &results) {
   Optional<StringRef> opNameRef = name->getName();
   const ods::Operation *odsOp = lookupODSOperation(opNameRef);
 
@@ -2896,7 +2896,7 @@ FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
 LogicalResult
 Parser::validateOperationOperands(SMRange loc, Optional<StringRef> name,
                                   const ods::Operation *odsOp,
-                                  MutableArrayRef<ast::Expr *> operands) {
+                                  SmallVectorImpl<ast::Expr *> &operands) {
   return validateOperationOperandsOrResults(
       "operand", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name,
       operands, odsOp ? odsOp->getOperands() : llvm::None, valueTy,
@@ -2906,7 +2906,7 @@ Parser::validateOperationOperands(SMRange loc, Optional<StringRef> name,
 LogicalResult
 Parser::validateOperationResults(SMRange loc, Optional<StringRef> name,
                                  const ods::Operation *odsOp,
-                                 MutableArrayRef<ast::Expr *> results) {
+                                 SmallVectorImpl<ast::Expr *> &results) {
   return validateOperationOperandsOrResults(
       "result", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name,
       results, odsOp ? odsOp->getResults() : llvm::None, typeTy, typeRangeTy);
@@ -2956,9 +2956,9 @@ void Parser::checkOperationResultTypeInferrence(SMRange loc, StringRef opName,
 
 LogicalResult Parser::validateOperationOperandsOrResults(
     StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc,
-    Optional<StringRef> name, MutableArrayRef<ast::Expr *> values,
+    Optional<StringRef> name, SmallVectorImpl<ast::Expr *> &values,
     ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
-    ast::Type rangeTy) {
+    ast::RangeType rangeTy) {
   // All operation types accept a single range parameter.
   if (values.size() == 1) {
     if (failed(convertExpressionTo(values[0], rangeTy)))
@@ -2969,14 +2969,56 @@ LogicalResult Parser::validateOperationOperandsOrResults(
   /// If the operation has ODS information, we can more accurately verify the
   /// values.
   if (odsOpLoc) {
-    if (odsValues.size() != values.size()) {
+    auto emitSizeMismatchError = [&] {
       return emitErrorAndNote(
           loc,
           llvm::formatv("invalid number of {0} groups for `{1}`; expected "
                         "{2}, but got {3}",
                         groupName, *name, odsValues.size(), values.size()),
           *odsOpLoc, llvm::formatv("see the definition of `{0}` here", *name));
+    };
+
+    // Handle the case where no values were provided.
+    if (values.empty()) {
+      // If we don't expect any on the ODS side, we are done.
+      if (odsValues.empty())
+        return success();
+
+      // If we do, check if we actually need to provide values (i.e. if any of
+      // the values are actually required).
+      unsigned numVariadic = 0;
+      for (const auto &odsValue : odsValues) {
+        if (!odsValue.isVariableLength())
+          return emitSizeMismatchError();
+        ++numVariadic;
+      }
+
+      // If we are in a non-rewrite context, we don't need to do anything more.
+      // Zero-values is a valid constraint on the operation.
+      if (parserContext != ParserContext::Rewrite)
+        return success();
+
+      // Otherwise, when in a rewrite we may need to provide values to match the
+      // ODS signature of the operation to create.
+
+      // If we only have one variadic value, just use an empty list.
+      if (numVariadic == 1)
+        return success();
+
+      // Otherwise, create dummy values for each of the entries so that we
+      // adhere to the ODS signature.
+      for (unsigned i = 0, e = odsValues.size(); i < e; ++i) {
+        values.push_back(
+            ast::RangeExpr::create(ctx, loc, /*elements=*/llvm::None, rangeTy));
+      }
+      return success();
     }
+
+    // Verify that the number of values provided matches the number of value
+    // groups ODS expects.
+    if (odsValues.size() != values.size())
+      return emitSizeMismatchError();
+
     auto diagFn = [&](ast::Diagnostic &diag) {
       diag.attachNote(llvm::formatv("see the definition of `{0}` here", *name),
                       *odsOpLoc);
index c29e852..a6cd211 100644 (file)
@@ -10,9 +10,8 @@
 #include "mlir/Transforms/DialectConversion.pdll"
 
 /// Change the result type of a producer.
-// FIXME: We shouldn't need to specify arguments for the result cast.
-Pattern => replace op<test.cast>(args: ValueRange) -> (results: TypeRange)
-  with op<test.cast>(args) -> (convertTypes(results));
+Pattern => replace op<test.cast> -> (results: TypeRange)
+  with op<test.cast> -> (convertTypes(results));
 
 /// Pass through test.return conversion.
 Pattern => replace op<test.return>(args: ValueRange)
index 6e68883..0736962 100644 (file)
@@ -213,6 +213,34 @@ Pattern {
 
 // -----
 
+// Test that we don't need to provide values if all elements
+// are optional.
+
+#include "include/ops.td"
+
+// CHECK: Module
+// CHECK:  -OperationExpr {{.*}} Type<Op<test.multi_variadic>>
+// CHECK-NOT:   `Operands`
+// CHECK-NOT:   `Result Types`
+// CHECK:  -OperationExpr {{.*}} Type<Op<test.all_variadic>>
+// CHECK-NOT:   `Operands`
+// CHECK-NOT:   `Result Types`
+// CHECK:  -OperationExpr {{.*}} Type<Op<test.multi_variadic>>
+// CHECK:    `Operands`
+// CHECK:      -RangeExpr {{.*}} Type<ValueRange>
+// CHECK:      -RangeExpr {{.*}} Type<ValueRange>
+// CHECK:    `Result Types`
+// CHECK:      -RangeExpr {{.*}} Type<TypeRange>
+// CHECK:      -RangeExpr {{.*}} Type<TypeRange>
+Pattern {
+  rewrite op<test.multi_variadic>() -> () with {
+    op<test.all_variadic> -> ();
+    op<test.multi_variadic> -> ();
+  };
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // TupleExpr
 //===----------------------------------------------------------------------===//
index 91c4122..575b475 100644 (file)
@@ -28,3 +28,8 @@ def OpAllVariadic : Op<Test_Dialect, "all_variadic"> {
 def OpMultipleSingleResult : Op<Test_Dialect, "multiple_single_result"> {
   let results = (outs I64:$result, I64:$result2);
 }
+
+def OpMultiVariadic : Op<Test_Dialect, "multi_variadic"> {
+  let arguments = (ins Variadic<I64>:$operands, Variadic<I64>:$operand2);
+  let results = (outs Variadic<I64>:$results, Variadic<I64>:$results2);
+}