From d7904a702fe80e482b7fbb132c46863afd6eb3be Mon Sep 17 00:00:00 2001 From: Lorenzo Chelini Date: Sun, 8 Jan 2023 14:04:18 +0100 Subject: [PATCH] [MLIR] Fold outer dims permutation to pack when propagating Instead of folding the transpose into the linalg.generic keep the transposition in the packing operation, effectively making the linalg.generic transparent to the propagation. Additionally, if the init operand of the generic has users pack the init and pass it as the operand to the generic. Reviewed By: hanchung Differential Revision: https://reviews.llvm.org/D141483 --- .../Linalg/Transforms/DataLayoutPropagation.cpp | 72 +++++++++++++------ .../Dialect/Linalg/data-layout-propagation.mlir | 80 ++++++++++++++++++---- 2 files changed, 119 insertions(+), 33 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp index 5e54097..5660704 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -87,11 +87,39 @@ static PackInfo getPackingInfoFromConsumer( return packInfo; } +static SmallVector computeOuterDims(ArrayRef perm, + ArrayRef exprs) { + // Compute `outer_dims_perm`. See example: + // current exprs : (d0, d1, d2, d3) -> (d2, d3) + // perm : [0, 3, 1, 2] + // First map d2, d3 with their position in the array as: + // currentPositionTileLoops: dim | pos + // d2 | 0 + // d3 | 1 + // then scan `perm` in order and get the `outer_dims_perm` + // to be used, here it would be [1, 0]. + assert(!perm.empty() && "expect perm not to be empty"); + assert(!exprs.empty() && "expect exprs not to be empty"); + if (exprs.size() == 1) + return {}; + SmallVector outerDimsPerm; + DenseMap currentPositionTileLoops; + for (auto [pos, expr] : llvm::enumerate(exprs)) { + unsigned posInDomain = expr.cast().getPosition(); + currentPositionTileLoops[posInDomain] = pos; + } + for (int64_t loopIdx : perm) { + if (currentPositionTileLoops.count(loopIdx)) + outerDimsPerm.push_back(currentPositionTileLoops.lookup(loopIdx)); + } + return outerDimsPerm; +} + /// Returns a tuple for packed operand and indexing_map with the assumptions: /// 1) The generic op is the producer of the pack op. /// 2) The generic op has only one result. /// If the operand is a scalar or packing dimensions are all irrelevant to the -/// operand, the opreand and the updated indexing map will be returned. +/// operand, the operand and the updated indexing map will be returned. /// Otherwise, it returns the packed operand and the updated indexing map. E.g., /// /// #map0 = affine_map<(d0, d1) -> (d0, d1)> @@ -148,16 +176,26 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo, exprs.push_back(b.getAffineDimExpr(packInfo.tileToPointMapping[dimPos])); } - // Step 2. Fold transpose variants (i.e., outerDimsPerm) into generic op. - // TODO: should we propagate the permutation of outer dims to the pack op? + // Step 2. Handle outer dim permutations. SmallVector outerDimsPerm; if (!packInfo.outerDimsOnDomainPerm.empty()) { + outerDimsPerm = computeOuterDims(packInfo.outerDimsOnDomainPerm, exprs); + + // Step 2.1: Fold transpose into the linalg.generic. SmallVector inversedOuterPerm = invertPermutationVector(packInfo.outerDimsOnDomainPerm); for (auto i : llvm::seq(0, origIndexingMap.getNumResults())) { int64_t dimPos = exprs[i].cast().getPosition(); exprs[i] = b.getAffineDimExpr(inversedOuterPerm[dimPos]); } + // Step 2.2: Undo the transposition on `exprs` and propagate the + // transposition on the pack using outerDimsPerm. + if (!outerDimsPerm.empty()) { + SmallVector auxVec = exprs; + for (const auto &en : enumerate(outerDimsPerm)) + auxVec[en.index()] = exprs[en.value()]; + exprs = auxVec; + } } auto indexingMap = AffineMap::get(numLoops, 0, exprs, b.getContext()); @@ -254,9 +292,7 @@ bubbleUpPackOpThroughElemGenericOp(RewriterBase &rewriter, indexingMaps.push_back(packedIndexingMap); } - int64_t numLoops = genericOp.getNumLoops(); int64_t numInnerLoops = packInfo.getNumTiledLoops(); - int64_t newNumLoops = numLoops + numInnerLoops; SmallVector iterTypes = genericOp.getIteratorTypesArray(); iterTypes.append(numInnerLoops, utils::IteratorType::parallel); @@ -265,24 +301,18 @@ bubbleUpPackOpThroughElemGenericOp(RewriterBase &rewriter, auto [packedOutOperand, packedOutIndexingMap] = getOrCreatePackedViewOfOperand(rewriter, loc, packInfo, genericOp, opOperand); - SmallVector outExprs( - packedOutIndexingMap.getResults().drop_back(numInnerLoops)); - // Apply transpose to the indexing map, because we'll replace the init operand - // with the destination of pack op. - auto outerDimsPerm = packOp.getOuterDimsPerm(); - if (!outerDimsPerm.empty()) { - applyPermutationToVector(outExprs, outerDimsPerm); - } - for (int i = 0; i < numInnerLoops; ++i) - outExprs.push_back(rewriter.getAffineDimExpr(numLoops + i)); - AffineMap outMap = - AffineMap::get(newNumLoops, 0, outExprs, rewriter.getContext()); - indexingMaps.push_back(outMap); + indexingMaps.push_back(packedOutIndexingMap); + // We'll replace the init operand with the destination of pack op if the init + // operand has not users in the body of the linalg.generic (pure elementwise). + // If it has users we need to pack the init operand too and replace the init + // with the packing result. + Value dest = (genericOp.getRegionOutputArgs()[0].use_empty()) + ? packOp.getDest() + : packedOutOperand; auto newGenericOp = rewriter.create( - loc, packOp.getDestType(), inputOperands, packOp.getDest(), indexingMaps, - iterTypes, /*bodyBuild=*/nullptr, - linalg::getPrunedAttributeList(genericOp)); + loc, dest.getType(), inputOperands, dest, indexingMaps, iterTypes, + /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp)); rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(), newGenericOp.getRegion().begin()); return newGenericOp; diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir index bb84272..cd9d3ac 100644 --- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir +++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir @@ -96,17 +96,16 @@ func.func @elem_pack_transpose_outer_dims(%arg0: tensor<128x256xi32>, %dest: ten into %dest : tensor<128x256xi32> -> tensor<16x4x32x16xi32> return %pack : tensor<16x4x32x16xi32> } -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d0, d2, d3)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> // CHECK: func.func @elem_pack_transpose_outer_dims // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] -// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<4x16x32x16xi32> +// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32> // CHECK: %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]] -// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16] -// CHECK-SAME: into %[[ARG0_EMPTY]] : tensor<128x256xi32> -> tensor<4x16x32x16xi32> +// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16] +// CHECK-SAME: into %[[ARG0_EMPTY]] : tensor<128x256xi32> -> tensor<16x4x32x16xi32> // CHECK: %[[ELEM:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]]] // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] // CHECK-SAME: ins(%[[PACK_ARG0]] // CHECK-SAME: outs(%[[DEST]] @@ -131,17 +130,16 @@ func.func @elem_pack_transpose_inner_and_outer_dims(%arg0: tensor<128x256xi32>, into %dest : tensor<128x256xi32> -> tensor<16x4x16x32xi32> return %pack : tensor<16x4x16x32xi32> } -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d0, d2, d3)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> // CHECK: func.func @elem_pack_transpose_inner_and_outer_dims // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] -// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<4x16x16x32xi32> +// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x16x32xi32> // CHECK: %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]] -// CHECK-SAME: inner_dims_pos = [1, 0] inner_tiles = [16, 32] +// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 32] // CHECK-SAME: into %[[ARG0_EMPTY]] // CHECK: %[[ELEM:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]]] // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] // CHECK-SAME: ins(%[[PACK_ARG0]] // CHECK-SAME: outs(%[[DEST]] @@ -285,7 +283,7 @@ func.func @transpose_pack(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100x #map0 = affine_map<(d0, d1) -> (d0, d1)> #map1 = affine_map<(d0, d1) -> (d0)> #map2 = affine_map<(d0, d1) -> (d1)> -func.func @transpose_pack(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100xi32>, %arg2: tensor<128xi32>, %dest: tensor<200x4x16x100x16x32xi32>) -> tensor<200x4x16x100x16x32xi32> +func.func @transpose_pack_with_outer_dims(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100xi32>, %arg2: tensor<128xi32>, %dest: tensor<200x4x16x100x16x32xi32>) -> tensor<200x4x16x100x16x32xi32> { %init_transpose = tensor.empty() : tensor<100x200x128x256xi32> %transpose = linalg.generic { @@ -308,3 +306,61 @@ func.func @transpose_pack(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100x into %dest : tensor<100x200x128x256xi32> -> tensor<200x4x16x100x16x32xi32> return %4 : tensor<200x4x16x100x16x32xi32> } + +// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d5)> +// CHECK: func.func @transpose_pack_with_outer_dims +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<200x4x16x100x16x32xi32> +// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]] +// CHECK-SAME: outer_dims_perm = [2, 1, 3, 0] inner_dims_pos = [3, 1] inner_tiles = [16, 32] +// CHECK-SAME: into %[[ARG0_EMPTY]] +// CHECK: %[[ARG2_EMPTY:.+]] = tensor.empty() : tensor<4x32xi32> +// CHECK: %[[PACKED_ARG2:.+]] = tensor.pack %[[ARG2]] +// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [32] +// CHECK-SAME: into %[[ARG2_EMPTY]] +// CHECK: %[[RES:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]], #[[MAP]]] +// CHECK-SAME: ins(%[[PACKED_ARG0]], %[[ARG1]], %[[PACKED_ARG2]] +// CHECK-SAME: outs(%[[DEST]] + +// ----- + +#map0 = affine_map<(d0, d1) -> (d0, d1)> +func.func @elem_pack_transpose_outer_dims(%arg0: tensor<128x256xi32>, %init: tensor<128x256xi32>) -> tensor<16x4x32x16xi32>{ + %elem = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]} + ins(%arg0 : tensor<128x256xi32>) + outs(%init : tensor<128x256xi32>) { + ^bb0(%arg3: i32, %arg4: i32): + %4 = arith.addi %arg3, %arg4 : i32 + linalg.yield %4 : i32 + } -> tensor<128x256xi32> + %empty = tensor.empty() : tensor<16x4x32x16xi32> + %pack = tensor.pack %elem + outer_dims_perm = [1, 0] + inner_dims_pos = [0, 1] + inner_tiles = [32, 16] + into %empty : tensor<128x256xi32> -> tensor<16x4x32x16xi32> + return %pack : tensor<16x4x32x16xi32> +} + +// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK: func.func @elem_pack_transpose_outer_dims +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32> +// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]] +// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16] +// CHECK-SAME: into %[[ARG0_EMPTY]] +// CHECK: %[[ARG1_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32> +// CHECK: %[[PACKED_ARG1:.+]] = tensor.pack %[[ARG1]] +// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16] +// CHECK-SAME: into %[[ARG1_EMPTY]] +// CHECK: %[[RES:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]] +// CHECK-SAME: ins(%[[PACKED_ARG0]] +// CHECK-SAME: outs(%[[PACKED_ARG1]] -- 2.7.4