From 048764f23a380fd6f8cc562a0008dcc6095fb594 Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Sun, 2 Jul 2023 14:43:14 +0100 Subject: [PATCH] [mlir][transform] Allow arbitrary indices to be scalable This change lifts the limitation that only the trailing dimensions/sizes in dynamic index lists can be scalable. It allows us to extend `MaskedVectorizeOp` and `TileOp` from the Transform dialect so that the following is allowed: %1, %loops:3 = transform.structured.tile %0 [[4], [4], 4] This is also a follow up for https://reviews.llvm.org/D153372 that will enable the following (middle vector dimension is scalable): transform.structured.masked_vectorize %0 vector_sizes [2, [4], 8] To facilate this change, the hooks for parsing and printing dynamic index lists are updated accordingly (`printDynamicIndexList` and `parseDynamicIndexList`, respectively). `MaskedVectorizeOp` and `TileOp` are updated to include an array of attribute of bools that captures whether the corresponding vector dimension/tile size, respectively, are scalable or not. This change is a part of a larger effort to enable scalable vectorisation in Linalg. See this RFC for more context: * https://discourse.llvm.org/t/rfc-scalable-vectorisation-in-linalg/ Differential Revision: https://reviews.llvm.org/D154336 --- .../Linalg/TransformOps/LinalgTransformOps.td | 9 ++-- mlir/include/mlir/Interfaces/ViewLikeInterface.h | 48 ++++++++++-------- .../Linalg/TransformOps/LinalgTransformOps.cpp | 22 ++++---- mlir/lib/Dialect/SCF/IR/SCF.cpp | 32 ++++++------ mlir/lib/Dialect/Transform/Utils/Utils.cpp | 3 +- mlir/lib/Interfaces/ViewLikeInterface.cpp | 58 ++++++++-------------- mlir/test/Dialect/Linalg/transform-op-tile.mlir | 22 -------- mlir/test/Dialect/Transform/ops.mlir | 8 +++ 8 files changed, 86 insertions(+), 116 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 7caae2b..c33be09 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -1686,7 +1686,7 @@ def TileOp : Op:$dynamic_sizes, DefaultValuedOptionalAttr:$static_sizes, DefaultValuedOptionalAttr:$interchange, - DefaultValuedOptionalAttr:$last_tile_size_scalable); + DefaultValuedOptionalAttr:$scalable_sizes); let results = (outs TransformHandleTypeInterface:$tiled_linalg_op, Variadic:$loops); let builders = [ @@ -2008,9 +2008,10 @@ def MaskedVectorizeOp : Op:$vector_sizes, UnitAttr:$vectorize_nd_extract, + DefaultValuedOptionalAttr: + $scalable_sizes, DefaultValuedOptionalAttr: - $static_vector_sizes, - DefaultValuedOptionalAttr:$last_vector_size_scalable); + $static_vector_sizes); let results = (outs); let assemblyFormat = [{ @@ -2018,7 +2019,7 @@ def MaskedVectorizeOp : Op($vector_sizes, $static_vector_sizes, type($vector_sizes), - $last_vector_size_scalable) + $scalable_sizes) attr-dict `:` type($target) }]; diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h index fad380d..9c99c4c 100644 --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h @@ -52,13 +52,15 @@ namespace mlir { /// integer attributes in a list. E.g. /// `[%arg0 : index, 7, 42, %arg42 : i32]`. /// -/// If `isTrailingIdxScalable` is true, then wrap the trailing index with -/// square brackets, e.g. `[42]`, to denote scalability. This would normally be -/// used for scalable tile or vector sizes. +/// Indices can be scalable. For example, "4" in "[2, [4], 8]" is scalable. +/// This notation is similar to how scalable dims are marked when defining +/// Vectors. For each value in `integers`, the corresponding `bool` in +/// `scalables` encodes whether it's a scalable index. If `scalables` is +/// empty then assume that all indices are non-scalable. void printDynamicIndexList( OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef integers, TypeRange valueTypes = TypeRange(), - BoolAttr isTrailingIdxScalable = {}, + ArrayRef scalables = {}, AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square); /// Parser hook for custom directive in assemblyFormat. @@ -78,41 +80,43 @@ void printDynamicIndexList( /// `kDynamic`]" /// 2. `ssa` is filled with "[%arg0, %arg1]". /// -/// Trailing indices can be scalable. For example, "42" in "[7, [42]]" is -/// scalable. This notation is similar to how scalable dims are marked when -/// defining Vectors. If /p isTrailingIdxScalable is null, scalable indices are -/// not allowed/expected. When it's not null, this hook will set the -/// corresponding value to: -/// * true if the trailing idx is scalable, -/// * false otherwise. +/// Indices can be scalable. For example, "4" in "[2, [4], 8]" is scalable. +/// This notation is similar to how scalable dims are marked when defining +/// Vectors. For each value in `integers`, the corresponding `bool` in +/// `scalables` encodes whether it's a scalable index. ParseResult parseDynamicIndexList( OpAsmParser &parser, SmallVectorImpl &values, - DenseI64ArrayAttr &integers, bool *isTrailingIdxScalable = nullptr, + DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalables, SmallVectorImpl *valueTypes = nullptr, AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square); inline ParseResult parseDynamicIndexList( OpAsmParser &parser, SmallVectorImpl &values, - DenseI64ArrayAttr &integers, SmallVectorImpl &valueTypes, + DenseI64ArrayAttr &integers, SmallVectorImpl *valueTypes = nullptr, AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) { - return parseDynamicIndexList(parser, values, integers, - /*isTrailingIdxScalable=*/nullptr, &valueTypes, + DenseBoolArrayAttr scalables = {}; + return parseDynamicIndexList(parser, values, integers, scalables, valueTypes, delimiter); } inline ParseResult parseDynamicIndexList( OpAsmParser &parser, SmallVectorImpl &values, DenseI64ArrayAttr &integers, SmallVectorImpl &valueTypes, - BoolAttr &isTrailingIdxScalable, + AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) { + DenseBoolArrayAttr scalables = {}; + return parseDynamicIndexList(parser, values, integers, scalables, + &valueTypes, delimiter); +} +inline ParseResult parseDynamicIndexList( + OpAsmParser &parser, + SmallVectorImpl &values, + DenseI64ArrayAttr &integers, SmallVectorImpl &valueTypes, + DenseBoolArrayAttr &scalables, AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) { - bool scalable = false; - auto res = parseDynamicIndexList(parser, values, integers, &scalable, - &valueTypes, delimiter); - auto scalableAttr = parser.getBuilder().getBoolAttr(scalable); - isTrailingIdxScalable = scalableAttr; - return res; + return parseDynamicIndexList(parser, values, integers, scalables, &valueTypes, + delimiter); } /// Verify that a the `values` has as many elements as the number of entries in diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 781e48a..78f82c9 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -2434,7 +2434,7 @@ transform::TileOp::apply(transform::TransformRewriter &rewriter, SmallVector tiled; SmallVector, 4> loops; loops.resize(getLoops().size()); - bool scalable = getLastTileSizeScalable(); + auto scalableSizes = getScalableSizes(); for (auto [i, op] : llvm::enumerate(targets)) { auto tilingInterface = dyn_cast(op); auto dpsInterface = dyn_cast(op); @@ -2453,12 +2453,10 @@ transform::TileOp::apply(transform::TransformRewriter &rewriter, SmallVector sizes; sizes.reserve(tileSizes.size()); unsigned dynamicIdx = 0; - unsigned trailingIdx = getMixedSizes().size() - 1; for (auto [ofrIdx, ofr] : llvm::enumerate(getMixedSizes())) { if (auto attr = llvm::dyn_cast_if_present(ofr)) { - // Only the trailing tile size is allowed to be scalable atm. - if (scalable && (ofrIdx == trailingIdx)) { + if (scalableSizes[ofrIdx]) { auto val = b.create( getLoc(), attr.cast().getInt()); Value vscale = @@ -2560,9 +2558,10 @@ ParseResult transform::TileOp::parse(OpAsmParser &parser, DenseI64ArrayAttr staticSizes; FunctionType functionalType; llvm::SMLoc operandLoc; - bool scalable = false; + DenseBoolArrayAttr scalableSizes; + if (parser.parseOperand(target) || parser.getCurrentLocation(&operandLoc) || - parseDynamicIndexList(parser, dynamicSizes, staticSizes, &scalable) || + parseDynamicIndexList(parser, dynamicSizes, staticSizes, scalableSizes) || parseOptionalInterchange(parser, result) || parser.parseColonType(functionalType)) return ParseResult::failure(); @@ -2585,9 +2584,7 @@ ParseResult transform::TileOp::parse(OpAsmParser &parser, return failure(); } - auto scalableAttr = parser.getBuilder().getBoolAttr(scalable); - result.addAttribute(getLastTileSizeScalableAttrName(result.name), - scalableAttr); + result.addAttribute(getScalableSizesAttrName(result.name), scalableSizes); result.addAttribute(getStaticSizesAttrName(result.name), staticSizes); result.addTypes(functionalType.getResults()); @@ -2597,7 +2594,7 @@ ParseResult transform::TileOp::parse(OpAsmParser &parser, void TileOp::print(OpAsmPrinter &p) { p << ' ' << getTarget(); printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes(), - /*valueTypes=*/{}, getLastTileSizeScalableAttr(), + /*valueTypes=*/{}, getScalableSizesAttr(), OpAsmParser::Delimiter::Square); printOptionalInterchange(p, getInterchange()); p << " : "; @@ -3144,15 +3141,14 @@ DiagnosedSilenceableFailure transform::MaskedVectorizeOp::apply( } // TODO: Check that the correct number of vectorSizes was provided. - SmallVector scalableVecDims(vectorSizes.size(), false); - scalableVecDims.back() = getLastVectorSizeScalable(); for (Operation *target : targets) { if (!isa(target)) { return mlir::emitSilenceableFailure(target->getLoc()) << "Unsupported Op, cannot vectorize"; } - if (failed(linalg::vectorize(rewriter, target, vectorSizes, scalableVecDims, + if (failed(linalg::vectorize(rewriter, target, vectorSizes, + getScalableSizes(), getVectorizeNdExtract()))) { return mlir::emitSilenceableFailure(target->getLoc()) << "Attempted to vectorize, but failed"; diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 4f805d6..97c086b 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -1254,20 +1254,20 @@ void ForallOp::print(OpAsmPrinter &p) { if (isNormalized()) { p << ") in "; printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(), - /*valueTypes=*/{}, /*=isTrailingIdxScalable=*/{}, + /*valueTypes=*/{}, /*scalables=*/{}, OpAsmParser::Delimiter::Paren); } else { p << ") = "; printDynamicIndexList(p, op, getDynamicLowerBound(), getStaticLowerBound(), - /*valueTypes=*/{}, /*=isTrailingIdxScalable=*/{}, + /*valueTypes=*/{}, /*scalables=*/{}, OpAsmParser::Delimiter::Paren); p << " to "; printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(), - /*valueTypes=*/{}, /*=isTrailingIdxScalable=*/{}, + /*valueTypes=*/{}, /*scalables=*/{}, OpAsmParser::Delimiter::Paren); p << " step "; printDynamicIndexList(p, op, getDynamicStep(), getStaticStep(), - /*valueTypes=*/{}, /*=isTrailingIdxScalable=*/{}, + /*valueTypes=*/{}, /*scalables=*/{}, OpAsmParser::Delimiter::Paren); } printInitializationList(p, getRegionOutArgs(), getOutputs(), " shared_outs"); @@ -1299,9 +1299,9 @@ ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) { dynamicSteps; if (succeeded(parser.parseOptionalKeyword("in"))) { // Parse upper bounds. - if (parseDynamicIndexList( - parser, dynamicUbs, staticUbs, /*isTrailingIdxScalable=*/nullptr, - /*valueTypes=*/nullptr, OpAsmParser::Delimiter::Paren) || + if (parseDynamicIndexList(parser, dynamicUbs, staticUbs, + /*valueTypes=*/nullptr, + OpAsmParser::Delimiter::Paren) || parser.resolveOperands(dynamicUbs, indexType, result.operands)) return failure(); @@ -1311,26 +1311,26 @@ ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) { } else { // Parse lower bounds. if (parser.parseEqual() || - parseDynamicIndexList( - parser, dynamicLbs, staticLbs, /*isTrailingIdxScalable=*/nullptr, - /*valueTypes=*/nullptr, OpAsmParser::Delimiter::Paren) || + parseDynamicIndexList(parser, dynamicLbs, staticLbs, + /*valueTypes=*/nullptr, + OpAsmParser::Delimiter::Paren) || parser.resolveOperands(dynamicLbs, indexType, result.operands)) return failure(); // Parse upper bounds. if (parser.parseKeyword("to") || - parseDynamicIndexList( - parser, dynamicUbs, staticUbs, /*isTrailingIdxScalable=*/nullptr, - /*valueTypes=*/nullptr, OpAsmParser::Delimiter::Paren) || + parseDynamicIndexList(parser, dynamicUbs, staticUbs, + /*valueTypes=*/nullptr, + OpAsmParser::Delimiter::Paren) || parser.resolveOperands(dynamicUbs, indexType, result.operands)) return failure(); // Parse step values. if (parser.parseKeyword("step") || - parseDynamicIndexList( - parser, dynamicSteps, staticSteps, /*scalable=*/nullptr, - /*valueTypes=*/nullptr, OpAsmParser::Delimiter::Paren) || + parseDynamicIndexList(parser, dynamicSteps, staticSteps, + /*valueTypes=*/nullptr, + OpAsmParser::Delimiter::Paren) || parser.resolveOperands(dynamicSteps, indexType, result.operands)) return failure(); } diff --git a/mlir/lib/Dialect/Transform/Utils/Utils.cpp b/mlir/lib/Dialect/Transform/Utils/Utils.cpp index e751642..d516a56 100644 --- a/mlir/lib/Dialect/Transform/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Transform/Utils/Utils.cpp @@ -42,6 +42,5 @@ ParseResult mlir::transform::parsePackedOrDynamicIndexList( return success(); } - return parseDynamicIndexList(parser, values, integers, - /*isTrailingIdxScalable=*/nullptr, &valueTypes); + return parseDynamicIndexList(parser, values, integers, &valueTypes); } diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp index 0f75cc1..667f66b 100644 --- a/mlir/lib/Interfaces/ViewLikeInterface.cpp +++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp @@ -102,8 +102,7 @@ static char getRightDelimiter(AsmParser::Delimiter delimiter) { void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef integers, - TypeRange valueTypes, - BoolAttr isTrailingIdxScalable, + TypeRange valueTypes, ArrayRef scalables, AsmParser::Delimiter delimiter) { char leftDelimiter = getLeftDelimiter(delimiter); char rightDelimiter = getRightDelimiter(delimiter); @@ -113,33 +112,24 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op, return; } - int64_t trailingScalableInteger; - if (isTrailingIdxScalable && isTrailingIdxScalable.getValue()) { - // ATM only the trailing idx can be scalable - trailingScalableInteger = integers.back(); - integers = integers.drop_back(); - } - - unsigned idx = 0; + unsigned dynamicValIdx = 0; + unsigned scalableIndexIdx = 0; llvm::interleaveComma(integers, printer, [&](int64_t integer) { + if (not scalables.empty() && scalables[scalableIndexIdx]) + printer << "["; if (ShapedType::isDynamic(integer)) { - printer << values[idx]; + printer << values[dynamicValIdx]; if (!valueTypes.empty()) - printer << " : " << valueTypes[idx]; - ++idx; + printer << " : " << valueTypes[dynamicValIdx]; + ++dynamicValIdx; } else { printer << integer; } - }); + if (!scalables.empty() && scalables[scalableIndexIdx]) + printer << "]"; - // Print the trailing scalable index - if (isTrailingIdxScalable && isTrailingIdxScalable.getValue()) { - if (!integers.empty()) - printer << ", "; - printer << "["; - printer << trailingScalableInteger; - printer << "]"; - } + scalableIndexIdx++; + }); printer << rightDelimiter; } @@ -147,25 +137,17 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op, ParseResult mlir::parseDynamicIndexList( OpAsmParser &parser, SmallVectorImpl &values, - DenseI64ArrayAttr &integers, bool *isTrailingIdxScalable, + DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalables, SmallVectorImpl *valueTypes, AsmParser::Delimiter delimiter) { SmallVector integerVals; - bool foundScalable = false; + SmallVector scalableVals; auto parseIntegerOrValue = [&]() { OpAsmParser::UnresolvedOperand operand; auto res = parser.parseOptionalOperand(operand); - // If `foundScalable` has already been set to `true` then a non-trailing - // index was identified as scalable. - if (foundScalable) { - parser.emitError(parser.getNameLoc()) - << "non-trailing index cannot be scalable"; - return failure(); - } - - if (isTrailingIdxScalable && parser.parseOptionalLSquare().succeeded()) - foundScalable = true; + // When encountering `[`, assume that this is a scalable index. + scalableVals.push_back(parser.parseOptionalLSquare().succeeded()); if (res.has_value() && succeeded(res.value())) { values.push_back(operand); @@ -178,7 +160,10 @@ ParseResult mlir::parseDynamicIndexList( return failure(); integerVals.push_back(integer); } - if (foundScalable && parser.parseOptionalRSquare().failed()) + + // If this is assumed to be a scalable index, verify that there's a closing + // `]`. + if (scalableVals.back() && parser.parseOptionalRSquare().failed()) return failure(); return success(); }; @@ -187,8 +172,7 @@ ParseResult mlir::parseDynamicIndexList( return parser.emitError(parser.getNameLoc()) << "expected SSA value or integer"; integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals); - if (isTrailingIdxScalable) - *isTrailingIdxScalable = foundScalable; + scalables = parser.getBuilder().getDenseBoolArrayAttr(scalableVals); return success(); } diff --git a/mlir/test/Dialect/Linalg/transform-op-tile.mlir b/mlir/test/Dialect/Linalg/transform-op-tile.mlir index 3300e86..8b44977 100644 --- a/mlir/test/Dialect/Linalg/transform-op-tile.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-tile.mlir @@ -220,25 +220,3 @@ transform.sequence failures(propagate) { %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1, %loops:3 = transform.structured.tile %0 [4, 4, [4]] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) } - -// ----- - -// TODO: Add support for for specyfying more than one scalable tile size - -func.func @scalable_and_fixed_length_tile( - %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>) - -> tensor<128x128xf32> { - %0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>) - outs(%arg2: tensor<128x128xf32>) - -> tensor<128x128xf32> - - return %0 : tensor<128x128xf32> -} - -transform.sequence failures(propagate) { -^bb0(%arg1: !transform.any_op): - %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op - // expected-error @below {{non-trailing index cannot be scalable}} - // expected-error @below {{expected SSA value or integer}} - %1, %loops:3 = transform.structured.tile %0 [4, [4], [4]] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) -} diff --git a/mlir/test/Dialect/Transform/ops.mlir b/mlir/test/Dialect/Transform/ops.mlir index 7ddfcc6..dc35a9a 100644 --- a/mlir/test/Dialect/Transform/ops.mlir +++ b/mlir/test/Dialect/Transform/ops.mlir @@ -105,3 +105,11 @@ transform.sequence failures(propagate) { %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op transform.structured.tile %0 [4, 4, [4]] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) } + +// CHECK: transform.sequence +// CHECK: transform.structured.tile %0{{\[}}[2], 4, 8] +transform.sequence failures(propagate) { +^bb0(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.tile %0 [[2], 4, 8] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) +} -- 2.7.4