LoopFusionUtils CL 2/n: Factor out and generalize slice union computation.
authorAndy Davis <andydavis@google.com>
Wed, 29 May 2019 21:02:14 +0000 (14:02 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 2 Jun 2019 03:08:52 +0000 (20:08 -0700)
    *) Factors slice union computation out of LoopFusion into Analysis/Utils (where other iteration slice utilities exist).
    *) Generalizes slice union computation to take the union of slices computed on all loads/stores pairs between source and destination loop nests.
    *) Fixes a bug in FlatAffineConstraints::addSliceBounds where redundant constraints were added.
    *) Takes care of a TODO to expose FlatAffineConstraints::mergeAndAlignIds as a public method.

--

PiperOrigin-RevId: 250561529

mlir/include/mlir/Analysis/AffineStructures.h
mlir/include/mlir/Analysis/Utils.h
mlir/lib/Analysis/AffineStructures.cpp
mlir/lib/Analysis/Utils.cpp
mlir/lib/Transforms/LoopFusion.cpp
mlir/lib/Transforms/TestLoopFusion.cpp
mlir/lib/Transforms/Utils/LoopFusionUtils.cpp

index 1cff429..aadace0 100644 (file)
@@ -541,6 +541,25 @@ public:
   ///     <= 15}, output = {0 <= d0 <= 6, 1 <= d1 <= 15}.
   LogicalResult unionBoundingBox(const FlatAffineConstraints &other);
 
+  /// Returns 'true' if this constraint system and 'other' are in the same
+  /// space, i.e., if they are associated with the same set of identifiers,
+  /// appearing in the same order. Returns 'false' otherwise.
+  bool areIdsAlignedWithOther(const FlatAffineConstraints &other);
+
+  /// Merge and align the identifiers of 'this' and 'other' starting at
+  /// 'offset', so that both constraint systems get the union of the contained
+  /// identifiers that is dimension-wise and symbol-wise unique; both
+  /// constraint systems are updated so that they have the union of all
+  /// identifiers, with this's original identifiers appearing first followed by
+  /// any of other's identifiers that didn't appear in 'this'. Local
+  /// identifiers of each system are by design separate/local and are placed
+  /// one after other (this's followed by other's).
+  //  Eg: Input: 'this'  has ((%i %j) [%M %N])
+  //             'other' has (%k, %j) [%P, %N, %M])
+  //      Output: both 'this', 'other' have (%i, %j, %k) [%M, %N, %P]
+  //
+  void mergeAndAlignIdsWithOther(unsigned offset, FlatAffineConstraints *other);
+
   unsigned getNumConstraints() const {
     return getNumInequalities() + getNumEqualities();
   }
index 34eb627..d6bf0c6 100644 (file)
@@ -92,6 +92,14 @@ LogicalResult getBackwardComputationSliceState(
     const MemRefAccess &srcAccess, const MemRefAccess &dstAccess,
     unsigned dstLoopDepth, ComputationSliceState *sliceState);
 
+/// Computes in 'sliceUnion' the union of all slice bounds computed at
+/// 'dstLoopDepth' between all pairs in 'srcOps' and 'dstOp' which access the
+/// same memref. Returns 'success' if union was computed, 'failure' otherwise.
+LogicalResult computeSliceUnion(ArrayRef<Operation *> srcOps,
+                                ArrayRef<Operation *> dstOps,
+                                unsigned dstLoopDepth,
+                                ComputationSliceState *sliceUnion);
+
 /// Creates a clone of the computation contained in the loop nest surrounding
 /// 'srcOpInst', slices the iteration space of src loop based on slice bounds
 /// in 'sliceState', and inserts the computation slice at the beginning of the
