From 41a73ddce8923f506eaf6e8c5a61d32add5e4c06 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Thu, 12 Dec 2019 10:35:40 -0800 Subject: [PATCH] Add type inference variant for separate params builder generated Add variant that does invoke infer type op interface where defined. Also add entry function that invokes that different separate argument builders for wrapped, unwrapped and inference variant. PiperOrigin-RevId: 285220709 --- mlir/g3doc/OpDefinitions.md | 18 ++- mlir/test/mlir-tblgen/op-decl.td | 5 +- mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 183 ++++++++++++++-------------- 3 files changed, 108 insertions(+), 98 deletions(-) diff --git a/mlir/g3doc/OpDefinitions.md b/mlir/g3doc/OpDefinitions.md index d00b19b..0e786a0 100644 --- a/mlir/g3doc/OpDefinitions.md +++ b/mlir/g3doc/OpDefinitions.md @@ -290,7 +290,7 @@ class. See [Constraints](#constraints) for more information. ### Operation interfaces [Operation interfaces](Interfaces.md#operation-interfaces) are a mechanism by -which to opaquely call methods and access information on an *Op instance, +which to opaquely call methods and access information on an *Op instance*, without knowing the exact operation type. Operation interfaces defined in C++ can be accessed in the ODS framework via the `OpInterfaceTrait` class. Aside from using pre-existing interfaces in the C++ API, the ODS framework also @@ -414,7 +414,7 @@ The following builders are generated: // All result-types/operands/attributes have one aggregate parameter. static void build(Builder *tblgen_builder, OperationState &tblgen_state, ArrayRef resultTypes, - ArrayRef operands, + ValueRange operands, ArrayRef attributes); // Each result-type/operand/attribute has a separate parameter. The parameters @@ -433,7 +433,19 @@ static void build(Builder *tblgen_builder, OperationState &tblgen_state, Value *i32_operand, Value *f32_operand, ..., APInt i32_attr, StringRef f32_attr, ...); -// (And potentially others depending on the specific op.) +// Each operand/attribute has a separate parameter but result type is aggregate. +static void build(Builder *tblgen_builder, OperationState &tblgen_state, + ArrayRef resultTypes, + Value *i32_operand, Value *f32_operand, ..., + IntegerAttr i32_attr, FloatAttr f32_attr, ...); + +// All operands/attributes have aggregate parameters. +// Generated if InferTypeOpInterface interface is specified. +static void build(Builder *tblgen_builder, OperationState &tblgen_state, + ValueRange operands, + ArrayRef attributes); + +// (And manually specified builders depending on the specific op.) ``` The first form provides basic uniformity so that we can create ops using the diff --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td index 2c90c27..c0420cb 100644 --- a/mlir/test/mlir-tblgen/op-decl.td +++ b/mlir/test/mlir-tblgen/op-decl.td @@ -68,8 +68,9 @@ def NS_AOp : NS_Op<"a_op", [NoSideEffect]> { // CHECK: FloatAttr attr2Attr() // CHECK: Optional< APFloat > attr2(); // CHECK: static void build(Value *val); -// CHECK: static void build(Builder *tblgen_builder, OperationState &tblgen_state, Type r, ArrayRef s, Value *a, ValueRange b, IntegerAttr attr1, /*optional*/FloatAttr attr2); -// CHECK: static void build(Builder *, OperationState &tblgen_state, ArrayRef resultTypes, ValueRange operands, ArrayRef attributes); +// CHECK: static void build(Builder *tblgen_builder, OperationState &tblgen_state, Type r, ArrayRef s, Value *a, ValueRange b, IntegerAttr attr1, /*optional*/FloatAttr attr2) +// CHECK: static void build(Builder *tblgen_builder, OperationState &tblgen_state, Type r, ArrayRef s, Value *a, ValueRange b, APInt attr1, /*optional*/FloatAttr attr2) +// CHECK: static void build(Builder *, OperationState &tblgen_state, ArrayRef resultTypes, ValueRange operands, ArrayRef attributes) // CHECK: static ParseResult parse(OpAsmParser &parser, OperationState &result); // CHECK: void print(OpAsmPrinter &p); // CHECK: LogicalResult verify(); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 37fa9c7..a73b113 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -514,21 +514,9 @@ private: // Generates builder methods for the operation. void genBuilder(); - // Generates the build() method that takes each result-type/operand/attribute - // as a stand-alone parameter. Attributes will take wrapped mlir::Attribute - // values. The generated build() method also requires specifying result types - // for all results. - void genSeparateParamWrappedAttrBuilder(); - - // Generates the build() method that takes each result-type/operand/attribute - // as a stand-alone parameter. Attributes will take raw values without - // mlir::Attribute wrapper. The generated build() method also requires - // specifying result types for all results. - void genSeparateParamUnwrappedAttrBuilder(); - - // Generates the build() method that takes a single parameter for all the - // result types and a separate parameter for each operand/attribute. - void genCollectiveTypeParamBuilder(); + // Generates the build() method that takes each operand/attribute + // as a stand-alone parameter. + void genSeparateArgParamBuilder(); // Generates the build() method that takes each operand/attribute as a // stand-alone parameter. The generated build() method uses first operand's @@ -897,26 +885,11 @@ void OpEmitter::genNamedRegionGetters() { } } -void OpEmitter::genSeparateParamWrappedAttrBuilder() { - std::string paramList; - llvm::SmallVector resultNames; - buildParamList(paramList, resultNames, TypeParamKind::Separate); - - auto &m = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static); - genCodeForAddingArgAndRegionForBuilder(m.body()); - - // Push all result types to the operation state - for (int i = 0, e = op.getNumResults(); i < e; ++i) { - m.body() << " " << builderOpState << ".addTypes(" << resultNames[i] - << ");\n"; - } -} - -void OpEmitter::genSeparateParamUnwrappedAttrBuilder() { +static bool canGenerateUnwrappedBuilder(Operator &op) { // If this op does not have native attributes at all, return directly to avoid // redefining builders. if (op.getNumNativeAttributes() == 0) - return; + return false; bool canGenerate = false; // We are generating builders that take raw values for attributes. We need to @@ -930,47 +903,75 @@ void OpEmitter::genSeparateParamUnwrappedAttrBuilder() { break; } } - if (!canGenerate) - return; - - std::string paramList; - llvm::SmallVector resultNames; - buildParamList(paramList, resultNames, TypeParamKind::Separate, - AttrParamKind::UnwrappedValue); - - auto &m = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static); - genCodeForAddingArgAndRegionForBuilder(m.body(), /*isRawValueAttr=*/true); - - // Push all result types to the operation state. - for (int i = 0, e = op.getNumResults(); i < e; ++i) { - m.body() << " " << builderOpState << ".addTypes(" << resultNames[i] - << ");\n"; - } + return canGenerate; } -void OpEmitter::genCollectiveTypeParamBuilder() { - auto numResults = op.getNumResults(); - - // If this op has no results, then just skip generating this builder. - // Otherwise we are generating the same signature as the separate-parameter - // builder. - if (numResults == 0) - return; - - // Similarly for ops with one single variadic result, which will also have one - // `ArrayRef` parameter for the result type. - if (numResults == 1 && op.getResult(0).isVariadic()) - return; - - std::string paramList; - llvm::SmallVector resultNames; - buildParamList(paramList, resultNames, TypeParamKind::Collective); +void OpEmitter::genSeparateArgParamBuilder() { + SmallVector attrBuilderType; + attrBuilderType.push_back(AttrParamKind::WrappedAttr); + if (canGenerateUnwrappedBuilder(op)) + attrBuilderType.push_back(AttrParamKind::UnwrappedValue); + + // Emit with separate builders with or without unwrapped attributes and/or + // inferring result type. + auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind, + bool inferType) { + std::string paramList; + llvm::SmallVector resultNames; + buildParamList(paramList, resultNames, paramKind, attrType); + + auto &m = + opClass.newMethod("void", "build", paramList, OpMethod::MP_Static); + auto &body = m.body(); + genCodeForAddingArgAndRegionForBuilder( + body, /*isRawValueAttr=*/attrType == AttrParamKind::UnwrappedValue); + + // Push all result types to the operation state + + if (inferType) { + // Generate builder that infers type too. + // TODO(jpienaar): Subsume this with general checking if type can be + // infered automatically. + // TODO(jpienaar): Expand to handle regions. + body << formatv(R"( + SmallVector inferedReturnTypes; + if (succeeded({0}::inferReturnTypes({1}.location, {1}.operands, + {1}.attributes, /*regions=*/{{}, inferedReturnTypes))) + {1}.addTypes(inferedReturnTypes); + else + llvm::report_fatal_error("Failed to infer result type(s).");)", + opClass.getClassName(), builderOpState); + return; + } - auto &m = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static); - genCodeForAddingArgAndRegionForBuilder(m.body()); + switch (paramKind) { + case TypeParamKind::None: + return; + case TypeParamKind::Separate: + for (int i = 0, e = op.getNumResults(); i < e; ++i) { + body << " " << builderOpState << ".addTypes(" << resultNames[i] + << ");\n"; + } + return; + case TypeParamKind::Collective: + body << " " << builderOpState << ".addTypes(resultTypes);\n"; + return; + }; + llvm_unreachable("unhandled TypeParamKind"); + }; - // Push all result types to the operation state - m.body() << formatv(" {0}.addTypes(resultTypes);\n", builderOpState); + bool canInferType = + op.getTrait("InferTypeOpInterface::Trait") && op.getNumRegions() == 0; + for (auto attrType : attrBuilderType) { + emit(attrType, TypeParamKind::Separate, /*inferType=*/false); + if (canInferType) + emit(attrType, TypeParamKind::None, /*inferType=*/true); + // Emit separate arg build with collective type, unless there is only one + // variadic result, in which case the above would have already generated + // the same build method. + if (op.getNumResults() == 1 && !op.getResult(0).isVariadic()) + emit(attrType, TypeParamKind::Collective, /*inferType=*/false); + } } void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() { @@ -1021,8 +1022,7 @@ void OpEmitter::genInferedTypeCollectiveParamBuilder() { /*regions=*/{{}, inferedReturnTypes))) build(builder, tblgen_state, inferedReturnTypes, operands, attributes); else - llvm::report_fatal_error("Failed to infer result type(s)."); - )", + llvm::report_fatal_error("Failed to infer result type(s).");)", opClass.getClassName(), builderOpState); } @@ -1111,18 +1111,13 @@ void OpEmitter::genBuilder() { // Generate default builders that requires all result type, operands, and // attributes as parameters. - // We generate three builders here: - // 1. one having a stand-alone parameter for each result type / operand / - // attribute, and - genSeparateParamWrappedAttrBuilder(); - genSeparateParamUnwrappedAttrBuilder(); - // 2. one having a stand-alone parameter for each operand / attribute and - // an aggregated parameter for all result types, and - genCollectiveTypeParamBuilder(); - // 3. one having an aggregated parameter for all result types / operands / + // We generate three classes of builders here: + // 1. one having a stand-alone parameter for each operand / attribute, and + genSeparateArgParamBuilder(); + // 2. one having an aggregated parameter for all result types / operands / // attributes, and genCollectiveParamBuilder(); - // 4. one having a stand-alone parameter for each operand and attribute, + // 3. one having a stand-alone parameter for each operand and attribute, // use the first operand or attribute's type as all result types // to facilitate different call patterns. if (op.getNumVariadicResults() == 0) { @@ -1133,11 +1128,6 @@ void OpEmitter::genBuilder() { if (op.getTrait("OpTrait::FirstAttrDerivedResultType")) genUseAttrAsResultTypeBuilder(); } - // TODO(jpienaar): Subsume this with general checking if type can be infered - // automatically. - // TODO(jpienaar): Expand to handle regions. - if (op.getTrait("InferTypeOpInterface::Trait") && op.getNumRegions() == 0) - genInferedTypeCollectiveParamBuilder(); } void OpEmitter::genCollectiveParamBuilder() { @@ -1156,13 +1146,6 @@ void OpEmitter::genCollectiveParamBuilder() { auto &m = opClass.newMethod("void", "build", params, OpMethod::MP_Static); auto &body = m.body(); - // Result types - if (numVariadicResults == 0 || numNonVariadicResults != 0) - body << " assert(resultTypes.size()" - << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults - << "u && \"mismatched number of return types\");\n"; - body << " " << builderOpState << ".addTypes(resultTypes);\n"; - // Operands if (numVariadicOperands == 0 || numNonVariadicOperands != 0) body << " assert(operands.size()" @@ -1179,6 +1162,20 @@ void OpEmitter::genCollectiveParamBuilder() { for (int i = 0; i < numRegions; ++i) m.body() << " (void)" << builderOpState << ".addRegion();\n"; } + + // Result types + if (numVariadicResults == 0 || numNonVariadicResults != 0) + body << " assert(resultTypes.size()" + << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults + << "u && \"mismatched number of return types\");\n"; + body << " " << builderOpState << ".addTypes(resultTypes);\n"; + + // Generate builder that infers type too. + // TODO(jpienaar): Subsume this with general checking if type can be infered + // automatically. + // TODO(jpienaar): Expand to handle regions. + if (op.getTrait("InferTypeOpInterface::Trait") && op.getNumRegions() == 0) + genInferedTypeCollectiveParamBuilder(); } void OpEmitter::buildParamList(std::string ¶mList, -- 2.7.4