From 7eef3ea5f4fe4f4cc461b191bac031e3962d0347 Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Fri, 28 Oct 2022 09:56:59 +0200 Subject: [PATCH] Revert "[mlir][linalg] Add nicer builders for `map` and `reduce`." This reverts commit aebde280476943e58f5bcd9993fdd7e36cdbe47e. --- .../mlir/Dialect/Linalg/IR/LinalgStructuredOps.td | 13 -- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 138 ++++++++------------- 2 files changed, 53 insertions(+), 98 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index 510f883..1692a0f 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -267,12 +267,6 @@ def MapOp : LinalgStructuredBase_Op<"map", [ let results = (outs Variadic:$result); let regions = (region SizedRegion<1>:$mapper); - let builders = [ - OpBuilder<(ins "ValueRange":$inputs, "Value":$init, - "function_ref", - CArg<"ArrayRef", "{}">:$attributes)> - ]; - let extraClassDeclaration = structuredOpsBaseDecls # [{ // Implement functions necessary for LinalgStructuredInterface. SmallVector getIteratorTypesArray(); @@ -347,13 +341,6 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [ let results = (outs Variadic); let regions = (region SizedRegion<1>:$combiner); - let builders = [ - OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$inits, - "ArrayRef":$dimensions, - "function_ref", - CArg<"ArrayRef", "{}">:$attributes)> - ]; - let extraClassDeclaration = structuredOpsBaseDecls # [{ // Declare functions necessary for LinalgStructuredInterface. SmallVector getIteratorTypesArray(); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 443bc5c..03a959a 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -661,26 +661,6 @@ void FillOp::getCanonicalizationPatterns(RewritePatternSet &results, // GenericOp //===----------------------------------------------------------------------===// -static void buildGenericRegion( - OpBuilder &builder, OperationState &result, ValueRange inputs, - ValueRange outputs, - function_ref bodyBuild) { - SmallVector blockArgTypes; - SmallVector blockArgLocs; - for (ValueRange container : {inputs, outputs}) { - for (Value v : container) { - blockArgTypes.push_back(getElementTypeOrSelf(v)); - blockArgLocs.push_back(v.getLoc()); - } - } - - OpBuilder::InsertionGuard guard(builder); - auto ®ion = *result.regions.front(); - Block *bodyBlock = - builder.createBlock(®ion, region.end(), blockArgTypes, blockArgLocs); - bodyBuild(builder, result.location, bodyBlock->getArguments()); -} - void GenericOp::getAsmBlockArgumentNames(Region ®ion, OpAsmSetValueNameFn setNameFn) { for (Value v : getRegionInputArgs()) @@ -698,8 +678,23 @@ void GenericOp::build( build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps, iteratorTypes, doc, libraryCall); result.addAttributes(attributes); - if (bodyBuild) - buildGenericRegion(builder, result, inputs, outputs, bodyBuild); + if (!bodyBuild) + return; + + SmallVector blockArgTypes; + SmallVector blockArgLocs; + for (ValueRange container : {inputs, outputs}) { + for (Value v : container) { + blockArgTypes.push_back(getElementTypeOrSelf(v)); + blockArgLocs.push_back(v.getLoc()); + } + } + + OpBuilder::InsertionGuard guard(builder); + auto ®ion = *result.regions.front(); + Block *bodyBlock = + builder.createBlock(®ion, region.end(), blockArgTypes, blockArgLocs); + bodyBuild(builder, result.location, bodyBlock->getArguments()); } void GenericOp::build( @@ -1334,22 +1329,6 @@ void MapOp::getAsmResultNames(function_ref setNameFn) { setNameFn(getResults().front(), "mapped"); } -void MapOp::build( - OpBuilder &builder, OperationState &result, ValueRange inputs, Value init, - function_ref bodyBuild, - ArrayRef attributes) { - build(builder, result, TypeRange{}, inputs, init); - result.addAttributes(attributes); - - // Add output types for `RankedTensorType` output arguments. - Type initType = init.getType(); - if (initType.isa()) - result.addTypes(initType); - - if (bodyBuild) - buildGenericRegion(builder, result, inputs, /*outputs=*/{}, bodyBuild); -} - ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) { if (parseDstStyleOp(parser, result)) return failure(); @@ -1457,25 +1436,6 @@ void ReduceOp::getAsmResultNames( setNameFn(getResults().front(), "reduced"); } -void ReduceOp::build( - OpBuilder &builder, OperationState &result, ValueRange inputs, - ValueRange inits, ArrayRef dimensions, - function_ref bodyBuild, - ArrayRef attributes) { - build(builder, result, TypeRange{}, inputs, inits, dimensions); - result.addAttributes(attributes); - - // Add output types for `RankedTensorType` output arguments. - for (Value init : inits) { - Type initType = init.getType(); - if (initType.isa()) - result.addTypes(initType); - } - - if (bodyBuild) - buildGenericRegion(builder, result, inputs, inits, bodyBuild); -} - SmallVector ReduceOp::getIteratorTypesArray() { int64_t inputRank = getInputs()[0].getType().cast().getRank(); SmallVector iteratorTypes(inputRank, @@ -1658,32 +1618,45 @@ TransposeOp::getRegionBuilder() { }; } -void TransposeOp::build(::mlir::OpBuilder &builder, - ::mlir::OperationState &result, Value input, Value init, - DenseI64ArrayAttr permutation, - ArrayRef attributes) { - result.addOperands(input); - result.addOperands(init); - result.addAttribute(getPermutationAttrName(result.name), permutation); - result.addAttributes(attributes); +void TransposeOp::createRegion(::mlir::OpBuilder &opBuilder, + ::mlir::OperationState &odsState) { + Region *region = odsState.addRegion(); - // Add output types for `RankedTensorType` output arguments. - Type initType = init.getType(); - if (initType.isa()) - result.addTypes(initType); + SmallVector argTypes; + SmallVector argLocs; + for (auto t : odsState.operands) { + argTypes.push_back(getElementTypeOrSelf(t)); + argLocs.push_back(opBuilder.getUnknownLoc()); + } - buildGenericRegion(builder, result, input, init, - [&](OpBuilder &b, Location loc, ValueRange args) { - b.create(loc, args[0]); - }); + // RAII. + OpBuilder::InsertionGuard guard(opBuilder); + Block *body = + opBuilder.createBlock(region, /*insertPt=*/{}, argTypes, argLocs); + + ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder); + getRegionBuilder()(b, *body, odsState.attributes.getAttrs()); } -void TransposeOp::build(::mlir::OpBuilder &builder, - ::mlir::OperationState &result, Value input, Value init, - ArrayRef permutation, +void TransposeOp::build(::mlir::OpBuilder &odsBuilder, + ::mlir::OperationState &odsState, Value input, + Value init, DenseI64ArrayAttr permutation, ArrayRef attributes) { - build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation), - attributes); + odsState.addOperands(input); + odsState.addOperands(init); + odsState.addAttribute(getPermutationAttrName(odsState.name), permutation); + odsState.addAttributes(attributes); + odsState.addTypes(init.getType()); + + createRegion(odsBuilder, odsState); +} + +void TransposeOp::build(::mlir::OpBuilder &odsBuilder, + ::mlir::OperationState &odsState, Value input, + Value init, ArrayRef permutation, + ArrayRef attributes) { + build(odsBuilder, odsState, input, init, + odsBuilder.getDenseI64ArrayAttr(permutation), attributes); } ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) { @@ -1693,13 +1666,8 @@ ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) { }))) return failure(); - (void)result.addRegion(); - OpBuilder builder(parser.getContext()); - buildGenericRegion(builder, result, /*inputs=*/result.operands, - /*outputs=*/{}, - [&](OpBuilder &b, Location loc, ValueRange args) { - b.create(loc, args[0]); - }); + OpBuilder opBuilder(parser.getContext()); + createRegion(opBuilder, result); return success(); } -- 2.7.4