From 4521b113978d9ddaaae038e3cdd9d8902e2392f9 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Wed, 22 Feb 2023 05:24:25 -0800 Subject: [PATCH] [mlir][Linalg] Reimplement hoisting on tensors as a subset-based transformation This revision significantly rewrites hoisting on tensors. Previously, `vector.transfer_read/write` and `tensor.extract/insert_slice` would be clumped together when looking for candidate pairs. This would significantly increase the complexity of the logic and would not apply independently to `tensor.extract/insert_slice`. The new implementation decouples the cases and starts to cast the problem as a generic matching subset extract/insert, which will be future proof when other such operation pairs are introduced. Lastly, the implementation makes the distinction clear between `vector.transfer_read/write` for which we allow bypasses of the disjoint subsets from `tensor.extract/insert_slice` for which we do not yet allow it. This can be extended in the future and unified once we have subset disjunction implemented more generally. The algorithm can be rewritten to be less of a fixed point with interspersed canonicalizations. As a consequence, the test explicitly adds a canonicalization to clean up the IR and verify we end up in the same state. That extra canonicalization exhibited that one of the uses in one of the tests was dead, so we fix the appropriate test. Differential Revision: https://reviews.llvm.org/D144656 --- .../Linalg/TransformOps/LinalgTransformOps.td | 51 ++ .../mlir/Dialect/Linalg/Transforms/Hoisting.h | 111 +++- .../mlir/Dialect/Tensor/IR/TensorOps.td | 34 +- .../mlir/Dialect/Utils/StaticValueUtils.h | 2 + .../TransformOps/LinalgTransformOps.cpp | 30 +- .../Dialect/Linalg/Transforms/CMakeLists.txt | 1 + .../Dialect/Linalg/Transforms/Hoisting.cpp | 373 +----------- .../Linalg/Transforms/SubsetHoisting.cpp | 553 ++++++++++++++++++ mlir/lib/Dialect/Utils/StaticValueUtils.cpp | 10 + mlir/test/Dialect/Linalg/hoisting.mlir | 193 ++++-- 10 files changed, 932 insertions(+), 426 deletions(-) create mode 100644 mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 41c5daf6744d..4aacd68e3bc9 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -1725,6 +1725,10 @@ def HoistRedundantVectorTransfersOp : dominated by the transfer_write (i.e. no aliasing between the write and the read across the loop) + WARNING: This hoisting does not model parallelism and is generally incorrect + when used on distributed loops with memref semantics! + TODO: obsolete and should be retired. + #### Return modes: The operation always succeeds and returns a handle to the transformed @@ -1823,4 +1827,51 @@ def ConvertConv2DToImg2ColOp : Op { + let description = [{ + Hoists supported tensor subset extract/insert operation pairs out of + immediately enclosing loop iteratively, if the following conditions + are true: + 1. The 2 ops access the same tensor subset. + 2. All operands are invariant under the enclosing loop. + + The supported subset extract/insert operation pairs currently comprise: + - tensor.extract_slice / tensor.insert_slice + - vector.transfer_read / vector.transfer_write on tensors + + Only scf.for loops are currently supported. + + When applied to: + 1. an scf.for loop, hoist out of this loop only. + 2. a non-loop op, apply hoisting to all the contained loop ops. + + #### Return modes: + + The operation always succeeds and returns a handle to the transformed + function op. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs TransformHandleTypeInterface:$transformed); + + let assemblyFormat = "$target attr-dict `:` functional-type(operands, results) "; + + let builders = [ + OpBuilder<(ins "Value":$target)>, + ]; + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + #endif // LINALG_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h index 355106ddd917..24cb754d65e1 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h @@ -10,9 +10,13 @@ #define MLIR_DIALECT_LINALG_TRANSFORMS_HOISTING_H_ namespace mlir { +class RewriterBase; namespace func { class FuncOp; } // namespace func +namespace scf { +class ForOp; +} // namespace scf namespace linalg { @@ -28,11 +32,112 @@ namespace linalg { /// function on the candidate loop above which to hoist. Hoisting the transfers /// results in scf::ForOp yielding the value that originally transited through /// memory. -// TODO: generalize on a per-need basis. +/// +/// WARNING: This hoisting does not model parallelism and is generally incorrect +/// when used on distributed loops with memref semantics! void hoistRedundantVectorTransfers(func::FuncOp func); -/// Same behavior as `hoistRedundantVectorTransfers` but works on tensors -/// instead of buffers. +/// Greedily hoist redundant subset extract/insert operations on tensors outside +/// of `forOp`. The logic follows: +/// 1. Look for a write walking back from the `forOp` yield. +/// 2. Check the uses of the matching block argument and look for a matching +/// read (i.e. extract_slice of transfer_read) with matching indices. +/// 3. In the case of a transfer_write, we can bypass other non-conflicting +/// operations and find more hoisting opportunities. +/// 4. Hoist the read/write pair and update the tensor SSA links. +/// +/// Return the unmodified `forOp` if no hoisting occured. +/// Return a new scf::ForOp if hoisting on tensors occured. +/// +/// After this transformation the returned scf::ForOp may have unused arguments +/// that can be removed by application of canonicalization patterns. +/// +/// Example: +/// ======== +/// IR Resembling: +/// +/// ``` +/// %0 = scf.for %i = %l to %u step %s iter_args(%a0 = %t0)->(tensor<10xf32>) { +/// %1 = scf.for %j = %l to %u step %s iter_args(%a6 = %a0)->(tensor<10xf32>) { +/// %e = tensor.extract_slice %a6[%i][%sz][1]: tensor<10xf32> to tensor +/// %r = vector.transfer_read %e[%c0], %cst: tensor, vector<4xf32> +/// %u = "some_use"(%r) : (vector<4xf32>) -> vector<4xf32> +/// %w = vector.transfer_write %u, %e[%c0] : vector<4xf32>, tensor +/// %st = tensor.insert_slice %w into %a6[%i][%sz][1] +/// : tensor into tensor<10xf32> +/// scf.yield %st: tensor<10xf32> +/// } +/// scf.yield %1: tensor<10xf32> +/// } +/// ``` +/// +/// Progressively hoists to: +/// +/// ``` +/// %0 = scf.for %i = %l to %u step %s iter_args(%a0 = %t0) -> (tensor<10xf32>){ +/// %e = tensor.extract_slice %a0[%i][%sz][1]: tensor<10xf32> to tensor +/// %1:2 = scf.for %j = %l to %u step %s iter_args(%a6 = a0, %a7 = %e) +/// -> (tensor<10xf32>, tensor) { +/// %r = vector.transfer_read %a7[%c0], %cst: tensor, vector<4xf32> +/// %u = "some_use"(%r) : (vector<4xf32>) -> vector<4xf32> +/// %w = vector.transfer_write %u, %a7[%c0] : vector<4xf32>, tensor +/// scf.yield %a6, %w: tensor<10xf32>, tensor +/// } +/// %st = tensor.insert_slice %1#1 into %1#0[%i][%sz][1] +/// : tensor into tensor<10xf32> +/// scf.yield %1: tensor<10xf32> +/// } +/// ``` +/// +/// and +/// +/// ``` +/// %0 = scf.for %i = %l to %u step %s iter_args(%a0 = %t0) -> (tensor<10xf32>){ +/// %e = tensor.extract_slice %a0[%i][%sz][1]: tensor<10xf32> to tensor +/// %r = vector.transfer_read %a7[%c0], %cst: tensor, vector<4xf32> +/// %1:3 = scf.for %j = %l to %u step %s iter_args(%a6 = a0, %a7 = %e, %a7 = r) +/// -> (tensor<10xf32>, tensor, vector<4xf32>) { +/// %u = "some_use"(%r) : (vector<4xf32>) -> vector<4xf32> +/// scf.yield %a6, %a7, %u: tensor<10xf32>, tensor, vector<4xf32> +/// } +/// %w = vector.transfer_write %1#2, %1#1[%c0] : vector<4xf32>, tensor +/// %st = tensor.insert_slice %w into %1#0[%i][%sz][1] +/// : tensor into tensor<10xf32> +/// scf.yield %1: tensor<10xf32> +/// } +/// ``` +/// +/// It can then canonicalize to: +/// +/// ``` +/// %0 = scf.for %i = %l to %u step %s iter_args(%a0 = %t0) -> (tensor<10xf32>){ +/// %e = tensor.extract_slice %a0[%i][%sz][1]: tensor<10xf32> to tensor +/// %r = vector.transfer_read %a7[%c0], %cst: tensor, vector<4xf32> +/// %1 = scf.for %j = %l to %u step %s iter_args(%a7 = r) +/// -> (tensor<10xf32>, tensor, vector<4xf32>) { +/// %u = "some_use"(%r) : (vector<4xf32>) -> vector<4xf32> +/// scf.yield %u: vector<4xf32> +/// } +/// %w = vector.transfer_write %1, %e[%c0] : vector<4xf32>, tensor +/// %st = tensor.insert_slice %w into %a0[%i][%sz][1] +/// : tensor into tensor<10xf32> +/// scf.yield %1: tensor<10xf32> +/// } +/// ``` +/// +// TODO: This should be further generalized along a few different axes: +// - Other loops than scf.ForOp that operate on tensors (both sequential and +// parallel loops). +// - Other subset extract/insert pairs than tensor.extract/insert_slice and +// vector.transfer_read/write. +// - More general areSubsetDisjoint analysis/interface to work across all +// subset op types and allow bypassing non-WAW-conflicting operations in +// more cases. +scf::ForOp hoistRedundantSubsetExtractInsert(RewriterBase &rewriter, + scf::ForOp forOp); + +/// Call into `hoistRedundantSubsetInsertExtract` without a RewriterBase. +// TODO: obsolete and should be retired void hoistRedundantVectorTransfersOnTensor(func::FuncOp func); } // namespace linalg diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index 9652d7de5f7c..80c0ba5e754a 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -30,7 +30,17 @@ class Tensor_OpWithOffsetSizesAndStrides traits = []> : Tensor_Op { code extraBaseClassDeclaration = [{ - /// Returns the dynamic sizes for this subview operation if specified. + /// Return the type of the base tensor operand. + ::mlir::RankedTensorType getSourceType() { + return getSource().getType().cast(); + } + + /// Return the type of the result tensor. + ::mlir::RankedTensorType getResultType() { + return getResult().getType().cast(); + } + + /// Return the dynamic sizes for this subview operation if specified. ::mlir::Operation::operand_range getDynamicSizes() { return getSizes(); } /// Return the list of Range (i.e. offset, size, stride). Each @@ -105,7 +115,7 @@ def Tensor_DimOp : Tensor_Op<"dim", [ %c0 = arith.constant 0 : index %x = tensor.dim %A, %c0 : tensor<4x?xf32> - // Returns the dynamic dimension of %A. + // Return the dynamic dimension of %A. %c1 = arith.constant 1 : index %y = tensor.dim %A, %c1 : memref<4x?xf32> @@ -361,14 +371,10 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice", ]; let extraClassDeclaration = extraBaseClassDeclaration # [{ - /// Returns the type of the base tensor operand. - RankedTensorType getSourceType() { - return getSource().getType().cast(); - } - /// The result of an extract_slice is always a tensor. + // TODO: deprecate RankedTensorType getType() { - return getResult().getType().cast(); + return getResultType(); } /// Compute the rank-reduction mask that can be applied to map the source @@ -834,25 +840,21 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [ ]; let extraClassDeclaration = extraBaseClassDeclaration # [{ - /// Returns the type of the base tensor operand. - RankedTensorType getSourceType() { - return getSource().getType().cast(); - } - /// The result of a insert_slice is always a tensor. + // TODO: Deprecate this method. RankedTensorType getType() { - return getResult().getType().cast(); + return getResultType(); } /// The `dest` type is the same as the result type. RankedTensorType getDestType() { - return getType(); + return getResultType(); } /// Return the expected rank of each of the`static_offsets`, `static_sizes` /// and `static_strides` attributes. std::array getArrayAttrMaxRanks() { - unsigned rank = getType().getRank(); + unsigned rank = getResultType().getRank(); return {rank, rank, rank}; } diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h index c37d35134dce..100699c7f7fd 100644 --- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -82,6 +82,8 @@ bool isConstantIntValue(OpFoldResult ofr, int64_t value); /// that come from the fact there is no IndexAttr and that IndexType have no /// bitwidth. bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2); +bool isEqualConstantIntOrValueArray(ArrayRef ofrs1, + ArrayRef ofrs2); /// Helper function to convert a vector of `OpFoldResult`s into a vector of /// `Value`s. For each `OpFoldResult` in `valueOrAttrVec` return the fold diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 5a8f9816aefd..6baf392f9508 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3097,8 +3097,10 @@ DiagnosedSilenceableFailure transform::HoistRedundantVectorTransfersOp::applyToOne( func::FuncOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { + // WARNING: This hoisting does not model parallelism and is generally + // incorrect when used on distributed loops with memref semantics! + // TODO: obsolete and should be retired. linalg::hoistRedundantVectorTransfers(target); - linalg::hoistRedundantVectorTransfersOnTensor(target); results.push_back(target); return DiagnosedSilenceableFailure::success(); } @@ -3136,6 +3138,32 @@ DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne( return DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// HoistRedundantTensorSubsetsOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::HoistRedundantTensorSubsetsOp::applyToOne( + Operation *target, transform::ApplyToEachResultList &results, + transform::TransformState &state) { + IRRewriter rewriter(target->getContext()); + auto forOp = dyn_cast(target); + if (forOp) { + scf::ForOp newForOp = + linalg::hoistRedundantSubsetExtractInsert(rewriter, forOp); + results.push_back(newForOp); + return DiagnosedSilenceableFailure::success(); + } + + // TODO: walking in some reverse / inside-out order would be more efficient + // and would capture more cases. + target->walk([&](scf::ForOp forOp) { + hoistRedundantSubsetExtractInsert(rewriter, forOp); + }); + results.push_back(target); + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt index adcc87f42dab..8ad28f9be2ab 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -25,6 +25,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms Promotion.cpp Split.cpp SplitReduction.cpp + SubsetHoisting.cpp SwapExtractSliceWithFillPatterns.cpp Tiling.cpp TilingInterfaceImpl.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp index f51b4ffe9999..9bab4ffb4c99 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -43,374 +43,13 @@ using llvm::dbgs; using namespace mlir; using namespace mlir::linalg; -namespace { -/// Represents a unit of hoistable TransferWriteOp. This may comprise other -/// instructions that need to be hoisted too. -struct HoistableWrite { - vector::TransferWriteOp transferWriteOp; - tensor::InsertSliceOp insertSliceOp; -}; -/// Represents a unit of hoistable TransferReadOp. This may comprise other -/// instructions that need to be hoisted too. -struct HoistableRead { - vector::TransferReadOp transferReadOp; - tensor::ExtractSliceOp extractSliceOp; -}; -} // namespace - -/// Return true if op1 and op2 are the same constant or the same SSA value. -static bool isEqualOffsetSizeOrStride(OpFoldResult op1, OpFoldResult op2) { - auto getConstantIntValue = [](OpFoldResult ofr) -> std::optional { - Attribute attr = ofr.dyn_cast(); - // Note: isa+cast-like pattern allows writing the condition below as 1 line. - if (!attr && ofr.get().getDefiningOp()) - attr = ofr.get().getDefiningOp().getValue(); - if (auto intAttr = attr.dyn_cast_or_null()) - return intAttr.getValue().getSExtValue(); - return std::nullopt; - }; - auto cst1 = getConstantIntValue(op1), cst2 = getConstantIntValue(op2); - if (cst1 && cst2 && *cst1 == *cst2) - return true; - auto v1 = op1.dyn_cast(), v2 = op2.dyn_cast(); - return v1 && v2 && v1 == v2; -} - -/// Return true is all offsets, sizes and strides are equal. -static bool sameOffsetsSizesAndStrides(tensor::ExtractSliceOp s, - tensor::InsertSliceOp si) { - if (s.getStaticOffsets().size() != si.getStaticOffsets().size()) - return false; - if (s.getStaticSizes().size() != si.getStaticSizes().size()) - return false; - if (s.getStaticStrides().size() != si.getStaticStrides().size()) - return false; - for (auto it : llvm::zip(s.getMixedOffsets(), si.getMixedOffsets())) - if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it))) - return false; - for (auto it : llvm::zip(s.getMixedSizes(), si.getMixedSizes())) - if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it))) - return false; - for (auto it : llvm::zip(s.getMixedStrides(), si.getMixedStrides())) - if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it))) - return false; - return true; -} - -/// Look for a HoistableRead, in the given tensor uses, accessing the same -/// offset as the HoistableWrite. -static HoistableRead findMatchingTransferRead(HoistableWrite write, - Value srcTensor) { - assert(write.transferWriteOp && - "expected hoistable write to have a .transfer_write"); - - LLVM_DEBUG(DBGS() << "findMatchingTransferRead for: " - << *write.transferWriteOp.getOperation() << "\n"); - if (write.insertSliceOp) - LLVM_DEBUG(DBGS() << "findMatchingTransferRead inserSliceOp: " - << *write.insertSliceOp.getOperation() << "\n"); - SmallVector users(srcTensor.getUsers().begin(), - srcTensor.getUsers().end()); - while (!users.empty()) { - Operation *user = users.pop_back_val(); - LLVM_DEBUG(DBGS() << "findMatchingTransferRead inspect user: " << *user - << "\n"); - - // If HoistableWrite involves a InsertSliceOp, we need to find a - // matching ExtractSliceOp. - tensor::ExtractSliceOp sliceOp; - Operation *maybeTransferReadUser = user; - if (write.insertSliceOp) { - sliceOp = dyn_cast(user); - if (!sliceOp || sliceOp.getResult().getType() != - write.insertSliceOp.getSource().getType()) - continue; - - LLVM_DEBUG(DBGS() << "check whether sameOffsetsSizesAndStrides: " - << *sliceOp << " vs " << *write.insertSliceOp << "\n"); - if (!sameOffsetsSizesAndStrides(sliceOp, write.insertSliceOp)) - continue; - - LLVM_DEBUG(DBGS() << "sameOffsetsSizesAndStrides: SUCCESS\n"); - // If we got here, sliceOp is hoistable iff it has exactly 2 uses: - // 1. the transfer_write we want to hoist. - // 2. a matching transfer_read. - // Anything else, we skip. - bool skip = false; - Operation *otherUser = nullptr; - for (Operation *u : sliceOp->getUsers()) { - if (u == write.transferWriteOp) - continue; - if (otherUser) { - skip = true; - break; - } - otherUser = u; - } - if (skip || !otherUser) - continue; - maybeTransferReadUser = otherUser; - } - - LLVM_DEBUG(DBGS() << "maybeTransferReadUser: " << *maybeTransferReadUser - << "\n"); - auto read = dyn_cast(maybeTransferReadUser); - if (read && read.getIndices() == write.transferWriteOp.getIndices() && - read.getVectorType() == write.transferWriteOp.getVectorType()) - return HoistableRead{read, sliceOp}; - - if (isa(user)) { - // If we find a write with disjoint indices recurse through its uses. - if (vector::isDisjointTransferIndices( - cast(user), - cast( - write.transferWriteOp.getOperation()))) { - users.append(user->getUsers().begin(), user->getUsers().end()); - } - } - } - return HoistableRead(); -} - -/// Check if the chunk of data inserted by the HoistableWrite are read by any -/// other op than the HoistableRead candidate. -static bool tensorChunkAccessedByUnknownOp(HoistableWrite write, - HoistableRead candidateRead, - BlockArgument tensorArg) { - // Make sure none of the other uses read the part of the tensor modified - // by the transfer_write. - llvm::SmallVector uses; - uses.push_back(tensorArg.getUses()); - while (!uses.empty()) { - for (OpOperand &use : uses.pop_back_val()) { - Operation *user = use.getOwner(); - // Skip the candidate use, only inspect the "other" uses. - if (user == candidateRead.transferReadOp || - user == candidateRead.extractSliceOp || - user == write.transferWriteOp || user == write.insertSliceOp) - continue; - // Consider all transitive uses through a extract_slice / insert_slice. - // TODO: atm we just bail because a stronger analysis is needed for these - // cases. - if (isa(user)) - return true; - // Consider all transitive uses through a vector.transfer_write. - if (auto writeUser = dyn_cast(user)) { - uses.push_back(writeUser->getResult(0).getUses()); - continue; - } - // Consider all nested uses through an scf::ForOp. We may have - // pass-through tensor arguments left from previous level of - // hoisting. - if (auto forUser = dyn_cast(user)) { - Value arg = forUser.getLoopBody().getArgument( - use.getOperandNumber() - forUser.getNumControlOperands() + - /*iv value*/ 1); - uses.push_back(arg.getUses()); - continue; - } - // Follow the use yield as long as it doesn't escape the original - // region. - scf::YieldOp yieldUser = dyn_cast(user); - if (yieldUser && write.transferWriteOp->getParentOp()->isAncestor( - yieldUser->getParentOp())) { - Value ret = yieldUser->getParentOp()->getResult(use.getOperandNumber()); - uses.push_back(ret.getUses()); - continue; - } - auto read = dyn_cast(user); - if (!read || !vector::isDisjointTransferIndices( - cast(read.getOperation()), - cast( - write.transferWriteOp.getOperation()))) { - return true; - } - } - } - return false; -} - -/// Return the `forOp`-invariant HoistableWrite that produces `yieldOperand`. -/// Return the null HoistableWrite() if it is not comprised of a -/// vector.transfer_write + optional insert_slice or if any of the indexings -/// is `forOp`-dependent. -static HoistableWrite -getLoopInvariantTransferWriteOpDefining(scf::ForOp forOp, - OpOperand &yieldOperand) { - Value v = yieldOperand.get(); - if (auto write = v.getDefiningOp()) { - // Indexing must not depend on `forOp`. - for (Value operand : write.getIndices()) - if (!forOp.isDefinedOutsideOfLoop(operand)) - return HoistableWrite(); - - return HoistableWrite{write, nullptr}; - } - - if (auto insertSliceOp = v.getDefiningOp()) { - // Inserted slice must come from vector.transfer_write. - auto write = - insertSliceOp.getSource().getDefiningOp(); - if (!write) - return HoistableWrite(); - - // Tensor inserted into must be a BBArg at position matching yieldOperand's. - auto bbArg = insertSliceOp.getDest().dyn_cast(); - if (!bbArg || bbArg.getOwner()->getParentOp() != forOp || - bbArg.getArgNumber() != /*num iv=*/1 + yieldOperand.getOperandNumber()) - return HoistableWrite(); - - // Indexing inserted into must not depend on `forOp`. - for (Value operand : insertSliceOp->getOperands().drop_front( - tensor::InsertSliceOp::getOffsetSizeAndStrideStartOperandIndex())) - if (!forOp.isDefinedOutsideOfLoop(operand)) - return HoistableWrite(); - - return HoistableWrite{write, insertSliceOp}; - } - - return HoistableWrite(); -} - -/// Mechanical hoisting of a matching HoistableRead / HoistableWrite pair. -static void hoistReadWrite(HoistableRead read, HoistableWrite write, - BlockArgument tensorBBArg) { - scf::ForOp forOp = cast(tensorBBArg.getOwner()->getParentOp()); - assert(read.transferReadOp && write.transferWriteOp && - "expected transfer_read and transfer_write ops to be set"); - assert(((read.extractSliceOp && write.insertSliceOp) || - (!read.extractSliceOp && !write.insertSliceOp)) && - "expected matching extract_slice / insert_slice"); - LLVM_DEBUG(DBGS() << "In forOp:\n" - << *forOp.getOperation() - << "\nHoist: " << *read.transferReadOp.getOperation() - << "\nHoist: " << *write.transferWriteOp.getOperation() - << "\nInvolving: " << tensorBBArg << "\n"); - - // If a read slice is present, hoist it. - if (read.extractSliceOp) - forOp.moveOutOfLoop(read.extractSliceOp); - - // Hoist the transfer_read op. - forOp.moveOutOfLoop(read.transferReadOp); - - // TODO: don't hardcode /*numIvs=*/1. - assert(tensorBBArg.getArgNumber() >= /*numIvs=*/1); - unsigned initArgNumber = tensorBBArg.getArgNumber() - /*numIvs=*/1; - - // Update the source tensor. - if (read.extractSliceOp) - read.extractSliceOp.getSourceMutable().assign( - forOp.getInitArgs()[initArgNumber]); - else - read.transferReadOp.getSourceMutable().assign( - forOp.getInitArgs()[initArgNumber]); - - // Hoist write after. - if (write.insertSliceOp) - write.insertSliceOp->moveAfter(forOp); - write.transferWriteOp->moveAfter(forOp); - - // Update the yield. - auto yieldOp = cast(forOp.getRegion().front().getTerminator()); - if (write.insertSliceOp) - yieldOp->setOperand(initArgNumber, write.insertSliceOp.getDest()); - else - yieldOp->setOperand(initArgNumber, write.transferWriteOp.getSource()); - - // Rewrite `loop` with additional new yields. - OpBuilder b(read.transferReadOp); - NewYieldValueFn yieldFn = [&](OpBuilder &b, Location loc, - ArrayRef newBBArgs) { - return SmallVector{write.transferWriteOp.getVector()}; - }; - auto newForOp = replaceLoopWithNewYields( - b, forOp, read.transferReadOp.getVector(), yieldFn); - - // Transfer write has been hoisted, need to update the vector and tensor - // source. Replace the result of the loop to use the new tensor created - // outside the loop. - // Depending on whether a insert_slice is present or not, it carries the - // update on the tensor operands. - if (write.insertSliceOp) { - newForOp.getResult(initArgNumber) - .replaceAllUsesWith(write.insertSliceOp.getResult()); - write.transferWriteOp.getSourceMutable().assign( - read.extractSliceOp.getResult()); - write.insertSliceOp.getDestMutable().assign( - read.extractSliceOp.getSource()); - } else { - newForOp.getResult(initArgNumber) - .replaceAllUsesWith(write.transferWriteOp.getResult()); - write.transferWriteOp.getSourceMutable().assign( - newForOp.getResult(initArgNumber)); - } - - // Always update with the newly yield tensor and vector. - write.transferWriteOp.getVectorMutable().assign(newForOp.getResults().back()); -} - -// To hoist transfer op on tensor the logic can be significantly simplified -// compared to the case on buffer. The transformation follows this logic: -// 1. Look for transfer_write with a single use from ForOp yield -// 2. Check the uses of the matching block argument and look for a transfer_read -// with the same indices. -// 3. Check that all the other uses of the tensor argument are either disjoint -// tensor_read or transfer_write. For transfer_write uses recurse to make sure -// the new tensor has the same restrictions on its uses. -// 4. Hoist the tensor_read/tensor_write and update the tensor SSA links. -// After this transformation the scf.forOp may have unused arguments that can be -// remove by the canonicalization pass. void mlir::linalg::hoistRedundantVectorTransfersOnTensor(func::FuncOp func) { - bool changed = true; - while (changed) { - changed = false; - func.walk([&](scf::ForOp forOp) { - Operation *yield = forOp.getBody()->getTerminator(); - for (const auto &it : llvm::enumerate(forOp.getRegionIterArgs())) { - OpOperand &ret = yield->getOpOperand(it.index()); - HoistableWrite write = - getLoopInvariantTransferWriteOpDefining(forOp, ret); - if (!write.transferWriteOp || !write.transferWriteOp->hasOneUse()) - continue; - LLVM_DEBUG(dbgs() << "\n"; - DBGS() << "Candidate write for hoisting: " - << *write.transferWriteOp.getOperation() << "\n"); - if (write.insertSliceOp) - LLVM_DEBUG(DBGS() << "Candidate insert_slice for hoisting: " - << *write.insertSliceOp.getOperation() << "\n"); - if (llvm::any_of(write.transferWriteOp.getIndices(), - [&forOp](Value index) { - return !forOp.isDefinedOutsideOfLoop(index); - })) - continue; - // Find a read with the same type and indices. - HoistableRead matchingRead = - findMatchingTransferRead(write, it.value()); - // Make sure none of the other uses read the part of the tensor modified - // by the transfer_write. - if (!matchingRead.transferReadOp || - tensorChunkAccessedByUnknownOp(write, matchingRead, it.value())) - continue; - - LLVM_DEBUG(DBGS() << "Start hoisting\n"); - hoistReadWrite(matchingRead, write, it.value()); - changed = true; - forOp.erase(); - - // Need to interrupt and restart: erasing the loop messes up the walk. - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - // Apply canonicalization so the newForOp + yield folds immediately, thus - // cleaning up the IR and potentially enabling more hoisting. - if (changed) { - RewritePatternSet patterns(func->getContext()); - scf::ForOp::getCanonicalizationPatterns(patterns, func->getContext()); - (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); - } - } + IRRewriter rewriter(func->getContext()); + // TODO: walking in some reverse / inside-out order would be more efficient + // and would capture more cases. + func.walk([&](scf::ForOp forOp) { + hoistRedundantSubsetExtractInsert(rewriter, forOp); + }); } void mlir::linalg::hoistRedundantVectorTransfers(func::FuncOp func) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp new file mode 100644 index 000000000000..c0355a14d366 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp @@ -0,0 +1,553 @@ +//===- SubsetHoisting.cpp - Linalg hoisting transformations----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements functions concerned with hoisting invariant subset +// operations in the context of Linalg transformations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/Transforms/Hoisting.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" + +#define DEBUG_TYPE "subset-hoisting" + +#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") + +using namespace mlir; +using namespace mlir::linalg; + +/// Return true if the location of the subset defined by the op is invariant of +/// the loop iteration. +static bool +isSubsetLocationLoopInvariant(scf::ForOp forOp, + vector::TransferWriteOp transferWriteOp) { + for (Value operand : transferWriteOp.getIndices()) + if (!forOp.isDefinedOutsideOfLoop(operand)) + return false; + return true; +} + +/// Return true if the location of the subset defined by the op is invariant of +/// the loop iteration. +static bool isSubsetLocationLoopInvariant(scf::ForOp forOp, + tensor::InsertSliceOp insertSliceOp) { + for (Value operand : insertSliceOp->getOperands().drop_front( + tensor::InsertSliceOp::getOffsetSizeAndStrideStartOperandIndex())) + if (!forOp.isDefinedOutsideOfLoop(operand)) + return false; + return true; +} + +/// Given an `srcTensor` that is a block argument belong to a loop. +/// Greedily look for the first read that can be hoisted out of the loop (i.e. +/// that satisfied the conditions): +/// - The read is of type `tensor.extract_slice`. +/// - The read is one of the uses of `srcTensor`. +/// - The read is to the same subset that `tensor.insert_slice` writes. +// TODO: Unify implementations once the "bypassing behavior" is the same. +static FailureOr +findHoistableMatchingExtractSlice(RewriterBase &rewriter, + tensor::InsertSliceOp insertSliceOp, + BlockArgument srcTensor) { + assert(srcTensor.getType().isa() && "not a ranked tensor"); + + auto forOp = cast(srcTensor.getOwner()->getParentOp()); + + LLVM_DEBUG(DBGS() << "--find matching read for: " << insertSliceOp << "\n"; + DBGS() << "--amongst users of: " << srcTensor << "\n"); + + SmallVector users(srcTensor.getUsers()); + if (forOp.isDefinedOutsideOfLoop(insertSliceOp.getDest())) + llvm::append_range(users, insertSliceOp.getDest().getUsers()); + + for (Operation *user : users) { + LLVM_DEBUG(DBGS() << "----inspect user: " << *user << "\n"); + auto extractSliceOp = dyn_cast(user); + // Skip ops other than extract_slice with an exact matching of their tensor + // subset. + if (extractSliceOp) { + auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; }; + if (extractSliceOp.getResultType() != insertSliceOp.getSourceType() || + !extractSliceOp.isSameAs(insertSliceOp, isSame)) { + LLVM_DEBUG(DBGS() << "------not a matching extract_slice\n"; + DBGS() << *user << " vs " << *insertSliceOp << "\n"); + continue; + } + + // Skip insert_slice whose vector is defined within the loop: we need to + // hoist that definition first otherwise dominance violations trigger. + if (!extractSliceOp.getSource().isa() && + !forOp.isDefinedOutsideOfLoop(extractSliceOp.getSource())) { + LLVM_DEBUG(DBGS() << "------transfer_read vector is loop-dependent\n"); + continue; + } + return extractSliceOp; + } + + // TODO: Look through disjoint subsets, similar to vector.transfer_write + // and unify implementations. + } + + LLVM_DEBUG(DBGS() << "----no matching extract_slice"); + return failure(); +} + +/// Given an `srcTensor` that is a block argument belong to a loop. +/// Greedily look for the first read that can be hoisted out of the loop (i.e. +/// that satisfied the conditions): +/// - The read is of type `tensor.transfer_read`. +/// - The read is one of the uses of `srcTensor`. +/// - The read is to the same subset that `tensor.transfer_write` writes. +// TODO: Unify implementations once the "bypassing behavior" is the same. +static FailureOr +findHoistableMatchingTransferRead(RewriterBase &rewriter, + vector::TransferWriteOp transferWriteOp, + BlockArgument srcTensor) { + if (!srcTensor.getType().isa()) + return failure(); + + auto forOp = cast(srcTensor.getOwner()->getParentOp()); + + LLVM_DEBUG(DBGS() << "--find matching read for: " << transferWriteOp << "\n"; + DBGS() << "--amongst users of: " << srcTensor << "\n";); + + // vector.transfer_write is a bit peculiar: we look through dependencies + // to disjoint tensor subsets. This requires a while loop. + // TODO: Look through disjoint subsets for tensor.insert_slice and unify + // implementations. + SmallVector users(srcTensor.getUsers()); + // TODO: transferWriteOp.getSource is actually the destination tensor!! + if (forOp.isDefinedOutsideOfLoop(transferWriteOp.getSource())) + llvm::append_range(users, transferWriteOp.getSource().getUsers()); + while (!users.empty()) { + Operation *user = users.pop_back_val(); + LLVM_DEBUG(DBGS() << "----inspect user: " << *user << "\n"); + auto read = dyn_cast(user); + if (read) { + // Skip ops other than transfer_read with an exact matching subset. + if (read.getIndices() != transferWriteOp.getIndices() || + read.getVectorType() != transferWriteOp.getVectorType()) { + LLVM_DEBUG(DBGS() << "------not a transfer_read that matches the " + "transfer_write: " + << *user << "\n\t(vs " << *transferWriteOp << ")\n"); + continue; + } + + // transfer_read may be of a vector that is defined within the loop: we + // traverse it by virtue of bypassing disjoint subset operations rooted at + // a bbArg and yielding a matching yield. + if (!read.getSource().isa() && + !forOp.isDefinedOutsideOfLoop(read.getSource())) { + LLVM_DEBUG(DBGS() << "------transfer_read vector appears loop " + "dependent but will be tested for disjointness as " + "part of the bypass analysis\n"); + } + LLVM_DEBUG(DBGS() << "------found match\n"); + return read; + } + + // As an optimization, we look further through dependencies to disjoint + // tensor subsets. This creates more opportunities to find a matching read. + if (isa(user)) { + // If we find a write with disjoint indices append all its uses. + // TODO: Generalize areSubsetsDisjoint and allow other bypass than + // just vector.transfer_write - vector.transfer_write. + if (vector::isDisjointTransferIndices( + cast(user), + cast( + transferWriteOp.getOperation()))) { + LLVM_DEBUG(DBGS() << "----follow through disjoint write\n"); + users.append(user->getUsers().begin(), user->getUsers().end()); + } else { + LLVM_DEBUG(DBGS() << "----skip non-disjoint write\n"); + } + } + } + + LLVM_DEBUG(DBGS() << "--no matching transfer_read\n"); + return rewriter.notifyMatchFailure(transferWriteOp, + "no matching transfer_read"); +} + +/// Return the `vector.transfer_write` that produces `yieldOperand`, if: +/// - The write operates on tensors. +/// - All indices are defined outside of the loop. +/// Return failure otherwise. +/// +/// This is sufficient condition to hoist the `vector.transfer_write`; other +/// operands can always be yielded by the loop where needed. +// TODO: generalize beyond scf::ForOp. +// TODO: Unify implementations once the "bypassing behavior" is the same. +static FailureOr +getLoopInvariantTransferWriteDefining(RewriterBase &rewriter, scf::ForOp forOp, + BlockArgument bbArg, + OpOperand &yieldOperand) { + assert(bbArg.getArgNumber() == + forOp.getNumInductionVars() + yieldOperand.getOperandNumber() && + "bbArg and yieldOperand must match"); + assert(isa(yieldOperand.getOwner()) && "must be an scf.yield"); + + Value v = yieldOperand.get(); + auto transferWriteOp = v.getDefiningOp(); + if (!transferWriteOp) + return rewriter.notifyMatchFailure(v.getLoc(), "not a transfer_write"); + + if (transferWriteOp->getNumResults() == 0) { + return rewriter.notifyMatchFailure(v.getLoc(), + "unsupported transfer_write on buffers"); + } + + // We do not explicitly check that the destination is a BBarg that matches the + // yield operand as this would prevent us from bypassing other non-conflicting + // writes. + + // Indexing must not depend on `forOp`. + if (!isSubsetLocationLoopInvariant(forOp, transferWriteOp)) + return rewriter.notifyMatchFailure( + v.getLoc(), "transfer_write indexing is loop-dependent"); + + return transferWriteOp; +} + +/// Return the `tensor.insert_slice` that produces `yieldOperand`, if: +/// 1. Its destination tensor is a block argument of the `forOp`. +/// 2. The unique use of its result is a yield with operand number matching +/// the block argument. +/// 3. All indices are defined outside of the loop. +/// Return failure otherwise. +/// +/// This is sufficient condition to hoist the `tensor.insert_slice`; other +/// operands can always be yielded by the loop where needed. +/// Note: 1. + 2. ensure that the yield / iter_args cycle results in proper +/// semantics (i.e. no ping-ping between iter_args across iterations). +// TODO: generalize beyond scf::ForOp. +// TODO: Unify implementations once the "bypassing behavior" is the same. +static FailureOr +getLoopInvariantInsertSliceDefining(RewriterBase &rewriter, scf::ForOp forOp, + BlockArgument bbArg, + OpOperand &yieldOperand) { + assert(bbArg.getArgNumber() == + forOp.getNumInductionVars() + yieldOperand.getOperandNumber() && + "bbArg and yieldOperand must match"); + assert(isa(yieldOperand.getOwner()) && "must be an scf.yield"); + + Value v = yieldOperand.get(); + auto insertSliceOp = v.getDefiningOp(); + if (!insertSliceOp) + return rewriter.notifyMatchFailure(v.getLoc(), "not an insert_slice"); + + // Tensor inserted into must be a BBArg at position matching yield operand. + // TODO: In the future we should not perform this check if we want to bypass + // other non-conflicting writes. + if (bbArg != insertSliceOp.getDest()) + return rewriter.notifyMatchFailure(v.getLoc(), "not a matching bbarg"); + + // Indexing inserted into must not depend on `forOp`. + if (!isSubsetLocationLoopInvariant(forOp, insertSliceOp)) + return rewriter.notifyMatchFailure( + v.getLoc(), "insert_slice indexing is loop-dependent"); + + return insertSliceOp; +} + +/// Check if the chunk of data inserted by the `writeOp` is read by any other +/// op than the candidateReadOp. This conflicting operation prevents hoisting, +/// return it or nullptr if none is found. +// TODO: Generalize subset disjunction analysis/interface. +// TODO: Support more subset op types. +static Operation *isTensorChunkAccessedByUnknownOp(Operation *writeOp, + Operation *candidateReadOp, + BlockArgument tensorArg) { + // Make sure none of the other uses read the part of the tensor modified + // by the transfer_write. + llvm::SmallVector uses; + uses.push_back(tensorArg.getUses()); + while (!uses.empty()) { + for (OpOperand &use : uses.pop_back_val()) { + Operation *user = use.getOwner(); + // Skip the candidate use, only inspect the "other" uses. + if (user == candidateReadOp || user == writeOp) + continue; + + // TODO: Consider all transitive uses through + // extract_slice/insert_slice. Atm we just bail because a stronger + // analysis is needed for these cases. + if (isa(user)) + return user; + + // Consider all transitive uses through a vector.transfer_write. + if (isa(writeOp)) { + if (auto writeUser = dyn_cast(user)) { + uses.push_back(writeUser->getResult(0).getUses()); + continue; + } + } + + // Consider all nested uses through an scf::ForOp. We may have + // pass-through tensor arguments left from previous level of + // hoisting. + if (auto forUser = dyn_cast(user)) { + Value arg = forUser.getLoopBody().getArgument( + use.getOperandNumber() - forUser.getNumControlOperands() + + /*iv value*/ 1); + uses.push_back(arg.getUses()); + continue; + } + + // Follow the use yield, only if it doesn't escape the original region. + scf::YieldOp yieldUser = dyn_cast(user); + if (yieldUser && + writeOp->getParentOp()->isAncestor(yieldUser->getParentOp())) { + Value ret = yieldUser->getParentOp()->getResult(use.getOperandNumber()); + uses.push_back(ret.getUses()); + continue; + } + + // If the write is a vector::TransferWriteOp, it may have been bypassed + // and we need to check subset disjunction + if (isa(writeOp)) { + auto read = dyn_cast(user); + if (!read || !vector::isDisjointTransferIndices( + cast(read.getOperation()), + cast(writeOp))) { + return user; + } + } + } + } + return nullptr; +} + +/// Mechanical hoisting of a matching read / write pair. +/// Return the newly created scf::ForOp with an extra yields. +// TODO: Unify implementations once the "bypassing behavior" is the same. +static scf::ForOp hoistTransferReadWrite( + RewriterBase &rewriter, vector::TransferReadOp transferReadOp, + vector::TransferWriteOp transferWriteOp, BlockArgument tensorBBArg) { + scf::ForOp forOp = cast(tensorBBArg.getOwner()->getParentOp()); + LLVM_DEBUG(DBGS() << "--Start hoisting\n"; + DBGS() << "--Hoist read : " << transferReadOp << "\n"; + DBGS() << "--Hoist write: " << transferWriteOp << "\n"; + DBGS() << "--Involving : " << tensorBBArg << "\n"); + + // TODO: don't hardcode /*numIvs=*/1. + assert(tensorBBArg.getArgNumber() >= /*numIvs=*/1); + int64_t initArgNumber = tensorBBArg.getArgNumber() - /*numIvs=*/1; + + // 1. Hoist the read op. Thanks to our previous checks we know this will not + // trigger dominance violations once BBArgs are updated. + // TODO: should the rewriter ever want to track this move ? + transferReadOp->moveBefore(forOp); + if (!forOp.isDefinedOutsideOfLoop(transferReadOp.getSource())) { + rewriter.startRootUpdate(transferReadOp); + transferReadOp.getSourceMutable().assign( + forOp.getInitArgs()[initArgNumber]); + rewriter.finalizeRootUpdate(transferReadOp); + } + + // 2. Rewrite `loop` with an additional yield. This is the quantity that is + // computed iteratively but whose storage has become loop-invariant. + NewYieldValueFn yieldFn = [&](OpBuilder &b, Location loc, + ArrayRef newBBArgs) { + return SmallVector{transferWriteOp.getVector()}; + }; + auto newForOp = replaceLoopWithNewYields( + rewriter, forOp, {transferReadOp.getVector()}, yieldFn); + rewriter.eraseOp(forOp); + + // 3. Update the yield. Invariant: initArgNumber is the destination tensor. + auto yieldOp = + cast(newForOp.getRegion().front().getTerminator()); + // TODO: transferWriteOp.getSource is actually the destination tensor!! + rewriter.startRootUpdate(yieldOp); + yieldOp->setOperand(initArgNumber, transferWriteOp.getSource()); + rewriter.finalizeRootUpdate(yieldOp); + + // 4. Hoist write after and make uses of newForOp.getResult(initArgNumber) + // flow through it. + // TODO: should the rewriter ever want to track this move ? + transferWriteOp->moveAfter(newForOp); + rewriter.startRootUpdate(transferWriteOp); + transferWriteOp.getVectorMutable().assign(newForOp.getResults().back()); + // TODO: transferWriteOp.getSource is actually the destination tensor!! + transferWriteOp.getSourceMutable().assign(newForOp.getResult(initArgNumber)); + rewriter.finalizeRootUpdate(transferWriteOp); + rewriter.replaceAllUsesExcept(newForOp.getResult(initArgNumber), + transferWriteOp.getResult(), transferWriteOp); + return newForOp; +} + +/// Mechanical hoisting of a matching read / write pair. +/// Return the newly created scf::ForOp with an extra yields. +// TODO: Unify implementations once the "bypassing behavior" is the same. +static scf::ForOp hoistExtractInsertSlice(RewriterBase &rewriter, + tensor::ExtractSliceOp extractSliceOp, + tensor::InsertSliceOp insertSliceOp, + BlockArgument tensorBBArg) { + scf::ForOp forOp = cast(tensorBBArg.getOwner()->getParentOp()); + LLVM_DEBUG(DBGS() << "--Start hoisting\n"; + DBGS() << "--Hoist read : " << extractSliceOp << "\n"; + DBGS() << "--Hoist write: " << insertSliceOp << "\n"; + DBGS() << "--Involving : " << tensorBBArg << "\n"); + + // TODO: don't hardcode /*numIvs=*/1. + assert(tensorBBArg.getArgNumber() >= /*numIvs=*/1); + int64_t initArgNumber = tensorBBArg.getArgNumber() - /*numIvs=*/1; + + // 1. Hoist the read op. Thanks to our previous checks we know this will not + // trigger dominance violations once BBArgs are updated. + // TODO: should the rewriter ever want to track this move ? + extractSliceOp->moveBefore(forOp); + if (!forOp.isDefinedOutsideOfLoop(extractSliceOp.getSource())) { + assert(extractSliceOp.getSource() == tensorBBArg && + "extractSlice source not defined above must be the tracked bbArg"); + rewriter.startRootUpdate(extractSliceOp); + extractSliceOp.getSourceMutable().assign( + forOp.getInitArgs()[initArgNumber]); + rewriter.finalizeRootUpdate(extractSliceOp); + } + + // 2. Rewrite `loop` with an additional yield. This is the quantity that is + // computed iteratively but whose storage has become loop-invariant. + NewYieldValueFn yieldFn = [&](OpBuilder &b, Location loc, + ArrayRef newBBArgs) { + return SmallVector{insertSliceOp.getSource()}; + }; + auto newForOp = replaceLoopWithNewYields(rewriter, forOp, + extractSliceOp.getResult(), yieldFn); + rewriter.eraseOp(forOp); + + // 3. Update the yield. Invariant: initArgNumber is the destination tensor. + auto yieldOp = + cast(newForOp.getRegion().front().getTerminator()); + // TODO: should the rewriter ever want to track this ? + rewriter.startRootUpdate(yieldOp); + yieldOp->setOperand(initArgNumber, insertSliceOp.getDest()); + rewriter.finalizeRootUpdate(yieldOp); + + // 4. Hoist write after and make uses of newForOp.getResult(initArgNumber) + // flow through it. + // TODO: should the rewriter ever want to track this move ? + insertSliceOp->moveAfter(newForOp); + rewriter.startRootUpdate(insertSliceOp); + insertSliceOp.getSourceMutable().assign(newForOp.getResults().back()); + insertSliceOp.getDestMutable().assign(newForOp.getResult(initArgNumber)); + rewriter.finalizeRootUpdate(insertSliceOp); + rewriter.replaceAllUsesExcept(newForOp.getResult(initArgNumber), + insertSliceOp.getResult(), insertSliceOp); + return newForOp; +} + +/// Greedily hoist redundant subset extract/insert operations on tensors +/// outside `forOp`. +/// Return the unmodified `forOp` if no hoisting occurred. +/// Return a new scf::ForOp if hoisting on tensors occurred. +scf::ForOp +mlir::linalg::hoistRedundantSubsetExtractInsert(RewriterBase &rewriter, + scf::ForOp forOp) { + LLVM_DEBUG(DBGS() << "Enter hoistRedundantSubsetExtractInsert scf.for\n"); + Operation *yield = forOp.getBody()->getTerminator(); + + LLVM_DEBUG(DBGS() << "\n"; DBGS() << "Consider " << forOp << "\n"); + + scf::ForOp newForOp = forOp; + do { + forOp = newForOp; + for (const auto &it : llvm::enumerate(forOp.getRegionIterArgs())) { + LLVM_DEBUG(DBGS() << "Consider " << it.value() << "\n"); + + // 1. Find a loop invariant subset write yielding `ret` that we can + // consider for hoisting. + // TODO: TypeSwitch when we add more cases. + OpOperand &ret = yield->getOpOperand(it.index()); + FailureOr transferWriteOp = + getLoopInvariantTransferWriteDefining(rewriter, forOp, it.value(), + ret); + FailureOr insertSliceOp = + getLoopInvariantInsertSliceDefining(rewriter, forOp, it.value(), ret); + if (failed(transferWriteOp) && failed(insertSliceOp)) { + LLVM_DEBUG(DBGS() << "no loop invariant write defining iter_args " + << it.value() << "\n"); + continue; + } + + Operation *writeOp = succeeded(transferWriteOp) + ? transferWriteOp->getOperation() + : insertSliceOp->getOperation(); + + // 2. Only accept writes with a single use (i.e. the yield). + if (!writeOp->hasOneUse()) { + LLVM_DEBUG(DBGS() << "write with more than 1 use " << *writeOp << "\n"); + continue; + } + + LLVM_DEBUG(DBGS() << "Write to hoist: " << *writeOp << "\n"); + + // 3. Find a matching read that can also be hoisted. + Operation *matchingReadOp = nullptr; + // TODO: TypeSwitch. + if (succeeded(transferWriteOp)) { + auto maybeTransferRead = findHoistableMatchingTransferRead( + rewriter, *transferWriteOp, it.value()); + if (succeeded(maybeTransferRead)) + matchingReadOp = maybeTransferRead->getOperation(); + } else if (succeeded(insertSliceOp)) { + auto maybeExtractSlice = findHoistableMatchingExtractSlice( + rewriter, *insertSliceOp, it.value()); + if (succeeded(maybeExtractSlice)) + matchingReadOp = maybeExtractSlice->getOperation(); + } else { + llvm_unreachable("unexpected case"); + } + if (!matchingReadOp) { + LLVM_DEBUG(DBGS() << "No matching read\n"); + continue; + } + + // 4. Make sure no other use reads the part of the modified tensor. + // This is necessary to guard against hazards when non-conflicting subset + // ops are bypassed. + Operation *maybeUnknownOp = + isTensorChunkAccessedByUnknownOp(writeOp, matchingReadOp, it.value()); + if (maybeUnknownOp) { + LLVM_DEBUG(DBGS() << "Tensor chunk accessed by unknown op, skip: " + << *maybeUnknownOp << "\n"); + continue; + } + + // 5. Perform the actual mechanical hoisting. + // TODO: TypeSwitch. + LLVM_DEBUG(DBGS() << "Read to hoist: " << *matchingReadOp << "\n"); + if (succeeded(transferWriteOp)) { + newForOp = hoistTransferReadWrite( + rewriter, cast(matchingReadOp), + *transferWriteOp, it.value()); + } else if (succeeded(insertSliceOp)) { + newForOp = hoistExtractInsertSlice( + rewriter, cast(matchingReadOp), + *insertSliceOp, it.value()); + } else { + llvm_unreachable("unexpected case"); + } + break; + } + } while (forOp != newForOp); + + return newForOp; +} diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp index 294dc810507b..45ea541660fb 100644 --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -136,6 +136,16 @@ bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2) { return v1 && v1 == v2; } +bool isEqualConstantIntOrValueArray(ArrayRef ofrs1, + ArrayRef ofrs2) { + if (ofrs1.size() != ofrs2.size()) + return false; + for (auto [ofr1, ofr2] : llvm::zip_equal(ofrs1, ofrs2)) + if (!isEqualConstantIntOrValue(ofr1, ofr2)) + return false; + return true; +} + /// Helper function to convert a vector of `OpFoldResult`s into a vector of /// `Value`s. For each `OpFoldResult` in `valueOrAttrVec` return the fold result /// if it casts to a `Value` or create an index-type constant if it casts to diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir index 8830a4f42721..aeecb8cf95f8 100644 --- a/mlir/test/Dialect/Linalg/hoisting.mlir +++ b/mlir/test/Dialect/Linalg/hoisting.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -test-transform-dialect-interpreter --split-input-file --allow-unregistered-dialect %s | FileCheck %s +// RUN: mlir-opt -test-transform-dialect-interpreter -canonicalize --split-input-file --allow-unregistered-dialect %s | FileCheck %s // CHECK-LABEL: func @hoist_vector_transfer_pairs( // CHECK-SAME: %[[MEMREF0:[a-zA-Z0-9]*]]: memref, @@ -29,7 +29,7 @@ func.func @hoist_vector_transfer_pairs( // CHECK: vector.transfer_read %{{.*}} : memref, vector<5xf32> // CHECK: "some_use"(%{{.*}}) : (vector<1xf32>) -> vector<1xf32> // CHECK: "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32> -// CHECK: "some_use"(%[[MEMREF2]]) : (memref) -> vector<3xf32> +// CHECK: "some_use"(%[[MEMREF2]], %{{.*}}) : (memref, vector<3xf32>) -> vector<3xf32> // CHECK: "some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32> // CHECK: "some_use"(%{{.*}}) : (vector<5xf32>) -> vector<5xf32> // CHECK: vector.transfer_write %{{.*}} : vector<3xf32>, memref @@ -56,7 +56,7 @@ func.func @hoist_vector_transfer_pairs( "some_crippling_use"(%memref5) : (memref) -> () %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32> %u1 = "some_use"(%r1) : (vector<2xf32>) -> vector<2xf32> - %u2 = "some_use"(%memref2) : (memref) -> vector<3xf32> + %u2 = "some_use"(%memref2, %r2) : (memref, vector<3xf32>) -> vector<3xf32> %u3 = "some_use"(%r3) : (vector<4xf32>) -> vector<4xf32> %u4 = "some_use"(%r4) : (vector<5xf32>) -> vector<5xf32> %u5 = "some_use"(%r5) : (vector<6xf32>) -> vector<6xf32> @@ -173,6 +173,51 @@ transform.sequence failures(propagate) { // ----- +// CHECK-LABEL: func @hoist_vector_transfer_pairs_in_affine_loops( +// CHECK-SAME: %[[MEMREF0:[a-zA-Z0-9]+]]: memref<64x64xi32>, +// CHECK-SAME: %[[MEMREF1:[a-zA-Z0-9]+]]: memref<64x64xi32>, +// CHECK-SAME: %[[MEMREF2:[a-zA-Z0-9]+]]: memref<64x64xi32>) { +// CHECK: %[[C0:.*]] = arith.constant 0 : i32 +// CHECK: affine.for %[[I:.*]] = 0 to 64 { +// CHECK: affine.for %[[J:.*]] = 0 to 64 step 16 { +// CHECK: %[[R0:.*]] = vector.transfer_read %[[MEMREF2]][%[[I]], %[[J]]], %[[C0]] : memref<64x64xi32>, vector<16xi32> +// CHECK: %[[R:.*]] = affine.for %[[K:.*]] = 0 to 64 iter_args(%[[ACC:.*]] = %[[R0]]) -> (vector<16xi32>) { +// CHECK: %[[AV:.*]] = vector.transfer_read %[[MEMREF0]][%[[I]], %[[K]]], %[[C0]] {{.*}}: memref<64x64xi32>, vector<16xi32> +// CHECK: %[[BV:.*]] = vector.transfer_read %[[MEMREF1]][%[[K]], %[[J]]], %[[C0]] {{.*}}: memref<64x64xi32>, vector<16xi32> +// CHECK: %[[T0:.*]] = arith.muli %[[AV]], %[[BV]] : vector<16xi32> +// CHECK: %[[T1:.*]] = arith.addi %[[ACC]], %[[T0]] : vector<16xi32> +// CHECK: affine.yield %[[T1]] : vector<16xi32> +// CHECK: } +// CHECK: vector.transfer_write %[[R]], %[[MEMREF2]][%[[I]], %[[J]]] : vector<16xi32>, memref<64x64xi32> +// CHECK: } +// CHECK: } +func.func @hoist_vector_transfer_pairs_in_affine_loops(%memref0: memref<64x64xi32>, %memref1: memref<64x64xi32>, %memref2: memref<64x64xi32>) { + %c0_i32 = arith.constant 0 : i32 + affine.for %arg3 = 0 to 64 { + affine.for %arg4 = 0 to 64 step 16 { + affine.for %arg5 = 0 to 64 { + %0 = vector.transfer_read %memref0[%arg3, %arg5], %c0_i32 {permutation_map = affine_map<(d0, d1) -> (0)>} : memref<64x64xi32>, vector<16xi32> + %1 = vector.transfer_read %memref1[%arg5, %arg4], %c0_i32 : memref<64x64xi32>, vector<16xi32> + %2 = vector.transfer_read %memref2[%arg3, %arg4], %c0_i32 : memref<64x64xi32>, vector<16xi32> + %3 = arith.muli %0, %1 : vector<16xi32> + %4 = arith.addi %2, %3 : vector<16xi32> + vector.transfer_write %4, %memref2[%arg3, %arg4] : vector<16xi32>, memref<64x64xi32> + } + } + } + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["func.func"]} in %arg1 + : (!pdl.operation) -> !pdl.operation + transform.structured.hoist_redundant_vector_transfers %0 + : (!pdl.operation) -> !pdl.operation +} + +// ----- + // CHECK-LABEL: func @hoist_vector_transfer_pairs_tensor func.func @hoist_vector_transfer_pairs_tensor( %tensor0: tensor, %tensor1: tensor, %tensor2: tensor, @@ -256,7 +301,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation - transform.structured.hoist_redundant_vector_transfers %0 + transform.structured.hoist_redundant_tensor_subsets %0 : (!pdl.operation) -> !pdl.operation } @@ -351,7 +396,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation - transform.structured.hoist_redundant_vector_transfers %0 + transform.structured.hoist_redundant_tensor_subsets %0 : (!pdl.operation) -> !pdl.operation } @@ -468,7 +513,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation - transform.structured.hoist_redundant_vector_transfers %0 + transform.structured.hoist_redundant_tensor_subsets %0 : (!pdl.operation) -> !pdl.operation } @@ -501,6 +546,8 @@ func.func @hoist_vector_transfer_write_pairs_disjoint_tensor( %r00 = vector.transfer_read %arg5[%c0, %c0], %cst: tensor, vector<2xf32> %u00 = "some_use"(%r00) : (vector<2xf32>) -> vector<2xf32> %w10 = vector.transfer_write %u00, %arg5[%c0, %c0] : vector<2xf32>, tensor + + // Hoist by properly bypassing the disjoint write %w10. %r01 = vector.transfer_read %w10[%c0, %c3], %cst: tensor, vector<2xf32> %u01 = "some_use"(%r01) : (vector<2xf32>) -> vector<2xf32> %w11 = vector.transfer_write %u01, %w10[%c0, %c3] : vector<2xf32>, tensor @@ -513,51 +560,119 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation - transform.structured.hoist_redundant_vector_transfers %0 + transform.structured.hoist_redundant_tensor_subsets %0 : (!pdl.operation) -> !pdl.operation } // ----- -// CHECK-LABEL: func @hoist_vector_transfer_pairs_in_affine_loops( -// CHECK-SAME: %[[MEMREF0:[a-zA-Z0-9]+]]: memref<64x64xi32>, -// CHECK-SAME: %[[MEMREF1:[a-zA-Z0-9]+]]: memref<64x64xi32>, -// CHECK-SAME: %[[MEMREF2:[a-zA-Z0-9]+]]: memref<64x64xi32>) { -// CHECK: %[[C0:.*]] = arith.constant 0 : i32 -// CHECK: affine.for %[[I:.*]] = 0 to 64 { -// CHECK: affine.for %[[J:.*]] = 0 to 64 step 16 { -// CHECK: %[[R0:.*]] = vector.transfer_read %[[MEMREF2]][%[[I]], %[[J]]], %[[C0]] : memref<64x64xi32>, vector<16xi32> -// CHECK: %[[R:.*]] = affine.for %[[K:.*]] = 0 to 64 iter_args(%[[ACC:.*]] = %[[R0]]) -> (vector<16xi32>) { -// CHECK: %[[AV:.*]] = vector.transfer_read %[[MEMREF0]][%[[I]], %[[K]]], %[[C0]] {{.*}}: memref<64x64xi32>, vector<16xi32> -// CHECK: %[[BV:.*]] = vector.transfer_read %[[MEMREF1]][%[[K]], %[[J]]], %[[C0]] {{.*}}: memref<64x64xi32>, vector<16xi32> -// CHECK: %[[T0:.*]] = arith.muli %[[AV]], %[[BV]] : vector<16xi32> -// CHECK: %[[T1:.*]] = arith.addi %[[ACC]], %[[T0]] : vector<16xi32> -// CHECK: affine.yield %[[T1]] : vector<16xi32> -// CHECK: } -// CHECK: vector.transfer_write %[[R]], %[[MEMREF2]][%[[I]], %[[J]]] : vector<16xi32>, memref<64x64xi32> -// CHECK: } -// CHECK: } -func.func @hoist_vector_transfer_pairs_in_affine_loops(%memref0: memref<64x64xi32>, %memref1: memref<64x64xi32>, %memref2: memref<64x64xi32>) { - %c0_i32 = arith.constant 0 : i32 - affine.for %arg3 = 0 to 64 { - affine.for %arg4 = 0 to 64 step 16 { - affine.for %arg5 = 0 to 64 { - %0 = vector.transfer_read %memref0[%arg3, %arg5], %c0_i32 {permutation_map = affine_map<(d0, d1) -> (0)>} : memref<64x64xi32>, vector<16xi32> - %1 = vector.transfer_read %memref1[%arg5, %arg4], %c0_i32 : memref<64x64xi32>, vector<16xi32> - %2 = vector.transfer_read %memref2[%arg3, %arg4], %c0_i32 : memref<64x64xi32>, vector<16xi32> - %3 = arith.muli %0, %1 : vector<16xi32> - %4 = arith.addi %2, %3 : vector<16xi32> - vector.transfer_write %4, %memref2[%arg3, %arg4] : vector<16xi32>, memref<64x64xi32> - } +// CHECK-LABEL: func @hoist_vector_transfer_pairs_tensor_and_slices_static_large_tensor +// CHECK-SAME: %[[TENSOR0:[a-zA-Z0-9]*]]: tensor<100x100xf32>, +// CHECK-SAME: %[[TENSOR1:[a-zA-Z0-9]*]]: tensor<200x200xf32>, +// CHECK-SAME: %[[TENSOR2:[a-zA-Z0-9]*]]: tensor<300x300xf32> +func.func @hoist_vector_transfer_pairs_tensor_and_slices_static_large_tensor( + %tensor0: tensor<100x100xf32>, %tensor1: tensor<200x200xf32>, %tensor2: tensor<300x300xf32>, + %val: index, %lb : index, %ub : index, %step: index) -> + ( + tensor<100x100xf32>, tensor<200x200xf32>, tensor<300x300xf32> + ) { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.0 : f32 + + // CHECK: scf.for %[[I:.*]] = {{.*}} iter_args( + // CHECK-SAME: %[[TENSOR0_ARG:[0-9a-zA-Z]+]] = %[[TENSOR0]], + // CHECK-SAME: %[[TENSOR1_ARG:[0-9a-zA-Z]+]] = %[[TENSOR1]], + // CHECK-SAME: %[[TENSOR2_ARG:[0-9a-zA-Z]+]] = %[[TENSOR2]] + // CHECK-SAME: ) -> + // CHECK-SAME: (tensor<100x100xf32>, tensor<200x200xf32>, tensor<300x300xf32> + %0:3 = scf.for %i = %lb to %ub step %step + iter_args(%arg0 = %tensor0, %arg1 = %tensor1, %arg2 = %tensor2) + -> (tensor<100x100xf32>, tensor<200x200xf32>, tensor<300x300xf32>) { + + // Hoisted + // CHECK: %[[ST0:.*]] = tensor.extract_slice %[[TENSOR0_ARG]][%[[I]], %[[I]]]{{.*}}: tensor<100x100xf32> to tensor + // CHECK: %[[V0:.*]] = vector.transfer_read %[[ST0]]{{.*}} : tensor, vector<1xf32> + + // CHECK: %[[R:.*]]:3 = scf.for %[[J:.*]] = {{.*}} iter_args( + // CHECK-SAME: %[[TENSOR1_ARG_L2:[0-9a-zA-Z]+]] = %[[TENSOR1_ARG]] + // CHECK-SAME: %[[TENSOR2_ARG_L2:[0-9a-zA-Z]+]] = %[[TENSOR2_ARG]] + // CHECK-SAME: %[[V0_ARG_L2:[0-9a-zA-Z]+]] = %[[V0]] + // CHECK-SAME: ) -> + // CHECK-SAME: (tensor<200x200xf32>, tensor<300x300xf32>, vector<1xf32> + %1:3 = scf.for %j = %lb to %ub step %step + iter_args(%arg6 = %arg0, %arg7 = %arg1, %arg8 = %arg2) + -> (tensor<100x100xf32>, tensor<200x200xf32>, tensor<300x300xf32>) { + // Hoists. + %st0 = tensor.extract_slice %arg6[%i, %i][%step, %step][1, 1] : tensor<100x100xf32> to tensor + %r0 = vector.transfer_read %st0[%c0, %c0], %cst: tensor, vector<1xf32> + + // CHECK: %[[ST1:.*]] = tensor.extract_slice %[[TENSOR1_ARG_L2]][%[[J]],{{.*}}: tensor<200x200xf32> to tensor + // CHECK: %[[V1:.*]] = vector.transfer_read %[[ST1]]{{.*}} : tensor, vector<2xf32> + // Does not hoist (slice depends on %j) + %st1 = tensor.extract_slice %arg7[%j, %c0][%step, %step][1, 1] : tensor<200x200xf32> to tensor + %r1 = vector.transfer_read %st1[%c0, %c0], %cst: tensor, vector<2xf32> + + // CHECK: %[[ST2:.*]] = tensor.extract_slice %[[TENSOR2_ARG_L2]][%[[I]],{{.*}}: tensor<300x300xf32> to tensor + // CHECK: %[[V2:.*]] = vector.transfer_read %[[ST2]]{{.*}} : tensor, vector<3xf32> + // Does not hoist, 2 slice %arg8. + %st2 = tensor.extract_slice %arg8[%i, %c0][%step, %step][1, 1] : tensor<300x300xf32> to tensor + %r2 = vector.transfer_read %st2[%c0, %c0], %cst: tensor, vector<3xf32> + + // CHECK: %[[U0:.*]] = "some_use"(%[[V0_ARG_L2]]) : (vector<1xf32>) -> vector<1xf32> + // CHECK: %[[U1:.*]] = "some_use"(%[[V1]]) : (vector<2xf32>) -> vector<2xf32> + // CHECK: %[[U2:.*]] = "some_use"(%[[V2]]) : (vector<3xf32>) -> vector<3xf32> + %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32> + %u1 = "some_use"(%r1) : (vector<2xf32>) -> vector<2xf32> + %u2 = "some_use"(%r2) : (vector<3xf32>) -> vector<3xf32> + + // Hoists + %w0 = vector.transfer_write %u0, %st0[%c0, %c0] : vector<1xf32>, tensor + + // CHECK-DAG: %[[STI1:.*]] = vector.transfer_write %[[U1]], %{{.*}} : vector<2xf32>, tensor + // Does not hoist (associated slice depends on %j). + %w1 = vector.transfer_write %u1, %st1[%i, %i] : vector<2xf32>, tensor + + // CHECK-DAG: %[[STI2:.*]] = vector.transfer_write %[[U2]], %{{.*}} : vector<3xf32>, tensor + // Does not hoist, 2 slice / insert_slice for %arg8. + %w2 = vector.transfer_write %u2, %st2[%c0, %c0] : vector<3xf32>, tensor + + // Hoists. + %sti0 = tensor.insert_slice %w0 into %arg6[%i, %i][%step, %step][1, 1] : tensor into tensor<100x100xf32> + + // CHECK-DAG: tensor.insert_slice %[[STI1]] into %[[TENSOR1_ARG_L2]][%[[J]],{{.*}}: tensor into tensor<200x200xf32> + // Does not hoist (depends on %j). + %sti1 = tensor.insert_slice %w1 into %arg7[%j, %c0][%step, %step][1, 1] : tensor into tensor<200x200xf32> + + // CHECK-DAG: tensor.insert_slice %[[STI2]] into %[[TENSOR2_ARG_L2]][%[[I]],{{.*}}: tensor into tensor<300x300xf32> + // Does not hoist, 2 slice / insert_slice for %arg8. + %sti2 = tensor.insert_slice %w2 into %arg8[%i, %c0][%step, %step][1, 1] : tensor into tensor<300x300xf32> + // Extract with a different stride to make sure we cannot fold this extract with the above insert. + %st22 = tensor.extract_slice %sti2[%i, %c0][%step, %step][2, 1] : tensor<300x300xf32> to tensor + %sti22 = tensor.insert_slice %st22 into %arg8[%i, %c0][%step, %step][1, 1] : tensor into tensor<300x300xf32> + + // CHECK: scf.yield {{.*}} : tensor<200x200xf32>, tensor<300x300xf32>, vector<1xf32> + // CHECK: } + scf.yield %sti0, %sti1, %sti22: + tensor<100x100xf32>, tensor<200x200xf32>, tensor<300x300xf32> } + + // Hoisted + // CHECK: %[[STI0:.*]] = vector.transfer_write %[[R]]#2, %[[ST0]]{{.*}} : vector<1xf32>, tensor + // CHECK: tensor.insert_slice %[[STI0]] into %[[TENSOR0_ARG]][%[[I]], %[[I]]]{{.*}} : tensor into tensor<100x100xf32> + + // CHECK: scf.yield {{.*}} : tensor<100x100xf32>, tensor<200x200xf32>, tensor<300x300xf32> + scf.yield %1#0, %1#1, %1#2 : + tensor<100x100xf32>, tensor<200x200xf32>, tensor<300x300xf32> + + // CHECK: } } - return + return %0#0, %0#1, %0#2 : tensor<100x100xf32>, tensor<200x200xf32>, tensor<300x300xf32> } transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation - transform.structured.hoist_redundant_vector_transfers %0 + transform.structured.hoist_redundant_tensor_subsets %0 : (!pdl.operation) -> !pdl.operation } -- 2.34.1