From 973e133b769773c89ce4b8bbfd6c77612d2ff9d4 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Fri, 12 Feb 2021 13:50:10 +0000 Subject: [PATCH] [mlir][Linalg] Improve region support in Linalg ops. This revision takes advantage of the newly extended `ref` directive in assembly format to allow better region handling for LinalgOps. Specifically, FillOp and CopyOp now build their regions explicitly which allows retiring older behavior that relied on specific op knowledge in both lowering to loops and vectorization. Differential Revision: https://reviews.llvm.org/D96598 --- .../mlir/Dialect/Linalg/IR/LinalgInterfaces.td | 31 +- .../mlir/Dialect/Linalg/IR/LinalgStructuredOps.td | 42 ++- .../LinalgToStandard/LinalgToStandard.cpp | 8 +- mlir/lib/Dialect/Linalg/EDSC/Builders.cpp | 13 +- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 363 ++++++++++++++------- .../Dialect/Linalg/Transforms/Generalization.cpp | 2 +- mlir/lib/Dialect/Linalg/Transforms/Loops.cpp | 46 +-- .../Dialect/Linalg/Transforms/Vectorization.cpp | 54 +-- mlir/test/Transforms/copy-removal.mlir | 9 +- .../mlir-linalg-ods-gen/test-linalg-ods-gen.tc | 6 +- .../mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp | 22 +- 11 files changed, 316 insertions(+), 280 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td index c26f022..95656eb 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -1056,20 +1056,6 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { //===------------------------------------------------------------------===// // Other static interface methods. //===------------------------------------------------------------------===// - StaticInterfaceMethod< - /*desc=*/[{ - Create an operation of the current type with the given location, - operands, and attributes. - }], - /*retTy=*/"Operation *", - /*methodName=*/"create", - (ins "OpBuilder &":$builder, "Location":$loc, "TypeRange":$resultTypes, - "ValueRange":$operands, - "ArrayRef":$attributes), [{ - return builder.create( - loc, resultTypes, operands, attributes); - }] - >, InterfaceMethod< /*desc=*/[{ Clone the current operation with the given location and operands. This @@ -1082,14 +1068,13 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { (ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes, "ValueRange":$operands), [{ - BlockAndValueMapping map; - unsigned numRegions = $_op->getNumRegions(); - Operation *res = create(b, loc, resultTypes, operands, $_op->getAttrs()); - assert(res->getNumRegions() == numRegions && "inconsistent # regions"); - for (unsigned ridx = 0; ridx < numRegions; ++ridx) - $_op->getRegion(ridx).cloneInto( - &res->getRegion(ridx), map); - return res; + BlockAndValueMapping bvm; + OperationState state( + loc, ConcreteOp::getOperationName(), operands, resultTypes, + $_op->getAttrs()); + for (Region &r : $_op->getRegions()) + r.cloneInto(state.addRegion(), bvm); + return b.createOperation(state); }] >, StaticInterfaceMethod< @@ -1098,7 +1083,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { Returns a null function if this named op does not define a region builder. }], - /*retTy=*/"std::function", + /*retTy=*/"std::function", /*methodName=*/"getRegionBuilder", (ins), [{ return ConcreteOp::getRegionBuilder(); }] diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index 8988a3a..05a6bb7 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -110,14 +110,13 @@ def CopyOp : LinalgStructured_Op<"copy", [CopyOpInterface]> { AnyStridedMemRef:$output, OptionalAttr:$inputPermutation, OptionalAttr:$outputPermutation); + let regions = (region AnyRegion:$region); - // TODO: this should go away once the usage of OptionalAttr triggers emission - // of builders with default arguments left unspecified. - let builders = [OpBuilderDAG<(ins "Value":$input, "Value":$output), - [{ - return build( - $_builder, $_state, input, output, AffineMapAttr(), AffineMapAttr()); - }]>]; + let builders = [ + OpBuilderDAG<(ins "Value":$input, "Value":$output, + CArg<"AffineMap", "AffineMap()">:$inputPermutation, + CArg<"AffineMap", "AffineMap()">:$outputPermutation, + CArg<"ArrayRef", "{}">:$attrs)>]; let extraClassDeclaration = structuredOpsDecls # [{ ValueRange inputs() { return getOperands().take_front(); } @@ -146,24 +145,31 @@ def CopyOp : LinalgStructured_Op<"copy", [CopyOpInterface]> { Value getSource() { return input();} Value getTarget() { return output(); } - static std::function getRegionBuilder() { - return nullptr; + static void regionBuilder(Block &block, ValueRange captures); + static std::function + getRegionBuilder() { + return ®ionBuilder; } + static unsigned getNumRegionArgs() { return 2; } }]; let verifier = [{ return ::verify(*this); }]; let assemblyFormat = [{ - `(` operands `)` attr-dict `:` type(operands) + `(` $input `,` $output `)` attr-dict `:` + type($input) `,` type($output) + custom($region, ref(type($input)), ref(type($input))) }]; let hasFolder = 1; let hasCanonicalizer = 1; + let skipDefaultBuilders = 1; } def FillOp : LinalgStructured_Op<"fill", []> { let arguments = (ins AnyShaped:$output, AnyTypeOf<[AnyFloat, AnySignlessInteger, AnyVector]>:$value); let results = (outs Optional:$result); + let regions = (region AnyRegion:$region); let extraClassDeclaration = structuredOpsDecls # [{ ValueRange inputs() { return {}; } ValueRange outputs() { return getOperands().take_front(); } @@ -183,13 +189,18 @@ def FillOp : LinalgStructured_Op<"fill", []> { extractOrIdentityMap(llvm::None, getNumParallelLoops(), context)}); } - static std::function getRegionBuilder() { - return nullptr; + static void regionBuilder(Block &block, ValueRange captures); + static std::function + getRegionBuilder() { + return ®ionBuilder; } + static unsigned getNumRegionArgs() { return 1; } }]; let assemblyFormat = [{ - `(` operands `)` attr-dict `:` type(operands) (`->` type($result)^)? + `(` $output `,` $value `)` attr-dict `:` + type($output) `,` type($value) (`->` type($result)^)? + custom($region, ref(type($output)), ref($value)) }]; let builders = [ @@ -268,7 +279,8 @@ class PoolingBase_Op props> return padding().getValue().getValue({i, 1}); } - static std::function getRegionBuilder() { + static std::function getRegionBuilder() + { return nullptr; } }]; @@ -519,7 +531,7 @@ class GenericOpBase : LinalgStructuredBase_Opstr() : "op_has_no_registered_library_name"; } - static std::function getRegionBuilder() { + static std::function getRegionBuilder() { return nullptr; } }]; diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp index 8b53ecb..276b124 100644 --- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp +++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp @@ -154,7 +154,13 @@ LogicalResult mlir::linalg::CopyTransposeRewrite::matchAndRewrite( if (in == op.input() && out == op.output()) return failure(); - rewriter.replaceOpWithNewOp(op, in, out); + auto libraryCallName = getLibraryCallSymbolRef(op, rewriter); + if (!libraryCallName) + return failure(); + + rewriter.replaceOpWithNewOp( + op, libraryCallName.getValue(), TypeRange(), + createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), {in, out})); return success(); } diff --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp index 8bb104d..46e42e2 100644 --- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp @@ -27,8 +27,6 @@ Operation *mlir::edsc::makeGenericLinalgOp( ArrayRef outputs, TypeRange resultTensorTypes, function_ref regionBuilder, ArrayRef otherValues, ArrayRef otherAttributes) { - OpBuilder &builder = edsc::ScopedContext::getBuilderRef(); - // Build maps SmallVector, 4> exprsList; exprsList.reserve(inputs.size() + outputs.size()); @@ -54,13 +52,10 @@ Operation *mlir::edsc::makeGenericLinalgOp( resultTensorTypes, inputValues, outputValues, - builder.getAffineMapArrayAttr(maps), - builder.getStrArrayAttr(iteratorStrTypes), - StringAttr() /*doc*/, - StringAttr() /*library_call*/, - ArrayAttr() /*sparse*/ - /* TODO: other attributes in op */ - ) + maps, + iteratorStrTypes, + ""/*doc*/, + ""/*library_call*/) .getOperation(); // clang-format on diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 49e3c2b..989a164 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -33,32 +33,53 @@ using namespace mlir; using namespace mlir::linalg; /// Forward declarations. + +/// Generic entry point to create the block for the region of a LinalgOp. +/// This is used by both named structured ops created by ods-gen and by manually +/// defined C++ ops. +/// This is used by both builders and parsers. +/// This function creates the block in the region with arguments corresponding +/// to the elemental types of `inputTypes` and `outputTypes`, which are asserted +/// to be ShapedType. +template +static void fillStructuredOpRegion( + OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, + TypeRange outputTypes, ValueRange captures = {}, + std::function errorHandler = [](unsigned, + unsigned) {}); + +/// Generic entry point to create both the region and the block of a LinalgOp. template -static void buildNamedStructuredOpRegionAndAttributes(OpBuilder &opBuilder, - OperationState &result, - TypeRange inputTypes, - TypeRange outputTypes); +static void +createAndFillStructuredOpRegion(OpBuilder &opBuilder, OperationState &result, + TypeRange inputTypes, TypeRange outputTypes, + ValueRange captures = {}); +/// Common parsing and printing used for both named structured ops created by +/// ods-gen and by manually defined C++ ops. Does not handle regions. static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl &inputTypes, SmallVectorImpl &outputTypes); +template +static void printCommonStructuredOpParts(OpAsmPrinter &p, + NamedStructuredOpType op); +/// Specific parsing and printing for named structured ops created by ods-gen. template static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, - TypeRange inputTypes, TypeRange outputTypes); + TypeRange inputTypes, TypeRange outputTypes, + ArrayRef captures = {}); + static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl &resultTypes); template -static ParseResult parseNamedStructuredOp(OpAsmParser &parser, - OperationState &result); - -template -static void printCommonStructuredOpParts(OpAsmPrinter &p, - NamedStructuredOpType op); +static ParseResult +parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, + ArrayRef captures = {}); static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes); @@ -84,13 +105,135 @@ static LogicalResult foldMemRefCast(Operation *op) { } //===----------------------------------------------------------------------===// +// CopyOp +//===----------------------------------------------------------------------===// +void CopyOp::regionBuilder(Block &block, ValueRange captures) { + using namespace edsc::intrinsics; + assert(block.getNumArguments() == 2 && "CopyOp regionBuilder expects 2 args"); + (linalg_yield(block.getArgument(0))); +} + +void CopyOp::build(OpBuilder &builder, OperationState &result, Value input, + Value output, AffineMap inputPermutation, + AffineMap outputPermutation, + ArrayRef namedAttrs) { + result.addOperands({input, output}); + result.addAttributes(namedAttrs); + if (inputPermutation) + result.addAttribute("inputPermutation", + AffineMapAttr::get(inputPermutation)); + if (outputPermutation) + result.addAttribute("outputPermutation", + AffineMapAttr::get(outputPermutation)); + result.addRegion(); + fillStructuredOpRegion(builder, *result.regions.front(), + TypeRange{input.getType()}, + TypeRange{output.getType()}); +} + +ParseResult parseCopyOpRegion(OpAsmParser &parser, Region &r, Type inputType, + Type outputType) { + OpBuilder opBuilder(parser.getBuilder().getContext()); + fillStructuredOpRegion(opBuilder, r, TypeRange{inputType}, + TypeRange{outputType}); + return success(); +} + +/// CopyOp region is elided when printing. +void printCopyOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Type) {} + +static LogicalResult verify(CopyOp op) { + auto outputViewType = op.getOutputShapedType(0); + auto inputViewType = op.getInputShapedType(0); + if (inputViewType.getElementType() != outputViewType.getElementType()) + return op.emitOpError("expects views of the same type"); + if (inputViewType.getRank() != outputViewType.getRank()) + return op.emitOpError("expects views of the same rank"); + auto rank = op.getNumParallelLoops(); + auto inputPermutationMap = op.inputPermutation(); + if (inputPermutationMap) { + if (inputPermutationMap->getNumInputs() != rank) + return op.emitOpError("expects optional input_permutation map of rank ") + << rank; + if (!inputPermutationMap->isPermutation()) + return op.emitOpError( + "expects optional input_permutation map to be a permutation"); + } + auto outputPermutationMap = op.outputPermutation(); + if (outputPermutationMap) { + if (outputPermutationMap->getNumInputs() != rank) + return op.emitOpError("expects optional output_permutation map of rank ") + << rank; + if (!outputPermutationMap->isPermutation()) + return op.emitOpError( + "expects optional output_permutation map to be a permutation"); + } + if (rank == 0 && inputPermutationMap) + return op.emitOpError("expected no input permutation when rank == 0"); + if (rank == 0 && outputPermutationMap) + return op.emitOpError("expected no output permutation when rank == 0"); + return success(); +} + +void CopyOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), input(), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), output(), + SideEffects::DefaultResource::get()); +} + +//===----------------------------------------------------------------------===// // FillOp //===----------------------------------------------------------------------===// +void FillOp::regionBuilder(Block &block, ValueRange captures) { + using namespace edsc::intrinsics; + assert(captures.size() == 1 && "FillOp regionBuilder expects 1 capture"); + (linalg_yield(captures)); +} void FillOp::build(OpBuilder &builder, OperationState &result, Value output, Value value) { build(builder, result, output.getType().dyn_cast(), output, value); + fillStructuredOpRegion(builder, *result.regions.front(), TypeRange{}, + TypeRange{output.getType()}, value); +} + +ParseResult parseFillOpRegion(OpAsmParser &parser, Region &r, Type outputType, + OpAsmParser::OperandType valueRef) { + OpBuilder opBuilder(parser.getBuilder().getContext()); + // Resolve `valueRef` into `value` at parse time so we can build the region + // with captures. + SmallVector value; + parser.resolveOperand(valueRef, getElementTypeOrSelf(outputType), value); + fillStructuredOpRegion(opBuilder, r, TypeRange{}, + TypeRange{outputType}, value); + return success(); +} + +/// FillOp region is elided when printing. +void printFillOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Value) {} + +static LogicalResult verify(FillOp op) { + auto viewType = op.getOutputShapedType(0); + auto fillType = op.value().getType(); + if (viewType.getElementType() != fillType) + return op.emitOpError("expects fill type to match view elemental type"); + if (!op.getNumResults() && !viewType.isa()) { + return op.emitOpError( + "expected fill op with no result value to use memref type"); + } + return success(); +} + +void FillOp::getEffects( + SmallVectorImpl> + &effects) { + if (output().getType().isa()) + effects.emplace_back(MemoryEffects::Write::get(), output(), + SideEffects::DefaultResource::get()); } //===----------------------------------------------------------------------===// @@ -397,7 +540,6 @@ static LogicalResult verify(IndexedGenericOp op) { return verifyGenericOp(op); } // InitTensorOp //===----------------------------------------------------------------------===// - static LogicalResult verify(InitTensorOp op) { RankedTensorType resultType = op.getType(); SmallVector staticSizes = llvm::to_vector<4>(llvm::map_range( @@ -1396,68 +1538,6 @@ static LogicalResult verify(linalg::YieldOp op) { /////// Operations corresponding to library calls defined with Tablegen //////// -void FillOp::getEffects( - SmallVectorImpl> - &effects) { - if (output().getType().isa()) - effects.emplace_back(MemoryEffects::Write::get(), output(), - SideEffects::DefaultResource::get()); -} - -static LogicalResult verify(FillOp op) { - auto viewType = op.getOutputShapedType(0); - auto fillType = op.value().getType(); - if (viewType.getElementType() != fillType) - return op.emitOpError("expects fill type to match view elemental type"); - if (!op.getNumResults() && !viewType.isa()) { - return op.emitOpError( - "expected fill op with no result value to use memref type"); - } - return success(); -} - -void CopyOp::getEffects( - SmallVectorImpl> - &effects) { - effects.emplace_back(MemoryEffects::Read::get(), input(), - SideEffects::DefaultResource::get()); - effects.emplace_back(MemoryEffects::Write::get(), output(), - SideEffects::DefaultResource::get()); -} - -static LogicalResult verify(CopyOp op) { - auto outputViewType = op.getOutputShapedType(0); - auto inputViewType = op.getInputShapedType(0); - if (inputViewType.getElementType() != outputViewType.getElementType()) - return op.emitOpError("expects views of the same type"); - if (inputViewType.getRank() != outputViewType.getRank()) - return op.emitOpError("expects views of the same rank"); - auto rank = op.getNumParallelLoops(); - auto inputPermutationMap = op.inputPermutation(); - if (inputPermutationMap) { - if (inputPermutationMap->getNumInputs() != rank) - return op.emitOpError("expects optional input_permutation map of rank ") - << rank; - if (!inputPermutationMap->isPermutation()) - return op.emitOpError( - "expects optional input_permutation map to be a permutation"); - } - auto outputPermutationMap = op.outputPermutation(); - if (outputPermutationMap) { - if (outputPermutationMap->getNumInputs() != rank) - return op.emitOpError("expects optional output_permutation map of rank ") - << rank; - if (!outputPermutationMap->isPermutation()) - return op.emitOpError( - "expects optional output_permutation map to be a permutation"); - } - if (rank == 0 && inputPermutationMap) - return op.emitOpError("expected no input permutation when rank == 0"); - if (rank == 0 && outputPermutationMap) - return op.emitOpError("expected no output permutation when rank == 0"); - return success(); -} - template static LogicalResult verifyStrideOrDilation(LinalgPoolingOp op, ArrayRef attrs, @@ -1690,14 +1770,25 @@ OpFoldResult TensorReshapeOp::fold(ArrayRef operands) { } //===----------------------------------------------------------------------===// -// Auto-generated Linalg named ops. +// Support for named Linalg ops defined in ods-gen. //===----------------------------------------------------------------------===// +/// Generic entry point to create the block for the region of a LinalgOp. +/// This is used by both named structured ops created by ods-gen and by manually +/// defined C++ ops. +/// This is used by both builders and parsers. +/// This function creates the block in the region with arguments corresponding +/// to the elemental types of `inputTypes` and `outputTypes`, which are asserted +/// to be ShapedType. template -static void buildNamedStructuredOpRegionAndAttributesImpl( - OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, - TypeRange outputTypes, - std::function errorHandler) { +static void +fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion, + TypeRange inputTypes, TypeRange outputTypes, + ValueRange captures, + std::function errorHandler) { + assert(llvm::all_of(inputTypes, [](Type t) { return t.isa(); })); + assert(llvm::all_of(outputTypes, [](Type t) { return t.isa(); })); + // TODO: atm all operands go through getElementTypeOrSelf, // reconsider when we have evidence we need to. SmallVector argTypes; @@ -1707,7 +1798,7 @@ static void buildNamedStructuredOpRegionAndAttributesImpl( // RAII. OpBuilder::InsertionGuard guard(opBuilder); - Block *body = opBuilder.createBlock(®ion, {}, argTypes); + Block *body = opBuilder.createBlock(®ion, /*insertPt=*/{}, argTypes); unsigned actual = body->getNumArguments(); unsigned expected = NamedStructuredOpType::getNumRegionArgs(); if (expected != actual) @@ -1715,53 +1806,30 @@ static void buildNamedStructuredOpRegionAndAttributesImpl( opBuilder.setInsertionPointToStart(body); mlir::edsc::ScopedContext scope(opBuilder, opBuilder.getUnknownLoc()); - NamedStructuredOpType::regionBuilder(*body); + NamedStructuredOpType::regionBuilder(*body, captures); // indexing_maps is an auto-generated method. // iterator_types is an auto-generated method. } +/// Generic entry point to create both the region and the block of a LinalgOp. template -void buildNamedStructuredOpRegionAndAttributes(OpBuilder &opBuilder, - OperationState &result, - TypeRange inputTypes, - TypeRange outputTypes) { +void createAndFillStructuredOpRegion(OpBuilder &opBuilder, + OperationState &result, + TypeRange inputTypes, + TypeRange outputTypes, + ValueRange captures) { Region ®ion = *result.addRegion(); - buildNamedStructuredOpRegionAndAttributesImpl( - opBuilder, region, inputTypes, outputTypes, + fillStructuredOpRegion( + opBuilder, region, inputTypes, outputTypes, captures, [&](unsigned expected, unsigned actual) { - llvm::errs() << "region expects " << expected << " args, got " - << actual; assert(expected != actual && "incorrect number of arguments"); }); } -template -static ParseResult -parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, - TypeRange inputTypes, TypeRange outputTypes) { - ParseResult res = success(); - OpBuilder opBuilder(parser.getBuilder().getContext()); - buildNamedStructuredOpRegionAndAttributesImpl( - opBuilder, region, inputTypes, outputTypes, - [&](unsigned expected, unsigned actual) { - res = parser.emitError(parser.getCurrentLocation(), - llvm::formatv("region expects {0} args, got {1}", - expected, actual)); - }); - return res; -} - -static ParseResult -parseNamedStructuredOpResults(OpAsmParser &parser, - SmallVectorImpl &resultTypes) { - if (succeeded(parser.parseOptionalArrow())) - if (parser.parseTypeList(resultTypes)) - return failure(); - return success(); -} - +/// Common parsing used for both named structured ops created by ods-gen and by +/// manually defined C++ ops. Does not handle regions. static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl &inputTypes, @@ -1802,8 +1870,56 @@ parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, } template -static ParseResult parseNamedStructuredOp(OpAsmParser &parser, - OperationState &result) { +static void printCommonStructuredOpParts(OpAsmPrinter &p, + NamedStructuredOpType op) { + if (!op.inputs().empty()) + p << " ins(" << op.inputs() << " : " << op.inputs().getTypes() << ")"; + if (!op.outputs().empty()) + p << " outs(" << op.outputs() << " : " << op.outputs().getTypes() << ")"; +} + +//===----------------------------------------------------------------------===// +// Specific parsing and printing for named structured ops created by ods-gen. +//===----------------------------------------------------------------------===// + +template +static ParseResult +parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, + TypeRange inputTypes, TypeRange outputTypes, + ArrayRef captures) { + ParseResult res = success(); + OpBuilder opBuilder(parser.getBuilder().getContext()); + // Resolve `captures` into `capturedValues` at parse time so we can build the + // region with captures. + SmallVector capturedValues; + fillStructuredOpRegion( + opBuilder, region, inputTypes, outputTypes, capturedValues, + [&](unsigned expected, unsigned actual) { + res = parser.emitError( + parser.getCurrentLocation(), + llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated " + "region expects {0} args, got {1}", + expected, actual)); + region.front().dump(); + }); + return res; +} + +static ParseResult +parseNamedStructuredOpResults(OpAsmParser &parser, + SmallVectorImpl &resultTypes) { + if (succeeded(parser.parseOptionalArrow())) + if (parser.parseTypeList(resultTypes)) + return failure(); + return success(); +} + +template +static ParseResult +parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, + ArrayRef captures) { + // TODO: Enable when ods-gen supports captures. + assert(captures.empty() && "unexpected captures for named structured ops"); SmallVector inputTypes, outputTypes; if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes)) return failure(); @@ -1817,7 +1933,7 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser, std::unique_ptr region = std::make_unique(); if (parseNamedStructuredOpRegion( - parser, *region, inputTypes, outputTypes)) + parser, *region, inputTypes, outputTypes, captures)) return failure(); result.addRegion(std::move(region)); @@ -1832,15 +1948,6 @@ static void printNamedStructuredOpResults(OpAsmPrinter &p, } template -static void printCommonStructuredOpParts(OpAsmPrinter &p, - NamedStructuredOpType op) { - if (!op.inputs().empty()) - p << " ins(" << op.inputs() << " : " << op.inputs().getTypes() << ")"; - if (!op.outputs().empty()) - p << " outs(" << op.outputs() << " : " << op.outputs().getTypes() << ")"; -} - -template static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) { p << op.getOperationName(); p.printOptionalAttrDict(op.getAttrs(), @@ -1861,6 +1968,10 @@ static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op) { return verifyGenericOp(op); } +//===----------------------------------------------------------------------===// +// Canonicalizers and Folders. +//===----------------------------------------------------------------------===// + namespace { struct EraseDeadLinalgOp : public RewritePattern { EraseDeadLinalgOp(PatternBenefit benefit = 1) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp index 0be1c55..69de55c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp @@ -49,7 +49,7 @@ static linalg::GenericOp createGenericOpFromNamedOp(linalg::LinalgOp namedOp, indexingMaps, iterators, [®ionBuilder](OpBuilder &bodyBuilder, Location loc, ValueRange) { edsc::ScopedContext scope(bodyBuilder, loc); - regionBuilder(*bodyBuilder.getBlock()); + regionBuilder(*bodyBuilder.getBlock(), /*captures=*/{}); }); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp index 391562b..d09d3e0 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -52,14 +52,6 @@ static SmallVector makeCanonicalAffineApplies(OpBuilder &b, return res; } -static SmallVector permuteIvs(ArrayRef ivs, - Optional permutation) { - return permutation ? applyMapToValues(ScopedContext::getBuilderRef(), - ScopedContext::getLocation(), - permutation.getValue(), ivs) - : SmallVector(ivs.begin(), ivs.end()); -} - template static void inlineRegionAndEmitStore(OpType op, ArrayRef indexedValues, ArrayRef> indexing, @@ -178,40 +170,6 @@ static void emitScalarImplementation(ArrayRef allIvs, outputBuffers); } -template -static void emitScalarImplementation(ArrayRef allIvs, CopyOp copyOp) { - assert(copyOp.hasBufferSemantics() && - "expected linalg op with buffer semantics"); - auto nPar = copyOp.getNumParallelLoops(); - assert(nPar == allIvs.size()); - auto inputIvs = - permuteIvs(allIvs.take_front(nPar), copyOp.inputPermutation()); - auto outputIvs = - permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation()); - SmallVector iivs(inputIvs.begin(), inputIvs.end()); - SmallVector oivs(outputIvs.begin(), outputIvs.end()); - IndexedValueType O(copyOp.getOutputBuffer(0)), I(copyOp.getInput(0)); - // Emit the proper scalar assignment, whether we are dealing with a 0-D or - // an n-D loop nest; with or without permutations. - // clang-format off - nPar > 0 ? O(oivs) = I(iivs) : - O() = I(); - // clang-format on -} - -template -static void emitScalarImplementation(ArrayRef allIvs, FillOp fillOp) { - assert(fillOp.hasBufferSemantics() && - "expected linalg op with buffer semantics"); - auto nPar = fillOp.getNumParallelLoops(); - assert(nPar == allIvs.size()); - auto ivs = SmallVector(allIvs.begin(), allIvs.begin() + nPar); - IndexedValueType O(fillOp.getOutputBuffer(0)); - // Emit the proper scalar assignment, whether we are dealing with a 0-D or - // an n-D loop nest; with or without permutations. - nPar > 0 ? O(ivs) = fillOp.value() : O() = fillOp.value(); -} - // Create a padded view into the given `input` tensor using the 'indices' // to access the tensor. `skipPadding` lists the dimensions for which no padding // is needed e.g. the non-spatial dimensions for convolutions. @@ -533,8 +491,8 @@ linalgOpToLoopsImpl(Operation *op, OpBuilder &builder, assert(iterArgs.empty() && "unexpected iterArgs"); allIvs.append(ivs.begin(), ivs.end()); llvm::TypeSwitch(op) - .Case([&](auto op) { + .Case([&](auto op) { emitScalarImplementation(allIvs, op); }) .Default([&](Operation *op) { assert(false && "unexpected op"); }); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 49d323a..bfd2884 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -267,7 +267,7 @@ static Optional vectorizeAsLinalgGeneric( llvm::map_range(linalgOp.getShapedOperandTypes(), [](ShapedType t) { return t.getElementType(); })); block->addArguments(elementTypes); - linalgOp.getRegionBuilder()(*block); + linalgOp.getRegionBuilder()(*block, /*captures=*/{}); } Block *block = ®ion->front(); @@ -333,24 +333,26 @@ static bool hasOnlyScalarElementwiseOp(Region &r) { // Return true if the op is an element-wise linalg op. static bool isElementwise(Operation *op) { - auto genericOp = dyn_cast(op); - if (!genericOp) + auto linalgOp = dyn_cast(op); + if (!linalgOp) return false; - if (genericOp.getNumLoops() != genericOp.getNumParallelLoops()) + if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops()) return false; // TODO: relax the restrictions on indexing map. - for (unsigned i = 0, e = genericOp.getNumOutputs(); i < e; i++) { - if (!genericOp.getOutputIndexingMap(i).isIdentity()) + for (unsigned i = 0, e = linalgOp.getNumOutputs(); i < e; i++) { + if (!linalgOp.getOutputIndexingMap(i).isIdentity()) return false; } // Currently bound the input indexing map to minor identity as other // permutations might require adding transpose ops to convert the vector read // to the right shape. - for (unsigned i = 0, e = genericOp.getNumInputs(); i < e; i++) { - if (!genericOp.getInputIndexingMap(i).isMinorIdentity()) + for (unsigned i = 0, e = linalgOp.getNumInputs(); i < e; i++) { + if (!linalgOp.getInputIndexingMap(i).isMinorIdentity()) return false; } - return hasOnlyScalarElementwiseOp(genericOp.getRegion()); + if (linalgOp->getNumRegions() != 1) + return false; + return hasOnlyScalarElementwiseOp(linalgOp->getRegion(0)); } static Optional vectorizeContraction(OpBuilder &builder, @@ -393,9 +395,6 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) { for (Type outputTensorType : linalgOp.getOutputTensorTypes()) if (!outputTensorType.cast().hasStaticShape()) return failure(); - - if (isa(op)) - return success(); if (isElementwise(op)) return success(); return success(isaContractionOpInterface(linalgOp)); @@ -407,43 +406,12 @@ Optional mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, return llvm::None; edsc::ScopedContext scope(builder, op->getLoc()); - // In the case of 0-D memrefs, return null and special case to scalar load or - // store later. - if (auto fillOp = dyn_cast(op)) { - // Vectorize fill as a vector.broadcast. - LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: " - << "Rewrite linalg.fill as vector.broadcast: " << *op); - VectorizedLinalgOp res; - if (Value v = buildVectorWrite(builder, fillOp.value(), fillOp.output())) - res.tensorResults.push_back(v); - return res; - } - if (auto copyOp = dyn_cast(op)) { - // Vectorize copy as a vector.transfer_read+vector.transfer_write. - LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: " - << "Rewrite linalg.copy as vector.transfer_read + " - "vector.transfer_write: " - << *op); - Value vector = buildVectorRead(builder, copyOp.input()); - VectorizedLinalgOp res; - if (Value v = buildVectorWrite(builder, vector, copyOp.output())) - res.tensorResults.push_back(v); - return res; - } if (isElementwise(op)) { LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: " << "Vectorize linalg op as a generic: " << *op); return vectorizeAsLinalgGeneric(builder, cast(op)); } - // TODO: as soon as Copy and FillOp. get a region builder, replace all the - // above by: - // if (isa(op) || isElementwise(op)) { - // LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: " - // << "Vectorize linalg op as a generic: " << *op); - // return vectorizeAsLinalgGeneric(builder, cast(op)); - // } - return vectorizeContraction(builder, cast(op)); } diff --git a/mlir/test/Transforms/copy-removal.mlir b/mlir/test/Transforms/copy-removal.mlir index a66006f..1432037 100644 --- a/mlir/test/Transforms/copy-removal.mlir +++ b/mlir/test/Transforms/copy-removal.mlir @@ -1,5 +1,4 @@ -// RUN: mlir-opt -copy-removal -split-input-file %s -//| FileCheck %s +// RUN: mlir-opt -copy-removal -split-input-file %s | FileCheck %s // All linalg copies except the linalg.copy(%1, %9) must be removed since the // defining operation of %1 and its DeallocOp have been defined in another block. @@ -256,7 +255,7 @@ func @test_ReuseCopyTargetAsSource(%arg0: memref<2xf32>, %result: memref<2xf32>) %tmp2 = math.exp %gen2_arg0 : f32 linalg.yield %tmp2 : f32 } - "linalg.copy"(%temp, %result) : (memref<2xf32>, memref<2xf32>) -> () + linalg.copy(%temp, %result) : memref<2xf32>, memref<2xf32> dealloc %temp : memref<2xf32> // CHECK: return return @@ -292,7 +291,7 @@ func @test_ReuseCopyTargetAsSource(%arg0: memref<2xf32>){ linalg.yield %tmp2 : f32 } // CHECK: linalg.copy - "linalg.copy"(%temp, %to) : (memref<2xf32>, memref<2xf32>) -> () + linalg.copy(%temp, %to) : memref<2xf32>, memref<2xf32> dealloc %temp : memref<2xf32> return } @@ -355,7 +354,7 @@ func @check_with_affine_dialect(%arg0: memref<4xf32>, %arg1: memref<4xf32>, %arg } // CHECK-NOT: linalg.copy // CHECK-NOT: dealloc - "linalg.copy"(%0, %arg2) : (memref<4xf32>, memref<4xf32>) -> () + linalg.copy(%0, %arg2) : memref<4xf32>, memref<4xf32> dealloc %0 : memref<4xf32> //CHECK: return return diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc index a16a2b8..b197ba3d 100644 --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc @@ -23,7 +23,7 @@ // IMPL-NEXT: map2 = simplifyAffineMap(map2); // IMPL-NEXT: return {{.+}}.getAffineMapArrayAttr({ map0, map1, map2 }); // -// IMPL: void Test1Op::regionBuilder(Block &block) { +// IMPL: void Test1Op::regionBuilder(Block &block, ValueRange captures) { // IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); // IMPL: Value [[d:.*]] = std_mulf([[a]], [[b]]); // IMPL: Value [[e:.*]] = std_addf([[c]], [[d]]); @@ -47,7 +47,7 @@ def test1(A: f32(M, K), B: f32(K)) -> (C: f32(M)) { // IMPL: AffineMap::get(3, 3, {d2, d1}, context) // IMPL: AffineMap::get(3, 3, {d0, d1}, context) // -// IMPL: Test2Op::regionBuilder(Block &block) { +// IMPL: Test2Op::regionBuilder(Block &block, ValueRange captures) { // IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); // IMPL: Value [[d:.*]] = std_mulf([[a]], [[b]]); // IMPL: Value [[e:.*]] = std_addf([[c]], [[d]]); @@ -71,7 +71,7 @@ def test2(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) { // IMPL: AffineMap::get(4, 4, {d3, d2}, context) // IMPL: AffineMap::get(4, 4, {d0, d1, d2}, context) // -// IMPL: Test3Op::regionBuilder(Block &block) { +// IMPL: Test3Op::regionBuilder(Block &block, ValueRange captures) { // IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); // IMPL: Value [[d:.*]] = std_mulf([[a]], [[b]]); // IMPL: Value [[e:.*]] = std_addf([[c]], [[d]]); diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp index 0934967..4f57322c 100644 --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp @@ -1871,11 +1871,11 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName, $_builder.getI32VectorAttr({{ static_cast(inputs.size()), static_cast(outputs.size())})); - buildNamedStructuredOpRegionAndAttributes<{0}>( + createAndFillStructuredOpRegion<{0}>( $_builder, $_state, TypeRange(inputs), - TypeRange(outputs)); + TypeRange(outputs)/*, TODO: support captures*/); }]>, OpBuilderDAG< (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, @@ -1889,11 +1889,11 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName, $_builder.getI32VectorAttr({{ static_cast(inputs.size()), static_cast(outputs.size())})); - buildNamedStructuredOpRegionAndAttributes<{0}>( + createAndFillStructuredOpRegion<{0}>( $_builder, $_state, TypeRange(inputs), - TypeRange(outputs)); + TypeRange(outputs)/*, TODO: support captures*/); }]>, OpBuilderDAG< (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands, @@ -1907,7 +1907,9 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName, {6} ]; let printer = [{{ return ::printNamedStructuredOp(p, *this); }]; - let parser = [{{ return ::parseNamedStructuredOp<{0}>(parser, result); }]; + let parser = [{{ + return ::parseNamedStructuredOp<{0}>(parser, result/*TODO:, captures*/); + }]; let hasFolder = 1; let hasCanonicalizer = 1; @@ -1915,8 +1917,8 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName, // Auto-generated. ArrayAttr iterator_types(); ArrayAttr indexing_maps(); - static void regionBuilder(Block &block); - static std::function getRegionBuilder() {{ + static void regionBuilder(Block &block, ValueRange captures); + static std::function getRegionBuilder() {{ return regionBuilder; } @@ -1980,11 +1982,11 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName, $_builder.getI32VectorAttr({{ static_cast(inputs.size()), static_cast(outputs.size())})); - buildNamedStructuredOpRegionAndAttributes<{0}>( + createAndFillStructuredOpRegion<{0}>( $_builder, $_state, TypeRange(inputs), - TypeRange(outputs)); + TypeRange(outputs)/*, TODO: support captures*/); {2} }]> )FMT"; @@ -2311,7 +2313,7 @@ void TCParser::printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName, }; const char *regionBuilderFmt = R"FMT( - void {0}::regionBuilder(Block &block) { + void {0}::regionBuilder(Block &block, ValueRange captures) { using namespace edsc; using namespace intrinsics; auto args = block.getArguments(); -- 2.7.4