index 3b7d5a0..9a821a0 100644 (file)
@@ -482,13 +482,20 @@ void FlatAffineConstraints::addId(IdKind kind, unsigned pos, Value *id) {
 
 /// Checks if two constraint systems are in the same space, i.e., if they are
 /// associated with the same set of identifiers, appearing in the same order.
-bool areIdsAligned(const FlatAffineConstraints &A,
-                   const FlatAffineConstraints &B) {
+static bool areIdsAligned(const FlatAffineConstraints &A,
+                          const FlatAffineConstraints &B) {
   return A.getNumDimIds() == B.getNumDimIds() &&
          A.getNumSymbolIds() == B.getNumSymbolIds() &&
          A.getNumIds() == B.getNumIds() && A.getIds().equals(B.getIds());
 }
 
+/// Calls areIdsAligned to check if two constraint systems have the same set
+/// of identifiers in the same order.
+bool FlatAffineConstraints::areIdsAlignedWithOther(
+    const FlatAffineConstraints &other) {
+  return areIdsAligned(*this, other);
+}
+
 /// Checks if the SSA values associated with `cst''s identifiers are unique.
 static bool LLVM_ATTRIBUTE_UNUSED
 areIdsUnique(const FlatAffineConstraints &cst) {
@@ -527,7 +534,6 @@ static void swapId(FlatAffineConstraints *A, unsigned posA, unsigned posB) {
 //  Eg: Input: A has ((%i %j) [%M %N]) and B has (%k, %j) [%P, %N, %M])
 //      Output: both A, B have (%i, %j, %k) [%M, %N, %P]
 //
-// TODO(mlir-team): expose this function at some point.
 static void mergeAndAlignIds(unsigned offset, FlatAffineConstraints *A,
                              FlatAffineConstraints *B) {
   assert(offset <= A->getNumDimIds() && offset <= B->getNumDimIds());
@@ -604,6 +610,12 @@ static void mergeAndAlignIds(unsigned offset, FlatAffineConstraints *A,
   assert(areIdsAligned(*A, *B) && "IDs expected to be aligned");
 }
 
+// Call 'mergeAndAlignIds' to align constraint systems of 'this' and 'other'.
+void FlatAffineConstraints::mergeAndAlignIdsWithOther(
+    unsigned offset, FlatAffineConstraints *other) {
+  mergeAndAlignIds(offset, this, other);
+}
+
 // This routine may add additional local variables if the flattened expression
 // corresponding to the map has such variables due to mod's, ceildiv's, and
 // floordiv's in it.
@@ -1745,18 +1757,12 @@ LogicalResult FlatAffineConstraints::addSliceBounds(
       if (failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/true,
                                       /*lower=*/true)))
         return failure();
-      if (failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/true,
-                                      /*lower=*/true)))
-        return failure();
       continue;
     }
 
     if (lbMap && failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/false,
                                              /*lower=*/true)))
       return failure();
-    if (lbMap && failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/false,
-                                             /*lower=*/true)))
-      return failure();
 
     if (ubMap && failed(addLowerOrUpperBound(pos, ubMap, operands, /*eq=*/false,
                                              /*lower=*/false)))
index 2a46c0e..3026074 100644 (file)
@@ -28,6 +28,7 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/StandardOps/Ops.h"
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
 
@@ -481,6 +482,153 @@ static Operation *getInstAtPosition(ArrayRef<unsigned> positions,
   return nullptr;
 }
 
