let results = (outs I32:$result1, F32:$result2, F32:$result3);
}
-def AnotherThreeResultOp : TEST_Op<"another_three_result", [DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+def AnotherThreeResultOp
+ : TEST_Op<"another_three_result",
+ [DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let arguments = (ins MultiResultOpEnum:$kind);
let results = (outs I32:$result1, F32:$result2, F32:$result3);
}
}];
}
+// Base class for testing mixing allOperandTypes, allOperands, and
+// inferResultTypes.
+class FormatInferAllTypesBaseOp<string mnemonic, list<OpTrait> traits = []>
+ : TEST_Op<mnemonic, [InferTypeOpInterface] # traits> {
+ let arguments = (ins Variadic<AnyType>:$args);
+ let results = (outs Variadic<AnyType>:$outs);
+ let extraClassDeclaration = [{
+ static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
+ ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands,
+ ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions,
+ ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
+ ::mlir::TypeRange operandTypes = operands.getTypes();
+ inferredReturnTypes.assign(operandTypes.begin(), operandTypes.end());
+ return ::mlir::success();
+ }
+ }];
+}
+
+// Test inferReturnTypes is called when allOperandTypes and allOperands is true.
+def FormatInferTypeAllOperandsAndTypesOp
+ : FormatInferAllTypesBaseOp<"format_infer_type_all_operands_and_types"> {
+ let assemblyFormat = "`(` operands `)` attr-dict `:` type(operands)";
+}
+
+// Test inferReturnTypes is called when allOperandTypes is true and there is one
+// ODS operand.
+def FormatInferTypeAllOperandsAndTypesOneOperandOp
+ : FormatInferAllTypesBaseOp<"format_infer_type_all_types_one_operand"> {
+ let assemblyFormat = "`(` $args `)` attr-dict `:` type(operands)";
+}
+
+// Test inferReturnTypes is called when allOperandTypes is true and there are
+// more than one ODS operands.
+def FormatInferTypeAllOperandsAndTypesTwoOperandsOp
+ : FormatInferAllTypesBaseOp<"format_infer_type_all_types_two_operands",
+ [SameVariadicOperandSize]> {
+ let arguments = (ins Variadic<AnyType>:$args0, Variadic<AnyType>:$args1);
+ let assemblyFormat = "`(` $args0 `)` `(` $args1 `)` attr-dict `:` type(operands)";
+}
+
+// Test inferReturnTypes is called when allOperands is true and operand types
+// are separately specified.
+def FormatInferTypeAllTypesOp
+ : FormatInferAllTypesBaseOp<"format_infer_type_all_types"> {
+ let assemblyFormat = "`(` operands `)` attr-dict `:` type($args)";
+}
+
//===----------------------------------------------------------------------===//
// Test SideEffects
//===----------------------------------------------------------------------===//
// CHECK: test.format_infer_type
%ignored_res7 = test.format_infer_type
+// CHECK: test.format_infer_type_all_operands_and_types(%[[I64]], %[[I32]]) : i64, i32
+%ignored_res8:2 = test.format_infer_type_all_operands_and_types(%i64, %i32) : i64, i32
+
+// CHECK: test.format_infer_type_all_types_one_operand(%[[I64]], %[[I32]]) : i64, i32
+%ignored_res9:2 = test.format_infer_type_all_types_one_operand(%i64, %i32) : i64, i32
+
+// CHECK: test.format_infer_type_all_types_two_operands(%[[I64]], %[[I32]]) (%[[I64]], %[[I32]]) : i64, i32, i64, i32
+%ignored_res10:4 = test.format_infer_type_all_types_two_operands(%i64, %i32) (%i64, %i32) : i64, i32, i64, i32
+
+// CHECK: test.format_infer_type_all_types(%[[I64]], %[[I32]]) : i64, i32
+%ignored_res11:2 = test.format_infer_type_all_types(%i64, %i32) : i64, i32
+
//===----------------------------------------------------------------------===//
// Check DefaultValuedStrAttr
//===----------------------------------------------------------------------===//
/// Generate the parser code for a specific format element.
void genElementParser(Element *element, MethodBody &body,
FmtContext &attrTypeCtx);
- /// Generate the c++ to resolve the types of operands and results during
+ /// Generate the C++ to resolve the types of operands and results during
/// parsing.
void genParserTypeResolution(Operator &op, MethodBody &body);
- /// Generate the c++ to resolve regions during parsing.
+ /// Generate the C++ to resolve the types of the operands during parsing.
+ void genParserOperandTypeResolution(
+ Operator &op, MethodBody &body,
+ function_ref<void(TypeResolution &, StringRef)> emitTypeResolver);
+ /// Generate the C++ to resolve regions during parsing.
void genParserRegionResolution(Operator &op, MethodBody &body);
- /// Generate the c++ to resolve successors during parsing.
+ /// Generate the C++ to resolve successors during parsing.
void genParserSuccessorResolution(Operator &op, MethodBody &body);
- /// Generate the c++ to handling variadic segment size traits.
+ /// Generate the C++ to handling variadic segment size traits.
void genParserVariadicSegmentResolution(Operator &op, MethodBody &body);
/// Generate the operation printer from this format.
}
}
+ // Emit the operand type resolutions.
+ genParserOperandTypeResolution(op, body, emitTypeResolver);
+
+ // Handle return type inference once all operands have been resolved
+ if (infersResultTypes)
+ body << formatv(inferReturnTypesParserCode, op.getCppClassName());
+}
+
+void OperationFormat::genParserOperandTypeResolution(
+ Operator &op, MethodBody &body,
+ function_ref<void(TypeResolution &, StringRef)> emitTypeResolver) {
// Early exit if there are no operands.
- if (op.getNumOperands() == 0) {
- // Handle return type inference here if there are no operands
- if (infersResultTypes)
- body << formatv(inferReturnTypesParserCode, op.getCppClassName());
+ if (op.getNumOperands() == 0)
return;
- }
- // Handle the case where all operand types are in one group.
+ // Handle the case where all operand types are grouped together with
+ // "types(operands)".
if (allOperandTypes) {
- // If we have all operands together, use the full operand list directly.
+ // If `operands` was specified, use the full operand list directly.
if (allOperands) {
body << " if (parser.resolveOperands(allOperands, allOperandTypes, "
"allOperandLoc, result.operands))\n"
<< " return ::mlir::failure();\n";
return;
}
- // Handle the case where all of the operands were grouped together.
+
+ // Handle the case where all operands are grouped together with "operands".
if (allOperands) {
body << " if (parser.resolveOperands(allOperands, ";
body << ", " << operand.name << "OperandsLoc";
body << ", result.operands))\n return ::mlir::failure();\n";
}
-
- // Handle return type inference once all operands have been resolved
- if (infersResultTypes)
- body << formatv(inferReturnTypesParserCode, op.getCppClassName());
}
void OperationFormat::genParserRegionResolution(Operator &op,
// keyword.
llvm::BitVector nonKeywordCases(cases.size());
bool hasStrCase = false;
- for (auto it : llvm::enumerate(cases)) {
+ for (auto &it : llvm::enumerate(cases)) {
hasStrCase = it.value().isStrCase();
if (!canFormatStringAsKeyword(it.value().getStr()))
nonKeywordCases.set(it.index());
// overlap with other cases. For simplicity sake, only allow cases with a
// single bit value.
if (enumAttr.isBitEnum()) {
- for (auto it : llvm::enumerate(cases)) {
+ for (auto &it : llvm::enumerate(cases)) {
int64_t value = it.value().getValue();
if (value < 0 || !llvm::isPowerOf2_64(value))
nonKeywordCases.set(it.index());
body << " switch (caseValue) {\n";
StringRef cppNamespace = enumAttr.getCppNamespace();
StringRef enumName = enumAttr.getEnumClassName();
- for (auto it : llvm::enumerate(cases)) {
+ for (auto &it : llvm::enumerate(cases)) {
if (nonKeywordCases.test(it.index()))
continue;
StringRef symbol = it.value().getSymbol();