From 9744606614df4ba85a4d546c94b3b5ef9d3a3a96 Mon Sep 17 00:00:00 2001 From: Rahul Joshi Date: Tue, 22 Sep 2020 20:58:30 -0700 Subject: [PATCH] [MLIR] Change default builders generated by TableGen to use TypeRange for result types - Change the default builders to use TypeRange instead of ArrayRef - Custom builders defined in LinalgStructuredOps now conflict with the default separate param ones, but the default collective params one is still needed. Resolve this by replicating the collective param builder as a custom builder and skipping the generation of default builders for these ops. Differential Revision: https://reviews.llvm.org/D87926 --- mlir/include/mlir/IR/OpBase.td | 4 +-- mlir/test/mlir-tblgen/op-decl.td | 32 +++++++++++----------- mlir/test/mlir-tblgen/op-result.td | 10 +++---- .../mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp | 11 +++++++- mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 8 +++--- 5 files changed, 37 insertions(+), 28 deletions(-) diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 330cdca..f1befa4 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2006,8 +2006,8 @@ class Op props = []> { // // ```c++ // static void build(OpBuilder &, OperationState &odsState, - // ArrayRef resultTypes, - // ArrayRef operands, + // TypeRange resultTypes, + // ValueRange operands, // ArrayRef attributes); // ``` list builders = ?; diff --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td index 0392264..4ff77dc 100644 --- a/mlir/test/mlir-tblgen/op-decl.td +++ b/mlir/test/mlir-tblgen/op-decl.td @@ -81,9 +81,9 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> { // CHECK: ::mlir::FloatAttr attr2Attr() // CHECK: ::llvm::Optional< ::llvm::APFloat > attr2(); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, Value val); -// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::llvm::ArrayRef<::mlir::Type> s, ::mlir::Value a, ::mlir::ValueRange b, ::mlir::IntegerAttr attr1, /*optional*/::mlir::FloatAttr attr2, unsigned someRegionsCount) -// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::llvm::ArrayRef<::mlir::Type> s, ::mlir::Value a, ::mlir::ValueRange b, uint32_t attr1, /*optional*/::mlir::FloatAttr attr2, unsigned someRegionsCount) -// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes, unsigned numRegions) +// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::TypeRange s, ::mlir::Value a, ::mlir::ValueRange b, ::mlir::IntegerAttr attr1, /*optional*/::mlir::FloatAttr attr2, unsigned someRegionsCount) +// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::TypeRange s, ::mlir::Value a, ::mlir::ValueRange b, uint32_t attr1, /*optional*/::mlir::FloatAttr attr2, unsigned someRegionsCount) +// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes, unsigned numRegions) // CHECK: static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); // CHECK: void print(::mlir::OpAsmPrinter &p); // CHECK: ::mlir::LogicalResult verify(); @@ -180,8 +180,8 @@ def NS_HCollectiveParamsOp : NS_Op<"op_collective_params", []> { // CHECK_LABEL: class NS_HCollectiveParamsOp : // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type b, ::mlir::Value a); -// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::Value a); -// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}) +// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a); +// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}) // Check suppression of "separate arg, separate result" build method for an op // with single variadic arg and single variadic result (since it will be @@ -192,8 +192,8 @@ def NS_HCollectiveParamsSuppress0Op : NS_Op<"op_collective_suppress0", []> { } // CHECK_LABEL: class NS_HCollectiveParamsSuppress0Op : -// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> b, ::mlir::ValueRange a); -// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); +// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange b, ::mlir::ValueRange a); +// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); // Check suppression of "separate arg, collective result" build method for an op // with single variadic arg and non variadic result (since it will be @@ -204,8 +204,8 @@ def NS_HCollectiveParamsSuppress1Op : NS_Op<"op_collective_suppress1", []> { } // CHECK_LABEL: class NS_HCollectiveParamsSuppress1Op : -// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> b, ::mlir::ValueRange a); -// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); +// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange b, ::mlir::ValueRange a); +// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); // Check suppression of "separate arg, collective result" build method for an op // with single variadic arg and > 1 variadic result (since it will be @@ -217,9 +217,9 @@ def NS_HCollectiveParamsSuppress2Op : NS_Op<"op_collective_suppress2", [SameVari let results = (outs Variadic:$b, Variadic:$c); } // CHECK_LABEL: class NS_HCollectiveParamsSuppress2Op : -// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> b, ::llvm::ArrayRef<::mlir::Type> c, ::mlir::ValueRange a); -// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> b, ::mlir::ValueRange a); -// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); +// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange b, ::mlir::TypeRange c, ::mlir::ValueRange a); +// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange b, ::mlir::ValueRange a); +// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); // Check default value of `attributes` for the `genUseOperandAsResultTypeCollectiveParamBuilder` builder def NS_IOp : NS_Op<"op_with_same_operands_and_result_types_trait", [SameOperandsAndResultType]> { @@ -228,8 +228,8 @@ def NS_IOp : NS_Op<"op_with_same_operands_and_result_types_trait", [SameOperands } // CHECK_LABEL: class NS_IOp : // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::Value a, ::mlir::Value b); -// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::Value a, ::mlir::Value b); -// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); +// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a, ::mlir::Value b); +// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); @@ -241,8 +241,8 @@ def NS_JOp : NS_Op<"op_with_InferTypeOpInterface_interface", [DeclareOpInterface // CHECK_LABEL: class NS_JOp : // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::Value a, ::mlir::Value b); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b); -// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::Value a, ::mlir::Value b); -// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); +// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a, ::mlir::Value b); +// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); // Check that default builders can be suppressed. diff --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td index 6849220..ac8f1cb 100644 --- a/mlir/test/mlir-tblgen/op-result.td +++ b/mlir/test/mlir-tblgen/op-result.td @@ -14,7 +14,7 @@ def OpA : NS_Op<"one_normal_result_op", []> { } // CHECK-LABEL: void OpA::build -// CHECK: ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands +// CHECK: ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands // CHECK: assert(resultTypes.size() == 1u && "mismatched number of return types"); // CHECK-NEXT: odsState.addTypes(resultTypes); @@ -39,7 +39,7 @@ def OpC : NS_Op<"three_normal_result_op", []> { // CHECK-NEXT: odsState.addTypes(resultType1) // CHECK-NEXT: odsState.addTypes(z) -// CHECK: void OpC::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes) { +// CHECK: void OpC::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes) { // CHECK-NEXT: assert(resultTypes.size() == 3u && "mismatched number of results"); // CHECK-NEXT: odsState.addTypes(resultTypes); @@ -67,7 +67,7 @@ def OpF : NS_Op<"one_variadic_result_op", []> { } // CHECK-LABEL: void OpF::build -// CHECK-SAME: ::llvm::ArrayRef<::mlir::Type> x +// CHECK-SAME: ::mlir::TypeRange x // CHECK-NOT: assert // CHECK: odsState.addTypes(x); @@ -78,12 +78,12 @@ def OpG : NS_Op<"one_normal_and_one_variadic_result_op", []> { // CHECK-LABEL: OpG definitions -// CHECK: void OpG::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type x, ::llvm::ArrayRef<::mlir::Type> y) +// CHECK: void OpG::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type x, ::mlir::TypeRange y) // CHECK-NEXT: odsState.addTypes(x); // CHECK-NEXT: odsState.addTypes(y); // CHECK: void OpG::build -// CHECK: ::llvm::ArrayRef<::mlir::Type> resultTypes +// CHECK: ::mlir::TypeRange resultTypes // CHECK: assert(resultTypes.size() >= 1u && "mismatched number of return types"); // CHECK-NEXT: odsState.addTypes(resultTypes); diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp index 4efdaf6..4fe3cd1 100644 --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp @@ -1455,8 +1455,9 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName, let results = (outs Variadic:$result_tensors); let regions = (region AnyRegion:$region); + let skipDefaultBuilders = 1; let builders = [ OpBuilder< - "OpBuilder &b, OperationState &result," + "OpBuilder &b, OperationState &result, " "ValueRange inputs, ValueRange outputBuffers", [{{ result.addOperands(inputs); @@ -1493,6 +1494,14 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName, TypeRange(outputBuffers), TypeRange(initTensors), resultTensorTypes); + }]>, OpBuilder< + "OpBuilder &b, OperationState &result, TypeRange resultTensorTypes," + "ValueRange operands, ArrayRef attributes = {{}", + [{{ + result.addOperands(operands); + result.addAttributes(attributes); + result.addTypes(resultTensorTypes); + (void)result.addRegion(); }]> ]; let printer = [{{ return ::printNamedStructuredOp(p, *this); }]; diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index f89547d..18b5d26 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -1252,7 +1252,7 @@ void OpEmitter::genCollectiveParamBuilder() { SmallVector paramList; paramList.emplace_back("::mlir::OpBuilder &", ""); paramList.emplace_back("::mlir::OperationState &", builderOpState); - paramList.emplace_back("::llvm::ArrayRef<::mlir::Type>", "resultTypes"); + paramList.emplace_back("::mlir::TypeRange", "resultTypes"); paramList.emplace_back("::mlir::ValueRange", "operands"); // Provide default value for `attributes` when its the last parameter StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}"; @@ -1322,8 +1322,8 @@ void OpEmitter::buildParamList(SmallVectorImpl ¶mList, if (resultName.empty()) resultName = std::string(formatv("resultType{0}", i)); - StringRef type = result.isVariadic() ? "::llvm::ArrayRef<::mlir::Type>" - : "::mlir::Type"; + StringRef type = + result.isVariadic() ? "::mlir::TypeRange" : "::mlir::Type"; OpMethodParameter::Property properties = OpMethodParameter::PP_None; if (result.isOptional()) properties = OpMethodParameter::PP_Optional; @@ -1333,7 +1333,7 @@ void OpEmitter::buildParamList(SmallVectorImpl ¶mList, } } break; case TypeParamKind::Collective: { - paramList.emplace_back("::llvm::ArrayRef<::mlir::Type>", "resultTypes"); + paramList.emplace_back("::mlir::TypeRange", "resultTypes"); resultTypeNames.push_back("resultTypes"); } break; } -- 2.7.4