From 55cf0de35efd6cb81d6a21fee35186f6fb6864c2 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Thu, 26 Jan 2023 12:30:54 -0800 Subject: [PATCH] [mlir][Linalg] Adding a greedy packing transform dialect op. This PR adds a `pack_greedily` transform operation that infers the packing for gemm subcomputations embedded within in any LinalgOp and packs accordingly. A normalization step guarantees that we get the innermost op dimensions in one of `8` possible `(m, n, k)` orders, specified as a parameter, from which we can emit all packed forms. The current implementation takes an arbitrary LinalgOp and tries to pack it along the specified dimensions with specified sizes and inner dim permutation. This achieves a new level of normalization and generalization for any `n-D` LinalgOp that contains a gemm embedded within it: we will always see a predictable packed form for any of these ops. Differential Revision: https://reviews.llvm.org/D142661 --- .../Linalg/TransformOps/LinalgTransformOps.td | 80 +++++- .../Linalg/TransformOps/LinalgTransformOps.cpp | 290 ++++++++++++++++++++- .../Dialect/Linalg/transform-pack-greedily.mlir | 228 ++++++++++++++++ 3 files changed, 584 insertions(+), 14 deletions(-) create mode 100644 mlir/test/Dialect/Linalg/transform-pack-greedily.mlir diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 30506fa..bd2fe3b 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -370,7 +370,7 @@ def MultiTileSizesOp : Op, - DeclareOpInterfaceMethods,]> { + DeclareOpInterfaceMethods]> { let description = [{ Pack a LinalgOp by applying a data tiling transformation on the op and packing the operands according to the `packed_sizes` specification. @@ -454,6 +454,84 @@ def PackOp : Op, + DeclareOpInterfaceMethods]> { + let description = [{ + Target a Linalg op and rewrite it into packed LinalgOp form by trying to + infer whether a known suboperation is embedded + + Different packing strategies are applied in order, when one applies + successfully, the transform returns: + 1. Gemm packing: Try to infer a gemm operation embedded in the target op. + Specifically, this looks for 2 parallel dimensions that participate in + an outer-product and 1 reduction dimension. + These dimensions are referred as (m, n, k) to match canonical gemm + terminology. + The packed sizes for (m, n, k) are specified by `gemm_packed_sizes`. + The ordering of the packed dimensions (mm, nn, kk) is specified by the + `gemm_inner_dims_order` attribute. + + Packing occurs as follows: + 1. Find the dimensions to pack according to the strategy. + 2. The target is converted to linalg.generic form. + 3. An interchange transform is applied to isolate the dimensions to pack as + the most minor indexing dimensions of the linalg.generic. The most minor + dimensions are themselves ordered according to `inner_dims_order`. + 4. Packing is performed by `packed_sizes` and following `inner_dims_order`. + + By normalizing the most minor dimensions to `inner_dims_order`, the transform + guarantees that packing immediates generates inner dimensions in a desirable + layout. + + Outer dimension layout permutations are not controlled by this transform op + at the moment and can be obtained by composing with the pack_transpose + transformation. + + #### Return modes + + This operation ignores non-Linalg ops and drops them in the return. + It returns the list of packed Linalg ops or the original op when all available + packing strategies failed to apply. + }]; + + // TODO: Transform_ConcreteOpType needs interface. + let arguments = (ins TransformHandleTypeInterface:$target, + Variadic:$gemm_packed_sizes, + DefaultValuedAttr + :$static_gemm_packed_sizes, + DefaultValuedAttr + :$gemm_inner_dims_order); + let results = (outs Transform_ConcreteOpType<"linalg.generic">:$packed_op); + + let builders = [ + OpBuilder<(ins "Value":$target, + "ArrayRef":$mixedGemmPackedSizes, + CArg<"ArrayRef", "{}">:$gemmDimsInnerDimsOrder)> + ]; + + let assemblyFormat = [{ + $target + oilist( + `gemm_packed_sizes` `=` custom($gemm_packed_sizes, + $static_gemm_packed_sizes) + `gemm_inner_dims_order` `=` $gemm_inner_dims_order + ) + attr-dict + `:` functional-type($target, results) + }]; + let hasVerifier = 1; + + let extraClassDeclaration = [{ + /// Returns the list of tile sizes, which may be static (Attribute) or + /// dynamic (Value). + SmallVector getMixedGemmPackedSizes(); + }]; +} + +//===----------------------------------------------------------------------===// // PackTransposeOp //===----------------------------------------------------------------------===// def PackTransposeOp : Op interchangeVector = getIteratorInterchange(); @@ -730,7 +732,7 @@ transform::MatchOp::apply(transform::TransformResults &results, if (getInterface().has_value()) { auto iface = getInterface().value(); if (iface == transform::MatchInterfaceEnum::LinalgOp && - !isa(op)) + !isa(op)) return; if (iface == transform::MatchInterfaceEnum::TilingInterface && isa(op)) @@ -885,7 +887,7 @@ void transform::PackOp::build(OpBuilder &builder, OperationState &result, // attributes for multiple variadic operands. In the absence of this, horrible // bugs ensue. Type linalgOpHType = transform::OperationType::get( - builder.getContext(), linalg::GenericOp::getOperationName()); + builder.getContext(), GenericOp::getOperationName()); build(builder, result, /*resultType=*/linalgOpHType, /*target=*/target, @@ -908,7 +910,7 @@ transform::PackOp::apply(transform::TransformResults &transformResults, return DiagnosedSilenceableFailure::success(); } // Fail on multi-op handles. - auto linalgOp = dyn_cast(targetOps.front()); + auto linalgOp = dyn_cast(targetOps.front()); if (targetOps.size() != 1 || !linalgOp) { return emitSilenceableError() << "requires target to map to exactly 1 LinalgOp (got " @@ -947,6 +949,268 @@ void transform::PackOp::getEffects( } //===---------------------------------------------------------------------===// +// PackGreedilyOp. +//===---------------------------------------------------------------------===// + +LogicalResult transform::PackGreedilyOp::verify() { + if (!isPermutationVector(getGemmInnerDimsOrder())) { + return emitOpError() << getGemmInnerDimsOrderAttrName() + << " is not a valid permutation"; + } + // TODO: relax to allow empty once we have another strategy than just gemm. + if (getGemmInnerDimsOrder().size() != 3 || + getMixedGemmPackedSizes().size() != 3) { + return emitOpError() << " needs 3 entries for gemm_packed_sizes and " + << getGemmInnerDimsOrderAttrName() + << " order for the gemm strategy"; + } + return success(); +} + +namespace { +auto par = utils::IteratorType::parallel; +auto red = utils::IteratorType::reduction; +} // namespace + +/// Return the set of AffineDimExpr +static DenseSet +findPermutationsIndexingOperand(LinalgOp linalgOp, OpOperand *opOperand, + utils::IteratorType iter) { + DenseSet res; + assert(linalgOp == opOperand->getOwner() && "expected linalgOp owner"); + AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand); + for (AffineExpr e : indexingMap.getResults()) { + if (auto d = e.dyn_cast()) { + if (linalgOp.getIteratorTypesArray()[d.getPosition()] == iter && + llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) { + return e.isFunctionOfDim(d.getPosition()); + }) == 1) + res.insert(d.getPosition()); + } + } + return res; +} + +struct GemmDimsForPacking { + int64_t mPos, nPos, kPos; +}; +/// Greedily look for 2 parallel (m and n) and 1 reduction (k) dimension that +/// form a gemm. Such dimensions are such that: +/// 1. The m dimension is involved in an outer-product along LHS +/// (i.e. it is a permutation on RES and LHS and does not appear in RHS). +/// 2. The n dimension is involved in an outer-product along RHS +/// (i.e. it is a permutation on RES and RHS and does not appear in LHS). +/// 3. The k dimension appears as a permutation on LHS and RHS. +/// 4. m, n and k appear only once in any given indexing. +/// +/// This allows detecting that some gemm is embedded within `linalgOp`. +/// +/// When multiple possibilities for selecting m, n and k appear, we just pick +/// an arbitrary one (i.e. the first in a DenseSet). +// TODO: Better heuristic (e.g pick dims based on packing-based metric). +static FailureOr getGemmDims(LinalgOp linalgOp) { + assert(linalgOp.getNumDpsInits() == 1 && "wrong number of dps inits"); + assert(linalgOp.getNumDpsInputs() == 2 && "wrong number of dps inputs"); + + DenseSet a = findPermutationsIndexingOperand( + linalgOp, linalgOp.getDpsInputOperand(0), par); + DenseSet b = findPermutationsIndexingOperand( + linalgOp, linalgOp.getDpsInputOperand(1), par); + DenseSet c = findPermutationsIndexingOperand( + linalgOp, linalgOp.getDpsInitOperand(0), par); + + // A & C - B are the iterators involved in an outer-product along A (the LHS). + DenseSet ac = a; + llvm::set_intersect(ac, c); + llvm::set_subtract(ac, b); + // B & C - A are the iterators involved in an outer-product along B (the RHS). + DenseSet bc = b; + llvm::set_intersect(bc, c); + llvm::set_subtract(bc, a); + + // Note: if we ever need them, A & B & C would be "batch" dimensions. + + // A & B red are the reduction dimensions. + DenseSet ra = findPermutationsIndexingOperand( + linalgOp, linalgOp.getDpsInputOperand(0), red); + DenseSet rb = findPermutationsIndexingOperand( + linalgOp, linalgOp.getDpsInputOperand(1), red); + llvm::set_intersect(ra, rb); + + if (ac.empty() || bc.empty() || ra.empty()) + return failure(); + + // Pick the first one in each set. + // TODO: Better heuristic (e.g pick dims based on packing-based metric). + return GemmDimsForPacking{*ac.begin(), *bc.begin(), *ra.begin()}; +} + +/// Return a permutation vector of size permSize that would result in moving +/// positions into desiredPositions. +/// +/// For example, permSize == 5, positions = {2, 4}, desiredPositions = {1, 0} +/// would result in a {4, 2, 0, 1, 3} permutation vector. +static SmallVector +computePermutationVector(int64_t permSize, ArrayRef positions, + ArrayRef desiredPositions) { + SmallVector res(permSize, -1); + DenseSet seen; + for (auto [pos, desiredPos] : llvm::zip(positions, desiredPositions)) { + res[desiredPos] = pos; + seen.insert(pos); + } + int64_t nextPos = 0; + for (int64_t &entry : res) { + if (entry != -1) + continue; + while (seen.contains(nextPos)) + ++nextPos; + entry = nextPos; + ++nextPos; + } + return res; +} + +/// Pack a LinalgOp by greedily inferring gemm dimensions (m, n, k) +/// where m and n are proper parallel dimensions and k is a proper reduction +/// dimension. +/// Packing occurs by rewriting the op as a linalg.generic and calling +/// linalg::pack by `mnkPackedSizes`. +/// The order of the packed dimensions is customizable: the `mnkOrder` is a +/// permutation of {0, 1, 2} to reorder {m, n, k} into one of the 8 possible +/// forms. +/// The outer dimensions of the operands are not permuted at this time, this is +/// left for future work. +static FailureOr +packGemmGreedily(RewriterBase &rewriter, LinalgOp linalgOp, + ArrayRef mnkPackedSizes, + ArrayRef mnkOrder) { + assert(mnkPackedSizes.size() == 3 && "unexpected num of packing sizes"); + assert(mnkOrder.size() == 3 && "unexpected mnkOrder size"); + assert(isPermutationVector(mnkOrder) && "expected a permutation"); + + int64_t numLoops = linalgOp.getNumLoops(); + if (numLoops <= 2) { + return rewriter.notifyMatchFailure(linalgOp, + "need 3+ loops to find a gemm to pack"); + } + + // Locally adjust the desired iterator position of mnk and packing sizes. + int64_t numPackedDims = mnkPackedSizes.size(); + SmallVector mmnnkkPos(numPackedDims); + for (int64_t i = 0, e = numPackedDims; i < e; ++i) + mmnnkkPos[i] = numLoops - numPackedDims + mnkOrder[i]; + SmallVector packedSizes(mnkPackedSizes.size()); + for (int64_t i = 0, e = numPackedDims; i < e; ++i) + packedSizes[mnkOrder[i]] = mnkPackedSizes[i]; + + // 1. Infer dims that are important for gemm. + FailureOr res = getGemmDims(linalgOp); + if (failed(res)) { + return rewriter.notifyMatchFailure(linalgOp, + "couldn't infer gemm iterators"); + } + + // 2. Normalize linalgOp to an kmn-matmul-like with [red, par, par] most + // minor iterators. If we wanted a different normalization order, this is + // where it would have to start. + int64_t mPos = res->mPos, nPos = res->nPos, kPos = res->kPos; + LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL(); + DBGS() << "Start packing generic op greedily with (m@" << mPos + << ", n@" << nPos << ", k@" << kPos << "): " << linalgOp + << "\n";); + + // 2.a. Rewrite as a generic. + auto genericOp = dyn_cast(linalgOp.getOperation()); + if (!genericOp) { + FailureOr generalizeResult = + generalizeNamedOp(rewriter, linalgOp); + assert(succeeded(generalizeResult) && "unexpected failure generalizing op"); + genericOp = *generalizeResult; + } + + // 2.b. Interchange to move the dimensions (k, m, n) as most-minor iterators. + // Note that this only normalized the iteration order and does not change the + // indexings of any operand. + SmallVector permutation = + computePermutationVector(numLoops, {mPos, nPos, kPos}, mmnnkkPos); + LLVM_DEBUG(llvm::interleaveComma(permutation, DBGS() << "perm: "); DBGSNL();); + // Sign .. unsigned pollution. + SmallVector unsignedPerm(permutation.begin(), permutation.end()); + FailureOr interchangeResult = + interchangeGenericOp(rewriter, genericOp, unsignedPerm); + assert(succeeded(interchangeResult) && "unexpected failure interchanging op"); + genericOp = *interchangeResult; + LLVM_DEBUG(DBGS() << "Generalized Op to pack: " << genericOp << "\n";); + + // At this point, the op iterators are normalized to {leading, k, m, n}. + // The layouts induced by packing will always be: + // - LHS{leading_lhs, kk, mm} + // - RHS{leading_rhs, kk, nn} + // - RES{leading_res, mm, nn} + // If we wanted to change the packed order, we would reorder (k, m, n) to + // something else above. + // + // Additional permutations of the outer dims of the operands (i.e. + // leading_lhs, leading_rhs and leading_res) could follow by computing the + // desired outerPerm for each operand. + // This is left for future work. + + // Add leading zeros to match numLoops. + SmallVector adjustedPackedSizes(numLoops - packedSizes.size(), + rewriter.getIndexAttr(0)); + llvm::append_range(adjustedPackedSizes, packedSizes); + + // TODO: If we wanted to give the genericOp a name after packing, after + // calling `pack` would be a good time. + return linalg::pack(rewriter, genericOp, adjustedPackedSizes); +} + +DiagnosedSilenceableFailure +PackGreedilyOp::apply(transform::TransformResults &transformResults, + transform::TransformState &state) { + ArrayRef targetOps = state.getPayloadOps(getTarget()); + + SmallVector results; + IRRewriter rewriter(getContext()); + for (Operation *op : targetOps) { + auto linalgOp = dyn_cast(op); + if (!linalgOp) + continue; + // linalgOp will be replaced and the insertion point may be invalidated if + // we set it before -> set it after. + rewriter.setInsertionPointAfter(linalgOp); + // Failing to pack greedily is perfectly fine. + // In the future we will want to order packings according to some metric. + FailureOr gemm = packGemmGreedily( + /*rewriter=*/rewriter, + /*linalgOp=*/linalgOp, + /*mnkPackedSizes=*/getMixedGemmPackedSizes(), + /*mnkOrder=*/getGemmInnerDimsOrder()); + if (succeeded(gemm)) { + results.push_back(*gemm); + continue; + } + results.push_back(linalgOp); + } + transformResults.set(getPackedOp().cast(), results); + return DiagnosedSilenceableFailure::success(); +} + +SmallVector PackGreedilyOp::getMixedGemmPackedSizes() { + Builder b(getContext()); + return getMixedValues(getStaticGemmPackedSizes(), getGemmPackedSizes(), b); +} + +void transform::PackGreedilyOp::getEffects( + SmallVectorImpl &effects) { + transform::consumesHandle(getTarget(), effects); + transform::onlyReadsHandle(getGemmPackedSizes(), effects); + transform::producesHandle(getPackedOp(), effects); + transform::modifiesPayload(effects); +} + +//===---------------------------------------------------------------------===// // PackTransposeOp //===---------------------------------------------------------------------===// @@ -1030,7 +1294,7 @@ transform::PackTransposeOp::apply(transform::TransformResults &transformResults, return emitSilenceableError() << "requires target to map to a " "tensor.pack or tensor.unpack"; } - LinalgOp linalgOpTarget = dyn_cast(linalgOps.front()); + LinalgOp linalgOpTarget = dyn_cast(linalgOps.front()); if (!linalgOpTarget) return emitSilenceableError() << "requires a LinalgOp target"; @@ -1102,7 +1366,7 @@ transform::PackTransposeOp::apply(transform::TransformResults &transformResults, //===---------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::PadOp::applyToOne(linalg::LinalgOp target, +transform::PadOp::applyToOne(LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { // Convert the integer packing flags to booleans. @@ -1214,7 +1478,7 @@ LogicalResult transform::PadOp::verify() { //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::PromoteOp::applyToOne(linalg::LinalgOp target, +transform::PromoteOp::applyToOne(LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { LinalgPromotionOptions promotionOptions; @@ -1308,7 +1572,7 @@ LogicalResult transform::ReplaceOp::verify() { //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::ScalarizeOp::applyToOne(linalg::LinalgOp target, +transform::ScalarizeOp::applyToOne(LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { scf::SCFTilingOptions tilingOptions; @@ -1560,7 +1824,7 @@ void transform::SplitReductionOp::build( } DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne( - linalg::LinalgOp target, transform::ApplyToEachResultList &results, + LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { ControlSplitReductionFn splitFn = [&](LinalgOp) { return linalg::SplitReductionOptions{int64_t(getSplitFactor()), @@ -1605,7 +1869,7 @@ void transform::TileReductionUsingScfOp::build( } DiagnosedSilenceableFailure transform::TileReductionUsingScfOp::applyToOne( - linalg::LinalgOp target, transform::ApplyToEachResultList &results, + LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { TrivialPatternRewriter rewriter(getContext()); rewriter.setInsertionPoint(target); @@ -1649,7 +1913,7 @@ void transform::TileReductionUsingForeachThreadOp::build( DiagnosedSilenceableFailure transform::TileReductionUsingForeachThreadOp::applyToOne( - linalg::LinalgOp target, transform::ApplyToEachResultList &results, + LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { TrivialPatternRewriter rewriter(getContext()); rewriter.setInsertionPoint(target); diff --git a/mlir/test/Dialect/Linalg/transform-pack-greedily.mlir b/mlir/test/Dialect/Linalg/transform-pack-greedily.mlir new file mode 100644 index 0000000..42f3a6c --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-pack-greedily.mlir @@ -0,0 +1,228 @@ +// RUN: mlir-opt %s -test-transform-dialect-interpreter --split-input-file | FileCheck %s + +!A_mk = tensor<1023x255xf32> +!B_kn = tensor<255x127xf32> +!C_mn = tensor<1023x127xf32> + +// Normalized dims are: ( k, m, n)(kk, mm, nn) +// CHECK-DAG: #[[$mk_kkmm:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d3, d4)> +// CHECK-DAG: #[[$kn_kknn:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)> +// CHECK-DAG: #[[$mn_mmnn:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)> + +// CHECK-LABEL: @matmul_mk_kn_mn( +func.func @matmul_mk_kn_mn(%A : !A_mk, %B : !B_kn, %C : !C_mn) -> !C_mn { + // CHECK: linalg.generic + // CHECK-SAME: indexing_maps = [#[[$mk_kkmm]], #[[$kn_kknn]], #[[$mn_mmnn]]] + // CHECK-SAME: ["reduction", "parallel", "parallel", "reduction", "parallel", "parallel"]} + // CHECK-SAME: ins(%{{.*}} : tensor<128x8x32x8xf32>, tensor<8x8x32x16xf32>) + // CHECK-SAME: outs(%{{.*}} : tensor<128x8x8x16xf32>) + %0 = linalg.matmul ins(%A, %B : !A_mk, !B_kn) outs(%C : !C_mn) -> !C_mn + return %0 : !C_mn +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + %matmul = transform.structured.match ops{["linalg.matmul"]} in %module_op + : (!pdl.operation) -> !transform.op<"linalg.matmul"> + transform.structured.pack_greedily %matmul + gemm_packed_sizes = [8, 16, 32] gemm_inner_dims_order = [1, 2, 0] + : (!transform.op<"linalg.matmul">) -> !transform.op<"linalg.generic"> +} + +// ----- + +!A_mk = tensor<1023x255xf32> +!B_nk = tensor<127x255xf32> +!C_nm = tensor<127x1023xf32> + +#mkn_accesses = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (n, k)>, + affine_map<(m, n, k) -> (n, m)> +] +#mkn_trait = { + indexing_maps = #mkn_accesses, + iterator_types = ["parallel", "parallel", "reduction"] +} + +// Normalized dims are: ( k, m, n)(kk, mm, nn) +// CHECK-DAG: #[[$km_kkmm:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d3, d4)> +// CHECK-DAG: #[[$kn_kknn:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d3, d5)> +// CHECK-DAG: #[[$mn_mmnn:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d4, d5)> + +// CHECK-LABEL: @matmul_mk_nk_nm( +func.func @matmul_mk_nk_nm(%A : !A_mk, %B : !B_nk, %C : !C_nm) -> !C_nm { + // CHECK: linalg.generic + // CHECK-SAME: indexing_maps = [#[[$mk_kkmm]], #[[$kn_kknn]], #[[$mn_mmnn]]] + // CHECK-SAME: ["reduction", "parallel", "parallel", "reduction", "parallel", "parallel"]} + // CHECK-SAME: ins(%{{.*}} : tensor<128x8x32x8xf32>, tensor<8x8x32x16xf32>) + // CHECK-SAME: outs(%{{.*}} : tensor<8x128x8x16xf32>) + %0 = linalg.generic #mkn_trait ins(%A, %B : !A_mk, !B_nk) outs(%C : !C_nm) { + ^bb0(%a: f32, %b: f32, %c: f32): + %d = arith.mulf %a, %b : f32 + %e = arith.addf %c, %d : f32 + linalg.yield %e : f32 + } -> !C_nm + return %0 : !C_nm +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + %generic = transform.structured.match ops{["linalg.generic"]} in %module_op : (!pdl.operation) -> !transform.op<"linalg.generic"> + transform.structured.pack_greedily %generic + gemm_packed_sizes = [8, 16, 32] gemm_inner_dims_order = [1, 2, 0] + : (!transform.op<"linalg.generic">) -> !transform.op<"linalg.generic"> +} + +// ----- + +!A_mk = tensor<1023x255xf32> +!B_nk = tensor<127x255xf32> +!C_nm = tensor<127x1023xf32> + +#mkn_accesses = [ + affine_map<(k, m, n) -> (m, k)>, + affine_map<(k, m, n) -> (n, k)>, + affine_map<(k, m, n) -> (n, m)> +] +#mkn_trait = { + indexing_maps = #mkn_accesses, + iterator_types = ["reduction", "parallel", "parallel"] +} + +// Normalized dims are: ( k, m, n)(kk, mm, nn) +// CHECK-DAG: #[[$mk_kkmm:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d3, d4)> +// CHECK-DAG: #[[$kn_kknn:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d3, d5)> +// CHECK-DAG: #[[$mn_mmnn:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d4, d5)> + +// CHECK-LABEL: @matmul_mk_nk_nm_transposed( +func.func @matmul_mk_nk_nm_transposed(%A : !A_mk, %B : !B_nk, %C : !C_nm) -> !C_nm { + // CHECK: linalg.generic + // CHECK-SAME: indexing_maps = [#[[$mk_kkmm]], #[[$kn_kknn]], #[[$mn_mmnn]]] + // CHECK-SAME: ["reduction", "parallel", "parallel", "reduction", "parallel", "parallel"]} + // CHECK-SAME: ins(%{{.*}} : tensor<128x8x32x8xf32>, tensor<8x8x32x16xf32>) + // CHECK-SAME: outs(%{{.*}} : tensor<8x128x8x16xf32>) + %0 = linalg.generic #mkn_trait ins(%A, %B : !A_mk, !B_nk) outs(%C : !C_nm) { + ^bb0(%a: f32, %b: f32, %c: f32): + %d = arith.mulf %a, %b : f32 + %e = arith.addf %c, %d : f32 + linalg.yield %e : f32 + } -> !C_nm + return %0 : !C_nm +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + %generic = transform.structured.match ops{["linalg.generic"]} in %module_op : (!pdl.operation) -> !transform.op<"linalg.generic"> + transform.structured.pack_greedily %generic + gemm_packed_sizes = [8, 16, 32] gemm_inner_dims_order = [1, 2, 0] + : (!transform.op<"linalg.generic">) -> !transform.op<"linalg.generic"> +} + +// ----- + +!A_bmkm2 = tensor<42x1023x255x33xf32> +!B_nkb = tensor<127x255x42xf32> +!C_nbm = tensor<127x42x1023xf32> + +#mkn_accesses = [ + affine_map<(k, m, n, b, m2) -> (b, m, k, m2)>, + affine_map<(k, m, n, b, m2) -> (n, k, b)>, + affine_map<(k, m, n, b, m2) -> (n, b, m)> +] +#mkn_trait = { + indexing_maps = #mkn_accesses, + iterator_types = ["reduction", "parallel", "parallel", "parallel", "parallel"] +} + +// Normalized dims are: ( ?, ?, k, m, n)(kk, mm, nn) +// CHECK-DAG: #[[$bmkm2_kkmm:.*]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d3, d2, d1, d5, d6)> +// CHECK-DAG: #[[$nkb_kknn:.*]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d2, d0, d5, d7)> +// CHECK-DAG: #[[$nbm_mmnn:.*]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d0, d3, d6, d7)> + +// CHECK-LABEL: @contraction_bmkm2_nkb_nbm( +func.func @contraction_bmkm2_nkb_nbm(%A : !A_bmkm2, %B : !B_nkb, %C : !C_nbm) -> !C_nbm { + // CHECK: linalg.generic + // CHECK-SAME: indexing_maps = [#[[$bmkm2_kkmm]], #[[$nkb_kknn]], #[[$nbm_mmnn]]] + // CHECK-SAME: ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel"]} + // CHECK-SAME: ins(%{{.*}} : tensor<42x128x8x33x32x8xf32>, tensor<8x8x42x32x16xf32>) + // CHECK-SAME: outs(%{{.*}} : tensor<8x42x128x8x16xf32>) + %0 = linalg.generic #mkn_trait ins(%A, %B : !A_bmkm2, !B_nkb) outs(%C : !C_nbm) { + ^bb0(%a: f32, %b: f32, %c: f32): + %d = arith.mulf %a, %b : f32 + %e = arith.addf %c, %d : f32 + linalg.yield %e : f32 + } -> !C_nbm + return %0 : !C_nbm +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + %generic = transform.structured.match ops{["linalg.generic"]} in %module_op : (!pdl.operation) -> !transform.op<"linalg.generic"> + transform.structured.pack_greedily %generic + gemm_packed_sizes = [8, 16, 32] gemm_inner_dims_order = [1, 2, 0] + : (!transform.op<"linalg.generic">) -> !transform.op<"linalg.generic"> +} + +// ----- + +// Conv linguo: h w kh kw c n f cc nn ff +// Normalized dims are: ( ?, ?, ?, ?, k, m, n)(kk, mm, nn) +// n c h + kh w + kw cc nn +// CHECK-DAG: #[[$M1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9) -> (d5, d4, d0 + d2, d1 + d3, d7, d8)> +// f c kh kw cc ff +// CHECK-DAG: #[[$M2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9) -> (d6, d4, d2, d3, d7, d9)> +// n f h w nn ff +// CHECK-DAG: #[[$M3:.*]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9) -> (d5, d6, d0, d1, d8, d9)> + +// CHECK-LABEL: @conv_2d_nchw_fchw +func.func @conv_2d_nchw_fchw(%arg0: tensor, %arg2: tensor) -> tensor { + %c0 = arith.constant dense<0.1> : tensor<16x47x3x3xf32> + // CHECK: linalg.generic + // CHECK-SAME: indexing_maps = [#[[$M1]], #[[$M2]], #[[$M3]]] + // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "reduction", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel"] + // CHECK-SAME: ins(%{{.*}} : tensor, tensor<1x2x3x3x32x16xf32>) + // CHECK-SAME: outs(%{{.*}} : tensor) + %0 = linalg.conv_2d_nchw_fchw + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } + ins(%arg0, %c0: tensor, tensor<16x47x3x3xf32>) + outs(%arg2: tensor) -> tensor + return %0 : tensor +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + %conv = transform.structured.match ops{["linalg.conv_2d_nchw_fchw"]} in %module_op + : (!pdl.operation) -> !transform.op<"linalg.conv_2d_nchw_fchw"> + transform.structured.pack_greedily %conv + gemm_packed_sizes = [8, 16, 32] gemm_inner_dims_order = [1, 2, 0] + : (!transform.op<"linalg.conv_2d_nchw_fchw">) -> !transform.op<"linalg.generic"> +} + + +// ----- + +// These should fail to pack for now as they don't contain a contraction. +// CHECK-LABEL: @reduce_and_map +func.func @reduce_and_map(%arg0: tensor<10x100xf32>, + %arg1: tensor<10x100xf32>, %output: tensor<10xf32>) -> tensor<10xf32> { + %map_init = tensor.empty() : tensor<10x100xf32> + // CHECK: linalg.map + %mapped = linalg.map { arith.addf } + ins(%arg0, %arg1 : tensor<10x100xf32>, tensor<10x100xf32>) + outs(%map_init : tensor<10x100xf32>) + // CHECK: linalg.reduce + %res = linalg.reduce { arith.addf } + ins(%mapped: tensor<10x100xf32>) + outs(%output: tensor<10xf32>) + dimensions = [1] + return %res : tensor<10xf32> +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + %generic = transform.structured.match ops{["linalg.generic"]} in %module_op : (!pdl.operation) -> !transform.op<"linalg.generic"> + transform.structured.pack_greedily %generic + gemm_packed_sizes = [8, 16, 32] gemm_inner_dims_order = [1, 2, 0] + : (!transform.op<"linalg.generic">) -> !transform.op<"linalg.generic"> +} -- 2.7.4