+// Returns the MemRef accessed by load or store 'op'.
+static Value *getLoadOrStoreMemRef(Operation *op) {
+  if (auto loadOp = dyn_cast<LoadOp>(op))
+    return loadOp.getMemRef();
+  return cast<StoreOp>(op).getMemRef();
+}
+
+// Adds loop IV bounds to 'cst' for loop IVs not found in 'ivs'.
+LogicalResult addMissingLoopIVBounds(SmallPtrSet<Value *, 8> &ivs,
+                                     FlatAffineConstraints *cst) {
+  for (unsigned i = 0, e = cst->getNumDimIds(); i < e; ++i) {
+    auto *value = cst->getIdValue(i);
+    if (ivs.count(value) == 0) {
+      assert(isForInductionVar(value));
+      auto loop = getForInductionVarOwner(value);
+      if (failed(cst->addAffineForOpDomain(loop)))
+        return failure();
+    }
+  }
+  return success();
+}
+
+/// Computes in 'sliceUnion' the union of all slice bounds computed at
+/// 'dstLoopDepth' between all pairs in 'srcOps' and 'dstOp' which access the
+/// same memref. Returns 'Success' if union was computed, 'failure' otherwise.
+LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> srcOps,
+                                      ArrayRef<Operation *> dstOps,
+                                      unsigned dstLoopDepth,
+                                      ComputationSliceState *sliceUnion) {
+  unsigned numSrcOps = srcOps.size();
+  unsigned numDstOps = dstOps.size();
+  assert(numSrcOps > 0 && numDstOps > 0);
+
+  // Compute the intersection of 'srcMemrefToOps' and 'dstMemrefToOps'.
+  llvm::SmallDenseSet<Value *> memrefIntersection;
+  for (auto *srcOp : srcOps) {
+    auto *srcMemRef = getLoadOrStoreMemRef(srcOp);
+    for (auto *dstOp : dstOps) {
+      if (srcMemRef == getLoadOrStoreMemRef(dstOp))
+        memrefIntersection.insert(srcMemRef);
+    }
+  }
+  // Return failure if 'memrefIntersection' is empty.
+  if (memrefIntersection.empty())
+    return failure();
+
+  // Compute the union of slice bounds between all pairs in 'srcOps' and
+  // 'dstOps' in 'sliceUnionCst'.
+  FlatAffineConstraints sliceUnionCst;
+  assert(sliceUnionCst.getNumDimAndSymbolIds() == 0);
+  for (unsigned i = 0; i < numSrcOps; ++i) {
+    MemRefAccess srcAccess(srcOps[i]);
+    for (unsigned j = 0; j < numDstOps; ++j) {
+      MemRefAccess dstAccess(dstOps[j]);
+      if (srcAccess.memref != dstAccess.memref)
+        continue;
+      // Compute slice bounds for 'srcAccess' and 'dstAccess'.
+      ComputationSliceState tmpSliceState;
+      if (failed(mlir::getBackwardComputationSliceState(
+              srcAccess, dstAccess, dstLoopDepth, &tmpSliceState))) {
+        LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice bounds\n.");
+        return failure();
+      }
+
+      if (sliceUnionCst.getNumDimAndSymbolIds() == 0) {
+        // Initialize 'sliceUnionCst' with the bounds computed in previous step.
+        if (failed(tmpSliceState.getAsConstraints(&sliceUnionCst))) {
+          LLVM_DEBUG(llvm::dbgs()
+                     << "Unable to compute slice bound constraints\n.");
+          return failure();
+        }
+        assert(sliceUnionCst.getNumDimAndSymbolIds() > 0);
+        continue;
+      }
+
+      // Compute constraints for 'tmpSliceState' in 'tmpSliceCst'.
+      FlatAffineConstraints tmpSliceCst;
+      if (failed(tmpSliceState.getAsConstraints(&tmpSliceCst))) {
+        LLVM_DEBUG(llvm::dbgs()
+                   << "Unable to compute slice bound constraints\n.");
+        return failure();
+      }
+
+      // Align coordinate spaces of 'sliceUnionCst' and 'tmpSliceCst' if needed.
+      if (!sliceUnionCst.areIdsAlignedWithOther(tmpSliceCst)) {
+
+        // Pre-constraint id alignment: record loop IVs used in each constraint
+        // system.
+        SmallPtrSet<Value *, 8> sliceUnionIVs;
+        for (unsigned k = 0, l = sliceUnionCst.getNumDimIds(); k < l; ++k)
+          sliceUnionIVs.insert(sliceUnionCst.getIdValue(k));
+        SmallPtrSet<Value *, 8> tmpSliceIVs;
+        for (unsigned k = 0, l = tmpSliceCst.getNumDimIds(); k < l; ++k)
+          tmpSliceIVs.insert(tmpSliceCst.getIdValue(k));
+
+        sliceUnionCst.mergeAndAlignIdsWithOther(/*offset=*/0, &tmpSliceCst);
+
+        // Post-constraint id alignment: add loop IV bounds missing after
+        // id alignment to constraint systems. This can occur if one constraint
+        // system uses an loop IV that is not used by the other. The call
+        // to unionBoundingBox below expects constraints for each Loop IV, even
+        // if they are the unsliced full loop bounds added here.
+        if (failed(addMissingLoopIVBounds(sliceUnionIVs, &sliceUnionCst)))
+          return failure();
+        if (failed(addMissingLoopIVBounds(tmpSliceIVs, &tmpSliceCst)))
+          return failure();
+      }
+      // Compute union bounding box of 'sliceUnionCst' and 'tmpSliceCst'.
+      if (failed(sliceUnionCst.unionBoundingBox(tmpSliceCst))) {
+        LLVM_DEBUG(llvm::dbgs()
+                   << "Unable to compute union bounding box of slice bounds."
+                      "\n.");
+        return failure();
+      }
+    }
+  }
+
+  // Store 'numSrcLoopIvs' before converting dst loop IVs to dims.
+  unsigned numSrcLoopIVs = sliceUnionCst.getNumDimIds();
+
+  // Convert any dst loop IVs which are symbol identifiers to dim identifiers.
+  sliceUnionCst.convertLoopIVSymbolsToDims();
+  sliceUnion->clearBounds();
+  sliceUnion->lbs.resize(numSrcLoopIVs, AffineMap());
+  sliceUnion->ubs.resize(numSrcLoopIVs, AffineMap());
+
+  // Get slice bounds from slice union constraints 'sliceUnionCst'.
+  sliceUnionCst.getSliceBounds(numSrcLoopIVs, srcOps[0]->getContext(),
+                               &sliceUnion->lbs, &sliceUnion->ubs);
+
+  // Add slice bound operands of union.
+  SmallVector<Value *, 4> sliceBoundOperands;
+  sliceUnionCst.getIdValues(numSrcLoopIVs,
+                            sliceUnionCst.getNumDimAndSymbolIds(),
+                            &sliceBoundOperands);
+
+  // Copy src loop IVs from 'sliceUnionCst' to 'sliceUnion'.
+  sliceUnion->ivs.clear();
+  sliceUnionCst.getIdValues(0, numSrcLoopIVs, &sliceUnion->ivs);
+
+  // Give each bound its own copy of 'sliceBoundOperands' for subsequent
+  // canonicalization.
+  sliceUnion->lbOperands.resize(numSrcLoopIVs, sliceBoundOperands);
+  sliceUnion->ubOperands.resize(numSrcLoopIVs, sliceBoundOperands);
+  return success();
+}
+
 const char *const kSliceFusionBarrierAttrName = "slice_fusion_barrier";
 // Computes memref dependence between 'srcAccess' and 'dstAccess', projects
 // out any dst loop IVs at depth greater than 'dstLoopDepth', and computes slice
