Some loop fusion code cleanup/simplification post cl/229575126
authorUday Bondhugula <bondhugula@google.com>
Wed, 16 Jan 2019 21:13:00 +0000 (13:13 -0800)
committerjpienaar <jpienaar@google.com>
Fri, 29 Mar 2019 22:23:43 +0000 (15:23 -0700)
- enforce the assumptions better / in a simpler way

PiperOrigin-RevId: 229612424

mlir/lib/Transforms/LoopFusion.cpp

index cdd1c77f302a9a18bcc67e465ac7683da8678557..804acba0d5aee5f7f7b8c62c5a62cd9028960e9b 100644 (file)
@@ -465,8 +465,8 @@ static uint64_t getComputeCost(
 } // end anonymous namespace
 
 static Optional<uint64_t> getConstDifference(AffineMap lbMap, AffineMap ubMap) {
-  assert(lbMap.getNumResults() == 1);
-  assert(ubMap.getNumResults() == 1);
+  assert(lbMap.getNumResults() == 1 && "expected single result bound map");
+  assert(ubMap.getNumResults() == 1 && "expected single result bound map");
   assert(lbMap.getNumDims() == ubMap.getNumDims());
   assert(lbMap.getNumSymbols() == ubMap.getNumSymbols());
   // TODO(andydavis) Merge this code with 'mlir::getTripCountExpr'.
@@ -560,33 +560,16 @@ static unsigned getInnermostCommonLoopDepth(ArrayRef<OperationInst *> ops) {
   return loopDepth;
 }
 
-// Returns true if 'map' is a single result constant or single result
-// dim expr where its corresponding loop IV in 'operands' has zero constant
-// lower bound.
-static bool hasZeroMinValue(AffineMap map, ArrayRef<Value *> operands) {
-  if (map.isSingleConstant() && map.getSingleConstantResult() == 0)
-    return true;
-  if (map.getNumResults() != 1 || !map.getResult(0).isa<AffineDimExpr>())
-    return false;
-  // Get operand position of single dim expr result.
-  unsigned pos = map.getResult(0).cast<AffineDimExpr>().getPosition();
-  // Check if loop IV at 'pos' has zero constant lower bound.
-  auto *operand = operands[pos];
-  assert(isa<ForInst>(operand));
-  auto *forInst = cast<ForInst>(operand);
-  return forInst->hasConstantLowerBound() &&
-         forInst->getConstantLowerBound() == 0;
-}
-// Returns the slice bound union of 'sliceStateA' and 'sliceStateB' in
-// 'sliceStateB'.
+// Returns the slice union of 'sliceStateA' and 'sliceStateB' in 'sliceStateB'
+// using a rectangular bounding box.
 // TODO(andydavis) This function assumes that lower bounds for 'sliceStateA'
 // and 'sliceStateB' are aligned.
 // Specifically, when taking the union of overlapping intervals, it assumes
 // that both intervals start at zero. Support needs to be added to take into
 // account interval start offset when computing the union.
 // TODO(andydavis) Move this function to an analysis library.
-static bool getSliceBoundUnion(const ComputationSliceState &sliceStateA,
-                               ComputationSliceState *sliceStateB) {
+static bool getSliceUnion(const ComputationSliceState &sliceStateA,
+                          ComputationSliceState *sliceStateB) {
   assert(sliceStateA.lbs.size() == sliceStateB->lbs.size());
   assert(sliceStateA.ubs.size() == sliceStateB->ubs.size());
 
@@ -597,10 +580,7 @@ static bool getSliceBoundUnion(const ComputationSliceState &sliceStateA,
       assert(ubMapA == AffineMap::Null());
       continue;
     }
-    assert(ubMapA != AffineMap::Null());
-    // Validate that constant lower bounds are aligned at zero.
-    if (!hasZeroMinValue(lbMapA, sliceStateA.lbOperands[i]))
-      return false;
+    assert(ubMapA && "expected non-null ub map");
 
     AffineMap lbMapB = sliceStateB->lbs[i];
     AffineMap ubMapB = sliceStateB->ubs[i];
@@ -611,8 +591,13 @@ static bool getSliceBoundUnion(const ComputationSliceState &sliceStateA,
       sliceStateB->ubs[i] = ubMapA;
       continue;
     }
-    // Validate that constant lower bounds are aligned at zero.
-    if (!hasZeroMinValue(lbMapB, sliceStateB->lbOperands[i]))
+
+    // TODO(andydavis) Change this code to take the min across all lower bounds
+    // and max across all upper bounds for each dimension. This code can for
+    // cases where a unique min or max could not be statically determined.
+
+    // Assumption: both lower bounds are the same.
+    if (lbMapA != lbMapB)
       return false;
 
     // Add bound with the largest trip count to union.
@@ -620,9 +605,7 @@ static bool getSliceBoundUnion(const ComputationSliceState &sliceStateA,
     Optional<uint64_t> tripCountB = getConstDifference(lbMapB, ubMapB);
     if (!tripCountA.hasValue() || !tripCountB.hasValue())
       return false;
-    // TODO(andydavis) Change this code to take the min across all lower bounds
-    // and max across all upper bounds for each dimension. This code can for
-    // cases where a unique min or max could not be statically determined.
+
     if (tripCountA.getValue() > tripCountB.getValue()) {
       sliceStateB->lbs[i] = lbMapA;
       sliceStateB->ubs[i] = ubMapA;
@@ -720,7 +703,7 @@ static bool isFusionProfitable(OperationInst *srcOpInst,
                                                   &tmpSliceState))
         return false;
       // Compute slice boun dunion of 'tmpSliceState' and 'sliceStates[i - 1]'.
-      getSliceBoundUnion(tmpSliceState, &sliceStates[i - 1]);
+      getSliceUnion(tmpSliceState, &sliceStates[i - 1]);
     }
     // Build trip count map for computation slice.
     sliceTripCountMap.clear();