Refactor/update memref-dep-check's addMemRefAccessConstraints and
authorUday Bondhugula <bondhugula@google.com>
Tue, 18 Dec 2018 04:16:37 +0000 (20:16 -0800)
committerjpienaar <jpienaar@google.com>
Fri, 29 Mar 2019 21:31:58 +0000 (14:31 -0700)
addDomainConstraints; add support for mod/div for dependence testing.

- add support for mod/div expressions in dependence analysis
- refactor addMemRefAccessConstraints to use getFlattenedAffineExprs (instead
  of getFlattenedAffineExpr); update addDomainConstraints.
- rename AffineExprFlattener::cst -> localVarCst

PiperOrigin-RevId: 225933306

mlir/lib/Analysis/AffineAnalysis.cpp
mlir/lib/Analysis/AffineStructures.cpp
mlir/test/Transforms/memref-dependence-check.mlir

index 12d70cacb55416e48fc8b2cc8595edfc45a49ee2..9c344d4fb87f34068a0cf0279070b2eb93aa9327 100644 (file)
@@ -120,7 +120,7 @@ namespace {
 // is more efficient than creating a new flattener for each expression since
 // common idenical div and mod expressions appearing across different
 // expressions are mapped to the local identifier (same column position in
-// 'cst').
+// 'localVarCst').
 struct AffineExprFlattener : public AffineExprVisitor<AffineExprFlattener> {
 public:
   // Flattend expression layout: [dims, symbols, locals, constant]
@@ -129,9 +129,10 @@ public:
   // will be, and linearize this to std::vector<int64_t> to prevent
   // SmallVector moves on re-allocation.
   std::vector<SmallVector<int64_t, 32>> operandExprStack;
-  // Constraints connecting newly introduced local variables to existing
-  // (dimensional and symbolic) ones.
-  FlatAffineConstraints cst;
+  // Constraints connecting newly introduced local variables (for mod's and
+  // div's) to existing (dimensional and symbolic) ones. These are always
+  // inequalities.
+  FlatAffineConstraints localVarCst;
 
   unsigned numDims;
   unsigned numSymbols;
@@ -153,7 +154,7 @@ public:
       : numDims(numDims), numSymbols(numSymbols), numLocals(0),
         context(context) {
     operandExprStack.reserve(8);
-    cst.reset(numDims, numSymbols, numLocals);
+    localVarCst.reset(numDims, numSymbols, numLocals);
   }
 
   void visitMulExpr(AffineBinaryOpExpr expr) {
@@ -214,9 +215,9 @@ public:
     if ((loc = findLocalId(floorDiv)) == -1) {
       addLocalId(floorDiv);
       lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst;
-      // Update cst:  0 <= expr1 - c * expr2  <= c - 1.
-      cst.addConstantLowerBound(lhs, 0);
-      cst.addConstantUpperBound(lhs, rhsConst - 1);
+      // Update localVarCst:  0 <= expr1 - c * expr2  <= c - 1.
+      localVarCst.addConstantLowerBound(lhs, 0);
+      localVarCst.addConstantUpperBound(lhs, rhsConst - 1);
     } else {
       // Reuse the existing local id.
       lhs[getLocalVarStartIndex() + loc] = -rhsConst;
@@ -305,14 +306,14 @@ private:
       bound[getLocalVarStartIndex() + numLocals - 1] = rhsConst;
       if (!isCeil) {
         // q = lhs floordiv c  <=>  c*q <= lhs <= c*q + c - 1.
-        cst.addLowerBound(lhs, bound);
+        localVarCst.addLowerBound(lhs, bound);
         bound[bound.size() - 1] = rhsConst - 1;
-        cst.addUpperBound(lhs, bound);
+        localVarCst.addUpperBound(lhs, bound);
       } else {
         // q = lhs ceildiv c  <=>  c*q - (c - 1) <= lhs <= c*q.
-        cst.addUpperBound(lhs, bound);
+        localVarCst.addUpperBound(lhs, bound);
         bound[bound.size() - 1] = -(rhsConst - 1);
-        cst.addLowerBound(lhs, bound);
+        localVarCst.addLowerBound(lhs, bound);
       }
     }
     // Set the expression on stack to the local var introduced to capture the
@@ -333,7 +334,7 @@ private:
     }
     localExprs.push_back(localExpr);
     numLocals++;
-    cst.addLocalId(cst.getNumLocalIds());
+    localVarCst.addLocalId(localVarCst.getNumLocalIds());
   }
 
   int findLocalId(AffineExpr localExpr) {
@@ -409,9 +410,9 @@ AffineExpr mlir::composeWithUnboundedMap(AffineExpr e, AffineMap g) {
 static bool getFlattenedAffineExprs(
     ArrayRef<AffineExpr> exprs, unsigned numDims, unsigned numSymbols,
     std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs,
-    FlatAffineConstraints *cst) {
+    FlatAffineConstraints *localVarCst) {
   if (exprs.empty()) {
-    cst->reset(numDims, numSymbols);
+    localVarCst->reset(numDims, numSymbols);
     return true;
   }
 
@@ -435,8 +436,8 @@ static bool getFlattenedAffineExprs(
     flattenedExprs->push_back(flattenedExpr);
     flattener.operandExprStack.pop_back();
   }
-  if (cst)
-    cst->clearAndCopyFrom(flattener.cst);
+  if (localVarCst)
+    localVarCst->clearAndCopyFrom(flattener.localVarCst);
 
   return true;
 }
@@ -447,10 +448,10 @@ static bool getFlattenedAffineExprs(
 bool mlir::getFlattenedAffineExpr(AffineExpr expr, unsigned numDims,
                                   unsigned numSymbols,
                                   llvm::SmallVectorImpl<int64_t> *flattenedExpr,
-                                  FlatAffineConstraints *cst) {
+                                  FlatAffineConstraints *localVarCst) {
   std::vector<SmallVector<int64_t, 8>> flattenedExprs;
   bool ret = ::getFlattenedAffineExprs({expr}, numDims, numSymbols,
-                                       &flattenedExprs, cst);
+                                       &flattenedExprs, localVarCst);
   *flattenedExpr = flattenedExprs[0];
   return ret;
 }
@@ -460,24 +461,26 @@ bool mlir::getFlattenedAffineExpr(AffineExpr expr, unsigned numDims,
 /// handled yet).
 bool mlir::getFlattenedAffineExprs(
     AffineMap map, std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs,
-    FlatAffineConstraints *cst) {
+    FlatAffineConstraints *localVarCst) {
   if (map.getNumResults() == 0) {
-    cst->reset(map.getNumDims(), map.getNumSymbols());
+    localVarCst->reset(map.getNumDims(), map.getNumSymbols());
     return true;
   }
   return ::getFlattenedAffineExprs(map.getResults(), map.getNumDims(),
-                                   map.getNumSymbols(), flattenedExprs, cst);
+                                   map.getNumSymbols(), flattenedExprs,
+                                   localVarCst);
 }
 
 bool mlir::getFlattenedAffineExprs(
     IntegerSet set, std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs,
-    FlatAffineConstraints *cst) {
+    FlatAffineConstraints *localVarCst) {
   if (set.getNumConstraints() == 0) {
-    cst->reset(set.getNumDims(), set.getNumSymbols());
+    localVarCst->reset(set.getNumDims(), set.getNumSymbols());
     return true;
   }
   return ::getFlattenedAffineExprs(set.getConstraints(), set.getNumDims(),
-                                   set.getNumSymbols(), flattenedExprs, cst);
+                                   set.getNumSymbols(), flattenedExprs,
+                                   localVarCst);
 }
 
 /// Returns the sequence of AffineApplyOp OperationStmts operation in
@@ -760,12 +763,7 @@ static void addDomainConstraints(const IterationDomainContext &srcCtx,
   unsigned dstNumSymbols = dstCtx.domain.getNumSymbolIds();
   unsigned dstNumIds = dstNumDims + dstNumSymbols;
 
-  unsigned outputNumDims = dependenceDomain->getNumDimIds();
-  unsigned outputNumSymbols = dependenceDomain->getNumSymbolIds();
-  unsigned outputNumIds = outputNumDims + outputNumSymbols;
-
-  SmallVector<int64_t, 4> ineq;
-  ineq.resize(outputNumIds + 1);
+  SmallVector<int64_t, 4> ineq(dependenceDomain->getNumCols());
   // Add inequalities from src domain.
   for (unsigned i = 0; i < srcNumIneq; ++i) {
     // Zero fill.
@@ -775,7 +773,7 @@ static void addDomainConstraints(const IterationDomainContext &srcCtx,
       ineq[valuePosMap.getSrcDimOrSymPos(srcCtx.values[j])] =
           srcCtx.domain.atIneq(i, j);
     // Set constant term.
-    ineq[outputNumIds] = srcCtx.domain.atIneq(i, srcNumIds);
+    ineq[ineq.size() - 1] = srcCtx.domain.atIneq(i, srcNumIds);
     // Add inequality constraint.
     dependenceDomain->addInequality(ineq);
   }
@@ -788,7 +786,7 @@ static void addDomainConstraints(const IterationDomainContext &srcCtx,
       ineq[valuePosMap.getDstDimOrSymPos(dstCtx.values[j])] =
           dstCtx.domain.atIneq(i, j);
     // Set constant term.
-    ineq[outputNumIds] = dstCtx.domain.atIneq(i, dstNumIds);
+    ineq[ineq.size() - 1] = dstCtx.domain.atIneq(i, dstNumIds);
     // Add inequality constraint.
     dependenceDomain->addInequality(ineq);
   }
@@ -815,8 +813,8 @@ static void addDomainConstraints(const IterationDomainContext &srcCtx,
 //   a0     -c0      (a1 - c1)  (a1 - c2) = 0
 //   b0     -f0      (b1 - f1)  (b1 - f2) = 0
 //
-// Returns false if any AffineExpr cannot be flattened (which will be removed
-// when mod/floor/ceil support is added). Returns true otherwise.
+// Returns false if any AffineExpr cannot be flattened (due to it being
+// semi-affine). Returns true otherwise.
 static bool
 addMemRefAccessConstraints(const AffineValueMap &srcAccessMap,
                            const AffineValueMap &dstAccessMap,
@@ -827,48 +825,58 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap,
   assert(srcMap.getNumResults() == dstMap.getNumResults());
   unsigned numResults = srcMap.getNumResults();
 
-  unsigned srcNumDims = srcMap.getNumDims();
-  unsigned srcNumSymbols = srcMap.getNumSymbols();
-  unsigned srcNumIds = srcNumDims + srcNumSymbols;
   ArrayRef<MLValue *> srcOperands = srcAccessMap.getOperands();
-
-  unsigned dstNumDims = dstMap.getNumDims();
-  unsigned dstNumSymbols = dstMap.getNumSymbols();
-  unsigned dstNumIds = dstNumDims + dstNumSymbols;
   ArrayRef<MLValue *> dstOperands = dstAccessMap.getOperands();
 
-  unsigned outputNumDims = dependenceDomain->getNumDimIds();
-  unsigned outputNumSymbols = dependenceDomain->getNumSymbolIds();
-  unsigned outputNumIds = outputNumDims + outputNumSymbols;
+  std::vector<SmallVector<int64_t, 8>> srcFlatExprs;
+  std::vector<SmallVector<int64_t, 8>> destFlatExprs;
+  FlatAffineConstraints srcLocalVarCst, destLocalVarCst;
+  // Get flattened expressions for the source destination maps.
+  if (!getFlattenedAffineExprs(srcMap, &srcFlatExprs, &srcLocalVarCst) ||
+      !getFlattenedAffineExprs(dstMap, &destFlatExprs, &destLocalVarCst))
+    return false;
+
+  unsigned numLocalIdsToAdd =
+      srcLocalVarCst.getNumLocalIds() + destLocalVarCst.getNumLocalIds();
+  for (unsigned i = 0; i < numLocalIdsToAdd; i++) {
+    dependenceDomain->addLocalId(dependenceDomain->getNumLocalIds());
+  }
+
+  unsigned numDims = dependenceDomain->getNumDimIds();
+  unsigned numSymbols = dependenceDomain->getNumSymbolIds();
+  unsigned numSrcLocalIds = srcLocalVarCst.getNumLocalIds();
 
-  SmallVector<int64_t, 4> eq(outputNumIds + 1);
-  SmallVector<int64_t, 4> flattenedExpr;
+  // Equality to add.
+  SmallVector<int64_t, 8> eq(dependenceDomain->getNumCols());
   for (unsigned i = 0; i < numResults; ++i) {
     // Zero fill.
     std::fill(eq.begin(), eq.end(), 0);
-    // Get flattened AffineExpr for result 'i' from src access function.
-    auto srcExpr = srcMap.getResult(i);
-    flattenedExpr.clear();
-    if (!getFlattenedAffineExpr(srcExpr, srcNumDims, srcNumSymbols,
-                                &flattenedExpr))
-      return false;
+
+    // Flattened AffineExpr for src result 'i'.
+    const auto &srcFlatExpr = srcFlatExprs[i];
     // Set identifier coefficients from src access function.
-    for (unsigned j = 0, e = srcOperands.size(); j < e; ++j)
-      eq[valuePosMap.getSrcDimOrSymPos(srcOperands[j])] = flattenedExpr[j];
+    unsigned j, e;
+    for (j = 0, e = srcOperands.size(); j < e; ++j)
+      eq[valuePosMap.getSrcDimOrSymPos(srcOperands[j])] = srcFlatExpr[j];
+    // Local terms.
+    for (e = srcFlatExpr.size() - 1; j < e; j++) {
+      eq[numDims + numSymbols + j] = srcFlatExpr[j];
+    }
     // Set constant term.
-    eq[outputNumIds] = flattenedExpr[srcNumIds];
+    eq[eq.size() - 1] = srcFlatExpr[srcFlatExpr.size() - 1];
 
-    // Get flattened AffineExpr for result 'i' from dst access function.
-    auto dstExpr = dstMap.getResult(i);
-    flattenedExpr.clear();
-    if (!getFlattenedAffineExpr(dstExpr, dstNumDims, dstNumSymbols,
-                                &flattenedExpr))
-      return false;
+    // Flattened AffineExpr for dest result 'i'.
+    const auto &destFlatExpr = destFlatExprs[i];
     // Set identifier coefficients from dst access function.
     for (unsigned j = 0, e = dstOperands.size(); j < e; ++j)
-      eq[valuePosMap.getDstDimOrSymPos(dstOperands[j])] -= flattenedExpr[j];
+      eq[valuePosMap.getDstDimOrSymPos(dstOperands[j])] -= destFlatExpr[j];
+    // Local terms.
+    for (e = destFlatExpr.size() - 1; j < e; j++) {
+      eq[numDims + numSymbols + numSrcLocalIds + j] = destFlatExpr[j];
+    }
     // Set constant term.
-    eq[outputNumIds] -= flattenedExpr[dstNumIds];
+    eq[eq.size() - 1] -= destFlatExpr[destFlatExpr.size() - 1];
+
     // Add equality constraint.
     dependenceDomain->addEquality(eq);
   }
@@ -894,6 +902,9 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap,
   addEqForConstOperands(srcOperands);
   // Add equality constraints for any dst symbols defined by constant ops.
   addEqForConstOperands(dstOperands);
+
+  // TODO(bondhugula): add srcLocalVarCst, destLocalVarCst to the dependence
+  // domain.
   return true;
 }
 
index 9d14405427a6c915348b3ca805f514d4af3bac67..d3c3e4ab04f19a8c081021ea99f0e13095b8ba45 100644 (file)
@@ -518,13 +518,13 @@ FlatAffineConstraints::FlatAffineConstraints(IntegerSet set)
 
   // Flatten expressions and add them to the constraint system.
   std::vector<SmallVector<int64_t, 8>> flatExprs;
-  FlatAffineConstraints cst;
-  if (!getFlattenedAffineExprs(set, &flatExprs, &cst)) {
+  FlatAffineConstraints localVarCst;
+  if (!getFlattenedAffineExprs(set, &flatExprs, &localVarCst)) {
     assert(false && "flattening unimplemented for semi-affine integer sets");
     return;
   }
   assert(flatExprs.size() == set.getNumConstraints());
-  for (unsigned l = 0, e = cst.getNumLocalIds(); l < e; l++) {
+  for (unsigned l = 0, e = localVarCst.getNumLocalIds(); l < e; l++) {
     addLocalId(getNumLocalIds());
   }
 
@@ -538,7 +538,7 @@ FlatAffineConstraints::FlatAffineConstraints(IntegerSet set)
     }
   }
   // Add the other constraints involving local id's from flattening.
-  append(cst);
+  append(localVarCst);
 }
 
 void FlatAffineConstraints::reset(unsigned numReservedInequalities,
@@ -1282,13 +1282,13 @@ bool FlatAffineConstraints::addBoundsFromForStmt(const ForStmt &forStmt) {
     auto boundMap =
         lower ? forStmt.getLowerBoundMap() : forStmt.getUpperBoundMap();
 
-    FlatAffineConstraints cst;
+    FlatAffineConstraints localVarCst;
     std::vector<SmallVector<int64_t, 8>> flatExprs;
-    if (!getFlattenedAffineExprs(boundMap, &flatExprs, &cst)) {
+    if (!getFlattenedAffineExprs(boundMap, &flatExprs, &localVarCst)) {
       LLVM_DEBUG(llvm::dbgs() << "semi-affine expressions not yet supported\n");
       return false;
     }
-    if (cst.getNumLocalIds() > 0) {
+    if (localVarCst.getNumLocalIds() > 0) {
       LLVM_DEBUG(llvm::dbgs()
                  << "loop bounds with mod/floordiv expr's not yet supported\n");
       return false;
index 26f2738c9940ad629ea5a6c55ec0d7ecfbb8780b..eb22510f20ad7507a2f90cd20e87b7cd74bfaacf 100644 (file)
@@ -111,9 +111,9 @@ mlfunc @store_load_different_symbols(%arg0 : index, %arg1 : index) {
 mlfunc @store_load_diff_element_affine_apply_const() {
   %m = alloc() : memref<100xf32>
   %c1 = constant 1 : index
-  %c7 = constant 7.0 : f32
+  %c8 = constant 8.0 : f32
   %a0 = affine_apply (d0) -> (d0) (%c1)
-  store %c7, %m[%a0] : memref<100xf32>
+  store %c8, %m[%a0] : memref<100xf32>
   // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}}
   // expected-note@-2 {{dependence from 0 to 1 at depth 1 = false}}
   %a1 = affine_apply (d0) -> (d0 + 1) (%c1)
@@ -565,3 +565,28 @@ mlfunc @war_raw_waw_deps() {
   }
   return
 }
+
+// -----
+// CHECK-LABEL: mlfunc @mod_deps() {
+mlfunc @mod_deps() {
+  %m = alloc() : memref<100xf32>
+  %c7 = constant 7.0 : f32
+  for %i0 = 0 to 10 {
+    %a0 = affine_apply (d0) -> (d0 mod 2) (%i0)
+    // Results are conservative here since constraint information after
+    // flattening isn't being completely added. Will be done in the next CL.
+    // The third and the fifth dependence below shouldn't have existed.
+    %v0 = load %m[%a0] : memref<100xf32>
+    // expected-note@-1 {{dependence from 0 to 0 at depth 1 = false}}
+    // expected-note@-2 {{dependence from 0 to 0 at depth 2 = false}}
+    // expected-note@-3 {{dependence from 0 to 1 at depth 1 = [1, 9]}}
+    // expected-note@-4 {{dependence from 0 to 1 at depth 2 = false}}
+    %a1 = affine_apply (d0) -> ( (d0 + 1) mod 2) (%i0)
+    store %c7, %m[%a1] : memref<100xf32>
+    // expected-note@-1 {{dependence from 1 to 0 at depth 1 = [1, 9]}}
+    // expected-note@-2 {{dependence from 1 to 0 at depth 2 = false}}
+    // expected-note@-3 {{dependence from 1 to 1 at depth 1 = [2, 9]}}
+    // expected-note@-4 {{dependence from 1 to 1 at depth 2 = false}}
+  }
+  return
+}