index 1f475f1..7eb2c72 100644 (file)
@@ -1192,82 +1192,6 @@ static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId,
   return true;
 }
 
-// Computes the union of all slice bounds computed between 'srcOpInst'
-// and each load op in 'dstLoadOpInsts' at 'dstLoopDepth', and returns
-// the union in 'sliceState'. Returns true on success, false otherwise.
-// TODO(andydavis) Move this to a loop fusion utility function.
-static bool getSliceUnion(Operation *srcOpInst,
-                          ArrayRef<Operation *> dstLoadOpInsts,
-                          unsigned numSrcLoopIVs, unsigned dstLoopDepth,
-                          ComputationSliceState *sliceState) {
-  MemRefAccess srcAccess(srcOpInst);
-  unsigned numDstLoadOpInsts = dstLoadOpInsts.size();
-  assert(numDstLoadOpInsts > 0);
-  // Compute the slice bounds between 'srcOpInst' and 'dstLoadOpInsts[0]'.
-  if (failed(mlir::getBackwardComputationSliceState(
-          srcAccess, MemRefAccess(dstLoadOpInsts[0]), dstLoopDepth,
-          sliceState)))
-    return false;
-  // Handle the common case of one dst load without a copy.
-  if (numDstLoadOpInsts == 1)
-    return true;
-
-  // Initialize 'sliceUnionCst' with the bounds computed in previous step.
-  FlatAffineConstraints sliceUnionCst;
-  if (failed(sliceState->getAsConstraints(&sliceUnionCst))) {
-    LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice bound constraints\n.");
-    return false;
-  }
-
-  // Compute the union of slice bounds between 'srcOpInst' and each load
-  // in 'dstLoadOpInsts' in range [1, numDstLoadOpInsts), in 'sliceUnionCst'.
-  for (unsigned i = 1; i < numDstLoadOpInsts; ++i) {
-    MemRefAccess dstAccess(dstLoadOpInsts[i]);
-    // Compute slice bounds for 'srcOpInst' and 'dstLoadOpInsts[i]'.
-    ComputationSliceState tmpSliceState;
-    if (failed(mlir::getBackwardComputationSliceState(
-            srcAccess, dstAccess, dstLoopDepth, &tmpSliceState))) {
-      LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice bounds\n.");
-      return false;
-    }
-
-    // Compute constraints for 'tmpSliceState' in 'tmpSliceCst'.
-    FlatAffineConstraints tmpSliceCst;
-    if (failed(tmpSliceState.getAsConstraints(&tmpSliceCst))) {
-      LLVM_DEBUG(llvm::dbgs()
-                 << "Unable to compute slice bound constraints\n.");
-      return false;
-    }
-    // Compute union bounding box of 'sliceUnionCst' and 'tmpSliceCst'.
-    if (failed(sliceUnionCst.unionBoundingBox(tmpSliceCst))) {
-      LLVM_DEBUG(llvm::dbgs()
-                 << "Unable to compute union bounding box of slice bounds.\n.");
-      return false;
-    }
-  }
-
-  // Convert any dst loop IVs which are symbol identifiers to dim identifiers.
-  sliceUnionCst.convertLoopIVSymbolsToDims();
-
-  sliceState->clearBounds();
-  sliceState->lbs.resize(numSrcLoopIVs, AffineMap());
-  sliceState->ubs.resize(numSrcLoopIVs, AffineMap());
-
-  // Get slice bounds from slice union constraints 'sliceUnionCst'.
-  sliceUnionCst.getSliceBounds(numSrcLoopIVs, srcOpInst->getContext(),
-                               &sliceState->lbs, &sliceState->ubs);
-  // Add slice bound operands of union.
-  SmallVector<Value *, 4> sliceBoundOperands;
-  sliceUnionCst.getIdValues(numSrcLoopIVs,
-                            sliceUnionCst.getNumDimAndSymbolIds(),
-                            &sliceBoundOperands);
-  // Give each bound its own copy of 'sliceBoundOperands' for subsequent
-  // canonicalization.
-  sliceState->lbOperands.resize(numSrcLoopIVs, sliceBoundOperands);
-  sliceState->ubOperands.resize(numSrcLoopIVs, sliceBoundOperands);
-  return true;
-}
-
 // Checks the profitability of fusing a backwards slice of the loop nest
 // surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'.
 // The argument 'srcStoreOpInst' is used to calculate the storage reduction on
