// is more efficient than creating a new flattener for each expression since
// common idenical div and mod expressions appearing across different
// expressions are mapped to the local identifier (same column position in
-// 'cst').
+// 'localVarCst').
struct AffineExprFlattener : public AffineExprVisitor<AffineExprFlattener> {
public:
// Flattend expression layout: [dims, symbols, locals, constant]
// will be, and linearize this to std::vector<int64_t> to prevent
// SmallVector moves on re-allocation.
std::vector<SmallVector<int64_t, 32>> operandExprStack;
- // Constraints connecting newly introduced local variables to existing
- // (dimensional and symbolic) ones.
- FlatAffineConstraints cst;
+ // Constraints connecting newly introduced local variables (for mod's and
+ // div's) to existing (dimensional and symbolic) ones. These are always
+ // inequalities.
+ FlatAffineConstraints localVarCst;
unsigned numDims;
unsigned numSymbols;
: numDims(numDims), numSymbols(numSymbols), numLocals(0),
context(context) {
operandExprStack.reserve(8);
- cst.reset(numDims, numSymbols, numLocals);
+ localVarCst.reset(numDims, numSymbols, numLocals);
}
void visitMulExpr(AffineBinaryOpExpr expr) {
if ((loc = findLocalId(floorDiv)) == -1) {
addLocalId(floorDiv);
lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst;
- // Update cst: 0 <= expr1 - c * expr2 <= c - 1.
- cst.addConstantLowerBound(lhs, 0);
- cst.addConstantUpperBound(lhs, rhsConst - 1);
+ // Update localVarCst: 0 <= expr1 - c * expr2 <= c - 1.
+ localVarCst.addConstantLowerBound(lhs, 0);
+ localVarCst.addConstantUpperBound(lhs, rhsConst - 1);
} else {
// Reuse the existing local id.
lhs[getLocalVarStartIndex() + loc] = -rhsConst;
bound[getLocalVarStartIndex() + numLocals - 1] = rhsConst;
if (!isCeil) {
// q = lhs floordiv c <=> c*q <= lhs <= c*q + c - 1.
- cst.addLowerBound(lhs, bound);
+ localVarCst.addLowerBound(lhs, bound);
bound[bound.size() - 1] = rhsConst - 1;
- cst.addUpperBound(lhs, bound);
+ localVarCst.addUpperBound(lhs, bound);
} else {
// q = lhs ceildiv c <=> c*q - (c - 1) <= lhs <= c*q.
- cst.addUpperBound(lhs, bound);
+ localVarCst.addUpperBound(lhs, bound);
bound[bound.size() - 1] = -(rhsConst - 1);
- cst.addLowerBound(lhs, bound);
+ localVarCst.addLowerBound(lhs, bound);
}
}
// Set the expression on stack to the local var introduced to capture the
}
localExprs.push_back(localExpr);
numLocals++;
- cst.addLocalId(cst.getNumLocalIds());
+ localVarCst.addLocalId(localVarCst.getNumLocalIds());
}
int findLocalId(AffineExpr localExpr) {
static bool getFlattenedAffineExprs(
ArrayRef<AffineExpr> exprs, unsigned numDims, unsigned numSymbols,
std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs,
- FlatAffineConstraints *cst) {
+ FlatAffineConstraints *localVarCst) {
if (exprs.empty()) {
- cst->reset(numDims, numSymbols);
+ localVarCst->reset(numDims, numSymbols);
return true;
}
flattenedExprs->push_back(flattenedExpr);
flattener.operandExprStack.pop_back();
}
- if (cst)
- cst->clearAndCopyFrom(flattener.cst);
+ if (localVarCst)
+ localVarCst->clearAndCopyFrom(flattener.localVarCst);
return true;
}
bool mlir::getFlattenedAffineExpr(AffineExpr expr, unsigned numDims,
unsigned numSymbols,
llvm::SmallVectorImpl<int64_t> *flattenedExpr,
- FlatAffineConstraints *cst) {
+ FlatAffineConstraints *localVarCst) {
std::vector<SmallVector<int64_t, 8>> flattenedExprs;
bool ret = ::getFlattenedAffineExprs({expr}, numDims, numSymbols,
- &flattenedExprs, cst);
+ &flattenedExprs, localVarCst);
*flattenedExpr = flattenedExprs[0];
return ret;
}
/// handled yet).
bool mlir::getFlattenedAffineExprs(
AffineMap map, std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs,
- FlatAffineConstraints *cst) {
+ FlatAffineConstraints *localVarCst) {
if (map.getNumResults() == 0) {
- cst->reset(map.getNumDims(), map.getNumSymbols());
+ localVarCst->reset(map.getNumDims(), map.getNumSymbols());
return true;
}
return ::getFlattenedAffineExprs(map.getResults(), map.getNumDims(),
- map.getNumSymbols(), flattenedExprs, cst);
+ map.getNumSymbols(), flattenedExprs,
+ localVarCst);
}
bool mlir::getFlattenedAffineExprs(
IntegerSet set, std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs,
- FlatAffineConstraints *cst) {
+ FlatAffineConstraints *localVarCst) {
if (set.getNumConstraints() == 0) {
- cst->reset(set.getNumDims(), set.getNumSymbols());
+ localVarCst->reset(set.getNumDims(), set.getNumSymbols());
return true;
}
return ::getFlattenedAffineExprs(set.getConstraints(), set.getNumDims(),
- set.getNumSymbols(), flattenedExprs, cst);
+ set.getNumSymbols(), flattenedExprs,
+ localVarCst);
}
/// Returns the sequence of AffineApplyOp OperationStmts operation in
unsigned dstNumSymbols = dstCtx.domain.getNumSymbolIds();
unsigned dstNumIds = dstNumDims + dstNumSymbols;
- unsigned outputNumDims = dependenceDomain->getNumDimIds();
- unsigned outputNumSymbols = dependenceDomain->getNumSymbolIds();
- unsigned outputNumIds = outputNumDims + outputNumSymbols;
-
- SmallVector<int64_t, 4> ineq;
- ineq.resize(outputNumIds + 1);
+ SmallVector<int64_t, 4> ineq(dependenceDomain->getNumCols());
// Add inequalities from src domain.
for (unsigned i = 0; i < srcNumIneq; ++i) {
// Zero fill.
ineq[valuePosMap.getSrcDimOrSymPos(srcCtx.values[j])] =
srcCtx.domain.atIneq(i, j);
// Set constant term.
- ineq[outputNumIds] = srcCtx.domain.atIneq(i, srcNumIds);
+ ineq[ineq.size() - 1] = srcCtx.domain.atIneq(i, srcNumIds);
// Add inequality constraint.
dependenceDomain->addInequality(ineq);
}
ineq[valuePosMap.getDstDimOrSymPos(dstCtx.values[j])] =
dstCtx.domain.atIneq(i, j);
// Set constant term.
- ineq[outputNumIds] = dstCtx.domain.atIneq(i, dstNumIds);
+ ineq[ineq.size() - 1] = dstCtx.domain.atIneq(i, dstNumIds);
// Add inequality constraint.
dependenceDomain->addInequality(ineq);
}
// a0 -c0 (a1 - c1) (a1 - c2) = 0
// b0 -f0 (b1 - f1) (b1 - f2) = 0
//
-// Returns false if any AffineExpr cannot be flattened (which will be removed
-// when mod/floor/ceil support is added). Returns true otherwise.
+// Returns false if any AffineExpr cannot be flattened (due to it being
+// semi-affine). Returns true otherwise.
static bool
addMemRefAccessConstraints(const AffineValueMap &srcAccessMap,
const AffineValueMap &dstAccessMap,
assert(srcMap.getNumResults() == dstMap.getNumResults());
unsigned numResults = srcMap.getNumResults();
- unsigned srcNumDims = srcMap.getNumDims();
- unsigned srcNumSymbols = srcMap.getNumSymbols();
- unsigned srcNumIds = srcNumDims + srcNumSymbols;
ArrayRef<MLValue *> srcOperands = srcAccessMap.getOperands();
-
- unsigned dstNumDims = dstMap.getNumDims();
- unsigned dstNumSymbols = dstMap.getNumSymbols();
- unsigned dstNumIds = dstNumDims + dstNumSymbols;
ArrayRef<MLValue *> dstOperands = dstAccessMap.getOperands();
- unsigned outputNumDims = dependenceDomain->getNumDimIds();
- unsigned outputNumSymbols = dependenceDomain->getNumSymbolIds();
- unsigned outputNumIds = outputNumDims + outputNumSymbols;
+ std::vector<SmallVector<int64_t, 8>> srcFlatExprs;
+ std::vector<SmallVector<int64_t, 8>> destFlatExprs;
+ FlatAffineConstraints srcLocalVarCst, destLocalVarCst;
+ // Get flattened expressions for the source destination maps.
+ if (!getFlattenedAffineExprs(srcMap, &srcFlatExprs, &srcLocalVarCst) ||
+ !getFlattenedAffineExprs(dstMap, &destFlatExprs, &destLocalVarCst))
+ return false;
+
+ unsigned numLocalIdsToAdd =
+ srcLocalVarCst.getNumLocalIds() + destLocalVarCst.getNumLocalIds();
+ for (unsigned i = 0; i < numLocalIdsToAdd; i++) {
+ dependenceDomain->addLocalId(dependenceDomain->getNumLocalIds());
+ }
+
+ unsigned numDims = dependenceDomain->getNumDimIds();
+ unsigned numSymbols = dependenceDomain->getNumSymbolIds();
+ unsigned numSrcLocalIds = srcLocalVarCst.getNumLocalIds();
- SmallVector<int64_t, 4> eq(outputNumIds + 1);
- SmallVector<int64_t, 4> flattenedExpr;
+ // Equality to add.
+ SmallVector<int64_t, 8> eq(dependenceDomain->getNumCols());
for (unsigned i = 0; i < numResults; ++i) {
// Zero fill.
std::fill(eq.begin(), eq.end(), 0);
- // Get flattened AffineExpr for result 'i' from src access function.
- auto srcExpr = srcMap.getResult(i);
- flattenedExpr.clear();
- if (!getFlattenedAffineExpr(srcExpr, srcNumDims, srcNumSymbols,
- &flattenedExpr))
- return false;
+
+ // Flattened AffineExpr for src result 'i'.
+ const auto &srcFlatExpr = srcFlatExprs[i];
// Set identifier coefficients from src access function.
- for (unsigned j = 0, e = srcOperands.size(); j < e; ++j)
- eq[valuePosMap.getSrcDimOrSymPos(srcOperands[j])] = flattenedExpr[j];
+ unsigned j, e;
+ for (j = 0, e = srcOperands.size(); j < e; ++j)
+ eq[valuePosMap.getSrcDimOrSymPos(srcOperands[j])] = srcFlatExpr[j];
+ // Local terms.
+ for (e = srcFlatExpr.size() - 1; j < e; j++) {
+ eq[numDims + numSymbols + j] = srcFlatExpr[j];
+ }
// Set constant term.
- eq[outputNumIds] = flattenedExpr[srcNumIds];
+ eq[eq.size() - 1] = srcFlatExpr[srcFlatExpr.size() - 1];
- // Get flattened AffineExpr for result 'i' from dst access function.
- auto dstExpr = dstMap.getResult(i);
- flattenedExpr.clear();
- if (!getFlattenedAffineExpr(dstExpr, dstNumDims, dstNumSymbols,
- &flattenedExpr))
- return false;
+ // Flattened AffineExpr for dest result 'i'.
+ const auto &destFlatExpr = destFlatExprs[i];
// Set identifier coefficients from dst access function.
for (unsigned j = 0, e = dstOperands.size(); j < e; ++j)
- eq[valuePosMap.getDstDimOrSymPos(dstOperands[j])] -= flattenedExpr[j];
+ eq[valuePosMap.getDstDimOrSymPos(dstOperands[j])] -= destFlatExpr[j];
+ // Local terms.
+ for (e = destFlatExpr.size() - 1; j < e; j++) {
+ eq[numDims + numSymbols + numSrcLocalIds + j] = destFlatExpr[j];
+ }
// Set constant term.
- eq[outputNumIds] -= flattenedExpr[dstNumIds];
+ eq[eq.size() - 1] -= destFlatExpr[destFlatExpr.size() - 1];
+
// Add equality constraint.
dependenceDomain->addEquality(eq);
}
addEqForConstOperands(srcOperands);
// Add equality constraints for any dst symbols defined by constant ops.
addEqForConstOperands(dstOperands);
+
+ // TODO(bondhugula): add srcLocalVarCst, destLocalVarCst to the dependence
+ // domain.
return true;
}