From: Oleg Shyshkov Date: Fri, 28 Oct 2022 09:31:58 +0000 (+0200) Subject: Revert "Revert "[mlir][linalg] Add nicer builders for `map` and `reduce`."" X-Git-Tag: upstream/17.0.6~29231 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=ad89eb5b1fccf002eb59dfbab0fdb515ea3e65b7;p=platform%2Fupstream%2Fllvm.git Revert "Revert "[mlir][linalg] Add nicer builders for `map` and `reduce`."" This reverts commit 7eef3ea5f4fe4f4cc461b191bac031e3962d0347. --- diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index 1692a0f..510f883 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -267,6 +267,12 @@ def MapOp : LinalgStructuredBase_Op<"map", [ let results = (outs Variadic:$result); let regions = (region SizedRegion<1>:$mapper); + let builders = [ + OpBuilder<(ins "ValueRange":$inputs, "Value":$init, + "function_ref", + CArg<"ArrayRef", "{}">:$attributes)> + ]; + let extraClassDeclaration = structuredOpsBaseDecls # [{ // Implement functions necessary for LinalgStructuredInterface. SmallVector getIteratorTypesArray(); @@ -341,6 +347,13 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [ let results = (outs Variadic); let regions = (region SizedRegion<1>:$combiner); + let builders = [ + OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$inits, + "ArrayRef":$dimensions, + "function_ref", + CArg<"ArrayRef", "{}">:$attributes)> + ]; + let extraClassDeclaration = structuredOpsBaseDecls # [{ // Declare functions necessary for LinalgStructuredInterface. SmallVector getIteratorTypesArray(); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 896fcf4..e9f2630 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -661,6 +661,26 @@ void FillOp::getCanonicalizationPatterns(RewritePatternSet &results, // GenericOp //===----------------------------------------------------------------------===// +static void buildGenericRegion( + OpBuilder &builder, OperationState &result, ValueRange inputs, + ValueRange outputs, + function_ref bodyBuild) { + SmallVector blockArgTypes; + SmallVector blockArgLocs; + for (ValueRange container : {inputs, outputs}) { + for (Value v : container) { + blockArgTypes.push_back(getElementTypeOrSelf(v)); + blockArgLocs.push_back(v.getLoc()); + } + } + + OpBuilder::InsertionGuard guard(builder); + auto ®ion = *result.regions.front(); + Block *bodyBlock = + builder.createBlock(®ion, region.end(), blockArgTypes, blockArgLocs); + bodyBuild(builder, result.location, bodyBlock->getArguments()); +} + void GenericOp::getAsmBlockArgumentNames(Region ®ion, OpAsmSetValueNameFn setNameFn) { for (Value v : getRegionInputArgs()) @@ -678,23 +698,8 @@ void GenericOp::build( build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps, iteratorTypes, doc, libraryCall); result.addAttributes(attributes); - if (!bodyBuild) - return; - - SmallVector blockArgTypes; - SmallVector blockArgLocs; - for (ValueRange container : {inputs, outputs}) { - for (Value v : container) { - blockArgTypes.push_back(getElementTypeOrSelf(v)); - blockArgLocs.push_back(v.getLoc()); - } - } - - OpBuilder::InsertionGuard guard(builder); - auto ®ion = *result.regions.front(); - Block *bodyBlock = - builder.createBlock(®ion, region.end(), blockArgTypes, blockArgLocs); - bodyBuild(builder, result.location, bodyBlock->getArguments()); + if (bodyBuild) + buildGenericRegion(builder, result, inputs, outputs, bodyBuild); } void GenericOp::build( @@ -1329,6 +1334,22 @@ void MapOp::getAsmResultNames(function_ref setNameFn) { setNameFn(getResults().front(), "mapped"); } +void MapOp::build( + OpBuilder &builder, OperationState &result, ValueRange inputs, Value init, + function_ref bodyBuild, + ArrayRef attributes) { + build(builder, result, TypeRange{}, inputs, init); + result.addAttributes(attributes); + + // Add output types for `RankedTensorType` output arguments. + Type initType = init.getType(); + if (initType.isa()) + result.addTypes(initType); + + if (bodyBuild) + buildGenericRegion(builder, result, inputs, /*outputs=*/{}, bodyBuild); +} + ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) { if (parseDstStyleOp(parser, result)) return failure(); @@ -1436,6 +1457,25 @@ void ReduceOp::getAsmResultNames( setNameFn(getResults().front(), "reduced"); } +void ReduceOp::build( + OpBuilder &builder, OperationState &result, ValueRange inputs, + ValueRange inits, ArrayRef dimensions, + function_ref bodyBuild, + ArrayRef attributes) { + build(builder, result, TypeRange{}, inputs, inits, dimensions); + result.addAttributes(attributes); + + // Add output types for `RankedTensorType` output arguments. + for (Value init : inits) { + Type initType = init.getType(); + if (initType.isa()) + result.addTypes(initType); + } + + if (bodyBuild) + buildGenericRegion(builder, result, inputs, inits, bodyBuild); +} + SmallVector ReduceOp::getIteratorTypesArray() { int64_t inputRank = getInputs()[0].getType().cast().getRank(); SmallVector iteratorTypes(inputRank, @@ -1618,45 +1658,32 @@ TransposeOp::getRegionBuilder() { }; } -void TransposeOp::createRegion(::mlir::OpBuilder &opBuilder, - ::mlir::OperationState &odsState) { - Region *region = odsState.addRegion(); - - SmallVector argTypes; - SmallVector argLocs; - for (auto t : odsState.operands) { - argTypes.push_back(getElementTypeOrSelf(t)); - argLocs.push_back(opBuilder.getUnknownLoc()); - } - - // RAII. - OpBuilder::InsertionGuard guard(opBuilder); - Block *body = - opBuilder.createBlock(region, /*insertPt=*/{}, argTypes, argLocs); - - ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder); - getRegionBuilder()(b, *body, odsState.attributes.getAttrs()); -} - -void TransposeOp::build(::mlir::OpBuilder &odsBuilder, - ::mlir::OperationState &odsState, Value input, - Value init, DenseI64ArrayAttr permutation, +void TransposeOp::build(::mlir::OpBuilder &builder, + ::mlir::OperationState &result, Value input, Value init, + DenseI64ArrayAttr permutation, ArrayRef attributes) { - odsState.addOperands(input); - odsState.addOperands(init); - odsState.addAttribute(getPermutationAttrName(odsState.name), permutation); - odsState.addAttributes(attributes); - odsState.addTypes(init.getType()); + result.addOperands(input); + result.addOperands(init); + result.addAttribute(getPermutationAttrName(result.name), permutation); + result.addAttributes(attributes); + + // Add output types for `RankedTensorType` output arguments. + Type initType = init.getType(); + if (initType.isa()) + result.addTypes(initType); - createRegion(odsBuilder, odsState); + buildGenericRegion(builder, result, input, init, + [&](OpBuilder &b, Location loc, ValueRange args) { + b.create(loc, args[0]); + }); } -void TransposeOp::build(::mlir::OpBuilder &odsBuilder, - ::mlir::OperationState &odsState, Value input, - Value init, ArrayRef permutation, +void TransposeOp::build(::mlir::OpBuilder &builder, + ::mlir::OperationState &result, Value input, Value init, + ArrayRef permutation, ArrayRef attributes) { - build(odsBuilder, odsState, input, init, - odsBuilder.getDenseI64ArrayAttr(permutation), attributes); + build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation), + attributes); } ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) { @@ -1666,8 +1693,13 @@ ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) { }))) return failure(); - OpBuilder opBuilder(parser.getContext()); - createRegion(opBuilder, result); + (void)result.addRegion(); + OpBuilder builder(parser.getContext()); + buildGenericRegion(builder, result, /*inputs=*/result.operands, + /*outputs=*/{}, + [&](OpBuilder &b, Location loc, ValueRange args) { + b.create(loc, args[0]); + }); return success(); }