From 1de0f97fff7b7f5fae21374e77d35c5c311c9f39 Mon Sep 17 00:00:00 2001 From: Andy Davis Date: Wed, 29 May 2019 14:02:14 -0700 Subject: [PATCH] LoopFusionUtils CL 2/n: Factor out and generalize slice union computation. *) Factors slice union computation out of LoopFusion into Analysis/Utils (where other iteration slice utilities exist). *) Generalizes slice union computation to take the union of slices computed on all loads/stores pairs between source and destination loop nests. *) Fixes a bug in FlatAffineConstraints::addSliceBounds where redundant constraints were added. *) Takes care of a TODO to expose FlatAffineConstraints::mergeAndAlignIds as a public method. -- PiperOrigin-RevId: 250561529 --- mlir/include/mlir/Analysis/AffineStructures.h | 19 ++++ mlir/include/mlir/Analysis/Utils.h | 8 ++ mlir/lib/Analysis/AffineStructures.cpp | 24 +++-- mlir/lib/Analysis/Utils.cpp | 148 ++++++++++++++++++++++++++ mlir/lib/Transforms/LoopFusion.cpp | 86 ++------------- mlir/lib/Transforms/TestLoopFusion.cpp | 6 +- mlir/lib/Transforms/Utils/LoopFusionUtils.cpp | 57 ++++++++-- 7 files changed, 250 insertions(+), 98 deletions(-) diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h index 1cff429..aadace0 100644 --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -541,6 +541,25 @@ public: /// <= 15}, output = {0 <= d0 <= 6, 1 <= d1 <= 15}. LogicalResult unionBoundingBox(const FlatAffineConstraints &other); + /// Returns 'true' if this constraint system and 'other' are in the same + /// space, i.e., if they are associated with the same set of identifiers, + /// appearing in the same order. Returns 'false' otherwise. + bool areIdsAlignedWithOther(const FlatAffineConstraints &other); + + /// Merge and align the identifiers of 'this' and 'other' starting at + /// 'offset', so that both constraint systems get the union of the contained + /// identifiers that is dimension-wise and symbol-wise unique; both + /// constraint systems are updated so that they have the union of all + /// identifiers, with this's original identifiers appearing first followed by + /// any of other's identifiers that didn't appear in 'this'. Local + /// identifiers of each system are by design separate/local and are placed + /// one after other (this's followed by other's). + // Eg: Input: 'this' has ((%i %j) [%M %N]) + // 'other' has (%k, %j) [%P, %N, %M]) + // Output: both 'this', 'other' have (%i, %j, %k) [%M, %N, %P] + // + void mergeAndAlignIdsWithOther(unsigned offset, FlatAffineConstraints *other); + unsigned getNumConstraints() const { return getNumInequalities() + getNumEqualities(); } diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h index 34eb627..d6bf0c6 100644 --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -92,6 +92,14 @@ LogicalResult getBackwardComputationSliceState( const MemRefAccess &srcAccess, const MemRefAccess &dstAccess, unsigned dstLoopDepth, ComputationSliceState *sliceState); +/// Computes in 'sliceUnion' the union of all slice bounds computed at +/// 'dstLoopDepth' between all pairs in 'srcOps' and 'dstOp' which access the +/// same memref. Returns 'success' if union was computed, 'failure' otherwise. +LogicalResult computeSliceUnion(ArrayRef srcOps, + ArrayRef dstOps, + unsigned dstLoopDepth, + ComputationSliceState *sliceUnion); + /// Creates a clone of the computation contained in the loop nest surrounding /// 'srcOpInst', slices the iteration space of src loop based on slice bounds /// in 'sliceState', and inserts the computation slice at the beginning of the diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index 3b7d5a0..9a821a0 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -482,13 +482,20 @@ void FlatAffineConstraints::addId(IdKind kind, unsigned pos, Value *id) { /// Checks if two constraint systems are in the same space, i.e., if they are /// associated with the same set of identifiers, appearing in the same order. -bool areIdsAligned(const FlatAffineConstraints &A, - const FlatAffineConstraints &B) { +static bool areIdsAligned(const FlatAffineConstraints &A, + const FlatAffineConstraints &B) { return A.getNumDimIds() == B.getNumDimIds() && A.getNumSymbolIds() == B.getNumSymbolIds() && A.getNumIds() == B.getNumIds() && A.getIds().equals(B.getIds()); } +/// Calls areIdsAligned to check if two constraint systems have the same set +/// of identifiers in the same order. +bool FlatAffineConstraints::areIdsAlignedWithOther( + const FlatAffineConstraints &other) { + return areIdsAligned(*this, other); +} + /// Checks if the SSA values associated with `cst''s identifiers are unique. static bool LLVM_ATTRIBUTE_UNUSED areIdsUnique(const FlatAffineConstraints &cst) { @@ -527,7 +534,6 @@ static void swapId(FlatAffineConstraints *A, unsigned posA, unsigned posB) { // Eg: Input: A has ((%i %j) [%M %N]) and B has (%k, %j) [%P, %N, %M]) // Output: both A, B have (%i, %j, %k) [%M, %N, %P] // -// TODO(mlir-team): expose this function at some point. static void mergeAndAlignIds(unsigned offset, FlatAffineConstraints *A, FlatAffineConstraints *B) { assert(offset <= A->getNumDimIds() && offset <= B->getNumDimIds()); @@ -604,6 +610,12 @@ static void mergeAndAlignIds(unsigned offset, FlatAffineConstraints *A, assert(areIdsAligned(*A, *B) && "IDs expected to be aligned"); } +// Call 'mergeAndAlignIds' to align constraint systems of 'this' and 'other'. +void FlatAffineConstraints::mergeAndAlignIdsWithOther( + unsigned offset, FlatAffineConstraints *other) { + mergeAndAlignIds(offset, this, other); +} + // This routine may add additional local variables if the flattened expression // corresponding to the map has such variables due to mod's, ceildiv's, and // floordiv's in it. @@ -1745,18 +1757,12 @@ LogicalResult FlatAffineConstraints::addSliceBounds( if (failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/true, /*lower=*/true))) return failure(); - if (failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/true, - /*lower=*/true))) - return failure(); continue; } if (lbMap && failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/false, /*lower=*/true))) return failure(); - if (lbMap && failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/false, - /*lower=*/true))) - return failure(); if (ubMap && failed(addLowerOrUpperBound(pos, ubMap, operands, /*eq=*/false, /*lower=*/false))) diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 2a46c0e..3026074 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -28,6 +28,7 @@ #include "mlir/IR/Builders.h" #include "mlir/StandardOps/Ops.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" @@ -481,6 +482,153 @@ static Operation *getInstAtPosition(ArrayRef positions, return nullptr; } +// Returns the MemRef accessed by load or store 'op'. +static Value *getLoadOrStoreMemRef(Operation *op) { + if (auto loadOp = dyn_cast(op)) + return loadOp.getMemRef(); + return cast(op).getMemRef(); +} + +// Adds loop IV bounds to 'cst' for loop IVs not found in 'ivs'. +LogicalResult addMissingLoopIVBounds(SmallPtrSet &ivs, + FlatAffineConstraints *cst) { + for (unsigned i = 0, e = cst->getNumDimIds(); i < e; ++i) { + auto *value = cst->getIdValue(i); + if (ivs.count(value) == 0) { + assert(isForInductionVar(value)); + auto loop = getForInductionVarOwner(value); + if (failed(cst->addAffineForOpDomain(loop))) + return failure(); + } + } + return success(); +} + +/// Computes in 'sliceUnion' the union of all slice bounds computed at +/// 'dstLoopDepth' between all pairs in 'srcOps' and 'dstOp' which access the +/// same memref. Returns 'Success' if union was computed, 'failure' otherwise. +LogicalResult mlir::computeSliceUnion(ArrayRef srcOps, + ArrayRef dstOps, + unsigned dstLoopDepth, + ComputationSliceState *sliceUnion) { + unsigned numSrcOps = srcOps.size(); + unsigned numDstOps = dstOps.size(); + assert(numSrcOps > 0 && numDstOps > 0); + + // Compute the intersection of 'srcMemrefToOps' and 'dstMemrefToOps'. + llvm::SmallDenseSet memrefIntersection; + for (auto *srcOp : srcOps) { + auto *srcMemRef = getLoadOrStoreMemRef(srcOp); + for (auto *dstOp : dstOps) { + if (srcMemRef == getLoadOrStoreMemRef(dstOp)) + memrefIntersection.insert(srcMemRef); + } + } + // Return failure if 'memrefIntersection' is empty. + if (memrefIntersection.empty()) + return failure(); + + // Compute the union of slice bounds between all pairs in 'srcOps' and + // 'dstOps' in 'sliceUnionCst'. + FlatAffineConstraints sliceUnionCst; + assert(sliceUnionCst.getNumDimAndSymbolIds() == 0); + for (unsigned i = 0; i < numSrcOps; ++i) { + MemRefAccess srcAccess(srcOps[i]); + for (unsigned j = 0; j < numDstOps; ++j) { + MemRefAccess dstAccess(dstOps[j]); + if (srcAccess.memref != dstAccess.memref) + continue; + // Compute slice bounds for 'srcAccess' and 'dstAccess'. + ComputationSliceState tmpSliceState; + if (failed(mlir::getBackwardComputationSliceState( + srcAccess, dstAccess, dstLoopDepth, &tmpSliceState))) { + LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice bounds\n."); + return failure(); + } + + if (sliceUnionCst.getNumDimAndSymbolIds() == 0) { + // Initialize 'sliceUnionCst' with the bounds computed in previous step. + if (failed(tmpSliceState.getAsConstraints(&sliceUnionCst))) { + LLVM_DEBUG(llvm::dbgs() + << "Unable to compute slice bound constraints\n."); + return failure(); + } + assert(sliceUnionCst.getNumDimAndSymbolIds() > 0); + continue; + } + + // Compute constraints for 'tmpSliceState' in 'tmpSliceCst'. + FlatAffineConstraints tmpSliceCst; + if (failed(tmpSliceState.getAsConstraints(&tmpSliceCst))) { + LLVM_DEBUG(llvm::dbgs() + << "Unable to compute slice bound constraints\n."); + return failure(); + } + + // Align coordinate spaces of 'sliceUnionCst' and 'tmpSliceCst' if needed. + if (!sliceUnionCst.areIdsAlignedWithOther(tmpSliceCst)) { + + // Pre-constraint id alignment: record loop IVs used in each constraint + // system. + SmallPtrSet sliceUnionIVs; + for (unsigned k = 0, l = sliceUnionCst.getNumDimIds(); k < l; ++k) + sliceUnionIVs.insert(sliceUnionCst.getIdValue(k)); + SmallPtrSet tmpSliceIVs; + for (unsigned k = 0, l = tmpSliceCst.getNumDimIds(); k < l; ++k) + tmpSliceIVs.insert(tmpSliceCst.getIdValue(k)); + + sliceUnionCst.mergeAndAlignIdsWithOther(/*offset=*/0, &tmpSliceCst); + + // Post-constraint id alignment: add loop IV bounds missing after + // id alignment to constraint systems. This can occur if one constraint + // system uses an loop IV that is not used by the other. The call + // to unionBoundingBox below expects constraints for each Loop IV, even + // if they are the unsliced full loop bounds added here. + if (failed(addMissingLoopIVBounds(sliceUnionIVs, &sliceUnionCst))) + return failure(); + if (failed(addMissingLoopIVBounds(tmpSliceIVs, &tmpSliceCst))) + return failure(); + } + // Compute union bounding box of 'sliceUnionCst' and 'tmpSliceCst'. + if (failed(sliceUnionCst.unionBoundingBox(tmpSliceCst))) { + LLVM_DEBUG(llvm::dbgs() + << "Unable to compute union bounding box of slice bounds." + "\n."); + return failure(); + } + } + } + + // Store 'numSrcLoopIvs' before converting dst loop IVs to dims. + unsigned numSrcLoopIVs = sliceUnionCst.getNumDimIds(); + + // Convert any dst loop IVs which are symbol identifiers to dim identifiers. + sliceUnionCst.convertLoopIVSymbolsToDims(); + sliceUnion->clearBounds(); + sliceUnion->lbs.resize(numSrcLoopIVs, AffineMap()); + sliceUnion->ubs.resize(numSrcLoopIVs, AffineMap()); + + // Get slice bounds from slice union constraints 'sliceUnionCst'. + sliceUnionCst.getSliceBounds(numSrcLoopIVs, srcOps[0]->getContext(), + &sliceUnion->lbs, &sliceUnion->ubs); + + // Add slice bound operands of union. + SmallVector sliceBoundOperands; + sliceUnionCst.getIdValues(numSrcLoopIVs, + sliceUnionCst.getNumDimAndSymbolIds(), + &sliceBoundOperands); + + // Copy src loop IVs from 'sliceUnionCst' to 'sliceUnion'. + sliceUnion->ivs.clear(); + sliceUnionCst.getIdValues(0, numSrcLoopIVs, &sliceUnion->ivs); + + // Give each bound its own copy of 'sliceBoundOperands' for subsequent + // canonicalization. + sliceUnion->lbOperands.resize(numSrcLoopIVs, sliceBoundOperands); + sliceUnion->ubOperands.resize(numSrcLoopIVs, sliceBoundOperands); + return success(); +} + const char *const kSliceFusionBarrierAttrName = "slice_fusion_barrier"; // Computes memref dependence between 'srcAccess' and 'dstAccess', projects // out any dst loop IVs at depth greater than 'dstLoopDepth', and computes slice diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 1f475f1..7eb2c72 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -1192,82 +1192,6 @@ static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId, return true; } -// Computes the union of all slice bounds computed between 'srcOpInst' -// and each load op in 'dstLoadOpInsts' at 'dstLoopDepth', and returns -// the union in 'sliceState'. Returns true on success, false otherwise. -// TODO(andydavis) Move this to a loop fusion utility function. -static bool getSliceUnion(Operation *srcOpInst, - ArrayRef dstLoadOpInsts, - unsigned numSrcLoopIVs, unsigned dstLoopDepth, - ComputationSliceState *sliceState) { - MemRefAccess srcAccess(srcOpInst); - unsigned numDstLoadOpInsts = dstLoadOpInsts.size(); - assert(numDstLoadOpInsts > 0); - // Compute the slice bounds between 'srcOpInst' and 'dstLoadOpInsts[0]'. - if (failed(mlir::getBackwardComputationSliceState( - srcAccess, MemRefAccess(dstLoadOpInsts[0]), dstLoopDepth, - sliceState))) - return false; - // Handle the common case of one dst load without a copy. - if (numDstLoadOpInsts == 1) - return true; - - // Initialize 'sliceUnionCst' with the bounds computed in previous step. - FlatAffineConstraints sliceUnionCst; - if (failed(sliceState->getAsConstraints(&sliceUnionCst))) { - LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice bound constraints\n."); - return false; - } - - // Compute the union of slice bounds between 'srcOpInst' and each load - // in 'dstLoadOpInsts' in range [1, numDstLoadOpInsts), in 'sliceUnionCst'. - for (unsigned i = 1; i < numDstLoadOpInsts; ++i) { - MemRefAccess dstAccess(dstLoadOpInsts[i]); - // Compute slice bounds for 'srcOpInst' and 'dstLoadOpInsts[i]'. - ComputationSliceState tmpSliceState; - if (failed(mlir::getBackwardComputationSliceState( - srcAccess, dstAccess, dstLoopDepth, &tmpSliceState))) { - LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice bounds\n."); - return false; - } - - // Compute constraints for 'tmpSliceState' in 'tmpSliceCst'. - FlatAffineConstraints tmpSliceCst; - if (failed(tmpSliceState.getAsConstraints(&tmpSliceCst))) { - LLVM_DEBUG(llvm::dbgs() - << "Unable to compute slice bound constraints\n."); - return false; - } - // Compute union bounding box of 'sliceUnionCst' and 'tmpSliceCst'. - if (failed(sliceUnionCst.unionBoundingBox(tmpSliceCst))) { - LLVM_DEBUG(llvm::dbgs() - << "Unable to compute union bounding box of slice bounds.\n."); - return false; - } - } - - // Convert any dst loop IVs which are symbol identifiers to dim identifiers. - sliceUnionCst.convertLoopIVSymbolsToDims(); - - sliceState->clearBounds(); - sliceState->lbs.resize(numSrcLoopIVs, AffineMap()); - sliceState->ubs.resize(numSrcLoopIVs, AffineMap()); - - // Get slice bounds from slice union constraints 'sliceUnionCst'. - sliceUnionCst.getSliceBounds(numSrcLoopIVs, srcOpInst->getContext(), - &sliceState->lbs, &sliceState->ubs); - // Add slice bound operands of union. - SmallVector sliceBoundOperands; - sliceUnionCst.getIdValues(numSrcLoopIVs, - sliceUnionCst.getNumDimAndSymbolIds(), - &sliceBoundOperands); - // Give each bound its own copy of 'sliceBoundOperands' for subsequent - // canonicalization. - sliceState->lbOperands.resize(numSrcLoopIVs, sliceBoundOperands); - sliceState->ubOperands.resize(numSrcLoopIVs, sliceBoundOperands); - return true; -} - // Checks the profitability of fusing a backwards slice of the loop nest // surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'. // The argument 'srcStoreOpInst' is used to calculate the storage reduction on @@ -1404,10 +1328,11 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, DenseMap computeCostMap; for (unsigned i = maxDstLoopDepth; i >= 1; --i) { // Compute the union of slice bounds of all ops in 'dstLoadOpInsts'. - if (!getSliceUnion(srcOpInst, dstLoadOpInsts, numSrcLoopIVs, i, - &sliceStates[i - 1])) { + if (failed(mlir::computeSliceUnion({srcOpInst}, dstLoadOpInsts, + /*dstLoopDepth=*/i, + &sliceStates[i - 1]))) { LLVM_DEBUG(llvm::dbgs() - << "getSliceUnion failed for loopDepth: " << i << "\n"); + << "computeSliceUnion failed for loopDepth: " << i << "\n"); continue; } @@ -1813,9 +1738,10 @@ public: continue; // TODO(andydavis) Remove assert and surrounding code when // canFuseLoops is fully functional. + mlir::ComputationSliceState sliceUnion; FusionResult result = mlir::canFuseLoops( cast(srcNode->op), cast(dstNode->op), - bestDstLoopDepth, /*srcSlice=*/nullptr); + bestDstLoopDepth, &sliceUnion); assert(result.value == FusionResult::Success); (void)result; diff --git a/mlir/lib/Transforms/TestLoopFusion.cpp b/mlir/lib/Transforms/TestLoopFusion.cpp index 9ace2fb..638cf91 100644 --- a/mlir/lib/Transforms/TestLoopFusion.cpp +++ b/mlir/lib/Transforms/TestLoopFusion.cpp @@ -76,8 +76,10 @@ static void testDependenceCheck(SmallVector &loops, unsigned i, unsigned j, unsigned loopDepth) { AffineForOp srcForOp = loops[i]; AffineForOp dstForOp = loops[j]; - FusionResult result = mlir::canFuseLoops(srcForOp, dstForOp, loopDepth, - /*srcSlice=*/nullptr); + mlir::ComputationSliceState sliceUnion; + // TODO(andydavis) Test at deeper loop depths current loop depth + 1. + FusionResult result = + mlir::canFuseLoops(srcForOp, dstForOp, loopDepth + 1, &sliceUnion); if (result.value == FusionResult::FailBlockDependence) { srcForOp.getOperation()->emitRemark("block-level dependence preventing" " fusion of loop nest ") diff --git a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp index 9de6766..cb1d9d1 100644 --- a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp @@ -40,9 +40,10 @@ using namespace mlir; -// Gathers all load and store operations in 'opA' into 'values', where +// Gathers all load and store memref accesses in 'opA' into 'values', where // 'values[memref] == true' for each store operation. -static void getLoadsAndStores(Operation *opA, DenseMap &values) { +static void getLoadAndStoreMemRefAccesses(Operation *opA, + DenseMap &values) { opA->walk([&](Operation *op) { if (auto loadOp = dyn_cast(op)) { if (values.count(loadOp.getMemRef()) == 0) @@ -73,7 +74,7 @@ static Operation *getFirstDependentOpInRange(Operation *opA, Operation *opB) { // Record memref values from all loads/store in loop nest rooted at 'opA'. // Map from memref value to bool which is true if store, false otherwise. DenseMap values; - getLoadsAndStores(opA, values); + getLoadAndStoreMemRefAccesses(opA, values); // For each 'opX' in block in range ('opA', 'opB'), check if there is a data // dependence from 'opA' to 'opX' ('opA' and 'opX' access the same memref @@ -99,7 +100,7 @@ static Operation *getLastDependentOpInRange(Operation *opA, Operation *opB) { // Record memref values from all loads/store in loop nest rooted at 'opB'. // Map from memref value to bool which is true if store, false otherwise. DenseMap values; - getLoadsAndStores(opB, values); + getLoadAndStoreMemRefAccesses(opB, values); // For each 'opX' in block in range ('opA', 'opB') in reverse order, // check if there is a data dependence from 'opX' to 'opB': @@ -176,8 +177,22 @@ static Operation *getFusedLoopNestInsertionPoint(AffineForOp srcForOp, return forOpB.getOperation(); } +// Gathers all load and store ops in loop nest rooted at 'forOp' into +// 'loadAndStoreOps'. +static bool +gatherLoadsAndStores(AffineForOp forOp, + SmallVectorImpl &loadAndStoreOps) { + bool hasIfOp = false; + forOp.getOperation()->walk([&](Operation *op) { + if (isa(op) || isa(op)) + loadAndStoreOps.push_back(op); + else if (isa(op)) + hasIfOp = true; + }); + return !hasIfOp; +} + // TODO(andydavis) Add support for the following features in subsequent CLs: -// *) Computing union of slices computed between src/dst loads and stores. // *) Compute dependences of unfused src/dst loops. // *) Compute dependences of src/dst loop as if they were fused. // *) Check for fusion preventing dependences (e.g. a dependence which changes @@ -185,18 +200,46 @@ static Operation *getFusedLoopNestInsertionPoint(AffineForOp srcForOp, FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, unsigned dstLoopDepth, ComputationSliceState *srcSlice) { - // Return 'false' if 'srcForOp' and 'dstForOp' are not in the same block. + // Return 'failure' if 'dstLoopDepth == 0'. + if (dstLoopDepth == 0) { + LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests at depth 0\n."); + return FusionResult::FailPrecondition; + } + // Return 'failure' if 'srcForOp' and 'dstForOp' are not in the same block. auto *block = srcForOp.getOperation()->getBlock(); if (block != dstForOp.getOperation()->getBlock()) { LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests in different blocks\n."); return FusionResult::FailPrecondition; } - // Return 'false' if no valid insertion point for fused loop nest in 'block' + // Return 'failure' if no valid insertion point for fused loop nest in 'block' // exists which would preserve dependences. if (!getFusedLoopNestInsertionPoint(srcForOp, dstForOp)) { LLVM_DEBUG(llvm::dbgs() << "Fusion would violate dependences in block\n."); return FusionResult::FailBlockDependence; } + + // Gather all load and store ops in 'srcForOp'. + SmallVector srcLoadAndStoreOps; + if (!gatherLoadsAndStores(srcForOp, srcLoadAndStoreOps)) { + LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported.\n."); + return FusionResult::FailPrecondition; + } + + // Gather all load and store ops in 'dstForOp'. + SmallVector dstLoadAndStoreOps; + if (!gatherLoadsAndStores(dstForOp, dstLoadAndStoreOps)) { + LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported.\n."); + return FusionResult::FailPrecondition; + } + + // Compute union of computation slices computed from all pairs in + // {'srcLoadAndStoreOps', 'dstLoadAndStoreOps'}. + if (failed(mlir::computeSliceUnion(srcLoadAndStoreOps, dstLoadAndStoreOps, + dstLoopDepth, srcSlice))) { + LLVM_DEBUG(llvm::dbgs() << "computeSliceUnion failed\n"); + return FusionResult::FailPrecondition; + } + return FusionResult::Success; } -- 2.7.4