From: Tobias Gysi Date: Thu, 4 Nov 2021 15:51:58 +0000 (+0000) Subject: [mlir][linalg] Add support for transitive fusion. X-Git-Tag: upstream/15.0.7~26725 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=29c31cb79b57594381aa15bcebe8c71b9fa64aef;p=platform%2Fupstream%2Fllvm.git [mlir][linalg] Add support for transitive fusion. Extend fusion on tensors to fuse producers greedily. Reviewed By: nicolasvasilache, hanchung Differential Revision: https://reviews.llvm.org/D110262 --- diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index b32d8e1..0924a5c 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -212,6 +212,7 @@ private: bool isEmpty(); /// Returns true if the tile loop nest invariants are satisfied: + /// - The `rootOp` has been tiled at least once. /// - The number of tile loop operations and dimensions match. /// - The innermost tile loop is the parent of `tiledOp`. /// - The tile loops are directly nested. @@ -233,8 +234,8 @@ private: bool hasOtherUses(BlockArgument bbArg, tensor::ExtractSliceOp sliceOp); LinalgOp rootOp; - SmallVector loopOps; - SmallVector loopDims; + SmallVector tileLoopOps; + DenseMap> tiledRootAndFusedOpsLoops; }; /// Tiles `consumerOp` and fuses its dependencies if possible. Uses the diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp index bfac63b..7156515 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -42,19 +42,62 @@ static SmallVector getTiledSliceDims(OpOperand *consumerOperand, AffineMap indexingMap = consumerOp.getTiedIndexingMap(consumerOperand); // Search the slice dimensions tiled by a tile loop dimension. - DenseSet tiledSliceDims; + DenseSet tiledSliceDimIndices; for (auto en : enumerate(indexingMap.getResults())) { for (auto tiledLoopDim : tiledLoopDims) { if (en.value().isFunctionOfDim(tiledLoopDim)) - tiledSliceDims.insert(en.index()); + tiledSliceDimIndices.insert(en.index()); } } - return {tiledSliceDims.begin(), tiledSliceDims.end()}; + return {tiledSliceDimIndices.begin(), tiledSliceDimIndices.end()}; +} + +/// Given a vector of `tiledSliceDimIndices` that represent the tiled dimensions +/// of the producer result slice returns the tiled producer loop dimensions. +/// Example: +/// ``` +/// %res = linalg.fill(%cst, %input) +/// scf.for %i +/// scf.for %j +/// %slice = tensor.extract_slice %res[%i, %j] +/// ``` +/// getTiledProducerLoops(%res, [0, 1]) returns the loop indices [0, 1]. +static SmallVector +getTiledProducerLoops(OpResult producerResult, + ArrayRef tiledSliceDimIndices) { + LinalgOp producerOp = producerResult.getOwner(); + + // Get the indexing map of the `producerOp` output operand that matches + // ´producerResult´. + AffineMap producerIndexingMap = producerOp.getTiedIndexingMap( + producerOp.getOutputOperand(producerResult.getResultNumber())); + + // Keep only the tiled result slice dimensions of `producerIndexingMap`. + AffineMap tiledProducerIndexingSubMap = + producerIndexingMap.getSubMap(SmallVector( + tiledSliceDimIndices.begin(), tiledSliceDimIndices.end())); + + // Compute the producer loop indices mapped to the tiled result slice + // dimensions. As the output indexing map of structured operations are + // projected permutations, `tiledProducerIndexingSubMap` has to be a + // projected permutation as well. We can thus obtain the producer loop indices + // by getting the positions of the result dimensions. + // Example: + // (d0, d1, d2) -> (d0, d2) has the result positions [0, 2]. + assert(tiledProducerIndexingSubMap.isProjectedPermutation() && + "expect slice and producer loop dimensions map one-to-one"); + SmallVector tiledProducerLoopIndices; + transform(llvm::seq(0, tiledProducerIndexingSubMap.getNumResults()), + std::back_inserter(tiledProducerLoopIndices), [&](unsigned idx) { + return tiledProducerIndexingSubMap.getDimPosition(idx); + }); + + return tiledProducerLoopIndices; } /// Returns the producer fused in place of `sliceOp`. Tile the producer operands -/// along the `tiledSliceDims` and clone the producer. Consider the case of -/// fusion of an output tensor: +/// along the `tiledSliceDimIndices` and clone the producer. Consider the case +/// of fusion of an output tensor: /// ``` /// %1 = producer ins(...) outs(%0) /// %2 = consumer ins(...) outs(%1) @@ -84,7 +127,8 @@ static SmallVector getTiledSliceDims(OpOperand *consumerOperand, /// producer is fused into a consumer and fold away unused iter_args. static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult, tensor::ExtractSliceOp sliceOp, - ArrayRef tiledSliceDims, + ArrayRef tiledSliceDimIndices, + ArrayRef tiledProducerLoopIndices, OpOperand *iterArg) { // Clone the producer after `sliceOp` since the slice may be reused to pass in // the producer result. @@ -102,23 +146,16 @@ static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult, [](Range range) { return range.size; }); SmallVector sliceOpRanges = sliceOp.getOrCreateRanges(b, loc); - // Get the producer result indexing map. - AffineMap producerIndexingMap = producerOp.getTiedIndexingMap( - producerOp.getOutputOperand(producerResult.getResultNumber())); - // Tile the producer operands given the `sliceOp` ranges. Iterate the - // `tiledSliceDims` and store the tile offset and size for the tiled slice - // dimension. Assumes the mapping from slice dimensions to producer loops is a - // permutation. + // `tiledSliceDimIndices` and store the tile offset and size for the tiled + // slice dimension. auto zero = b.create(loc, 0); SmallVector tileIvs(producerOp.getNumLoops(), nullptr); SmallVector tileSizes(producerOp.getNumLoops(), zero); SmallVector allIvs(producerOp.getNumLoops(), nullptr); - for (int64_t tiledSliceDim : tiledSliceDims) { - AffineExpr result = producerIndexingMap.getResults()[tiledSliceDim]; - assert(result.isa() && - "expect producer indexing map is a projected permutation"); - int64_t tiledProducerLoop = result.cast().getPosition(); + for (auto it : zip(tiledSliceDimIndices, tiledProducerLoopIndices)) { + int64_t tiledSliceDim = std::get<0>(it); + int64_t tiledProducerLoop = std::get<1>(it); tileIvs[tiledProducerLoop] = sliceOpRanges[tiledSliceDim].offset; tileSizes[tiledProducerLoop] = sliceOpRanges[tiledSliceDim].size; allIvs[tiledProducerLoop] = tileIvs[tiledProducerLoop]; @@ -156,22 +193,26 @@ static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult, // TileLoopNest specific helpers. //===----------------------------------------------------------------------===// -bool TileLoopNest::isEmpty() { return loopOps.empty(); } +bool TileLoopNest::isEmpty() { return tileLoopOps.empty(); } bool TileLoopNest::isValid() { - // Check if the number of `tileLoopOps` and `tileLoopDims` match. - if (loopOps.size() != loopDims.size()) + // Check if `rootOp` has been tiled at least once. + if (isEmpty() || tiledRootAndFusedOpsLoops.count(rootOp) == 0) + return false; + + // Check if the number of loop operations and dimensions match. + if (tileLoopOps.size() != tiledRootAndFusedOpsLoops[rootOp].size()) return false; // Check if the innermost tile loop is the parent of `tiledOp`. - if (rootOp->getParentOp() != loopOps.back()) + if (rootOp->getParentOp() != tileLoopOps.back()) return false; // Check if the tile loops are directly nested. - return std::adjacent_find(loopOps.begin(), loopOps.end(), + return std::adjacent_find(tileLoopOps.begin(), tileLoopOps.end(), [](Operation *op1, Operation *op2) { return op1 != op2->getParentOp(); - }) == loopOps.end(); + }) == tileLoopOps.end(); } SmallVector TileLoopNest::getTiedBBArgs(BlockArgument bbArg) { @@ -179,7 +220,7 @@ SmallVector TileLoopNest::getTiedBBArgs(BlockArgument bbArg) { SmallVector bbArgs; // Search all tile loop block arguments from inner to outer. - for (auto tileLoop : reverse(loopOps)) { + for (auto tileLoop : reverse(tileLoopOps)) { if (bbArg.getOwner()->getParentOp() != tileLoop) return {}; bbArgs.push_back(bbArg); @@ -194,9 +235,9 @@ SmallVector TileLoopNest::getTiedBBArgs(BlockArgument bbArg) { OpOperand *TileLoopNest::getTiedIterArg(BlockArgument bbArg) { // Search all block arguments and return the matching iteration argument. SmallVector bbArgs = getTiedBBArgs(bbArg); - if (bbArgs.size() != loopOps.size()) + if (bbArgs.size() != tileLoopOps.size()) return nullptr; - return &loopOps.front().getOpOperandForRegionIterArg(bbArgs.front()); + return &tileLoopOps.front().getOpOperandForRegionIterArg(bbArgs.front()); } bool TileLoopNest::hasOtherUses(BlockArgument bbArg, @@ -255,24 +296,29 @@ LogicalResult TileLoopNest::tileRootOp(OpBuilder &b, if (!isEmpty()) rootOp->replaceAllUsesWith(tiledRootOp->tensorResults); + // Transfer the stored `rootOp` loop dimensions if it has been tiled before. + if (tiledRootAndFusedOpsLoops.count(rootOp) != 0) { + tiledRootAndFusedOpsLoops[tiledRootOp->op] = + tiledRootAndFusedOpsLoops[rootOp]; + } + // Update the root operation and append the loops and tile loop dimensions. rootOp = tiledRootOp->op; - loopOps.append(tiledRootOp->loops.begin(), tiledRootOp->loops.end()); + tileLoopOps.append(tiledRootOp->loops.begin(), tiledRootOp->loops.end()); for (auto en : enumerate(tileSizes)) { // Copy only the tiled loop dimensions with non-zero tile size. if (en.value() == 0) continue; - loopDims.push_back(tileInterchange[en.index()]); + tiledRootAndFusedOpsLoops[rootOp].push_back(tileInterchange[en.index()]); } assert(isValid() && "expect tile loop nest to be valid after tiling"); - return success(); } FailureOr TileLoopNest::fuseProducer(OpBuilder &b, - OpOperand *rootOpOperand) { - assert(rootOpOperand->getOwner() == rootOp && - "expect the root op to be the owner of the operand to fuse"); + OpOperand *consumerOpOperand) { + assert(tiledRootAndFusedOpsLoops.count(consumerOpOperand->getOwner()) != 0 && + "expect the operand owner is the root operation or a fused producer"); assert(this->isValid() && "expect the tile loop nest to satisfy all invariants"); @@ -280,13 +326,16 @@ FailureOr TileLoopNest::fuseProducer(OpBuilder &b, if (isEmpty()) return failure(); - // Check `rootOpOperand` is defined by an ExtractSliceOp. - auto sliceOp = rootOpOperand->get().getDefiningOp(); + // Check `consumerOpOperand` is defined by an ExtractSliceOp. + auto sliceOp = + consumerOpOperand->get().getDefiningOp(); if (!sliceOp) return failure(); - // Check `sliceOp` is tiled by the tile loop nest. - if (sliceOp->getParentOp() != rootOp->getParentOp()) + // Check `sliceOp` and `consumerOp` are in the same block. + LinalgOp consumerOp = consumerOpOperand->getOwner(); + if (sliceOp->getBlock() != rootOp->getBlock() || + consumerOp->getBlock() != rootOp->getBlock()) return failure(); // Check if the producer is a LinalgOp possibly passed by iteration argument. @@ -302,19 +351,24 @@ FailureOr TileLoopNest::fuseProducer(OpBuilder &b, if (!producerResult || !isa(producerResult.getOwner())) return failure(); - // Compute the tiled producer slice dimensions given the tiled root operation - // loop dimensions `loopDims`. - SmallVector tiledSliceDims = - getTiledSliceDims(rootOpOperand, loopDims); - if (tiledSliceDims.empty()) + // Compute the tiled producer slice dimensions given the tiled consumer loops. + SmallVector tiledSliceDimIndices = getTiledSliceDims( + consumerOpOperand, tiledRootAndFusedOpsLoops[consumerOp]); + if (tiledSliceDimIndices.empty()) return failure(); + // Compute the tiled producer loop indices. + SmallVector tiledProducerLoopIndices = + getTiledProducerLoops(producerResult, tiledSliceDimIndices); + // Tile the producer operands and clone the producer in place of `sliceOp`. LinalgOp clonedOp = - getTiledProducer(b, producerResult, sliceOp, tiledSliceDims, iterArg); + getTiledProducer(b, producerResult, sliceOp, tiledSliceDimIndices, + tiledProducerLoopIndices, iterArg); + tiledRootAndFusedOpsLoops[clonedOp] = tiledProducerLoopIndices; // Cast the `clonedOp` result to gap type mismatches before canonicalization. - Type consumerOperandType = rootOpOperand->get().getType(); + Type consumerOperandType = consumerOpOperand->get().getType(); Value newResult = clonedOp->getResult(producerResult.getResultNumber()); if (newResult.getType() != consumerOperandType) { OpBuilder::InsertionGuard guard(b); @@ -330,7 +384,7 @@ FailureOr TileLoopNest::fuseProducer(OpBuilder &b, ValueRange TileLoopNest::getRootOpReplacementResults() { assert(!isEmpty() && "expect tile loop nest to be non-empty"); - return loopOps.front()->getOpResults(); + return tileLoopOps.front()->getOpResults(); } //===----------------------------------------------------------------------===// @@ -359,14 +413,25 @@ mlir::linalg::tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp, }); int64_t split = std::distance(iterTypes.begin(), it); + // Helper to fuse the producers greedily using a queue of fusion candidates. + auto fuseProducersGreedily = [&](ArrayRef operands) { + SmallVector candidates(operands.begin(), operands.end()); + while (!candidates.empty()) { + FailureOr fusedProducer = + tileLoopNest.fuseProducer(b, candidates.pop_back_val()); + if (failed(fusedProducer)) + continue; + candidates.append(fusedProducer->getInputAndOutputOperands()); + } + }; + // Tile the outer parallel loops and fuse the output operands. SmallVector outerTileSizes; outerTileSizes.append(tileSizes.begin(), tileSizes.begin() + split); outerTileSizes.append(tileSizes.size() - split, 0); if (failed(tileLoopNest.tileRootOp(b, outerTileSizes, tileInterchange))) return failure(); - for (OpOperand *opOperand : tileLoopNest.getRootOp().getOutputOperands()) - (void)tileLoopNest.fuseProducer(b, opOperand); + fuseProducersGreedily(tileLoopNest.getRootOp().getOutputOperands()); // Tile the remaining loops and fuse the input operands. SmallVector innerTileSizes; @@ -374,10 +439,7 @@ mlir::linalg::tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp, innerTileSizes.append(tileSizes.begin() + split, tileSizes.end()); if (failed(tileLoopNest.tileRootOp(b, innerTileSizes, tileInterchange))) return failure(); - SmallVector inputOperands = - tileLoopNest.getRootOp().getInputOperands(); - for (OpOperand *opOperand : tileLoopNest.getRootOp().getInputOperands()) - (void)tileLoopNest.fuseProducer(b, opOperand); + fuseProducersGreedily(tileLoopNest.getRootOp().getInputOperands()); return tileLoopNest; } diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-sequence-on-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-sequence-on-tensors.mlir new file mode 100644 index 0000000..1578d23 --- /dev/null +++ b/mlir/test/Dialect/Linalg/tile-and-fuse-sequence-on-tensors.mlir @@ -0,0 +1,83 @@ +// RUN: mlir-opt %s -linalg-tile-and-fuse-tensor-ops="tile-sizes=4,4,0,0 tile-interchange=0,1,2,3" -cse --canonicalize -split-input-file | FileCheck %s + +// CHECK: fuse_conv_chain +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<2x2xf32> +// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<11x11xf32> +// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]*]]: tensor<10x10xf32> +// CHECK-SAME: %[[ARG3:[0-9a-zA-Z]*]]: tensor<9x9xf32> +// CHECK-SAME: %[[ARG4:[0-9a-zA-Z]*]]: tensor<8x8xf32> +builtin.func @fuse_conv_chain(%arg0: tensor<2x2xf32>, + %arg1: tensor<11x11xf32>, + %arg2: tensor<10x10xf32>, + %arg3: tensor<9x9xf32>, + %arg4: tensor<8x8xf32>) -> tensor<8x8xf32> { + %cst = arith.constant 1.0 : f32 + + // Do not tile the filter fill since the filter dimensions are not tiled. + // CHECK: %[[T0:.*]] = linalg.fill(%{{.*}}, %[[ARG0]]) + %0 = linalg.fill(%cst, %arg0) : f32, tensor<2x2xf32> -> tensor<2x2xf32> + + // Fuse all other operations. + // CHECK: scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG5:.*]] = %[[ARG4]] + // CHECK: scf.for %[[IV1:.*]] = {{.*}} iter_args(%[[ARG6:.*]] = %[[ARG5]] + + // CHECK: %[[T1:.*]] = tensor.extract_slice %[[ARG1]] + // CHECK-SAME: %[[IV0]], %[[IV1]] + // CHECK: %[[T2:.*]] = tensor.extract_slice %[[ARG2]] + // CHECK-SAME: %[[IV0]], %[[IV1]] + // CHECK: %[[T3:.*]] = linalg.fill(%{{.*}}, %[[T2]]) + // CHECK: %[[T4:.*]] = linalg.conv_2d ins(%[[T1]], %[[T0]] : {{.*}} outs(%[[T3]] + %1 = linalg.fill(%cst, %arg2) : f32, tensor<10x10xf32> -> tensor<10x10xf32> + %2 = linalg.conv_2d ins(%arg1, %0 : tensor<11x11xf32>, tensor<2x2xf32>) outs(%1 : tensor<10x10xf32>) -> tensor<10x10xf32> + + // CHECK: %[[T5:.*]] = tensor.extract_slice %[[ARG3]] + // CHECK-SAME: %[[IV0]], %[[IV1]] + // CHECK: %[[T6:.*]] = linalg.fill(%{{.*}}, %[[T5]]) + // CHECK: %[[T7:.*]] = linalg.conv_2d ins(%[[T4]], %[[T0]] : {{.*}} outs(%[[T6]] + %3 = linalg.fill(%cst, %arg3) : f32, tensor<9x9xf32> -> tensor<9x9xf32> + %4 = linalg.conv_2d ins(%2, %0 : tensor<10x10xf32>, tensor<2x2xf32>) outs(%3 : tensor<9x9xf32>) -> tensor<9x9xf32> + + // Use the argument passed in by iteration argument. + // CHECK: %[[T8:.*]] = tensor.extract_slice %[[ARG6]] + // CHECK-SAME: %[[IV0]], %[[IV1]] + // CHECK: %[[T9:.*]] = linalg.fill(%{{.*}}, %[[T8]]) + // CHECK: %[[T5:.*]] = linalg.conv_2d ins(%[[T7]], %[[T0]] {{.*}} outs(%[[T9]] + %5 = linalg.fill(%cst, %arg4) : f32, tensor<8x8xf32> -> tensor<8x8xf32> + %6 = linalg.conv_2d ins(%4, %0 : tensor<9x9xf32>, tensor<2x2xf32>) outs(%5 : tensor<8x8xf32>) -> tensor<8x8xf32> + return %6 : tensor<8x8xf32> +} + +// ----- + +// CHECK: fuse_matmul_chain +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<8x8xf32> +builtin.func @fuse_matmul_chain(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { + %c0 = arith.constant 0 : index + %c12 = arith.constant 12 : index + %c25 = arith.constant 25 : index + %c24 = arith.constant 24 : index + %c4 = arith.constant 4 : index + %cst = arith.constant 0.000000e+00 : f32 + + // Do not tile rhs fill of the producer matmul since none of its loop dimension is tiled. + // CHECK: %[[T0:.*]] = linalg.fill(%{{.*}}, %[[ARG0]]) + %0 = linalg.fill(%cst, %arg0) : f32, tensor<8x8xf32> -> tensor<8x8xf32> + + // CHECK: scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG1:.*]] = %[[ARG0]] + // CHECK: scf.for %[[IV1:.*]] = {{.*}} iter_args(%[[ARG2:.*]] = %[[ARG1]] + + // Only the outermost loop of the producer matmul is tiled. + // CHECK: %[[T1:.*]] = tensor.extract_slice %[[ARG0]] + // CHECK-SAME: %[[IV0]], 0 + // CHECK: %[[T2:.*]] = linalg.fill(%{{.*}}, %[[T1]]) + // CHECK: %[[T3:.*]] = linalg.matmul ins(%[[T2]], %[[T0]] {{.*}} + %1 = linalg.matmul ins(%0, %0 : tensor<8x8xf32>, tensor<8x8xf32>) outs(%0 : tensor<8x8xf32>) -> tensor<8x8xf32> + + // Use the argument passed in by iteration argument. + // CHECK: %[[T4:.*]] = tensor.extract_slice %[[ARG2]] + // CHECK-SAME: %[[IV0]], %[[IV1]] + // CHECK: %[[T5:.*]] = linalg.fill(%{{.*}}, %[[T4]]) + // CHECK: %{{.*}} = linalg.matmul ins(%[[T3]], {{.*}} outs(%[[T5]] + %2 = linalg.matmul ins(%1, %0 : tensor<8x8xf32>, tensor<8x8xf32>) outs(%0 : tensor<8x8xf32>) -> tensor<8x8xf32> + return %2 : tensor<8x8xf32> +}