/// O(VC) time.
void removeRedundantConstraints();
+ /// Converts identifiers in the column range [idStart, idLimit) to local
+ /// variables.
+ void convertDimToLocal(unsigned dimStart, unsigned dimLimit);
+
/// Merge local ids of `this` and `other`. This is done by appending local ids
/// of `other` to `this` and inserting local ids of `this` to `other` at start
/// of its local ids.
numLocals + 1,
numDims, numSymbols, numLocals, valArgs) {}
+ FlatAffineValueConstraints(const FlatAffineConstraints &fac,
+ ArrayRef<Optional<Value>> valArgs = {})
+ : FlatAffineConstraints(fac) {
+ assert(valArgs.empty() || valArgs.size() == numIds);
+ if (valArgs.empty())
+ values.resize(numIds, None);
+ else
+ values.append(valArgs.begin(), valArgs.end());
+ }
+
/// Create a flat affine constraint system from an AffineValueMap or a list of
/// these. The constructed system will only include equalities.
explicit FlatAffineValueConstraints(const AffineValueMap &avm);
using FlatAffineConstraints::insertDimId;
unsigned insertSymbolId(unsigned pos, ValueRange vals);
using FlatAffineConstraints::insertSymbolId;
- unsigned insertId(IdKind kind, unsigned pos, unsigned num = 1) override;
+ virtual unsigned insertId(IdKind kind, unsigned pos,
+ unsigned num = 1) override;
unsigned insertId(IdKind kind, unsigned pos, ValueRange vals);
/// Append identifiers of the specified kind after the last identifier of that
/// Removes identifiers in the column range [idStart, idLimit), and copies any
/// remaining valid data into place, updates member variables, and resizes
/// arrays as needed.
- void removeIdRange(unsigned idStart, unsigned idLimit) override;
+ virtual void removeIdRange(unsigned idStart, unsigned idLimit) override;
/// Eliminates the identifier at the specified position using Fourier-Motzkin
/// variable elimination, but uses Gaussian elimination if there is an
SmallVector<Optional<Value>, 8> values;
};
+/// A FlatAffineRelation represents a set of ordered pairs (domain -> range)
+/// where "domain" and "range" are tuples of identifiers. The relation is
+/// represented as a FlatAffineValueConstraints with separation of dimension
+/// identifiers into domain and range. The identifiers are stored as:
+/// [domainIds, rangeIds, symbolIds, localIds, constant].
+class FlatAffineRelation : public FlatAffineValueConstraints {
+public:
+ FlatAffineRelation(unsigned numReservedInequalities,
+ unsigned numReservedEqualities, unsigned numReservedCols,
+ unsigned numDomainDims, unsigned numRangeDims,
+ unsigned numSymbols, unsigned numLocals,
+ ArrayRef<Optional<Value>> valArgs = {})
+ : FlatAffineValueConstraints(
+ numReservedInequalities, numReservedEqualities, numReservedCols,
+ numDomainDims + numRangeDims, numSymbols, numLocals, valArgs),
+ numDomainDims(numDomainDims), numRangeDims(numRangeDims) {}
+
+ FlatAffineRelation(unsigned numDomainDims = 0, unsigned numRangeDims = 0,
+ unsigned numSymbols = 0, unsigned numLocals = 0)
+ : FlatAffineValueConstraints(numDomainDims + numRangeDims, numSymbols,
+ numLocals),
+ numDomainDims(numDomainDims), numRangeDims(numRangeDims) {}
+
+ FlatAffineRelation(unsigned numDomainDims, unsigned numRangeDims,
+ FlatAffineValueConstraints &fac)
+ : FlatAffineValueConstraints(fac), numDomainDims(numDomainDims),
+ numRangeDims(numRangeDims) {}
+
+ FlatAffineRelation(unsigned numDomainDims, unsigned numRangeDims,
+ FlatAffineConstraints &fac)
+ : FlatAffineValueConstraints(fac), numDomainDims(numDomainDims),
+ numRangeDims(numRangeDims) {}
+
+ /// Returns a set corresponding to the domain/range of the affine relation.
+ FlatAffineValueConstraints getDomainSet() const;
+ FlatAffineValueConstraints getRangeSet() const;
+
+ /// Returns the number of identifiers corresponding to domain/range of
+ /// relation.
+ inline unsigned getNumDomainDims() const { return numDomainDims; }
+ inline unsigned getNumRangeDims() const { return numRangeDims; }
+
+ /// Given affine relation `other: (domainOther -> rangeOther)`, this operation
+ /// takes the composition of `other` on `this: (domainThis -> rangeThis)`.
+ /// The resulting relation represents tuples of the form: `domainOther ->
+ /// rangeThis`.
+ void compose(const FlatAffineRelation &other);
+
+ /// Swap domain and range of the relation.
+ /// `(domain -> range)` is converted to `(range -> domain)`.
+ void inverse();
+
+ /// Insert `num` identifiers of the specified kind after the `pos` identifier
+ /// of that kind. The coefficient columns corresponding to the added
+ /// identifiers are initialized to zero.
+ void insertDomainId(unsigned pos, unsigned num = 1);
+ void insertRangeId(unsigned pos, unsigned num = 1);
+
+ /// Append `num` identifiers of the specified kind after the last identifier
+ /// of that kind. The coefficient columns corresponding to the added
+ /// identifiers are initialized to zero.
+ void appendDomainId(unsigned num = 1);
+ void appendRangeId(unsigned num = 1);
+
+protected:
+ // Number of dimension identifers corresponding to domain identifers.
+ unsigned numDomainDims;
+
+ // Number of dimension identifers corresponding to range identifers.
+ unsigned numRangeDims;
+
+ /// Removes identifiers in the column range [idStart, idLimit), and copies any
+ /// remaining valid data into place, updates member variables, and resizes
+ /// arrays as needed.
+ void removeIdRange(unsigned idStart, unsigned idLimit) override;
+};
+
/// Flattens 'expr' into 'flattenedExpr', which contains the coefficients of the
/// dimensions, symbols, and additional variables that represent floor divisions
/// of dimensions, symbols, and in turn other floor divisions. Returns failure
ValueRange dims, ValueRange syms,
SmallVector<Value> *newSyms = nullptr);
+/// Builds a relation from the given AffineMap/AffineValueMap `map`, containing
+/// all pairs of the form `operands -> result` that satisfy `map`. `rel` is set
+/// to the relation built. For example, give the AffineMap:
+///
+/// (d0, d1)[s0] -> (d0 + s0, d0 - s0)
+///
+/// the resulting relation formed is:
+///
+/// (d0, d1) -> (r1, r2)
+/// [d0 d1 r1 r2 s0 const]
+/// 1 0 -1 0 1 0 = 0
+/// 0 1 0 -1 -1 0 = 0
+///
+/// For AffineValueMap, the domain and symbols have Value set corresponding to
+/// the Value in `map`. Returns failure if the AffineMap could not be flattened
+/// (i.e., semi-affine is not yet handled).
+LogicalResult getRelationFromMap(AffineMap &map, FlatAffineRelation &rel);
+LogicalResult getRelationFromMap(const AffineValueMap &map,
+ FlatAffineRelation &rel);
+
} // end namespace mlir.
#endif // MLIR_ANALYSIS_AFFINESTRUCTURES_H
return getIndexSet(ops, indexSet);
}
-namespace {
-// ValuePositionMap manages the mapping from Values which represent dimension
-// and symbol identifiers from 'src' and 'dst' access functions to positions
-// in new space where some Values are kept separate (using addSrc/DstValue)
-// and some Values are merged (addSymbolValue).
-// Position lookups return the absolute position in the new space which
-// has the following format:
-//
-// [src-dim-identifiers] [dst-dim-identifiers] [symbol-identifiers]
-//
-// Note: access function non-IV dimension identifiers (that have 'dimension'
-// positions in the access function position space) are assigned as symbols
-// in the output position space. Convenience access functions which lookup
-// an Value in multiple maps are provided (i.e. getSrcDimOrSymPos) to handle
-// the common case of resolving positions for all access function operands.
-//
-// TODO: Generalize this: could take a template parameter for the number of maps
-// (3 in the current case), and lookups could take indices of maps to check. So
-// getSrcDimOrSymPos would be "getPos(value, {0, 2})".
-class ValuePositionMap {
-public:
- void addSrcValue(Value value) {
- if (addValueAt(value, &srcDimPosMap, numSrcDims))
- ++numSrcDims;
- }
- void addDstValue(Value value) {
- if (addValueAt(value, &dstDimPosMap, numDstDims))
- ++numDstDims;
- }
- void addSymbolValue(Value value) {
- if (addValueAt(value, &symbolPosMap, numSymbols))
- ++numSymbols;
- }
- unsigned getSrcDimOrSymPos(Value value) const {
- return getDimOrSymPos(value, srcDimPosMap, 0);
- }
- unsigned getDstDimOrSymPos(Value value) const {
- return getDimOrSymPos(value, dstDimPosMap, numSrcDims);
- }
- unsigned getSymPos(Value value) const {
- auto it = symbolPosMap.find(value);
- assert(it != symbolPosMap.end());
- return numSrcDims + numDstDims + it->second;
- }
-
- unsigned getNumSrcDims() const { return numSrcDims; }
- unsigned getNumDstDims() const { return numDstDims; }
- unsigned getNumDims() const { return numSrcDims + numDstDims; }
- unsigned getNumSymbols() const { return numSymbols; }
-
-private:
- bool addValueAt(Value value, DenseMap<Value, unsigned> *posMap,
- unsigned position) {
- auto it = posMap->find(value);
- if (it == posMap->end()) {
- (*posMap)[value] = position;
- return true;
- }
- return false;
- }
- unsigned getDimOrSymPos(Value value,
- const DenseMap<Value, unsigned> &dimPosMap,
- unsigned dimPosOffset) const {
- auto it = dimPosMap.find(value);
- if (it != dimPosMap.end()) {
- return dimPosOffset + it->second;
- }
- it = symbolPosMap.find(value);
- assert(it != symbolPosMap.end());
- return numSrcDims + numDstDims + it->second;
- }
-
- unsigned numSrcDims = 0;
- unsigned numDstDims = 0;
- unsigned numSymbols = 0;
- DenseMap<Value, unsigned> srcDimPosMap;
- DenseMap<Value, unsigned> dstDimPosMap;
- DenseMap<Value, unsigned> symbolPosMap;
-};
-} // namespace
-
-// Builds a map from Value to identifier position in a new merged identifier
-// list, which is the result of merging dim/symbol lists from src/dst
-// iteration domains, the format of which is as follows:
-//
-// [src-dim-identifiers, dst-dim-identifiers, symbol-identifiers, const_term]
-//
-// This method populates 'valuePosMap' with mappings from operand Values in
-// 'srcAccessMap'/'dstAccessMap' (as well as those in 'srcDomain'/'dstDomain')
-// to the position of these values in the merged list.
-static void buildDimAndSymbolPositionMaps(
- const FlatAffineValueConstraints &srcDomain,
- const FlatAffineValueConstraints &dstDomain,
- const AffineValueMap &srcAccessMap, const AffineValueMap &dstAccessMap,
- ValuePositionMap *valuePosMap,
- FlatAffineValueConstraints *dependenceConstraints) {
-
- // IsDimState is a tri-state boolean. It is used to distinguish three
- // different cases of the values passed to updateValuePosMap.
- // - When it is TRUE, we are certain that all values are dim values.
- // - When it is FALSE, we are certain that all values are symbol values.
- // - When it is UNKNOWN, we need to further check whether the value is from a
- // loop IV to determine its type (dim or symbol).
-
- // We need this enumeration because sometimes we cannot determine whether a
- // Value is a symbol or a dim by the information from the Value itself. If a
- // Value appears in an affine map of a loop, we can determine whether it is a
- // dim or not by the function `isForInductionVar`. But when a Value is in the
- // affine set of an if-statement, there is no way to identify its category
- // (dim/symbol) by itself. Fortunately, the Values to be inserted into
- // `valuePosMap` come from `srcDomain` and `dstDomain`, and they hold such
- // information of Value category: `srcDomain` and `dstDomain` organize Values
- // by their category, such that the position of each Value stored in
- // `srcDomain` and `dstDomain` marks which category that a Value belongs to.
- // Therefore, we can separate Values into dim and symbol groups before passing
- // them to the function `updateValuePosMap`. Specifically, when passing the
- // dim group, we set IsDimState to TRUE; otherwise, we set it to FALSE.
- // However, Values from the operands of `srcAccessMap` and `dstAccessMap` are
- // not explicitly categorized into dim or symbol, and we have to rely on
- // `isForInductionVar` to make the decision. IsDimState is set to UNKNOWN in
- // this case.
- enum IsDimState { TRUE, FALSE, UNKNOWN };
-
- // This function places each given Value (in `values`) under a respective
- // category in `valuePosMap`. Specifically, the placement rules are:
- // 1) If `isDim` is FALSE, then every value in `values` are inserted into
- // `valuePosMap` as symbols.
- // 2) If `isDim` is UNKNOWN and the value of the current iteration is NOT an
- // induction variable of a for-loop, we treat it as symbol as well.
- // 3) For other cases, we decide whether to add a value to the `src` or the
- // `dst` section of the dim category simply by the boolean value `isSrc`.
- auto updateValuePosMap = [&](ArrayRef<Value> values, bool isSrc,
- IsDimState isDim) {
- for (unsigned i = 0, e = values.size(); i < e; ++i) {
- auto value = values[i];
- if (isDim == FALSE || (isDim == UNKNOWN && !isForInductionVar(value))) {
- assert(isValidSymbol(value) &&
- "access operand has to be either a loop IV or a symbol");
- valuePosMap->addSymbolValue(value);
- } else {
- if (isSrc)
- valuePosMap->addSrcValue(value);
- else
- valuePosMap->addDstValue(value);
- }
- }
- };
-
- // Collect values from the src and dst domains. For each domain, we separate
- // the collected values into dim and symbol parts.
- SmallVector<Value, 4> srcDimValues, dstDimValues, srcSymbolValues,
- dstSymbolValues;
- srcDomain.getValues(0, srcDomain.getNumDimIds(), &srcDimValues);
- dstDomain.getValues(0, dstDomain.getNumDimIds(), &dstDimValues);
- srcDomain.getValues(srcDomain.getNumDimIds(),
- srcDomain.getNumDimAndSymbolIds(), &srcSymbolValues);
- dstDomain.getValues(dstDomain.getNumDimIds(),
- dstDomain.getNumDimAndSymbolIds(), &dstSymbolValues);
-
- // Update value position map with dim values from src iteration domain.
- updateValuePosMap(srcDimValues, /*isSrc=*/true, /*isDim=*/TRUE);
- // Update value position map with dim values from dst iteration domain.
- updateValuePosMap(dstDimValues, /*isSrc=*/false, /*isDim=*/TRUE);
- // Update value position map with symbols from src iteration domain.
- updateValuePosMap(srcSymbolValues, /*isSrc=*/true, /*isDim=*/FALSE);
- // Update value position map with symbols from dst iteration domain.
- updateValuePosMap(dstSymbolValues, /*isSrc=*/false, /*isDim=*/FALSE);
- // Update value position map with identifiers from src access function.
- updateValuePosMap(srcAccessMap.getOperands(), /*isSrc=*/true,
- /*isDim=*/UNKNOWN);
- // Update value position map with identifiers from dst access function.
- updateValuePosMap(dstAccessMap.getOperands(), /*isSrc=*/false,
- /*isDim=*/UNKNOWN);
-}
-
-// Sets up dependence constraints columns appropriately, in the format:
-// [src-dim-ids, dst-dim-ids, symbol-ids, local-ids, const_term]
-static void
-initDependenceConstraints(const FlatAffineValueConstraints &srcDomain,
- const FlatAffineValueConstraints &dstDomain,
- const AffineValueMap &srcAccessMap,
- const AffineValueMap &dstAccessMap,
- const ValuePositionMap &valuePosMap,
- FlatAffineValueConstraints *dependenceConstraints) {
- // Calculate number of equalities/inequalities and columns required to
- // initialize FlatAffineValueConstraints for 'dependenceDomain'.
- unsigned numIneq =
- srcDomain.getNumInequalities() + dstDomain.getNumInequalities();
- AffineMap srcMap = srcAccessMap.getAffineMap();
- assert(srcMap.getNumResults() == dstAccessMap.getAffineMap().getNumResults());
- unsigned numEq = srcMap.getNumResults();
- unsigned numDims = srcDomain.getNumDimIds() + dstDomain.getNumDimIds();
- unsigned numSymbols = valuePosMap.getNumSymbols();
- unsigned numLocals = srcDomain.getNumLocalIds() + dstDomain.getNumLocalIds();
- unsigned numIds = numDims + numSymbols + numLocals;
- unsigned numCols = numIds + 1;
-
- // Set flat affine constraints sizes and reserving space for constraints.
- dependenceConstraints->reset(numIneq, numEq, numCols, numDims, numSymbols,
- numLocals);
-
- // Set values corresponding to dependence constraint identifiers.
- SmallVector<Value, 4> srcLoopIVs, dstLoopIVs;
- srcDomain.getValues(0, srcDomain.getNumDimIds(), &srcLoopIVs);
- dstDomain.getValues(0, dstDomain.getNumDimIds(), &dstLoopIVs);
-
- dependenceConstraints->setValues(0, srcLoopIVs.size(), srcLoopIVs);
- dependenceConstraints->setValues(
- srcLoopIVs.size(), srcLoopIVs.size() + dstLoopIVs.size(), dstLoopIVs);
-
- // Set values for the symbolic identifier dimensions. `isSymbolDetermined`
- // indicates whether we are certain that the `values` passed in are all
- // symbols. If `isSymbolDetermined` is true, then we treat every Value in
- // `values` as a symbol; otherwise, we let the function `isForInductionVar` to
- // distinguish whether a Value in `values` is a symbol or not.
- auto setSymbolIds = [&](ArrayRef<Value> values,
- bool isSymbolDetermined = true) {
- for (auto value : values) {
- if (isSymbolDetermined || !isForInductionVar(value)) {
- assert(isValidSymbol(value) && "expected symbol");
- dependenceConstraints->setValue(valuePosMap.getSymPos(value), value);
- }
- }
- };
-
- // We are uncertain about whether all operands in `srcAccessMap` and
- // `dstAccessMap` are symbols, so we set `isSymbolDetermined` to false.
- setSymbolIds(srcAccessMap.getOperands(), /*isSymbolDetermined=*/false);
- setSymbolIds(dstAccessMap.getOperands(), /*isSymbolDetermined=*/false);
-
- SmallVector<Value, 8> srcSymbolValues, dstSymbolValues;
- srcDomain.getValues(srcDomain.getNumDimIds(),
- srcDomain.getNumDimAndSymbolIds(), &srcSymbolValues);
- dstDomain.getValues(dstDomain.getNumDimIds(),
- dstDomain.getNumDimAndSymbolIds(), &dstSymbolValues);
- // Since we only take symbol Values out of `srcDomain` and `dstDomain`,
- // `isSymbolDetermined` is kept to its default value: true.
- setSymbolIds(srcSymbolValues);
- setSymbolIds(dstSymbolValues);
-
- for (unsigned i = 0, e = dependenceConstraints->getNumDimAndSymbolIds();
- i < e; i++)
- assert(dependenceConstraints->hasValue(i));
-}
-
-// Adds iteration domain constraints from 'srcDomain' and 'dstDomain' into
-// 'dependenceDomain'.
-// Uses 'valuePosMap' to determine the position in 'dependenceDomain' to which a
-// srcDomain/dstDomain Value maps.
-static void addDomainConstraints(const FlatAffineValueConstraints &srcDomain,
- const FlatAffineValueConstraints &dstDomain,
- const ValuePositionMap &valuePosMap,
- FlatAffineValueConstraints *dependenceDomain) {
- unsigned depNumDimsAndSymbolIds = dependenceDomain->getNumDimAndSymbolIds();
-
- SmallVector<int64_t, 4> cst(dependenceDomain->getNumCols());
-
- auto addDomain = [&](bool isSrc, bool isEq, unsigned localOffset) {
- const FlatAffineValueConstraints &domain = isSrc ? srcDomain : dstDomain;
- unsigned numCsts =
- isEq ? domain.getNumEqualities() : domain.getNumInequalities();
- unsigned numDimAndSymbolIds = domain.getNumDimAndSymbolIds();
- auto at = [&](unsigned i, unsigned j) -> int64_t {
- return isEq ? domain.atEq(i, j) : domain.atIneq(i, j);
- };
- auto map = [&](unsigned i) -> int64_t {
- return isSrc ? valuePosMap.getSrcDimOrSymPos(domain.getValue(i))
- : valuePosMap.getDstDimOrSymPos(domain.getValue(i));
- };
-
- for (unsigned i = 0; i < numCsts; ++i) {
- // Zero fill.
- std::fill(cst.begin(), cst.end(), 0);
- // Set coefficients for identifiers corresponding to domain.
- for (unsigned j = 0; j < numDimAndSymbolIds; ++j)
- cst[map(j)] = at(i, j);
- // Local terms.
- for (unsigned j = 0, e = domain.getNumLocalIds(); j < e; j++)
- cst[depNumDimsAndSymbolIds + localOffset + j] =
- at(i, numDimAndSymbolIds + j);
- // Set constant term.
- cst[cst.size() - 1] = at(i, domain.getNumCols() - 1);
- // Add constraint.
- if (isEq)
- dependenceDomain->addEquality(cst);
- else
- dependenceDomain->addInequality(cst);
- }
- };
-
- // Add equalities from src domain.
- addDomain(/*isSrc=*/true, /*isEq=*/true, /*localOffset=*/0);
- // Add inequalities from src domain.
- addDomain(/*isSrc=*/true, /*isEq=*/false, /*localOffset=*/0);
- // Add equalities from dst domain.
- addDomain(/*isSrc=*/false, /*isEq=*/true,
- /*localOffset=*/srcDomain.getNumLocalIds());
- // Add inequalities from dst domain.
- addDomain(/*isSrc=*/false, /*isEq=*/false,
- /*localOffset=*/srcDomain.getNumLocalIds());
-}
-
-// Adds equality constraints that equate src and dst access functions
-// represented by 'srcAccessMap' and 'dstAccessMap' for each result.
-// Requires that 'srcAccessMap' and 'dstAccessMap' have the same results count.
-// For example, given the following two accesses functions to a 2D memref:
-//
-// Source access function:
-// (a0 * d0 + a1 * s0 + a2, b0 * d0 + b1 * s0 + b2)
-//
-// Destination access function:
-// (c0 * d0 + c1 * s0 + c2, f0 * d0 + f1 * s0 + f2)
-//
-// This method constructs the following equality constraints in
-// 'dependenceDomain', by equating the access functions for each result
-// (i.e. each memref dim). Notice that 'd0' for the destination access function
-// is mapped into 'd0' in the equality constraint:
-//
-// d0 d1 s0 c
-// -- -- -- --
-// a0 -c0 (a1 - c1) (a2 - c2) = 0
-// b0 -f0 (b1 - f1) (b2 - f2) = 0
-//
-// Returns failure if any AffineExpr cannot be flattened (due to it being
-// semi-affine). Returns success otherwise.
-static LogicalResult
-addMemRefAccessConstraints(const AffineValueMap &srcAccessMap,
- const AffineValueMap &dstAccessMap,
- const ValuePositionMap &valuePosMap,
- FlatAffineValueConstraints *dependenceDomain) {
- AffineMap srcMap = srcAccessMap.getAffineMap();
- AffineMap dstMap = dstAccessMap.getAffineMap();
- assert(srcMap.getNumResults() == dstMap.getNumResults());
- unsigned numResults = srcMap.getNumResults();
-
- unsigned srcNumIds = srcMap.getNumDims() + srcMap.getNumSymbols();
- ArrayRef<Value> srcOperands = srcAccessMap.getOperands();
-
- unsigned dstNumIds = dstMap.getNumDims() + dstMap.getNumSymbols();
- ArrayRef<Value> dstOperands = dstAccessMap.getOperands();
-
- std::vector<SmallVector<int64_t, 8>> srcFlatExprs;
- std::vector<SmallVector<int64_t, 8>> destFlatExprs;
- FlatAffineValueConstraints srcLocalVarCst, destLocalVarCst;
- // Get flattened expressions for the source destination maps.
- if (failed(getFlattenedAffineExprs(srcMap, &srcFlatExprs, &srcLocalVarCst)) ||
- failed(getFlattenedAffineExprs(dstMap, &destFlatExprs, &destLocalVarCst)))
- return failure();
-
- unsigned domNumLocalIds = dependenceDomain->getNumLocalIds();
- unsigned srcNumLocalIds = srcLocalVarCst.getNumLocalIds();
- unsigned dstNumLocalIds = destLocalVarCst.getNumLocalIds();
- unsigned numLocalIdsToAdd = srcNumLocalIds + dstNumLocalIds;
- dependenceDomain->appendLocalId(numLocalIdsToAdd);
-
- unsigned numDims = dependenceDomain->getNumDimIds();
- unsigned numSymbols = dependenceDomain->getNumSymbolIds();
- unsigned numSrcLocalIds = srcLocalVarCst.getNumLocalIds();
- unsigned newLocalIdOffset = numDims + numSymbols + domNumLocalIds;
-
- // Equality to add.
- SmallVector<int64_t, 8> eq(dependenceDomain->getNumCols());
- for (unsigned i = 0; i < numResults; ++i) {
- // Zero fill.
- std::fill(eq.begin(), eq.end(), 0);
-
- // Flattened AffineExpr for src result 'i'.
- const auto &srcFlatExpr = srcFlatExprs[i];
- // Set identifier coefficients from src access function.
- for (unsigned j = 0, e = srcOperands.size(); j < e; ++j)
- eq[valuePosMap.getSrcDimOrSymPos(srcOperands[j])] = srcFlatExpr[j];
- // Local terms.
- for (unsigned j = 0, e = srcNumLocalIds; j < e; j++)
- eq[newLocalIdOffset + j] = srcFlatExpr[srcNumIds + j];
- // Set constant term.
- eq[eq.size() - 1] = srcFlatExpr[srcFlatExpr.size() - 1];
-
- // Flattened AffineExpr for dest result 'i'.
- const auto &destFlatExpr = destFlatExprs[i];
- // Set identifier coefficients from dst access function.
- for (unsigned j = 0, e = dstOperands.size(); j < e; ++j)
- eq[valuePosMap.getDstDimOrSymPos(dstOperands[j])] -= destFlatExpr[j];
- // Local terms.
- for (unsigned j = 0, e = dstNumLocalIds; j < e; j++)
- eq[newLocalIdOffset + numSrcLocalIds + j] = -destFlatExpr[dstNumIds + j];
- // Set constant term.
- eq[eq.size() - 1] -= destFlatExpr[destFlatExpr.size() - 1];
-
- // Add equality constraint.
- dependenceDomain->addEquality(eq);
- }
-
- // Add equality constraints for any operands that are defined by constant ops.
- auto addEqForConstOperands = [&](ArrayRef<Value> operands) {
- for (unsigned i = 0, e = operands.size(); i < e; ++i) {
- if (isForInductionVar(operands[i]))
- continue;
- auto symbol = operands[i];
- assert(isValidSymbol(symbol));
- // Check if the symbol is a constant.
- if (auto cOp = symbol.getDefiningOp<arith::ConstantIndexOp>())
- dependenceDomain->addBound(FlatAffineConstraints::EQ,
- valuePosMap.getSymPos(symbol), cOp.value());
- }
- };
-
- // Add equality constraints for any src symbols defined by constant ops.
- addEqForConstOperands(srcOperands);
- // Add equality constraints for any dst symbols defined by constant ops.
- addEqForConstOperands(dstOperands);
-
- // By construction (see flattener), local var constraints will not have any
- // equalities.
- assert(srcLocalVarCst.getNumEqualities() == 0 &&
- destLocalVarCst.getNumEqualities() == 0);
- // Add inequalities from srcLocalVarCst and destLocalVarCst into the
- // dependence domain.
- SmallVector<int64_t, 8> ineq(dependenceDomain->getNumCols());
- for (unsigned r = 0, e = srcLocalVarCst.getNumInequalities(); r < e; r++) {
- std::fill(ineq.begin(), ineq.end(), 0);
-
- // Set identifier coefficients from src local var constraints.
- for (unsigned j = 0, e = srcOperands.size(); j < e; ++j)
- ineq[valuePosMap.getSrcDimOrSymPos(srcOperands[j])] =
- srcLocalVarCst.atIneq(r, j);
- // Local terms.
- for (unsigned j = 0, e = srcNumLocalIds; j < e; j++)
- ineq[newLocalIdOffset + j] = srcLocalVarCst.atIneq(r, srcNumIds + j);
- // Set constant term.
- ineq[ineq.size() - 1] =
- srcLocalVarCst.atIneq(r, srcLocalVarCst.getNumCols() - 1);
- dependenceDomain->addInequality(ineq);
- }
-
- for (unsigned r = 0, e = destLocalVarCst.getNumInequalities(); r < e; r++) {
- std::fill(ineq.begin(), ineq.end(), 0);
- // Set identifier coefficients from dest local var constraints.
- for (unsigned j = 0, e = dstOperands.size(); j < e; ++j)
- ineq[valuePosMap.getDstDimOrSymPos(dstOperands[j])] =
- destLocalVarCst.atIneq(r, j);
- // Local terms.
- for (unsigned j = 0, e = dstNumLocalIds; j < e; j++)
- ineq[newLocalIdOffset + numSrcLocalIds + j] =
- destLocalVarCst.atIneq(r, dstNumIds + j);
- // Set constant term.
- ineq[ineq.size() - 1] =
- destLocalVarCst.atIneq(r, destLocalVarCst.getNumCols() - 1);
-
- dependenceDomain->addInequality(ineq);
- }
- return success();
-}
-
// Returns the number of outer loop common to 'src/dstDomain'.
// Loops common to 'src/dst' domains are added to 'commonLoops' if non-null.
static unsigned
}
}
+LogicalResult MemRefAccess::getAccessRelation(FlatAffineRelation &rel) const {
+ // Create set corresponding to domain of access.
+ FlatAffineValueConstraints domain;
+ if (failed(getOpIndexSet(opInst, &domain)))
+ return failure();
+
+ // Get access relation from access map.
+ AffineValueMap accessValueMap;
+ getAccessMap(&accessValueMap);
+ if (failed(getRelationFromMap(accessValueMap, rel)))
+ return failure();
+
+ FlatAffineRelation domainRel(rel.getNumDomainDims(), /*numRangeDims=*/0,
+ domain);
+
+ // Merge and align domain ids of `ret` and ids of `domain`. Since the domain
+ // of the access map is a subset of the domain of access, the domain ids of
+ // `ret` are guranteed to be a subset of ids of `domain`.
+ for (unsigned i = 0, e = domain.getNumDimIds(); i < e; ++i) {
+ unsigned loc;
+ if (rel.findId(domain.getValue(i), &loc)) {
+ rel.swapId(i, loc);
+ } else {
+ rel.insertDomainId(i);
+ rel.setValue(i, domain.getValue(i));
+ }
+ }
+
+ // Append domain constraints to `ret`.
+ domainRel.appendRangeId(rel.getNumRangeDims());
+ domainRel.mergeLocalIds(rel);
+ domainRel.mergeSymbolIds(rel);
+ rel.append(domainRel);
+
+ return success();
+}
+
// Populates 'accessMap' with composition of AffineApplyOps reachable from
// indices of MemRefAccess.
void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const {
// common to both accesses (see Dependence in AffineAnalysis.h for details).
//
// The memref access dependence check is comprised of the following steps:
-// *) Compute access functions for each access. Access functions are computed
-// using AffineValueMaps initialized with the indices from an access, then
-// composed with AffineApplyOps reachable from operands of that access,
-// until operands of the AffineValueMap are loop IVs or symbols.
-// *) Build iteration domain constraints for each access. Iteration domain
-// constraints are pairs of inequality constraints representing the
-// upper/lower loop bounds for each AffineForOp in the loop nest associated
-// with each access.
-// *) Build dimension and symbol position maps for each access, which map
-// Values from access functions and iteration domains to their position
-// in the merged constraint system built by this method.
+// *) Build access relation for each access. An access relation maps elements
+// of an iteration domain to the element(s) of an array domain accessed by
+// that iteration of the associated statement through some array reference.
+// *) Compute the dependence relation by composing access relation of
+// `srcAccess` with the inverse of access relation of `dstAccess`.
+// Doing this builds a relation between iteration domain of `srcAccess`
+// to the iteration domain of `dstAccess` which access the same memory
+// location.
+// *) Add ordering constraints for `srcAccess` to be accessed before
+// `dstAccess`.
//
// This method builds a constraint system with the following column format:
//
// }
// }
//
-// The access functions would be the following:
-//
-// src: (%i0 * 2 - %i1 * 4 + %N, %i1 * 3 - %M)
-// dst: (%i2 * 7 + %i3 * 9 - %M, %i3 * 11 - %K)
-//
-// The iteration domains for the src/dst accesses would be the following:
+// The access relation for `srcAccess` would be the following:
//
-// src: 0 <= %i0 <= 100, 0 <= %i1 <= 50
-// dst: 0 <= %i2 <= 100, 0 <= %i3 <= 50
+// [src_dim0, src_dim1, mem_dim0, mem_dim1, %N, %M, const]
+// 2 -4 -1 0 1 0 0 = 0
+// 0 3 0 -1 0 -1 0 = 0
+// 1 0 0 0 0 0 0 >= 0
+// -1 0 0 0 0 0 100 >= 0
+// 0 1 0 0 0 0 0 >= 0
+// 0 -1 0 0 0 0 50 >= 0
//
-// The symbols by both accesses would be assigned to a canonical position order
-// which will be used in the dependence constraint system:
+// The access relation for `dstAccess` would be the following:
//
-// symbol name: %M %N %K
-// symbol pos: 0 1 2
+// [dst_dim0, dst_dim1, mem_dim0, mem_dim1, %M, %K, const]
+// 7 9 -1 0 -1 0 0 = 0
+// 0 11 0 -1 0 -1 0 = 0
+// 1 0 0 0 0 0 0 >= 0
+// -1 0 0 0 0 0 100 >= 0
+// 0 1 0 0 0 0 0 >= 0
+// 0 -1 0 0 0 0 50 >= 0
//
-// Equality constraints are built by equating each result of src/destination
-// access functions. For this example, the following two equality constraints
-// will be added to the dependence constraint system:
+// The equalities in the above relations correspond to the access maps while
+// the inequalities corresspond to the iteration domain constraints.
//
-// [src_dim0, src_dim1, dst_dim0, dst_dim1, sym0, sym1, sym2, const]
-// 2 -4 -7 -9 1 1 0 0 = 0
-// 0 3 0 -11 -1 0 1 0 = 0
+// The dependence relation formed:
//
-// Inequality constraints from the iteration domain will be meged into
-// the dependence constraint system
-//
-// [src_dim0, src_dim1, dst_dim0, dst_dim1, sym0, sym1, sym2, const]
+// [src_dim0, src_dim1, dst_dim0, dst_dim1, %M, %N, %K, const]
+// 2 -4 -7 -9 1 1 0 0 = 0
+// 0 3 0 -11 -1 0 1 0 = 0
// 1 0 0 0 0 0 0 0 >= 0
// -1 0 0 0 0 0 0 100 >= 0
// 0 1 0 0 0 0 0 0 >= 0
!isa<AffineWriteOpInterface>(dstAccess.opInst))
return DependenceResult::NoDependence;
- // Get composed access function for 'srcAccess'.
- AffineValueMap srcAccessMap;
- srcAccess.getAccessMap(&srcAccessMap);
-
- // Get composed access function for 'dstAccess'.
- AffineValueMap dstAccessMap;
- dstAccess.getAccessMap(&dstAccessMap);
-
- // Get iteration domain for the 'srcAccess' operation.
- FlatAffineValueConstraints srcDomain;
- if (failed(getOpIndexSet(srcAccess.opInst, &srcDomain)))
+ // Create access relation from each MemRefAccess.
+ FlatAffineRelation srcRel, dstRel;
+ if (failed(srcAccess.getAccessRelation(srcRel)))
return DependenceResult::Failure;
-
- // Get iteration domain for 'dstAccess' operation.
- FlatAffineValueConstraints dstDomain;
- if (failed(getOpIndexSet(dstAccess.opInst, &dstDomain)))
+ if (failed(dstAccess.getAccessRelation(dstRel)))
return DependenceResult::Failure;
+ FlatAffineValueConstraints srcDomain = srcRel.getDomainSet();
+ FlatAffineValueConstraints dstDomain = dstRel.getDomainSet();
+
// Return 'NoDependence' if loopDepth > numCommonLoops and if the ancestor
// operation of 'srcAccess' does not properly dominate the ancestor
// operation of 'dstAccess' in the same common operation block.
numCommonLoops)) {
return DependenceResult::NoDependence;
}
- // Build dim and symbol position maps for each access from access operand
- // Value to position in merged constraint system.
- ValuePositionMap valuePosMap;
- buildDimAndSymbolPositionMaps(srcDomain, dstDomain, srcAccessMap,
- dstAccessMap, &valuePosMap,
- dependenceConstraints);
- initDependenceConstraints(srcDomain, dstDomain, srcAccessMap, dstAccessMap,
- valuePosMap, dependenceConstraints);
-
- assert(valuePosMap.getNumDims() ==
- srcDomain.getNumDimIds() + dstDomain.getNumDimIds());
- // Create memref access constraint by equating src/dst access functions.
- // Note that this check is conservative, and will fail in the future when
- // local variables for mod/div exprs are supported.
- if (failed(addMemRefAccessConstraints(srcAccessMap, dstAccessMap, valuePosMap,
- dependenceConstraints)))
- return DependenceResult::Failure;
+ // Compute the dependence relation by composing `srcRel` with the inverse of
+ // `dstRel`. Doing this builds a relation between iteration domain of
+ // `srcAccess` to the iteration domain of `dstAccess` which access the same
+ // memory locations.
+ dstRel.inverse();
+ dstRel.compose(srcRel);
+ *dependenceConstraints = dstRel;
// Add 'src' happens before 'dst' ordering constraints.
addOrderingConstraints(srcDomain, dstDomain, loopDepth,
dependenceConstraints);
- // Add src and dst domain constraints.
- addDomainConstraints(srcDomain, dstDomain, valuePosMap,
- dependenceConstraints);
// Return 'NoDependence' if the solution space is empty: no dependence.
- if (dependenceConstraints->isEmpty()) {
+ if (dependenceConstraints->isEmpty())
return DependenceResult::NoDependence;
- }
// Compute dependence direction vector and return true.
- if (dependenceComponents != nullptr) {
+ if (dependenceComponents != nullptr)
computeDirectionVector(srcDomain, dstDomain, loopDepth,
dependenceConstraints, dependenceComponents);
- }
LLVM_DEBUG(llvm::dbgs() << "Dependence polyhedron:\n");
LLVM_DEBUG(dependenceConstraints->dump());
}
}
+void FlatAffineConstraints::convertDimToLocal(unsigned dimStart,
+ unsigned dimLimit) {
+ assert(dimLimit <= getNumDimIds() && "Invalid dim pos range");
+
+ if (dimStart >= dimLimit)
+ return;
+
+ // Append new local variables corresponding to the dimensions to be converted.
+ unsigned convertCount = dimLimit - dimStart;
+ unsigned newLocalIdStart = getNumIds();
+ appendLocalId(convertCount);
+
+ // Swap the new local variables with dimensions.
+ for (unsigned i = 0; i < convertCount; ++i)
+ swapId(i + dimStart, i + newLocalIdStart);
+
+ // Remove dimensions converted to local variables.
+ removeIdRange(dimStart, dimLimit);
+}
+
std::pair<AffineMap, AffineMap> FlatAffineConstraints::getLowerAndUpperBound(
unsigned pos, unsigned offset, unsigned num, unsigned symStartPos,
ArrayRef<AffineExpr> localExprs, MLIRContext *context) const {
return map.replaceDimsAndSymbols(dimReplacements, symReplacements,
dims.size(), numSymbols);
}
+
+FlatAffineValueConstraints FlatAffineRelation::getDomainSet() const {
+ FlatAffineValueConstraints domain = *this;
+ // Convert all range variables to local variables.
+ domain.convertDimToLocal(getNumDomainDims(),
+ getNumDomainDims() + getNumRangeDims());
+ return domain;
+}
+
+FlatAffineValueConstraints FlatAffineRelation::getRangeSet() const {
+ FlatAffineValueConstraints range = *this;
+ // Convert all domain variables to local variables.
+ range.convertDimToLocal(0, getNumDomainDims());
+ return range;
+}
+
+void FlatAffineRelation::compose(const FlatAffineRelation &other) {
+ assert(getNumDomainDims() == other.getNumRangeDims() &&
+ "Domain of this and range of other do not match");
+ assert(std::equal(values.begin(), values.begin() + getNumDomainDims(),
+ other.values.begin() + other.getNumDomainDims()) &&
+ "Domain of this and range of other do not match");
+
+ FlatAffineRelation rel = other;
+ mergeSymbolIds(rel);
+ mergeLocalIds(rel);
+
+ // Convert domain of `this` and range of `rel` to local identifiers.
+ convertDimToLocal(0, getNumDomainDims());
+ rel.convertDimToLocal(rel.getNumDomainDims(), rel.getNumDimIds());
+ // Add dimensions such that both relations become `domainRel -> rangeThis`.
+ appendDomainId(rel.getNumDomainDims());
+ rel.appendRangeId(getNumRangeDims());
+
+ auto thisMaybeValues = getMaybeDimValues();
+ auto relMaybeValues = rel.getMaybeDimValues();
+
+ // Add and match domain of `rel` to domain of `this`.
+ for (unsigned i = 0, e = rel.getNumDomainDims(); i < e; ++i)
+ if (relMaybeValues[i].hasValue())
+ setValue(i, relMaybeValues[i].getValue());
+ // Add and match range of `this` to range of `rel`.
+ for (unsigned i = 0, e = getNumRangeDims(); i < e; ++i) {
+ unsigned rangeIdx = rel.getNumDomainDims() + i;
+ if (thisMaybeValues[rangeIdx].hasValue())
+ rel.setValue(rangeIdx, thisMaybeValues[rangeIdx].getValue());
+ }
+
+ // Append `this` to `rel` and simplify constraints.
+ rel.append(*this);
+ rel.removeRedundantLocalVars();
+
+ *this = rel;
+}
+
+void FlatAffineRelation::inverse() {
+ unsigned oldDomain = getNumDomainDims();
+ unsigned oldRange = getNumRangeDims();
+ // Add new range ids.
+ appendRangeId(oldDomain);
+ // Swap new ids with domain.
+ for (unsigned i = 0; i < oldDomain; ++i)
+ swapId(i, oldDomain + oldRange + i);
+ // Remove the swapped domain.
+ removeIdRange(0, oldDomain);
+ // Set domain and range as inverse.
+ numDomainDims = oldRange;
+ numRangeDims = oldDomain;
+}
+
+void FlatAffineRelation::insertDomainId(unsigned pos, unsigned num) {
+ assert(pos <= getNumDomainDims() &&
+ "Id cannot be inserted at invalid position");
+ insertDimId(pos, num);
+ numDomainDims += num;
+}
+
+void FlatAffineRelation::insertRangeId(unsigned pos, unsigned num) {
+ assert(pos <= getNumRangeDims() &&
+ "Id cannot be inserted at invalid position");
+ insertDimId(getNumDomainDims() + pos, num);
+ numRangeDims += num;
+}
+
+void FlatAffineRelation::appendDomainId(unsigned num) {
+ insertDimId(getNumDomainDims(), num);
+ numDomainDims += num;
+}
+
+void FlatAffineRelation::appendRangeId(unsigned num) {
+ insertDimId(getNumDimIds(), num);
+ numRangeDims += num;
+}
+
+void FlatAffineRelation::removeIdRange(unsigned idStart, unsigned idLimit) {
+ if (idStart >= idLimit)
+ return;
+
+ // Compute number of domain and range identifiers to remove. This is done by
+ // intersecting the range of domain/range ids with range of ids to remove.
+ unsigned intersectDomainLHS = std::min(idLimit, getNumDomainDims());
+ unsigned intersectDomainRHS = idStart;
+ unsigned intersectRangeLHS = std::min(idLimit, getNumDimIds());
+ unsigned intersectRangeRHS = std::max(idStart, getNumDomainDims());
+
+ FlatAffineValueConstraints::removeIdRange(idStart, idLimit);
+
+ if (intersectDomainLHS > intersectDomainRHS)
+ numDomainDims -= intersectDomainLHS - intersectDomainRHS;
+ if (intersectRangeLHS > intersectRangeRHS)
+ numRangeDims -= intersectRangeLHS - intersectRangeRHS;
+}
+
+LogicalResult mlir::getRelationFromMap(AffineMap &map,
+ FlatAffineRelation &rel) {
+ // Get flattened affine expressions.
+ std::vector<SmallVector<int64_t, 8>> flatExprs;
+ FlatAffineConstraints localVarCst;
+ if (failed(getFlattenedAffineExprs(map, &flatExprs, &localVarCst)))
+ return failure();
+
+ unsigned oldDimNum = localVarCst.getNumDimIds();
+ unsigned oldCols = localVarCst.getNumCols();
+ unsigned numRangeIds = map.getNumResults();
+ unsigned numDomainIds = map.getNumDims();
+
+ // Add range as the new expressions.
+ localVarCst.appendDimId(numRangeIds);
+
+ // Add equalities between source and range.
+ SmallVector<int64_t, 8> eq(localVarCst.getNumCols());
+ for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
+ // Zero fill.
+ std::fill(eq.begin(), eq.end(), 0);
+ // Fill equality.
+ for (unsigned j = 0, f = oldDimNum; j < f; ++j)
+ eq[j] = flatExprs[i][j];
+ for (unsigned j = oldDimNum, f = oldCols; j < f; ++j)
+ eq[j + numRangeIds] = flatExprs[i][j];
+ // Set this dimension to -1 to equate lhs and rhs and add equality.
+ eq[numDomainIds + i] = -1;
+ localVarCst.addEquality(eq);
+ }
+
+ // Create relation and return success.
+ rel = FlatAffineRelation(numDomainIds, numRangeIds, localVarCst);
+ return success();
+}
+
+LogicalResult mlir::getRelationFromMap(const AffineValueMap &map,
+ FlatAffineRelation &rel) {
+
+ AffineMap affineMap = map.getAffineMap();
+ if (failed(getRelationFromMap(affineMap, rel)))
+ return failure();
+
+ // Set symbol values for domain dimensions and symbols.
+ for (unsigned i = 0, e = rel.getNumDomainDims(); i < e; ++i)
+ rel.setValue(i, map.getOperand(i));
+ for (unsigned i = rel.getNumDimIds(), e = rel.getNumDimAndSymbolIds(); i < e;
+ ++i)
+ rel.setValue(i, map.getOperand(i - rel.getNumRangeDims()));
+
+ return success();
+}