Temp change in FlatAffineConstraints::getSliceBounds() to deal with TODO in
authorUday Bondhugula <bondhugula@google.com>
Wed, 27 Feb 2019 01:32:47 +0000 (17:32 -0800)
committerjpienaar <jpienaar@google.com>
Fri, 29 Mar 2019 23:45:23 +0000 (16:45 -0700)
LoopFusion

- getConstDifference in LoopFusion is pending a refactoring to handle bounds
  with min's and max's; it currently asserts on some useful test cases that we
  want to experiment with. This CL changes getSliceBounds to be more
  conservative so as to not trigger the assertion. Filed b/126426796 to track this.

PiperOrigin-RevId: 235826538

mlir/lib/Analysis/AffineStructures.cpp
mlir/lib/Transforms/LoopFusion.cpp

index 32129166afa6ec35fce036e4d34d466f9b42b547..276db4712c53e6974f3a478fab8135c4d3e4e382 100644 (file)
@@ -1433,9 +1433,12 @@ void FlatAffineConstraints::getSliceBounds(unsigned num, MLIRContext *context,
     if (expr)
       expr = simplifyAffineExpr(expr, numMapDims, numMapSymbols);
 
+    AffineMap &lbMap = (*lbMaps)[pos];
+    AffineMap &ubMap = (*ubMaps)[pos];
+
     if (expr) {
-      (*lbMaps)[pos] = AffineMap::get(numMapDims, numMapSymbols, expr, {});
-      (*ubMaps)[pos] = AffineMap::get(numMapDims, numMapSymbols, expr + 1, {});
+      lbMap = AffineMap::get(numMapDims, numMapSymbols, expr, {});
+      ubMap = AffineMap::get(numMapDims, numMapSymbols, expr + 1, {});
     } else {
       // TODO(bondhugula): Whenever there have local identifiers in the
       // dependence constraints, we'll conservatively over-approximate, since we
@@ -1448,38 +1451,40 @@ void FlatAffineConstraints::getSliceBounds(unsigned num, MLIRContext *context,
           // redundant loop bounds.
           tmpClone->removeRedundantInequalities();
         }
-        std::tie((*lbMaps)[pos], (*ubMaps)[pos]) =
-            tmpClone->getLowerAndUpperBound(pos, num, getNumDimIds(), {},
-                                            context);
+        std::tie(lbMap, ubMap) = tmpClone->getLowerAndUpperBound(
+            pos, num, getNumDimIds(), {}, context);
       }
 
       // If the above fails, we'll just use the constant lower bound and the
       // constant upper bound (if they exist) as the slice bounds.
-      if (!(*lbMaps)[pos]) {
+      // TODO(b/126426796): being conservative for the moment in cases that
+      // lead to multiple bounds - until getConstDifference in LoopFusion.cpp is
+      // fixed (b/126426796).
+      if (!lbMap || lbMap.getNumResults() > 1) {
         LLVM_DEBUG(llvm::dbgs()
                    << "WARNING: Potentially over-approximating slice lb\n");
         auto lbConst = getConstantLowerBound(pos);
         if (lbConst.hasValue()) {
-          (*lbMaps)[pos] = AffineMap::get(
+          lbMap = AffineMap::get(
               numMapDims, numMapSymbols,
               getAffineConstantExpr(lbConst.getValue(), context), {});
         }
       }
-      if (!(*ubMaps)[pos]) {
+      if (!ubMap || ubMap.getNumResults() > 1) {
         LLVM_DEBUG(llvm::dbgs()
                    << "WARNING: Potentially over-approximating slice ub\n");
         auto ubConst = getConstantUpperBound(pos);
         if (ubConst.hasValue()) {
-          (*ubMaps)[pos] = AffineMap::get(
+          (ubMap) = AffineMap::get(
               numMapDims, numMapSymbols,
               getAffineConstantExpr(ubConst.getValue() + 1, context), {});
         }
       }
     }
     LLVM_DEBUG(llvm::dbgs() << "lb map for pos = " << Twine(pos) << ", expr: ");
-    LLVM_DEBUG((*lbMaps)[pos].dump(););
+    LLVM_DEBUG(lbMap.dump(););
     LLVM_DEBUG(llvm::dbgs() << "ub map for pos = " << Twine(pos) << ", expr: ");
-    LLVM_DEBUG((*ubMaps)[pos].dump(););
+    LLVM_DEBUG(ubMap.dump(););
   }
 }
 
index 72176bacf9a1d56c707f42e41756b841e6c8fb29..0f4e45c372ae03bfa57621bb92f672363d422337 100644 (file)
@@ -704,13 +704,12 @@ static int64_t getComputeCost(
 
 } // end anonymous namespace
 
+// TODO(andydavis,b/126426796): extend this to handle multiple result maps.
 static Optional<uint64_t> getConstDifference(AffineMap lbMap, AffineMap ubMap) {
   assert(lbMap.getNumResults() == 1 && "expected single result bound map");
   assert(ubMap.getNumResults() == 1 && "expected single result bound map");
   assert(lbMap.getNumDims() == ubMap.getNumDims());
   assert(lbMap.getNumSymbols() == ubMap.getNumSymbols());
-  // TODO(andydavis) Merge this code with 'mlir::getTripCountExpr'.
-  // ub_expr - lb_expr
   AffineExpr lbExpr(lbMap.getResult(0));
   AffineExpr ubExpr(ubMap.getResult(0));
   auto loopSpanExpr = simplifyAffineExpr(ubExpr - lbExpr, lbMap.getNumDims(),