let results = (outs AnyType:$result);
}
-// Single variadic arg with SameOperandsAndResultType and InferTypeOpInterface.
-// Tests suppression of ambiguous build methods for operations with
-// SameOperandsAndResultType and InferTypeOpInterface.
-def TableGenBuildOp5 : TEST_Op<"tblgen_build_5",
- [SameOperandsAndResultType, InferTypeOpInterface]> {
+// Base class for testing `build` methods for ops with
+// InferReturnTypeOpInterface.
+class TableGenBuildInferReturnTypeBaseOp<string mnemonic,
+ list<OpTrait> traits = []>
+ : TEST_Op<mnemonic, [InferTypeOpInterface] # traits> {
let arguments = (ins Variadic<AnyType>:$inputs);
let results = (outs AnyType:$result);
}];
}
+// Single variadic arg with SameOperandsAndResultType and InferTypeOpInterface.
+// Tests suppression of ambiguous build methods for operations with
+// SameOperandsAndResultType and InferTypeOpInterface.
+def TableGenBuildOp5 : TableGenBuildInferReturnTypeBaseOp<
+ "tblgen_build_5", [SameOperandsAndResultType]>;
+
+// Op with InferTypeOpInterface and regions.
+def TableGenBuildOp6 : TableGenBuildInferReturnTypeBaseOp<
+ "tblgen_build_6", [InferTypeOpInterface]> {
+ let regions = (region AnyRegion:$body);
+}
+
//===----------------------------------------------------------------------===//
// Test BufferPlacement
//===----------------------------------------------------------------------===//
}
static bool canInferType(Operator &op) {
- return op.getTrait("::mlir::InferTypeOpInterface::Trait") &&
- op.getNumRegions() == 0;
+ return op.getTrait("::mlir::InferTypeOpInterface::Trait");
}
void OpEmitter::genSeparateArgParamBuilder() {
// ambiguous function detection will elide those ones.
for (auto attrType : attrBuilderType) {
emit(attrType, TypeParamKind::Separate, /*inferType=*/false);
- if (canInferType(op))
+ if (canInferType(op) && op.getNumRegions() == 0)
emit(attrType, TypeParamKind::None, /*inferType=*/true);
emit(attrType, TypeParamKind::Collective, /*inferType=*/false);
}
// Result types
body << formatv(R"(
- ::mlir::SmallVector<::mlir::Type, 2> inferredReturnTypes;
- if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
- {1}.location, operands,
- {1}.attributes.getDictionary({1}.getContext()),
- /*regions=*/{{}, inferredReturnTypes))) {{)",
+ ::mlir::SmallVector<::mlir::Type, 2> inferredReturnTypes;
+ if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
+ {1}.location, operands,
+ {1}.attributes.getDictionary({1}.getContext()),
+ {1}.regions, inferredReturnTypes))) {{)",
opClass.getClassName(), builderOpState);
if (numVariadicResults == 0 || numNonVariadicResults != 0)
- body << " assert(inferredReturnTypes.size()"
+ body << "\n assert(inferredReturnTypes.size()"
<< (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
- << "u && \"mismatched number of return types\");\n";
- body << " " << builderOpState << ".addTypes(inferredReturnTypes);";
+ << "u && \"mismatched number of return types\");";
+ body << "\n " << builderOpState << ".addTypes(inferredReturnTypes);";
body << formatv(R"(
- } else
- ::llvm::report_fatal_error("Failed to infer result type(s).");)",
+ } else {{
+ ::llvm::report_fatal_error("Failed to infer result type(s).");
+ })",
opClass.getClassName(), builderOpState);
}
body << " " << builderOpState << ".addTypes(resultTypes);\n";
// Generate builder that infers type too.
- // TODO: Expand to handle regions and successors.
+ // TODO: Expand to handle successors.
if (canInferType(op) && op.getNumSuccessors() == 0)
genInferredTypeCollectiveParamBuilder();
}
testSingleVariadicInputInferredType<test::TableGenBuildOp5>();
}
+TEST_F(OpBuildGenTest, BuildMethodsRegionsAndInferredType) {
+ auto op = builder.create<test::TableGenBuildOp6>(
+ loc, ValueRange{*cstI32, *cstF32}, /*attributes=*/noAttrs);
+ ASSERT_EQ(op->getNumRegions(), 1u);
+ verifyOp(std::move(op), {i32Ty}, {*cstI32, *cstF32}, noAttrs);
+}
+
} // namespace mlir