From e40624ae604fc292cb8a7102b0b91b571b26a32a Mon Sep 17 00:00:00 2001 From: Mogball Date: Fri, 10 Dec 2021 15:04:46 +0000 Subject: [PATCH] [mlir][ods] Fix OpFormatGen sometimes not calling inferReturnTypes Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D115522 --- mlir/test/lib/Dialect/Test/TestOps.td | 51 +++++++++++++++++++++++++++++++++- mlir/test/mlir-tblgen/op-format.mlir | 12 ++++++++ mlir/tools/mlir-tblgen/OpFormatGen.cpp | 47 ++++++++++++++++++------------- 3 files changed, 90 insertions(+), 20 deletions(-) diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 120749e..655ad2d 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1136,7 +1136,9 @@ def ThreeResultOp : TEST_Op<"three_result"> { let results = (outs I32:$result1, F32:$result2, F32:$result3); } -def AnotherThreeResultOp : TEST_Op<"another_three_result", [DeclareOpInterfaceMethods]> { +def AnotherThreeResultOp + : TEST_Op<"another_three_result", + [DeclareOpInterfaceMethods]> { let arguments = (ins MultiResultOpEnum:$kind); let results = (outs I32:$result1, F32:$result2, F32:$result3); } @@ -2101,6 +2103,53 @@ def FormatInferTypeOp : TEST_Op<"format_infer_type", [InferTypeOpInterface]> { }]; } +// Base class for testing mixing allOperandTypes, allOperands, and +// inferResultTypes. +class FormatInferAllTypesBaseOp traits = []> + : TEST_Op { + let arguments = (ins Variadic:$args); + let results = (outs Variadic:$outs); + let extraClassDeclaration = [{ + static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, + ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands, + ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions, + ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { + ::mlir::TypeRange operandTypes = operands.getTypes(); + inferredReturnTypes.assign(operandTypes.begin(), operandTypes.end()); + return ::mlir::success(); + } + }]; +} + +// Test inferReturnTypes is called when allOperandTypes and allOperands is true. +def FormatInferTypeAllOperandsAndTypesOp + : FormatInferAllTypesBaseOp<"format_infer_type_all_operands_and_types"> { + let assemblyFormat = "`(` operands `)` attr-dict `:` type(operands)"; +} + +// Test inferReturnTypes is called when allOperandTypes is true and there is one +// ODS operand. +def FormatInferTypeAllOperandsAndTypesOneOperandOp + : FormatInferAllTypesBaseOp<"format_infer_type_all_types_one_operand"> { + let assemblyFormat = "`(` $args `)` attr-dict `:` type(operands)"; +} + +// Test inferReturnTypes is called when allOperandTypes is true and there are +// more than one ODS operands. +def FormatInferTypeAllOperandsAndTypesTwoOperandsOp + : FormatInferAllTypesBaseOp<"format_infer_type_all_types_two_operands", + [SameVariadicOperandSize]> { + let arguments = (ins Variadic:$args0, Variadic:$args1); + let assemblyFormat = "`(` $args0 `)` `(` $args1 `)` attr-dict `:` type(operands)"; +} + +// Test inferReturnTypes is called when allOperands is true and operand types +// are separately specified. +def FormatInferTypeAllTypesOp + : FormatInferAllTypesBaseOp<"format_infer_type_all_types"> { + let assemblyFormat = "`(` operands `)` attr-dict `:` type($args)"; +} + //===----------------------------------------------------------------------===// // Test SideEffects //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir index c3214c7..c65d216 100644 --- a/mlir/test/mlir-tblgen/op-format.mlir +++ b/mlir/test/mlir-tblgen/op-format.mlir @@ -411,6 +411,18 @@ test.format_infer_variadic_type_from_non_variadic %i64, %i64 : i64 // CHECK: test.format_infer_type %ignored_res7 = test.format_infer_type +// CHECK: test.format_infer_type_all_operands_and_types(%[[I64]], %[[I32]]) : i64, i32 +%ignored_res8:2 = test.format_infer_type_all_operands_and_types(%i64, %i32) : i64, i32 + +// CHECK: test.format_infer_type_all_types_one_operand(%[[I64]], %[[I32]]) : i64, i32 +%ignored_res9:2 = test.format_infer_type_all_types_one_operand(%i64, %i32) : i64, i32 + +// CHECK: test.format_infer_type_all_types_two_operands(%[[I64]], %[[I32]]) (%[[I64]], %[[I32]]) : i64, i32, i64, i32 +%ignored_res10:4 = test.format_infer_type_all_types_two_operands(%i64, %i32) (%i64, %i32) : i64, i32, i64, i32 + +// CHECK: test.format_infer_type_all_types(%[[I64]], %[[I32]]) : i64, i32 +%ignored_res11:2 = test.format_infer_type_all_types(%i64, %i32) : i64, i32 + //===----------------------------------------------------------------------===// // Check DefaultValuedStrAttr //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp index 4223c2d..6203edb 100644 --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -424,14 +424,18 @@ struct OperationFormat { /// Generate the parser code for a specific format element. void genElementParser(Element *element, MethodBody &body, FmtContext &attrTypeCtx); - /// Generate the c++ to resolve the types of operands and results during + /// Generate the C++ to resolve the types of operands and results during /// parsing. void genParserTypeResolution(Operator &op, MethodBody &body); - /// Generate the c++ to resolve regions during parsing. + /// Generate the C++ to resolve the types of the operands during parsing. + void genParserOperandTypeResolution( + Operator &op, MethodBody &body, + function_ref emitTypeResolver); + /// Generate the C++ to resolve regions during parsing. void genParserRegionResolution(Operator &op, MethodBody &body); - /// Generate the c++ to resolve successors during parsing. + /// Generate the C++ to resolve successors during parsing. void genParserSuccessorResolution(Operator &op, MethodBody &body); - /// Generate the c++ to handling variadic segment size traits. + /// Generate the C++ to handling variadic segment size traits. void genParserVariadicSegmentResolution(Operator &op, MethodBody &body); /// Generate the operation printer from this format. @@ -1462,17 +1466,25 @@ void OperationFormat::genParserTypeResolution(Operator &op, MethodBody &body) { } } + // Emit the operand type resolutions. + genParserOperandTypeResolution(op, body, emitTypeResolver); + + // Handle return type inference once all operands have been resolved + if (infersResultTypes) + body << formatv(inferReturnTypesParserCode, op.getCppClassName()); +} + +void OperationFormat::genParserOperandTypeResolution( + Operator &op, MethodBody &body, + function_ref emitTypeResolver) { // Early exit if there are no operands. - if (op.getNumOperands() == 0) { - // Handle return type inference here if there are no operands - if (infersResultTypes) - body << formatv(inferReturnTypesParserCode, op.getCppClassName()); + if (op.getNumOperands() == 0) return; - } - // Handle the case where all operand types are in one group. + // Handle the case where all operand types are grouped together with + // "types(operands)". if (allOperandTypes) { - // If we have all operands together, use the full operand list directly. + // If `operands` was specified, use the full operand list directly. if (allOperands) { body << " if (parser.resolveOperands(allOperands, allOperandTypes, " "allOperandLoc, result.operands))\n" @@ -1496,7 +1508,8 @@ void OperationFormat::genParserTypeResolution(Operator &op, MethodBody &body) { << " return ::mlir::failure();\n"; return; } - // Handle the case where all of the operands were grouped together. + + // Handle the case where all operands are grouped together with "operands". if (allOperands) { body << " if (parser.resolveOperands(allOperands, "; @@ -1551,10 +1564,6 @@ void OperationFormat::genParserTypeResolution(Operator &op, MethodBody &body) { body << ", " << operand.name << "OperandsLoc"; body << ", result.operands))\n return ::mlir::failure();\n"; } - - // Handle return type inference once all operands have been resolved - if (infersResultTypes) - body << formatv(inferReturnTypesParserCode, op.getCppClassName()); } void OperationFormat::genParserRegionResolution(Operator &op, @@ -1833,7 +1842,7 @@ static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op, // keyword. llvm::BitVector nonKeywordCases(cases.size()); bool hasStrCase = false; - for (auto it : llvm::enumerate(cases)) { + for (auto &it : llvm::enumerate(cases)) { hasStrCase = it.value().isStrCase(); if (!canFormatStringAsKeyword(it.value().getStr())) nonKeywordCases.set(it.index()); @@ -1860,7 +1869,7 @@ static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op, // overlap with other cases. For simplicity sake, only allow cases with a // single bit value. if (enumAttr.isBitEnum()) { - for (auto it : llvm::enumerate(cases)) { + for (auto &it : llvm::enumerate(cases)) { int64_t value = it.value().getValue(); if (value < 0 || !llvm::isPowerOf2_64(value)) nonKeywordCases.set(it.index()); @@ -1873,7 +1882,7 @@ static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op, body << " switch (caseValue) {\n"; StringRef cppNamespace = enumAttr.getCppNamespace(); StringRef enumName = enumAttr.getEnumClassName(); - for (auto it : llvm::enumerate(cases)) { + for (auto &it : llvm::enumerate(cases)) { if (nonKeywordCases.test(it.index())) continue; StringRef symbol = it.value().getSymbol(); -- 2.7.4