Return dependence result enum to distiguish between dependence result and error cases...
authorAndy Davis <andydavis@google.com>
Mon, 10 Jun 2019 17:50:08 +0000 (10:50 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Tue, 11 Jun 2019 17:12:36 +0000 (10:12 -0700)
PiperOrigin-RevId: 252437616

mlir/include/mlir/Analysis/AffineAnalysis.h
mlir/lib/Analysis/AffineAnalysis.cpp
mlir/lib/Analysis/TestMemRefDependenceCheck.cpp
mlir/lib/Analysis/Utils.cpp
mlir/lib/Transforms/LoopFusion.cpp
mlir/lib/Transforms/MemRefDataFlowOpt.cpp

index 1b92bd1..bb25a65 100644 (file)
@@ -94,19 +94,34 @@ struct DependenceComponent {
 /// Checks whether two accesses to the same memref access the same element.
 /// Each access is specified using the MemRefAccess structure, which contains
 /// the operation, indices and memref associated with the access. Returns
-/// 'false' if it can be determined conclusively that the accesses do not
+/// 'NoDependence' if it can be determined conclusively that the accesses do not
 /// access the same memref element. If 'allowRAR' is true, will consider
 /// read-after-read dependences (typically used by applications trying to
 /// optimize input reuse).
 // TODO(andydavis) Wrap 'dependenceConstraints' and 'dependenceComponents' into
 // a single struct.
 // TODO(andydavis) Make 'dependenceConstraints' optional arg.
-bool checkMemrefAccessDependence(
+struct DependenceResult {
+  enum ResultEnum {
+    HasDependence, // A dependence exists between 'srcAccess' and 'dstAccess'.
+    NoDependence,  // No dependence exists between 'srcAccess' and 'dstAccess'.
+    Failure,       // Dependence check failed due to unsupported cases.
+  } value;
+  DependenceResult(ResultEnum v) : value(v) {}
+};
+
+DependenceResult checkMemrefAccessDependence(
     const MemRefAccess &srcAccess, const MemRefAccess &dstAccess,
     unsigned loopDepth, FlatAffineConstraints *dependenceConstraints,
     llvm::SmallVector<DependenceComponent, 2> *dependenceComponents,
     bool allowRAR = false);
 
+/// Utility function that returns true if the provided DependenceResult
+/// corresponds to a dependence result.
+inline bool hasDependence(DependenceResult result) {
+  return result.value == DependenceResult::HasDependence;
+}
+
 /// Returns in 'depCompsVec', dependence components for dependences between all
 /// load and store ops in loop nest rooted at 'forOp', at loop depths in range
 /// [1, maxLoopDepth].
index a9dce13..fc8c712 100644 (file)
@@ -681,8 +681,10 @@ void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const {
 
 // Builds a flat affine constraint system to check if there exists a dependence
 // between memref accesses 'srcAccess' and 'dstAccess'.
-// Returns 'false' if the accesses can be definitively shown not to access the
-// same element. Returns 'true' otherwise.
+// Returns 'NoDependence' if the accesses can be definitively shown not to
+// access the same element.
+// Returns 'HasDependence' if the accesses do access the same element.
+// Returns 'Failure' if an error or unsupported case was encountered.
 // If a dependence exists, returns in 'dependenceComponents' a direction
 // vector for the dependence, with a component for each loop IV in loops
 // common to both accesses (see Dependence in AffineAnalysis.h for details).
@@ -764,7 +766,7 @@ void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const {
 //
 //
 // TODO(andydavis) Support AffineExprs mod/floordiv/ceildiv.
-bool mlir::checkMemrefAccessDependence(
+DependenceResult mlir::checkMemrefAccessDependence(
     const MemRefAccess &srcAccess, const MemRefAccess &dstAccess,
     unsigned loopDepth, FlatAffineConstraints *dependenceConstraints,
     llvm::SmallVector<DependenceComponent, 2> *dependenceComponents,
@@ -774,13 +776,14 @@ bool mlir::checkMemrefAccessDependence(
   LLVM_DEBUG(srcAccess.opInst->dump(););
   LLVM_DEBUG(dstAccess.opInst->dump(););
 
-  // Return 'false' if these accesses do not acces the same memref.
+  // Return 'NoDependence' if these accesses do not access the same memref.
   if (srcAccess.memref != dstAccess.memref)
-    return false;
-  // Return 'false' if one of these accesses is not a StoreOp.
+    return DependenceResult::NoDependence;
+
+  // Return 'NoDependence' if one of these accesses is not a StoreOp.
   if (!allowRAR && !isa<StoreOp>(srcAccess.opInst) &&
       !isa<StoreOp>(dstAccess.opInst))
-    return false;
+    return DependenceResult::NoDependence;
 
   // Get composed access function for 'srcAccess'.
   AffineValueMap srcAccessMap;
@@ -793,14 +796,14 @@ bool mlir::checkMemrefAccessDependence(
   // Get iteration domain for the 'srcAccess' operation.
   FlatAffineConstraints srcDomain;
   if (failed(getInstIndexSet(srcAccess.opInst, &srcDomain)))
-    return false;
+    return DependenceResult::Failure;
 
   // Get iteration domain for 'dstAccess' operation.
   FlatAffineConstraints dstDomain;
   if (failed(getInstIndexSet(dstAccess.opInst, &dstDomain)))
-    return false;
+    return DependenceResult::Failure;
 
-  // Return 'false' if loopDepth > numCommonLoops and if the ancestor operation
+  // Return 'NoDependence' if loopDepth > numCommonLoops and if the ancestor
   // operation of 'srcAccess' does not properly dominate the ancestor
   // operation of 'dstAccess' in the same common operation block.
   // Note: this check is skipped if 'allowRAR' is true, because because RAR
@@ -810,7 +813,7 @@ bool mlir::checkMemrefAccessDependence(
   if (!allowRAR && loopDepth > numCommonLoops &&
       !srcAppearsBeforeDstInAncestralBlock(srcAccess, dstAccess, srcDomain,
                                            numCommonLoops)) {
-    return false;
+    return DependenceResult::NoDependence;
   }
   // Build dim and symbol position maps for each access from access operand
   // Value to position in merged contstraint system.
@@ -830,7 +833,7 @@ bool mlir::checkMemrefAccessDependence(
   // local variables for mod/div exprs are supported.
   if (failed(addMemRefAccessConstraints(srcAccessMap, dstAccessMap, valuePosMap,
                                         dependenceConstraints)))
-    return true;
+    return DependenceResult::Failure;
 
   // Add 'src' happens before 'dst' ordering constraints.
   addOrderingConstraints(srcDomain, dstDomain, loopDepth,
@@ -839,9 +842,9 @@ bool mlir::checkMemrefAccessDependence(
   addDomainConstraints(srcDomain, dstDomain, valuePosMap,
                        dependenceConstraints);
 
-  // Return false if the solution space is empty: no dependence.
+  // Return 'NoDependence' if the solution space is empty: no dependence.
   if (dependenceConstraints->isEmpty()) {
-    return false;
+    return DependenceResult::NoDependence;
   }
 
   // Compute dependence direction vector and return true.
@@ -852,7 +855,7 @@ bool mlir::checkMemrefAccessDependence(
 
   LLVM_DEBUG(llvm::dbgs() << "Dependence polyhedron:\n");
   LLVM_DEBUG(dependenceConstraints->dump());
-  return true;
+  return DependenceResult::HasDependence;
 }
 
 /// Gathers dependence components for dependences between all ops in loop nest
@@ -880,10 +883,10 @@ void mlir::getDependenceComponents(
         llvm::SmallVector<DependenceComponent, 2> depComps;
         // TODO(andydavis,bondhugula) Explore whether it would be profitable
         // to pre-compute and store deps instead of repeatedly checking.
-        if (checkMemrefAccessDependence(srcAccess, dstAccess, d,
-                                        &dependenceConstraints, &depComps)) {
+        DependenceResult result = checkMemrefAccessDependence(
+            srcAccess, dstAccess, d, &dependenceConstraints, &depComps);
+        if (hasDependence(result))
           depCompsVec->push_back(depComps);
-        }
       }
     }
   }
index 2b0f1ab..4456ac2 100644 (file)
@@ -93,9 +93,11 @@ static void checkDependences(ArrayRef<Operation *> loadsAndStores) {
       for (unsigned d = 1; d <= numCommonLoops + 1; ++d) {
         FlatAffineConstraints dependenceConstraints;
         llvm::SmallVector<DependenceComponent, 2> dependenceComponents;
-        bool ret = checkMemrefAccessDependence(srcAccess, dstAccess, d,
-                                               &dependenceConstraints,
-                                               &dependenceComponents);
+        DependenceResult result = checkMemrefAccessDependence(
+            srcAccess, dstAccess, d, &dependenceConstraints,
+            &dependenceComponents);
+        assert(result.value != DependenceResult::Failure);
+        bool ret = hasDependence(result);
         // TODO(andydavis) Print dependence type (i.e. RAW, etc) and print
         // distance vectors as: ([2, 3], [0, 10]). Also, shorten distance
         // vectors from ([1, 1], [3, 3]) to (1, 3).
index aa84236..e5418fc 100644 (file)
@@ -640,9 +640,10 @@ LogicalResult mlir::getBackwardComputationSliceState(
   bool readReadAccesses =
       isa<LoadOp>(srcAccess.opInst) && isa<LoadOp>(dstAccess.opInst);
   FlatAffineConstraints dependenceConstraints;
-  if (!checkMemrefAccessDependence(
-          srcAccess, dstAccess, /*loopDepth=*/1, &dependenceConstraints,
-          /*dependenceComponents=*/nullptr, /*allowRAR=*/readReadAccesses)) {
+  DependenceResult result = checkMemrefAccessDependence(
+      srcAccess, dstAccess, /*loopDepth=*/1, &dependenceConstraints,
+      /*dependenceComponents=*/nullptr, /*allowRAR=*/readReadAccesses);
+  if (!hasDependence(result)) {
     return failure();
   }
   // Get loop nest surrounding src operation.
@@ -922,9 +923,10 @@ bool mlir::isLoopParallel(AffineForOp forOp) {
     for (auto *dstOpInst : loadAndStoreOpInsts) {
       MemRefAccess dstAccess(dstOpInst);
       FlatAffineConstraints dependenceConstraints;
-      if (checkMemrefAccessDependence(srcAccess, dstAccess, depth,
-                                      &dependenceConstraints,
-                                      /*dependenceComponents=*/nullptr))
+      DependenceResult result = checkMemrefAccessDependence(
+          srcAccess, dstAccess, depth, &dependenceConstraints,
+          /*dependenceComponents=*/nullptr);
+      if (result.value != DependenceResult::NoDependence)
         return false;
     }
   }
index 0f39e52..829b1b2 100644 (file)
@@ -954,9 +954,10 @@ static unsigned getMaxLoopDepth(ArrayRef<Operation *> loadOpInsts,
       for (unsigned d = 1; d <= numCommonLoops + 1; ++d) {
         FlatAffineConstraints dependenceConstraints;
         // TODO(andydavis) Cache dependence analysis results, check cache here.
-        if (checkMemrefAccessDependence(srcAccess, dstAccess, d,
-                                        &dependenceConstraints,
-                                        /*dependenceComponents=*/nullptr)) {
+        DependenceResult result = checkMemrefAccessDependence(
+            srcAccess, dstAccess, d, &dependenceConstraints,
+            /*dependenceComponents=*/nullptr);
+        if (hasDependence(result)) {
           // Store minimum loop depth and break because we want the min 'd' at
           // which there is a dependence.
           loopDepth = std::min(loopDepth, d - 1);
index 45a11ef..c5676af 100644 (file)
@@ -131,9 +131,10 @@ void MemRefDataFlowOpt::forwardStoreToLoad(LoadOp loadOp) {
     unsigned nsLoops = getNumCommonSurroundingLoops(*loadOpInst, *storeOpInst);
     // Dependences at loop depth <= minSurroundingLoops do NOT matter.
     for (unsigned d = nsLoops + 1; d > minSurroundingLoops; d--) {
-      if (!checkMemrefAccessDependence(srcAccess, destAccess, d,
-                                       &dependenceConstraints,
-                                       /*dependenceComponents=*/nullptr))
+      DependenceResult result = checkMemrefAccessDependence(
+          srcAccess, destAccess, d, &dependenceConstraints,
+          /*dependenceComponents=*/nullptr);
+      if (!hasDependence(result))
         continue;
       depSrcStores.push_back(storeOpInst);
       // Check if this store is a candidate for forwarding; we only forward if