@@ -1404,10 +1328,11 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
   DenseMap<Operation *, int64_t> computeCostMap;
   for (unsigned i = maxDstLoopDepth; i >= 1; --i) {
     // Compute the union of slice bounds of all ops in 'dstLoadOpInsts'.
-    if (!getSliceUnion(srcOpInst, dstLoadOpInsts, numSrcLoopIVs, i,
-                       &sliceStates[i - 1])) {
+    if (failed(mlir::computeSliceUnion({srcOpInst}, dstLoadOpInsts,
+                                       /*dstLoopDepth=*/i,
+                                       &sliceStates[i - 1]))) {
       LLVM_DEBUG(llvm::dbgs()
-                 << "getSliceUnion failed for loopDepth: " << i << "\n");
+                 << "computeSliceUnion failed for loopDepth: " << i << "\n");
       continue;
     }
 
@@ -1813,9 +1738,10 @@ public:
             continue;
           // TODO(andydavis) Remove assert and surrounding code when
           // canFuseLoops is fully functional.
+          mlir::ComputationSliceState sliceUnion;
           FusionResult result = mlir::canFuseLoops(
               cast<AffineForOp>(srcNode->op), cast<AffineForOp>(dstNode->op),
-              bestDstLoopDepth, /*srcSlice=*/nullptr);
+              bestDstLoopDepth, &sliceUnion);
           assert(result.value == FusionResult::Success);
           (void)result;
 
