[mlir][ods] Fix OpDefinitionsGen infer return types builder with regions
authorMogball <jeffniu22@gmail.com>
Fri, 10 Dec 2021 15:51:02 +0000 (15:51 +0000)
committerMogball <jeffniu22@gmail.com>
Mon, 13 Dec 2021 15:11:35 +0000 (15:11 +0000)
Despite handling regions and inferred return types, the builder was never generated for ops with both InferReturnTypeOpInterface and regions.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D115525

mlir/test/lib/Dialect/Test/TestOps.td
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
mlir/unittests/TableGen/OpBuildGen.cpp

index 655ad2d..627fac4 100644 (file)
@@ -2315,11 +2315,11 @@ def TableGenBuildOp4 : TEST_Op<"tblgen_build_4", [SameOperandsAndResultType]> {
   let results = (outs AnyType:$result);
 }
 
-// Single variadic arg with SameOperandsAndResultType and InferTypeOpInterface.
-// Tests suppression of ambiguous build methods for operations with
-// SameOperandsAndResultType and InferTypeOpInterface.
-def TableGenBuildOp5 : TEST_Op<"tblgen_build_5",
-      [SameOperandsAndResultType, InferTypeOpInterface]> {
+// Base class for testing `build` methods for ops with
+// InferReturnTypeOpInterface.
+class TableGenBuildInferReturnTypeBaseOp<string mnemonic,
+                                         list<OpTrait> traits = []>
+    : TEST_Op<mnemonic, [InferTypeOpInterface] # traits> {
   let arguments = (ins Variadic<AnyType>:$inputs);
   let results = (outs AnyType:$result);
 
@@ -2334,6 +2334,18 @@ def TableGenBuildOp5 : TEST_Op<"tblgen_build_5",
    }];
 }
 
+// Single variadic arg with SameOperandsAndResultType and InferTypeOpInterface.
+// Tests suppression of ambiguous build methods for operations with
+// SameOperandsAndResultType and InferTypeOpInterface.
+def TableGenBuildOp5 : TableGenBuildInferReturnTypeBaseOp<
+    "tblgen_build_5", [SameOperandsAndResultType]>;
+
+// Op with InferTypeOpInterface and regions.
+def TableGenBuildOp6 : TableGenBuildInferReturnTypeBaseOp<
+    "tblgen_build_6", [InferTypeOpInterface]> {
+  let regions = (region AnyRegion:$body);
+}
+
 //===----------------------------------------------------------------------===//
 // Test BufferPlacement
 //===----------------------------------------------------------------------===//
index f84646c..e331397 100644 (file)
@@ -1220,8 +1220,7 @@ static bool canGenerateUnwrappedBuilder(Operator &op) {
 }
 
 static bool canInferType(Operator &op) {
-  return op.getTrait("::mlir::InferTypeOpInterface::Trait") &&
-         op.getNumRegions() == 0;
+  return op.getTrait("::mlir::InferTypeOpInterface::Trait");
 }
 
 void OpEmitter::genSeparateArgParamBuilder() {
@@ -1304,7 +1303,7 @@ void OpEmitter::genSeparateArgParamBuilder() {
   // ambiguous function detection will elide those ones.
   for (auto attrType : attrBuilderType) {
     emit(attrType, TypeParamKind::Separate, /*inferType=*/false);
-    if (canInferType(op))
+    if (canInferType(op) && op.getNumRegions() == 0)
       emit(attrType, TypeParamKind::None, /*inferType=*/true);
     emit(attrType, TypeParamKind::Collective, /*inferType=*/false);
   }
@@ -1392,21 +1391,22 @@ void OpEmitter::genInferredTypeCollectiveParamBuilder() {
 
   // Result types
   body << formatv(R"(
-    ::mlir::SmallVector<::mlir::Type, 2> inferredReturnTypes;
-    if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
-                  {1}.location, operands,
-                  {1}.attributes.getDictionary({1}.getContext()),
-                  /*regions=*/{{}, inferredReturnTypes))) {{)",
+  ::mlir::SmallVector<::mlir::Type, 2> inferredReturnTypes;
+  if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
+          {1}.location, operands,
+          {1}.attributes.getDictionary({1}.getContext()),
+          {1}.regions, inferredReturnTypes))) {{)",
                   opClass.getClassName(), builderOpState);
   if (numVariadicResults == 0 || numNonVariadicResults != 0)
-    body << "  assert(inferredReturnTypes.size()"
+    body << "\n    assert(inferredReturnTypes.size()"
          << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
-         << "u && \"mismatched number of return types\");\n";
-  body << "      " << builderOpState << ".addTypes(inferredReturnTypes);";
+         << "u && \"mismatched number of return types\");";
+  body << "\n    " << builderOpState << ".addTypes(inferredReturnTypes);";
 
   body << formatv(R"(
-    } else
-      ::llvm::report_fatal_error("Failed to infer result type(s).");)",
+  } else {{
+    ::llvm::report_fatal_error("Failed to infer result type(s).");
+  })",
                   opClass.getClassName(), builderOpState);
 }
 
@@ -1606,7 +1606,7 @@ void OpEmitter::genCollectiveParamBuilder() {
   body << "  " << builderOpState << ".addTypes(resultTypes);\n";
 
   // Generate builder that infers type too.
-  // TODO: Expand to handle regions and successors.
+  // TODO: Expand to handle successors.
   if (canInferType(op) && op.getNumSuccessors() == 0)
     genInferredTypeCollectiveParamBuilder();
 }
index 4e692ca..3b6f489 100644 (file)
@@ -219,4 +219,11 @@ TEST_F(
   testSingleVariadicInputInferredType<test::TableGenBuildOp5>();
 }
 
+TEST_F(OpBuildGenTest, BuildMethodsRegionsAndInferredType) {
+  auto op = builder.create<test::TableGenBuildOp6>(
+      loc, ValueRange{*cstI32, *cstF32}, /*attributes=*/noAttrs);
+  ASSERT_EQ(op->getNumRegions(), 1u);
+  verifyOp(std::move(op), {i32Ty}, {*cstI32, *cstF32}, noAttrs);
+}
+
 } // namespace mlir