/// <= 15}, output = {0 <= d0 <= 6, 1 <= d1 <= 15}.
LogicalResult unionBoundingBox(const FlatAffineConstraints &other);
+ /// Returns 'true' if this constraint system and 'other' are in the same
+ /// space, i.e., if they are associated with the same set of identifiers,
+ /// appearing in the same order. Returns 'false' otherwise.
+ bool areIdsAlignedWithOther(const FlatAffineConstraints &other);
+
+ /// Merge and align the identifiers of 'this' and 'other' starting at
+ /// 'offset', so that both constraint systems get the union of the contained
+ /// identifiers that is dimension-wise and symbol-wise unique; both
+ /// constraint systems are updated so that they have the union of all
+ /// identifiers, with this's original identifiers appearing first followed by
+ /// any of other's identifiers that didn't appear in 'this'. Local
+ /// identifiers of each system are by design separate/local and are placed
+ /// one after other (this's followed by other's).
+ // Eg: Input: 'this' has ((%i %j) [%M %N])
+ // 'other' has (%k, %j) [%P, %N, %M])
+ // Output: both 'this', 'other' have (%i, %j, %k) [%M, %N, %P]
+ //
+ void mergeAndAlignIdsWithOther(unsigned offset, FlatAffineConstraints *other);
+
unsigned getNumConstraints() const {
return getNumInequalities() + getNumEqualities();
}
const MemRefAccess &srcAccess, const MemRefAccess &dstAccess,
unsigned dstLoopDepth, ComputationSliceState *sliceState);
+/// Computes in 'sliceUnion' the union of all slice bounds computed at
+/// 'dstLoopDepth' between all pairs in 'srcOps' and 'dstOp' which access the
+/// same memref. Returns 'success' if union was computed, 'failure' otherwise.
+LogicalResult computeSliceUnion(ArrayRef<Operation *> srcOps,
+ ArrayRef<Operation *> dstOps,
+ unsigned dstLoopDepth,
+ ComputationSliceState *sliceUnion);
+
/// Creates a clone of the computation contained in the loop nest surrounding
/// 'srcOpInst', slices the iteration space of src loop based on slice bounds
/// in 'sliceState', and inserts the computation slice at the beginning of the
/// Checks if two constraint systems are in the same space, i.e., if they are
/// associated with the same set of identifiers, appearing in the same order.
-bool areIdsAligned(const FlatAffineConstraints &A,
- const FlatAffineConstraints &B) {
+static bool areIdsAligned(const FlatAffineConstraints &A,
+ const FlatAffineConstraints &B) {
return A.getNumDimIds() == B.getNumDimIds() &&
A.getNumSymbolIds() == B.getNumSymbolIds() &&
A.getNumIds() == B.getNumIds() && A.getIds().equals(B.getIds());
}
+/// Calls areIdsAligned to check if two constraint systems have the same set
+/// of identifiers in the same order.
+bool FlatAffineConstraints::areIdsAlignedWithOther(
+ const FlatAffineConstraints &other) {
+ return areIdsAligned(*this, other);
+}
+
/// Checks if the SSA values associated with `cst''s identifiers are unique.
static bool LLVM_ATTRIBUTE_UNUSED
areIdsUnique(const FlatAffineConstraints &cst) {
// Eg: Input: A has ((%i %j) [%M %N]) and B has (%k, %j) [%P, %N, %M])
// Output: both A, B have (%i, %j, %k) [%M, %N, %P]
//
-// TODO(mlir-team): expose this function at some point.
static void mergeAndAlignIds(unsigned offset, FlatAffineConstraints *A,
FlatAffineConstraints *B) {
assert(offset <= A->getNumDimIds() && offset <= B->getNumDimIds());
assert(areIdsAligned(*A, *B) && "IDs expected to be aligned");
}
+// Call 'mergeAndAlignIds' to align constraint systems of 'this' and 'other'.
+void FlatAffineConstraints::mergeAndAlignIdsWithOther(
+ unsigned offset, FlatAffineConstraints *other) {
+ mergeAndAlignIds(offset, this, other);
+}
+
// This routine may add additional local variables if the flattened expression
// corresponding to the map has such variables due to mod's, ceildiv's, and
// floordiv's in it.
if (failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/true,
/*lower=*/true)))
return failure();
- if (failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/true,
- /*lower=*/true)))
- return failure();
continue;
}
if (lbMap && failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/false,
/*lower=*/true)))
return failure();
- if (lbMap && failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/false,
- /*lower=*/true)))
- return failure();
if (ubMap && failed(addLowerOrUpperBound(pos, ubMap, operands, /*eq=*/false,
/*lower=*/false)))
#include "mlir/IR/Builders.h"
#include "mlir/StandardOps/Ops.h"
#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
return nullptr;
}
+// Returns the MemRef accessed by load or store 'op'.
+static Value *getLoadOrStoreMemRef(Operation *op) {
+ if (auto loadOp = dyn_cast<LoadOp>(op))
+ return loadOp.getMemRef();
+ return cast<StoreOp>(op).getMemRef();
+}
+
+// Adds loop IV bounds to 'cst' for loop IVs not found in 'ivs'.
+LogicalResult addMissingLoopIVBounds(SmallPtrSet<Value *, 8> &ivs,
+ FlatAffineConstraints *cst) {
+ for (unsigned i = 0, e = cst->getNumDimIds(); i < e; ++i) {
+ auto *value = cst->getIdValue(i);
+ if (ivs.count(value) == 0) {
+ assert(isForInductionVar(value));
+ auto loop = getForInductionVarOwner(value);
+ if (failed(cst->addAffineForOpDomain(loop)))
+ return failure();
+ }
+ }
+ return success();
+}
+
+/// Computes in 'sliceUnion' the union of all slice bounds computed at
+/// 'dstLoopDepth' between all pairs in 'srcOps' and 'dstOp' which access the
+/// same memref. Returns 'Success' if union was computed, 'failure' otherwise.
+LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> srcOps,
+ ArrayRef<Operation *> dstOps,
+ unsigned dstLoopDepth,
+ ComputationSliceState *sliceUnion) {
+ unsigned numSrcOps = srcOps.size();
+ unsigned numDstOps = dstOps.size();
+ assert(numSrcOps > 0 && numDstOps > 0);
+
+ // Compute the intersection of 'srcMemrefToOps' and 'dstMemrefToOps'.
+ llvm::SmallDenseSet<Value *> memrefIntersection;
+ for (auto *srcOp : srcOps) {
+ auto *srcMemRef = getLoadOrStoreMemRef(srcOp);
+ for (auto *dstOp : dstOps) {
+ if (srcMemRef == getLoadOrStoreMemRef(dstOp))
+ memrefIntersection.insert(srcMemRef);
+ }
+ }
+ // Return failure if 'memrefIntersection' is empty.
+ if (memrefIntersection.empty())
+ return failure();
+
+ // Compute the union of slice bounds between all pairs in 'srcOps' and
+ // 'dstOps' in 'sliceUnionCst'.
+ FlatAffineConstraints sliceUnionCst;
+ assert(sliceUnionCst.getNumDimAndSymbolIds() == 0);
+ for (unsigned i = 0; i < numSrcOps; ++i) {
+ MemRefAccess srcAccess(srcOps[i]);
+ for (unsigned j = 0; j < numDstOps; ++j) {
+ MemRefAccess dstAccess(dstOps[j]);
+ if (srcAccess.memref != dstAccess.memref)
+ continue;
+ // Compute slice bounds for 'srcAccess' and 'dstAccess'.
+ ComputationSliceState tmpSliceState;
+ if (failed(mlir::getBackwardComputationSliceState(
+ srcAccess, dstAccess, dstLoopDepth, &tmpSliceState))) {
+ LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice bounds\n.");
+ return failure();
+ }
+
+ if (sliceUnionCst.getNumDimAndSymbolIds() == 0) {
+ // Initialize 'sliceUnionCst' with the bounds computed in previous step.
+ if (failed(tmpSliceState.getAsConstraints(&sliceUnionCst))) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Unable to compute slice bound constraints\n.");
+ return failure();
+ }
+ assert(sliceUnionCst.getNumDimAndSymbolIds() > 0);
+ continue;
+ }
+
+ // Compute constraints for 'tmpSliceState' in 'tmpSliceCst'.
+ FlatAffineConstraints tmpSliceCst;
+ if (failed(tmpSliceState.getAsConstraints(&tmpSliceCst))) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Unable to compute slice bound constraints\n.");
+ return failure();
+ }
+
+ // Align coordinate spaces of 'sliceUnionCst' and 'tmpSliceCst' if needed.
+ if (!sliceUnionCst.areIdsAlignedWithOther(tmpSliceCst)) {
+
+ // Pre-constraint id alignment: record loop IVs used in each constraint
+ // system.
+ SmallPtrSet<Value *, 8> sliceUnionIVs;
+ for (unsigned k = 0, l = sliceUnionCst.getNumDimIds(); k < l; ++k)
+ sliceUnionIVs.insert(sliceUnionCst.getIdValue(k));
+ SmallPtrSet<Value *, 8> tmpSliceIVs;
+ for (unsigned k = 0, l = tmpSliceCst.getNumDimIds(); k < l; ++k)
+ tmpSliceIVs.insert(tmpSliceCst.getIdValue(k));
+
+ sliceUnionCst.mergeAndAlignIdsWithOther(/*offset=*/0, &tmpSliceCst);
+
+ // Post-constraint id alignment: add loop IV bounds missing after
+ // id alignment to constraint systems. This can occur if one constraint
+ // system uses an loop IV that is not used by the other. The call
+ // to unionBoundingBox below expects constraints for each Loop IV, even
+ // if they are the unsliced full loop bounds added here.
+ if (failed(addMissingLoopIVBounds(sliceUnionIVs, &sliceUnionCst)))
+ return failure();
+ if (failed(addMissingLoopIVBounds(tmpSliceIVs, &tmpSliceCst)))
+ return failure();
+ }
+ // Compute union bounding box of 'sliceUnionCst' and 'tmpSliceCst'.
+ if (failed(sliceUnionCst.unionBoundingBox(tmpSliceCst))) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Unable to compute union bounding box of slice bounds."
+ "\n.");
+ return failure();
+ }
+ }
+ }
+
+ // Store 'numSrcLoopIvs' before converting dst loop IVs to dims.
+ unsigned numSrcLoopIVs = sliceUnionCst.getNumDimIds();
+
+ // Convert any dst loop IVs which are symbol identifiers to dim identifiers.
+ sliceUnionCst.convertLoopIVSymbolsToDims();
+ sliceUnion->clearBounds();
+ sliceUnion->lbs.resize(numSrcLoopIVs, AffineMap());
+ sliceUnion->ubs.resize(numSrcLoopIVs, AffineMap());
+
+ // Get slice bounds from slice union constraints 'sliceUnionCst'.
+ sliceUnionCst.getSliceBounds(numSrcLoopIVs, srcOps[0]->getContext(),
+ &sliceUnion->lbs, &sliceUnion->ubs);
+
+ // Add slice bound operands of union.
+ SmallVector<Value *, 4> sliceBoundOperands;
+ sliceUnionCst.getIdValues(numSrcLoopIVs,
+ sliceUnionCst.getNumDimAndSymbolIds(),
+ &sliceBoundOperands);
+
+ // Copy src loop IVs from 'sliceUnionCst' to 'sliceUnion'.
+ sliceUnion->ivs.clear();
+ sliceUnionCst.getIdValues(0, numSrcLoopIVs, &sliceUnion->ivs);
+
+ // Give each bound its own copy of 'sliceBoundOperands' for subsequent
+ // canonicalization.
+ sliceUnion->lbOperands.resize(numSrcLoopIVs, sliceBoundOperands);
+ sliceUnion->ubOperands.resize(numSrcLoopIVs, sliceBoundOperands);
+ return success();
+}
+
const char *const kSliceFusionBarrierAttrName = "slice_fusion_barrier";
// Computes memref dependence between 'srcAccess' and 'dstAccess', projects
// out any dst loop IVs at depth greater than 'dstLoopDepth', and computes slice
return true;
}
-// Computes the union of all slice bounds computed between 'srcOpInst'
-// and each load op in 'dstLoadOpInsts' at 'dstLoopDepth', and returns
-// the union in 'sliceState'. Returns true on success, false otherwise.
-// TODO(andydavis) Move this to a loop fusion utility function.
-static bool getSliceUnion(Operation *srcOpInst,
- ArrayRef<Operation *> dstLoadOpInsts,
- unsigned numSrcLoopIVs, unsigned dstLoopDepth,
- ComputationSliceState *sliceState) {
- MemRefAccess srcAccess(srcOpInst);
- unsigned numDstLoadOpInsts = dstLoadOpInsts.size();
- assert(numDstLoadOpInsts > 0);
- // Compute the slice bounds between 'srcOpInst' and 'dstLoadOpInsts[0]'.
- if (failed(mlir::getBackwardComputationSliceState(
- srcAccess, MemRefAccess(dstLoadOpInsts[0]), dstLoopDepth,
- sliceState)))
- return false;
- // Handle the common case of one dst load without a copy.
- if (numDstLoadOpInsts == 1)
- return true;
-
- // Initialize 'sliceUnionCst' with the bounds computed in previous step.
- FlatAffineConstraints sliceUnionCst;
- if (failed(sliceState->getAsConstraints(&sliceUnionCst))) {
- LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice bound constraints\n.");
- return false;
- }
-
- // Compute the union of slice bounds between 'srcOpInst' and each load
- // in 'dstLoadOpInsts' in range [1, numDstLoadOpInsts), in 'sliceUnionCst'.
- for (unsigned i = 1; i < numDstLoadOpInsts; ++i) {
- MemRefAccess dstAccess(dstLoadOpInsts[i]);
- // Compute slice bounds for 'srcOpInst' and 'dstLoadOpInsts[i]'.
- ComputationSliceState tmpSliceState;
- if (failed(mlir::getBackwardComputationSliceState(
- srcAccess, dstAccess, dstLoopDepth, &tmpSliceState))) {
- LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice bounds\n.");
- return false;
- }
-
- // Compute constraints for 'tmpSliceState' in 'tmpSliceCst'.
- FlatAffineConstraints tmpSliceCst;
- if (failed(tmpSliceState.getAsConstraints(&tmpSliceCst))) {
- LLVM_DEBUG(llvm::dbgs()
- << "Unable to compute slice bound constraints\n.");
- return false;
- }
- // Compute union bounding box of 'sliceUnionCst' and 'tmpSliceCst'.
- if (failed(sliceUnionCst.unionBoundingBox(tmpSliceCst))) {
- LLVM_DEBUG(llvm::dbgs()
- << "Unable to compute union bounding box of slice bounds.\n.");
- return false;
- }
- }
-
- // Convert any dst loop IVs which are symbol identifiers to dim identifiers.
- sliceUnionCst.convertLoopIVSymbolsToDims();
-
- sliceState->clearBounds();
- sliceState->lbs.resize(numSrcLoopIVs, AffineMap());
- sliceState->ubs.resize(numSrcLoopIVs, AffineMap());
-
- // Get slice bounds from slice union constraints 'sliceUnionCst'.
- sliceUnionCst.getSliceBounds(numSrcLoopIVs, srcOpInst->getContext(),
- &sliceState->lbs, &sliceState->ubs);
- // Add slice bound operands of union.
- SmallVector<Value *, 4> sliceBoundOperands;
- sliceUnionCst.getIdValues(numSrcLoopIVs,
- sliceUnionCst.getNumDimAndSymbolIds(),
- &sliceBoundOperands);
- // Give each bound its own copy of 'sliceBoundOperands' for subsequent
- // canonicalization.
- sliceState->lbOperands.resize(numSrcLoopIVs, sliceBoundOperands);
- sliceState->ubOperands.resize(numSrcLoopIVs, sliceBoundOperands);
- return true;
-}
-
// Checks the profitability of fusing a backwards slice of the loop nest
// surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'.
// The argument 'srcStoreOpInst' is used to calculate the storage reduction on
DenseMap<Operation *, int64_t> computeCostMap;
for (unsigned i = maxDstLoopDepth; i >= 1; --i) {
// Compute the union of slice bounds of all ops in 'dstLoadOpInsts'.
- if (!getSliceUnion(srcOpInst, dstLoadOpInsts, numSrcLoopIVs, i,
- &sliceStates[i - 1])) {
+ if (failed(mlir::computeSliceUnion({srcOpInst}, dstLoadOpInsts,
+ /*dstLoopDepth=*/i,
+ &sliceStates[i - 1]))) {
LLVM_DEBUG(llvm::dbgs()
- << "getSliceUnion failed for loopDepth: " << i << "\n");
+ << "computeSliceUnion failed for loopDepth: " << i << "\n");
continue;
}
continue;
// TODO(andydavis) Remove assert and surrounding code when
// canFuseLoops is fully functional.
+ mlir::ComputationSliceState sliceUnion;
FusionResult result = mlir::canFuseLoops(
cast<AffineForOp>(srcNode->op), cast<AffineForOp>(dstNode->op),
- bestDstLoopDepth, /*srcSlice=*/nullptr);
+ bestDstLoopDepth, &sliceUnion);
assert(result.value == FusionResult::Success);
(void)result;
unsigned j, unsigned loopDepth) {
AffineForOp srcForOp = loops[i];
AffineForOp dstForOp = loops[j];
- FusionResult result = mlir::canFuseLoops(srcForOp, dstForOp, loopDepth,
- /*srcSlice=*/nullptr);
+ mlir::ComputationSliceState sliceUnion;
+ // TODO(andydavis) Test at deeper loop depths current loop depth + 1.
+ FusionResult result =
+ mlir::canFuseLoops(srcForOp, dstForOp, loopDepth + 1, &sliceUnion);
if (result.value == FusionResult::FailBlockDependence) {
srcForOp.getOperation()->emitRemark("block-level dependence preventing"
" fusion of loop nest ")
using namespace mlir;
-// Gathers all load and store operations in 'opA' into 'values', where
+// Gathers all load and store memref accesses in 'opA' into 'values', where
// 'values[memref] == true' for each store operation.
-static void getLoadsAndStores(Operation *opA, DenseMap<Value *, bool> &values) {
+static void getLoadAndStoreMemRefAccesses(Operation *opA,
+ DenseMap<Value *, bool> &values) {
opA->walk([&](Operation *op) {
if (auto loadOp = dyn_cast<LoadOp>(op)) {
if (values.count(loadOp.getMemRef()) == 0)
// Record memref values from all loads/store in loop nest rooted at 'opA'.
// Map from memref value to bool which is true if store, false otherwise.
DenseMap<Value *, bool> values;
- getLoadsAndStores(opA, values);
+ getLoadAndStoreMemRefAccesses(opA, values);
// For each 'opX' in block in range ('opA', 'opB'), check if there is a data
// dependence from 'opA' to 'opX' ('opA' and 'opX' access the same memref
// Record memref values from all loads/store in loop nest rooted at 'opB'.
// Map from memref value to bool which is true if store, false otherwise.
DenseMap<Value *, bool> values;
- getLoadsAndStores(opB, values);
+ getLoadAndStoreMemRefAccesses(opB, values);
// For each 'opX' in block in range ('opA', 'opB') in reverse order,
// check if there is a data dependence from 'opX' to 'opB':
return forOpB.getOperation();
}
+// Gathers all load and store ops in loop nest rooted at 'forOp' into
+// 'loadAndStoreOps'.
+static bool
+gatherLoadsAndStores(AffineForOp forOp,
+ SmallVectorImpl<Operation *> &loadAndStoreOps) {
+ bool hasIfOp = false;
+ forOp.getOperation()->walk([&](Operation *op) {
+ if (isa<LoadOp>(op) || isa<StoreOp>(op))
+ loadAndStoreOps.push_back(op);
+ else if (isa<AffineIfOp>(op))
+ hasIfOp = true;
+ });
+ return !hasIfOp;
+}
+
// TODO(andydavis) Add support for the following features in subsequent CLs:
-// *) Computing union of slices computed between src/dst loads and stores.
// *) Compute dependences of unfused src/dst loops.
// *) Compute dependences of src/dst loop as if they were fused.
// *) Check for fusion preventing dependences (e.g. a dependence which changes
FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
unsigned dstLoopDepth,
ComputationSliceState *srcSlice) {
- // Return 'false' if 'srcForOp' and 'dstForOp' are not in the same block.
+ // Return 'failure' if 'dstLoopDepth == 0'.
+ if (dstLoopDepth == 0) {
+ LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests at depth 0\n.");
+ return FusionResult::FailPrecondition;
+ }
+ // Return 'failure' if 'srcForOp' and 'dstForOp' are not in the same block.
auto *block = srcForOp.getOperation()->getBlock();
if (block != dstForOp.getOperation()->getBlock()) {
LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests in different blocks\n.");
return FusionResult::FailPrecondition;
}
- // Return 'false' if no valid insertion point for fused loop nest in 'block'
+ // Return 'failure' if no valid insertion point for fused loop nest in 'block'
// exists which would preserve dependences.
if (!getFusedLoopNestInsertionPoint(srcForOp, dstForOp)) {
LLVM_DEBUG(llvm::dbgs() << "Fusion would violate dependences in block\n.");
return FusionResult::FailBlockDependence;
}
+
+ // Gather all load and store ops in 'srcForOp'.
+ SmallVector<Operation *, 4> srcLoadAndStoreOps;
+ if (!gatherLoadsAndStores(srcForOp, srcLoadAndStoreOps)) {
+ LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported.\n.");
+ return FusionResult::FailPrecondition;
+ }
+
+ // Gather all load and store ops in 'dstForOp'.
+ SmallVector<Operation *, 4> dstLoadAndStoreOps;
+ if (!gatherLoadsAndStores(dstForOp, dstLoadAndStoreOps)) {
+ LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported.\n.");
+ return FusionResult::FailPrecondition;
+ }
+
+ // Compute union of computation slices computed from all pairs in
+ // {'srcLoadAndStoreOps', 'dstLoadAndStoreOps'}.
+ if (failed(mlir::computeSliceUnion(srcLoadAndStoreOps, dstLoadAndStoreOps,
+ dstLoopDepth, srcSlice))) {
+ LLVM_DEBUG(llvm::dbgs() << "computeSliceUnion failed\n");
+ return FusionResult::FailPrecondition;
+ }
+
return FusionResult::Success;
}