index 9ace2fb..638cf91 100644 (file)
@@ -76,8 +76,10 @@ static void testDependenceCheck(SmallVector<AffineForOp, 2> &loops, unsigned i,
                                 unsigned j, unsigned loopDepth) {
   AffineForOp srcForOp = loops[i];
   AffineForOp dstForOp = loops[j];
-  FusionResult result = mlir::canFuseLoops(srcForOp, dstForOp, loopDepth,
-                                           /*srcSlice=*/nullptr);
+  mlir::ComputationSliceState sliceUnion;
+  // TODO(andydavis) Test at deeper loop depths current loop depth + 1.
+  FusionResult result =
+      mlir::canFuseLoops(srcForOp, dstForOp, loopDepth + 1, &sliceUnion);
   if (result.value == FusionResult::FailBlockDependence) {
     srcForOp.getOperation()->emitRemark("block-level dependence preventing"
                                         " fusion of loop nest ")
index 9de6766..cb1d9d1 100644 (file)
 
 using namespace mlir;
 
-// Gathers all load and store operations in 'opA' into 'values', where
+// Gathers all load and store memref accesses in 'opA' into 'values', where
 // 'values[memref] == true' for each store operation.
-static void getLoadsAndStores(Operation *opA, DenseMap<Value *, bool> &values) {
+static void getLoadAndStoreMemRefAccesses(Operation *opA,
+                                          DenseMap<Value *, bool> &values) {
   opA->walk([&](Operation *op) {
     if (auto loadOp = dyn_cast<LoadOp>(op)) {
       if (values.count(loadOp.getMemRef()) == 0)
@@ -73,7 +74,7 @@ static Operation *getFirstDependentOpInRange(Operation *opA, Operation *opB) {
   // Record memref values from all loads/store in loop nest rooted at 'opA'.
   // Map from memref value to bool which is true if store, false otherwise.
   DenseMap<Value *, bool> values;
-  getLoadsAndStores(opA, values);
+  getLoadAndStoreMemRefAccesses(opA, values);
 
   // For each 'opX' in block in range ('opA', 'opB'), check if there is a data
   // dependence from 'opA' to 'opX' ('opA' and 'opX' access the same memref
@@ -99,7 +100,7 @@ static Operation *getLastDependentOpInRange(Operation *opA, Operation *opB) {
   // Record memref values from all loads/store in loop nest rooted at 'opB'.
   // Map from memref value to bool which is true if store, false otherwise.
   DenseMap<Value *, bool> values;
-  getLoadsAndStores(opB, values);
+  getLoadAndStoreMemRefAccesses(opB, values);
 
   // For each 'opX' in block in range ('opA', 'opB') in reverse order,
   // check if there is a data dependence from 'opX' to 'opB':
@@ -176,8 +177,22 @@ static Operation *getFusedLoopNestInsertionPoint(AffineForOp srcForOp,
   return forOpB.getOperation();
 }
 
+// Gathers all load and store ops in loop nest rooted at 'forOp' into
+// 'loadAndStoreOps'.
+static bool
+gatherLoadsAndStores(AffineForOp forOp,
+                     SmallVectorImpl<Operation *> &loadAndStoreOps) {
+  bool hasIfOp = false;
+  forOp.getOperation()->walk([&](Operation *op) {
+    if (isa<LoadOp>(op) || isa<StoreOp>(op))
+      loadAndStoreOps.push_back(op);
+    else if (isa<AffineIfOp>(op))
+      hasIfOp = true;
+  });
+  return !hasIfOp;
+}
+
 // TODO(andydavis) Add support for the following features in subsequent CLs:
-// *) Computing union of slices computed between src/dst loads and stores.
 // *) Compute dependences of unfused src/dst loops.
 // *) Compute dependences of src/dst loop as if they were fused.
 // *) Check for fusion preventing dependences (e.g. a dependence which changes
@@ -185,18 +200,46 @@ static Operation *getFusedLoopNestInsertionPoint(AffineForOp srcForOp,
 FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
                                 unsigned dstLoopDepth,
                                 ComputationSliceState *srcSlice) {
-  // Return 'false' if 'srcForOp' and 'dstForOp' are not in the same block.
+  // Return 'failure' if 'dstLoopDepth == 0'.
+  if (dstLoopDepth == 0) {
+    LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests at depth 0\n.");
+    return FusionResult::FailPrecondition;
+  }
+  // Return 'failure' if 'srcForOp' and 'dstForOp' are not in the same block.
   auto *block = srcForOp.getOperation()->getBlock();
   if (block != dstForOp.getOperation()->getBlock()) {
     LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests in different blocks\n.");
     return FusionResult::FailPrecondition;
   }
 
-  // Return 'false' if no valid insertion point for fused loop nest in 'block'
+  // Return 'failure' if no valid insertion point for fused loop nest in 'block'
   // exists which would preserve dependences.
   if (!getFusedLoopNestInsertionPoint(srcForOp, dstForOp)) {
     LLVM_DEBUG(llvm::dbgs() << "Fusion would violate dependences in block\n.");
     return FusionResult::FailBlockDependence;
   }
+
+  // Gather all load and store ops in 'srcForOp'.
+  SmallVector<Operation *, 4> srcLoadAndStoreOps;
+  if (!gatherLoadsAndStores(srcForOp, srcLoadAndStoreOps)) {
+    LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported.\n.");
+    return FusionResult::FailPrecondition;
+  }
+
+  // Gather all load and store ops in 'dstForOp'.
+  SmallVector<Operation *, 4> dstLoadAndStoreOps;
+  if (!gatherLoadsAndStores(dstForOp, dstLoadAndStoreOps)) {
+    LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported.\n.");
+    return FusionResult::FailPrecondition;
+  }
+
+  // Compute union of computation slices computed from all pairs in
+  // {'srcLoadAndStoreOps', 'dstLoadAndStoreOps'}.
+  if (failed(mlir::computeSliceUnion(srcLoadAndStoreOps, dstLoadAndStoreOps,
+                                     dstLoopDepth, srcSlice))) {
+    LLVM_DEBUG(llvm::dbgs() << "computeSliceUnion failed\n");
+    return FusionResult::FailPrecondition;
+  }
+
   return FusionResult::Success;
 }