From fd68d36109c6fcebb6d758046b88b0664acccf51 Mon Sep 17 00:00:00 2001 From: Peiming Liu Date: Tue, 6 Jun 2023 22:51:32 +0000 Subject: [PATCH] [mlir][sparse] unifying enterLoopOverTensorAtLvl and enterCoIterationOverTensorsAtLvls The tensor levels are now explicitly categorized into different `LoopCondKind` to instruct LoopEmitter generate different code for different kinds of condition (e.g., `SparseCond`, `SparseSliceCond`, `SparseAffineIdxCond`, etc) The process of generating a while loop is now dissembled into three steps and they are dispatched to different LoopCondKind handler. 1. Generate LoopCondition (e.g., `pos <= posHi` for `SparseCond`, `slice.isNonEmpty` for `SparseAffineIdxCond`) 2. Generate LoopBody (e.g., compute the coordinates) 3. Generate ExtraChecks (e.g., `if (onSlice(crd))` for `SparseSliceCond`) Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D152464 --- .../SparseTensor/Transforms/LoopEmitter.cpp | 1028 ++++++++++---------- .../Dialect/SparseTensor/Transforms/LoopEmitter.h | 241 +++-- .../Transforms/SparseTensorRewriting.cpp | 3 +- .../SparseTensor/Transforms/Sparsification.cpp | 77 +- mlir/test/Dialect/SparseTensor/sorted_coo.mlir | 4 +- 5 files changed, 762 insertions(+), 591 deletions(-) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp index f466cce..d0884ca 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp @@ -101,6 +101,28 @@ static std::pair fromSliceCrd(OpBuilder &builder, Location loc, return std::make_pair(crd, rem); } +// Generates a bool value for while loop condition that tries to iterate over a +// fully reduced level with affine index expression. +static Value genSparseReducedAffineCond(OpBuilder &builder, Location loc, + Value crdBuf, Value crdHi, Value posit, + Value posHi) { + Value inBound = CMPI(ult, posit, posHi); + auto ifOp = + builder.create(loc, builder.getI1Type(), inBound, true); + // if (inbound) + // yield coord < crdHi + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + Value crd = genIndexLoad(builder, loc, crdBuf, posit); + YIELD(CMPI(ult, crd, crdHi)); + // else + // yield false + builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); + YIELD(constantI1(builder, loc, false)); + + builder.setInsertionPointAfter(ifOp); + return ifOp.getResult(0); +} + std::pair LoopEmitter::genSliceLegitPredicate(OpBuilder &builder, Location loc, Value crd, TensorId tid, Level lvl) { @@ -470,6 +492,41 @@ void LoopEmitter::initializeLoopEmit(OpBuilder &builder, Location loc, localInsertPos = builder.getInsertionPoint()->getPrevNode(); } +void LoopEmitter::categorizeLoopCondition( + ArrayRef tidLvls, SmallVectorImpl &dnConds, + SmallVectorImpl &spConds) { + // Finds out the tensor level that we should use to generate loops. Amongs all + // the tensor levels, there is at most one sparse tensor level. + for (auto [t, l] : unpackTensorLevelRange(tidLvls)) { + assert(lvlTypes[t].size() > l); // Must be a valid tid, dim pair + auto lvlType = lvlTypes[t][l]; + // Must be a recognizable DLT. + assert(isDenseDLT(lvlType) || isCompressedDLT(lvlType) || + isCompressedWithHiDLT(lvlType) || isSingletonDLT(lvlType)); + + bool isSparse = !isDenseDLT(lvlType); + bool isSlice = isSparseSlices[t]; + bool isAffine = !dependentLvlMap[t][l].empty(); + bool isUnRedu = false; + // TODO: Supports affine index expression on sparse tensor slices. + assert(!isSlice || !isAffine); + + // Whether the affine index expression has been fully reduced or not. + if (!dependentLvlMap[t][l].empty()) + isUnRedu = !depFullyReduced(t, l); + + auto &dstVec = isSparse ? spConds : dnConds; + dstVec.emplace_back( + makeTensorLevel(t, l), + makeLoopCondKind(isSparse, isSlice, isAffine, isUnRedu)); + } + + std::sort(spConds.begin(), spConds.end(), [](auto lhs, auto rhs) { + // AffineUnRed > Affine > Slice > Trivial + return static_cast(lhs.second) > static_cast(rhs.second); + }); +} + void LoopEmitter::enterNewLoopSeq(OpBuilder &builder, Location loc, ArrayRef tidLvls) { // TODO: sort @@ -561,7 +618,7 @@ Value LoopEmitter::genAffine(OpBuilder &builder, Location loc, AffineExpr a) { } } -Operation *LoopEmitter::emitForLoopOverTensorAtLvl( +std::pair LoopEmitter::emitForLoopOverTensorAtLvl( OpBuilder &builder, Location loc, TensorId tid, Level dstLvl, Value lo, Value hi, MutableArrayRef reduc, bool isParallel) { bool isSparseCond = isCompressedDLT(lvlTypes[tid][dstLvl]) || @@ -651,166 +708,433 @@ Operation *LoopEmitter::emitForLoopOverTensorAtLvl( assert(crd); coords[tid][dstLvl] = crd; - return loop; + return {loop, crd}; } -Operation *LoopEmitter::emitWhileLoopOverSliceAtSparseLvl( - OpBuilder &builder, Location loc, Value pLo, Value pHi, Value offset, - Value sliceSize, TensorId tid, Level lvl, MutableArrayRef reduc) { - // TODO: we should generalize the method to support iteration over for - // normal slices as well to allow early break. - Operation *insertPoint = nullptr; - Operation *loop = - genSliceLvlTraverseLoop( - builder, loc, pLo, pHi, offset, sliceSize, tid, lvl, reduc, - /*genYield=*/false, // unaware of the yield values from user yet - [this, tid, lvl, reduc, offset, - &insertPoint](OpBuilder &builder, Location loc, Value iv, - MutableArrayRef innerReduc) { - assert(innerReduc.size() == reduc.size()); - // Updates users' reduction variable inplace - for (unsigned i = 0, e = reduc.size(); i < e; i++) - reduc[i] = innerReduc[i]; - // Loads the coordinates. - Value absC = - genIndexLoad(builder, loc, coordinatesBuffers[tid][lvl], iv); - - // We need to substract the offset to get relative coordinates. - // TODO: how to assert relC >=0 during runtime? - insertPoint = builder.create(loc, absC, offset); - posits[tid][lvl] = iv; - coords[tid][lvl] = insertPoint->getResult(0); - }) - .first; - // Sets the insertionn pointer inside loop body. - builder.setInsertionPointAfter(insertPoint); - return loop; +Value LoopEmitter::genWhileLoopConditions(OpBuilder &builder, Location loc, + ValueRange ivs, TensorLvlCond cond) { + auto [tid, lvl] = unpackTensorLevel(cond.first); + + switch (cond.second) { + case LoopCondKind::SparseCond: { + const auto reassoc = getCollapseReassociation(tid, lvl); + assert(reassoc.size() == ivs.size()); + assert(reassoc.size() == 1 || isUniqueCOOType(tensors[tid].getType())); + // We used the first level bound as the bound the collapsed set of levels. + return CMPI(ult, ivs.back(), highs[tid][reassoc.front()]); + } + case LoopCondKind::SparseSliceCond: { + assert(ivs.size() == 1); + return CMPI(ult, ivs.back(), highs[tid][lvl]); + } + case LoopCondKind::SparseAffineCond: { + assert(ivs.size() == 1); + Value crdHi; // loop upper bound + { + OpBuilder::InsertionGuard guard(builder); + Operation *loop = builder.getInsertionBlock()->getParentOp(); + // crdHi is a loop invariant, hosit the computation outside the loop. + if (llvm::isa_and_nonnull(loop)) + builder.setInsertionPoint(loop); + crdHi = ADDI(getMostRecentSliceOnLvl(tid, lvl).offset, + sliceSizes[tid][lvl].back()); + } + assert(crdHi); + return genSparseReducedAffineCond(builder, loc, + coordinatesBuffers[tid][lvl], crdHi, + ivs[0], highs[tid][lvl]); + } + case LoopCondKind::SparseAffineUnRedCond: { + assert(ivs.size() == 3); + return ivs.front(); // isNonEmpty + } + default: + llvm_unreachable("Unhandled LoopCondKind"); + } + llvm_unreachable("Unhandled LoopCondKind"); } -Operation *LoopEmitter::enterLoopOverTensorAtLvl(OpBuilder &builder, - Location loc, - ArrayRef tidLvls, - MutableArrayRef reduc, - bool isParallel) { - // TODO: support multiple return on parallel for? - assert(!isParallel || reduc.size() <= 1); - bool isSparseCond = false, isSparseSliceCond = false; - auto [tid, lvl] = unpackTensorLevel(tidLvls.front()); +std::optional LoopEmitter::genWhileLoopBody(OpBuilder &builder, + Location loc, ValueRange ivs, + TensorLvlCond cond) { + auto [tid, lvl] = unpackTensorLevel(cond.first); - // Finds out the tensor level that we should use to generate loops. Amongs all - // the tensor levels, there is at most one sparse tensor level. - for (auto [t, l] : unpackTensorLevelRange(tidLvls)) { - assert(lvlTypes[t].size() > l); // Must be a valid tid, dim pair - assert(!coords[t][l] || // We cannot re-enter the same level - !dependentLvlMap[t][l].empty()); // unless it is a slice-driver loop - auto lvlType = lvlTypes[t][l]; - // Must be a recognizable DLT. - assert(isDenseDLT(lvlType) || isCompressedDLT(lvlType) || - isCompressedWithHiDLT(lvlType) || isSingletonDLT(lvlType)); + switch (cond.second) { + case LoopCondKind::SparseCond: { + const auto reassoc = getCollapseReassociation(tid, lvl); + assert(reassoc.size() == 1 || isUniqueCOOType(tensors[tid].getType())); + // Links the SSA chain for segHi. + for (unsigned i = 0, e = reassoc.size() - 1; i < e; i++) + if (!isUniqueDLT(lvlTypes[tid][reassoc[i]])) + segHi[tid][reassoc[i]] = ivs[i]; + + // Updates position. For collapsed COO, the position is the same across + // consecutive levels. + for (auto srcLvl : reassoc) + posits[tid][srcLvl] = ivs.back(); + + // Update coordinates. + coords[tid][lvl] = genSparseCrd(builder, loc, tid, lvl); + return std::nullopt; + } + case LoopCondKind::SparseSliceCond: { + assert(ivs.size() == 1); + posits[tid][lvl] = ivs.front(); + Value sCrd = genSparseCrd(builder, loc, tid, lvl); + // Converts the coordinate loaded from the actual sparse tensor to the + // coordinates in the sparse slice. + auto [dCrd, pred] = genSliceLegitPredicate(builder, loc, sCrd, tid, lvl); + coords[tid][lvl] = dCrd; + return pred; + } + case LoopCondKind::SparseAffineCond: { + assert(ivs.size() == 1); + // Coord is the relative offset related to its parents. + assert(sliceStack[tid].back().depth == 1 && "TODO: not yet implement"); + // Update c = absOffset[lvl][depth] - absOffset[lvl][depth - 1] + Value posit = ivs[0]; + Value crdBuf = coordinatesBuffers[tid][lvl]; + // We need to substract the offset to get relative coordinates. + // TODO: Maybe assert relC >=0 during runtime in debug build? + Value absC = genIndexLoad(builder, loc, crdBuf, posit); + auto relC = SUBI(absC, getFinalSliceOnLvl(tid, lvl).offset); + posits[tid][lvl] = posit; + coords[tid][lvl] = relC; + return std::nullopt; + } + case LoopCondKind::SparseAffineUnRedCond: { + assert(ivs.size() == 3); + // Coord is the relative offset related to its parents. + // Update c = absOffset[lvl][depth] - absOffset[lvl][depth - 1] + assert(sliceStack[tid].back().depth == 1 && "TODO: not yet implement"); + // Updates the current slice info + SliceInfo &sliceInfo = sliceStack[tid].back(); + sliceInfo.isNonEmpty = ivs[0]; + sliceInfo.minCrd = ivs[1]; + sliceInfo.offset = ivs[2]; + coords[tid][lvl] = sliceInfo.offset; + // No extra check is needed before accessing the tensor level. + return std::nullopt; + } + default: + llvm_unreachable("Unhandled LoopCondKind"); + } + llvm_unreachable("Unhandled LoopCondKind"); +} - // This is a slice-driven loop on sparse level. - if (!dependentLvlMap[t][l].empty() && !isDenseDLT(lvlType)) { - assert(!isSparseSliceCond && !isSparseCond); - isSparseSliceCond = true; - tid = t; - lvl = l; - continue; +ValueRange LoopEmitter::genCheckedValue(OpBuilder &builder, Location loc, + Value pred, ValueRange curArgs, + TensorLvlCond cond) { + // Currently only sparse slice condition need extra check. + assert(isSliceCond(cond.second) && isSparseCond(cond.second)); + assert(curArgs.size() == 1); + Value nextPos = ADDI(curArgs.front(), C_IDX(1)); + return SELECT(pred, curArgs.front(), nextPos)->getResults(); +} + +std::pair LoopEmitter::emitWhileLoopOverTensorsAtLvls( + OpBuilder &builder, Location loc, ArrayRef spConds, + MutableArrayRef reduc, bool needsUniv) { + // NOTE: the slice driven tensor-related reduction variable must + // appear before normal tensors. + assert(!spConds.empty()); + + // The set of induction variables for the while loop. + SmallVector ivs; + // Segement sizes for induction variables used for different kinds of loop + // conditions. + SmallVector opSegSize; + + // Construct the while-loop with a parameter for each coordinate. + for (auto [tl, cKind] : spConds) { + auto [tid, lvl] = unpackTensorLevel(tl); + const auto lvlTp = lvlTypes[tid][lvl]; + // Dense level are handled by the shared univeral index. + assert(!isDenseCond(cKind)); + // Must be a recognizable sparse level. + assert(isCompressedDLT(lvlTp) || isCompressedWithHiDLT(lvlTp) || + isSingletonDLT(lvlTp)); + + unsigned prevSz = ivs.size(); + const auto reassoc = getCollapseReassociation(tid, lvl); + if (isAffineIdxCond(cKind)) { + // TODO: Support view-based reshape on sparse levels with affine index + // expressions. + assert(reassoc.size() == 1); + if (isAffineIdxUnRedCond(cKind)) { + SliceInfo &sliceInfo = sliceStack[tid].back(); + // The order matters! + ivs.push_back(sliceInfo.isNonEmpty); + ivs.push_back(sliceInfo.minCrd); + ivs.push_back(sliceInfo.offset); + } else { + ivs.push_back(posits[tid][lvl]); // loop lower bound (pos low). + } + // We reduced one more dependency after entering the loop. + levelReducedDep[tid][lvl]++; + } else { + assert(dependentLvlMap[tid][lvl].empty()); + for (unsigned i = 0, e = reassoc.size() - 1; i < e; i++) { + // This is the segment high for each non-unique levels. + if (!isUniqueDLT(lvlTypes[tid][reassoc[i]])) + ivs.push_back(C_IDX(0)); + } + const Value pos = posits[tid][reassoc.front()]; + ivs.push_back(pos); } + opSegSize.push_back(ivs.size() - prevSz); + } + + // The position where user-supplied reduction variable starts. + ivs.append(reduc.begin(), reduc.end()); + // Update universal index. + if (needsUniv) + ivs.push_back(loopSeqStack.back().first); + + // Ensures all operands are valid. + assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; })); + TypeRange types = ValueRange(ivs).getTypes(); + auto whileOp = builder.create(loc, types, ivs); + + SmallVector locs(types.size(), loc); + Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs); + Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs); - bool isSparse = isCompressedDLT(lvlType) || isSingletonDLT(lvlType) || - isCompressedWithHiDLT(lvlType); - // We can at most have one sparse input, otherwise, a while loop is - // required to co-iterate multiple sparse tensors. - assert(!isSparseCond || !isSparse); - assert(!isSparseSliceCond || !isSparseCond); - if (isSparse) { - tid = t; - lvl = l; + // Generates loop conditions. + builder.setInsertionPointToStart(before); + ValueRange bArgs = before->getArguments(); + Value whileCond = nullptr; // bool values for loop condition. + for (auto [c, segSz] : llvm::zip_equal(spConds, opSegSize)) { + Value cv = genWhileLoopConditions(builder, loc, bArgs.take_front(segSz), c); + bArgs = bArgs.drop_front(segSz); + whileCond = !whileCond ? cv : ANDI(whileCond, cv); + } + // The remaining block arguments are user-provided reduction values and an + // optional universal index. Make sure their sizes match. + assert(bArgs.size() == reduc.size() + needsUniv ? 1 : 0); + builder.create(loc, whileCond, before->getArguments()); + + // Generates loop body. + builder.setInsertionPointToStart(after); + ValueRange aArgs = after->getArguments(); + // Since some LoopCondKind might need extra checks to filter out invalid + // iterations, we maintains another array to hold the iteration arguments to + // yield if the checks fails. + SmallVector nextArgs(aArgs.begin(), aArgs.end()); + // A mutable alias for convenient slicing. + MutableArrayRef nextArgsRef = nextArgs; + Value extraPred = nullptr; + for (auto [c, segSz] : llvm::zip_equal(spConds, opSegSize)) { + ValueRange condArgs = aArgs.take_front(segSz); + auto pred = genWhileLoopBody(builder, loc, condArgs, c); + assert(pred.has_value() == isCondWithExtraCheck(c.second)); + if (pred.has_value()) { + // We need all extra checks to pass. + extraPred = extraPred == nullptr ? *pred : ANDI(*pred, extraPred); + ValueRange nxArgs = genCheckedValue(builder, loc, *pred, condArgs, c); + assert(nxArgs.size() == segSz); + // Update the value for cases when some check fails. + for (unsigned i = 0; i < segSz; i++) { + nextArgsRef[i] = nxArgs[i]; + } } - isSparseCond = isSparseCond || isSparse; + aArgs = aArgs.drop_front(segSz); + nextArgsRef = nextArgsRef.drop_front(segSz); } - DimLevelType lvlType = lvlTypes[tid][lvl]; - // TODO: Dense slice driven loop can be generated using for loop as well. - assert(!isSparseSliceCond || !isDenseDLT(lvlType)); - bool isDenseSliceCond = - isDenseDLT(lvlType) && !dependentLvlMap[tid][lvl].empty(); - // if the slice is fully reduced, we can now use TACO-based algorithm to - // iterate it. + if (extraPred) { + auto ifOp = builder.create(loc, types, extraPred, /*else*/ true); + // Marks this special IfOp so that Sparsification does not finalizing it. + ifOp->setAttr(getLoopEmitterLoopAttrName(), + StringAttr::get(builder.getContext(), "slice")); + // Links the SSA chain outside the if statement. + YIELD(ifOp->getResults()); - Operation *l = nullptr; + // If not all slices are legit, yield the updated value. + builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); + YIELD(nextArgs); + + // If all slices are legit, start the user generated code. + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + } + + for (auto [tid, dstLvl] : unpackTensorLevelFromCondRange(spConds)) { + const auto reassoc = getCollapseReassociation(tid, dstLvl); + assert(reassoc.size() == 1 || isUniqueCOOType(tensors[tid].getType())); + // TODO: Refactors this into smaller functions. + // NOTE: For all the collapsed level (except for the last one, that is why + // the loop ends with `reassoc.size() - 1`), as each iteration is advanced + // by the segment size of the last level, which does not always invalidate + // the segment size for the previous levels, thus we need to propagate the + // segment sizes across loop iterations and only forward if needed. + // + // E.g., for a COO tensor with the following coordinates array. + // (0, 0, 1), + // (0, 0, 2), + // (1, 1, 1), + // segHi[lvl=0] = segHi[lvl=1] = 2 + // segHi[lvl=2] = 1, + // the first iteration does not invalidate segHi[0] and segHi[1] + for (unsigned i = 0, e = reassoc.size() - 1; i < e; i++) { + const Level srcLvl = reassoc[i]; + if (!isUniqueDLT(lvlTypes[tid][srcLvl])) { + const Value pos = posits[tid][srcLvl]; + const auto oldSegHi = segHi[tid][srcLvl]; + assert(oldSegHi); + Value newSegHi = builder.create( + loc, arith::CmpIPredicate::uge, pos, oldSegHi); + auto ifNewSegHi = builder.create(loc, builder.getIndexType(), + newSegHi, true); + { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(ifNewSegHi.thenBlock()); + YIELD(genSegmentHigh(builder, loc, tid, srcLvl, pos, + highs[tid][srcLvl])); + // Else, resues the same segment high. + builder.setInsertionPointToStart(ifNewSegHi.elseBlock()); + YIELD(oldSegHi); + } + highs[tid][srcLvl + 1] = segHi[tid][srcLvl] = ifNewSegHi.getResult(0); + } + }; + const auto srcLvl = reassoc.back(); + if (!isUniqueDLT(lvlTypes[tid][srcLvl])) { + segHi[tid][srcLvl] = genSegmentHigh( + builder, loc, tid, srcLvl, posits[tid][srcLvl], highs[tid][srcLvl]); + } + } - // At most one tensor used as condition in for loop; - SmallVector condTidLvl; - // There might be multiple dense slice driven tensor. + // In-place update on reduction variable. + assert(aArgs.size() == reduc.size() + needsUniv ? 1 : 0); + for (unsigned i = 0, e = reduc.size(); i < e; i++) + reduc[i] = aArgs[i]; + + Value min; + // Finds the minimum coordinate + if (!needsUniv) { + for (auto [tid, lvl] : unpackTensorLevelFromCondRange(spConds)) { + const auto lvlTp = lvlTypes[tid][lvl]; + if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp) || + isCompressedWithHiDLT(lvlTp)) { + const auto crd = coords[tid][lvl]; + if (min) { + Value cmp = CMPI(ult, coords[tid][lvl], min); + min = SELECT(cmp, coords[tid][lvl], min); + } else { + min = crd; + } + } + } + } else { + assert(!min); + // Otherwise, universal index is the minimal pos. + min = whileOp.getAfterArguments().back(); + } + + return {whileOp, min}; +} + +bool LoopEmitter::shouldIteratedByForLoop(ArrayRef sparseConds, + bool genDedup) { + assert(llvm::all_of(sparseConds, + [](TensorLvlCond c) { return isSparseCond(c.second); })); + + // If we need to co-iterate over two sparse tensors, we need a while loop + if (sparseConds.size() > 1) + return false; + + // We also need a while loop for levels with affine index expression for + // non-unique levels when deduplication is required. + if (sparseConds.size() == 1) { + auto [tid, lvl] = unpackTensorLevel(sparseConds.back().first); + auto reassoc = getCollapseReassociation(tid, lvl); + return !isAffineIdxCond(sparseConds.back().second) && + !(genDedup && !isUniqueDLT(lvlTypes[tid][reassoc.back()])); + } + + return true; +} + +Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls( + OpBuilder &builder, Location loc, ArrayRef tidLvls, + MutableArrayRef reduc, bool tryParallel, bool genDedup, + bool needsUniv) { + // Sanity checks. + assert(!tidLvls.empty()); + for (auto [t, l] : unpackTensorLevelRange(tidLvls)) { + assert(!coords[t][l] || // We cannot re-enter the same level + !dependentLvlMap[t][l].empty()); // unless it is a slice-driver loop + } + // TODO: support multiple return on parallel for? + tryParallel = tryParallel && reduc.size() <= 1; + + SmallVector spConds; + SmallVector dnConds; + categorizeLoopCondition(tidLvls, dnConds, spConds); + + // Only when there is at least one sparse conditions, do we really need the + // universal index. + // TODO: Maybe we should instead requires merger to pass in a valid value at + // the first place instead of adjusting it in LoopEmitter? + needsUniv = !spConds.empty() && needsUniv; + // The TensorLevel used for loop conditions. + // If there is any sparse level, we need to use the sparse condition. + // If all levels are dense, we can pick arbitary one (dense slice-driven loop + // can be generated using a simple ForOp as well). + Operation *l = nullptr; + Value iv = nullptr; SmallVector sliceDrivenInfo; + SmallVector trivialLvls; // Generates loops differently depending on whether we need a slice-driven // loop or a simple level traversal loop. - if (isSparseSliceCond) { - bool fullyReduced = depFullyReduced(tid, lvl); - if (!fullyReduced) { - l = emitSliceDrivenLoopOverTensorAtLvl(builder, loc, tid, lvl, reduc); - } else { - // If the slice is fully reduced, we can now use TACO-based algorithm to - // iterate it. - l = emitWhileLoopOverSliceAtSparseLvl( - builder, loc, posits[tid][lvl], highs[tid][lvl], - getFinalSliceOnLvl(tid, lvl).offset, sliceSizes[tid][lvl].back(), tid, - lvl, reduc); - } - levelReducedDep[tid][lvl]++; - sliceDrivenInfo.emplace_back(tid, lvl, fullyReduced); - } else { - Value lo = isSparseCond ? posits[tid][lvl] // current offset - : loopSeqStack.back().first; // universal index + if (shouldIteratedByForLoop(spConds, genDedup) && !needsUniv) { + assert(spConds.size() <= 1); + TensorLvlCond tlCond = spConds.empty() ? dnConds.front() : spConds.front(); + auto loopCondKind = tlCond.second; + auto [tid, lvl] = unpackTensorLevel(tlCond.first); + Value lo = isSparseCond(loopCondKind) + ? posits[tid][lvl] // current offset + : loopSeqStack.back().first; // universal index Value hi = highs[tid][lvl]; - if (isDenseSliceCond) { - bool fullyReduced = depFullyReduced(tid, lvl); - Value sliceSz = sliceSizes[tid][lvl][sliceStack[tid].back().depth - 1]; - // Adjust for loop hi for dense slice-driven loop. - if (fullyReduced) { - hi = sliceSz; - condTidLvl.push_back(makeTensorLevel(tid, lvl)); - } else { - hi = SUBI(lvlSizes[tid][lvl], sliceSz); + if (isDenseCond(loopCondKind) && isAffineIdxCond(loopCondKind)) { + bool unReduc = isAffineIdxUnRedCond(loopCondKind); + assert(unReduc == !depFullyReduced(tid, lvl)); + hi = sliceSizes[tid][lvl][sliceStack[tid].back().depth - 1]; + if (unReduc) { + // Adjust for loop hi for dense slice-driven loop. + hi = SUBI(lvlSizes[tid][lvl], hi); hi = ADDI(hi, C_IDX(1)); } - } else { - condTidLvl.push_back(makeTensorLevel(tid, lvl)); } - l = emitForLoopOverTensorAtLvl(builder, loc, tid, lvl, lo, hi, reduc, - isParallel); - } - Value iv = coords[tid][lvl]; - for (auto [t, l] : unpackTensorLevelRange(tidLvls)) { - // We only need to handle slice-driven loops on dense level here. - // If it is a slice-driven loop on sparse level, it needs a while loop to - // insert break statements, and it must have been handled correctly in L692. - if (!dependentLvlMap[t][l].empty() && isDenseDLT(lvlTypes[t][l])) { - // Pushes sliced levels to build correct LoopInfo. - bool fullyReduc = depFullyReduced(t, l); - SliceInfo &info = sliceStack[t].back(); - if (fullyReduc) { - posits[t][l] = genAddress(builder, loc, t, l, ADDI(info.offset, iv)); + std::tie(l, iv) = emitForLoopOverTensorAtLvl(builder, loc, tid, lvl, lo, hi, + reduc, tryParallel); + // For loop condition must be a trivial condition (levels without affine + // index expression). + trivialLvls.push_back(tlCond.first); + } else { + for (auto [tl, cKind] : spConds) { + if (isAffineIdxCond(cKind)) { + auto [tid, lvl] = unpackTensorLevel(tl); + bool unReduc = isAffineIdxUnRedCond(cKind); + assert(unReduc == !depFullyReduced(tid, lvl)); + sliceDrivenInfo.emplace_back(tid, lvl, /*fullyReduced=*/!unReduc); } else { - // Puts sliced dense loop into LoopInfo so that LoopEmitter knows how to - // exit it. - sliceDrivenInfo.emplace_back(t, l, fullyReduc); - // Update the slice information as we enter the new loop. - assert(*info.slicedOnLvl == l); - info.minCrd = info.offset = iv; - info.isNonEmpty = constantI1(builder, loc, true); - levelReducedDep[t][l]++; + trivialLvls.push_back(tl); } } + + std::tie(l, iv) = + emitWhileLoopOverTensorsAtLvls(builder, loc, spConds, reduc, needsUniv); } + + // Enter dense tensor levels. + enterTensorsAtDenseLvls(builder, loc, dnConds, iv, sliceDrivenInfo); // NOTE: we can also prepare for next dim here in advance + // Pushes the loop into stack. - loopStack.emplace_back(condTidLvl, sliceDrivenInfo, l, + loopStack.emplace_back(trivialLvls, sliceDrivenInfo, l, builder.getInsertionBlock(), iv, loopTag); - // Emit extra locals. - emitExtraLocalsForTensorsAtDenseLvls(builder, loc, tidLvls); return l; } @@ -886,229 +1210,11 @@ void LoopEmitter::genDenseAffineAddress(OpBuilder &builder, Location loc, AffineExpr lvlExpr) { auto [tid, lvl] = unpackTensorLevel(tidLvl); assert(isDenseDLT(lvlTypes[tid][lvl])); - // For dense levels, the level-coordinate also serves as the position. + // For dense levels, the vel-coordinate also serves as the position. Value lvlCrd = genAffine(builder, loc, lvlExpr); posits[tid][lvl] = genAddress(builder, loc, tid, lvl, lvlCrd); } -Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls( - OpBuilder &builder, Location loc, ArrayRef tidLvls, - bool needsUniv, MutableArrayRef reduc) { - // NOTE: the slice driven tensor-related reduction variable must - // appear before normal tensors. - SmallVector types; - SmallVector operands; - // Construct the while-loop with a parameter for each coordinate. - const Type indexType = builder.getIndexType(); - for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) { - // TODO: support coiteration with slice driven tensors. - const auto lvlTp = lvlTypes[tid][lvl]; - assert(dependentLvlMap[tid][lvl].empty() && "TODO: not yet implemented"); - if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp) || - isCompressedWithHiDLT(lvlTp)) { - const auto reassoc = getCollapseReassociation(tid, lvl); - for (unsigned i = 0, e = reassoc.size() - 1; i < e; i++) { - if (!isUniqueDLT(lvlTypes[tid][reassoc[i]])) { - // This is the segment high for each non-unique levels. - types.push_back(indexType); - operands.push_back(C_IDX(0)); - } - } - const auto pos = posits[tid][reassoc.front()]; - assert(pos); - types.push_back(indexType); - operands.push_back(pos); - } - } - // The position where user-supplied reduction variable starts. - for (Value rec : reduc) { - types.push_back(rec.getType()); - operands.push_back(rec); - } - if (needsUniv) { - types.push_back(indexType); - // Update universal index. - operands.push_back(loopSeqStack.back().first); - } - assert(types.size() == operands.size()); - scf::WhileOp whileOp = builder.create(loc, types, operands); - - SmallVector locs(types.size(), loc); - Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs); - Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs); - - // Build the "before" region, which effectively consists - // of a conjunction of "i < upper" tests on all induction. - builder.setInsertionPointToStart(&whileOp.getBefore().front()); - Value cond; - unsigned o = 0; - for (auto [t, lvl] : unpackTensorLevelRange(tidLvls)) { - const TensorId tid = t; // Why `t` can not be captured by lambda? - const auto lvlTp = lvlTypes[tid][lvl]; - if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp) || - isCompressedWithHiDLT(lvlTp)) { - const auto reassoc = getCollapseReassociation(tid, lvl); - assert(reassoc.size() == 1 || isUniqueCOOType(tensors[tid].getType())); - for (unsigned i = 0, e = reassoc.size() - 1; i < e; i++) { - if (!isUniqueDLT(lvlTypes[tid][reassoc[i]])) { - // Links the SSA chain for segHi. - segHi[tid][reassoc[i]] = after->getArgument(o++); - } - } - Value op1 = before->getArgument(o); - // We used the first level bound as the bound the collapsed set of levels. - Value op2 = highs[tid][reassoc.front()]; - Value opc = CMPI(ult, op1, op2); - cond = cond ? ANDI(cond, opc) : opc; - // Update positions - Value pos = after->getArgument(o++); - // For COO, the position is the same across consecutive levels. - /// FIXME: See the [CLARIFY_POSITS_LVL] note in the header. - llvm::for_each(reassoc, [this, tid, pos](Level srcLvl) { - posits[tid][srcLvl] = pos; - }); - } - } - builder.create(loc, cond, before->getArguments()); - - // Generates while body. - builder.setInsertionPointToStart(&whileOp.getAfter().front()); - - SmallVector> slicesPreds; - unsigned i = 0; - for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) { - // Prepares for next level. - const auto lvlTp = lvlTypes[tid][lvl]; - if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp) || - isCompressedWithHiDLT(lvlTp)) { - coords[tid][lvl] = genSparseCrd(builder, loc, tid, lvl); - if (isSparseSlices[tid]) { - auto [trans, pred] = - genSliceLegitPredicate(builder, loc, coords[tid][lvl], tid, lvl); - slicesPreds.emplace_back(pred, i); - // Updates to the relative coordinate to the slice. - coords[tid][lvl] = trans; - } - i++; - } - } - - if (!slicesPreds.empty()) { - // Skips invalid loop iteration when slice coordinate is inapplicable. - SmallVector yields(after->getArguments()); - // Generates a list of if statments - // pos = in_slice ? pos : pos + 1 - // TODO: instead of always picking pos + 1, we should set pos = high to - // break to loop if the coordinates are larger than the slice size. - // - // This "idx" is the index into `llvm::zip(tids, lvls)` - for (auto [pred, idx] : slicesPreds) { - Value nextPos = ADDI(yields[idx], C_IDX(1)); - yields[idx] = SELECT(pred, yields[idx], nextPos); - } - - Value pred = slicesPreds.front().first; - for (int i = 1, e = slicesPreds.size(); i < e; i++) { - pred = ANDI(pred, slicesPreds[i].first); - } - auto ifOp = builder.create(loc, types, pred, /*else*/ true); - ifOp->setAttr(getLoopEmitterLoopAttrName(), - StringAttr::get(builder.getContext(), "slice")); - YIELD(ifOp->getResults()); - assert(types.size() == yields.size()); - // If not all slices are legit - builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); - YIELD(yields); - - // If all slices are legit, start the user generated code. - builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); - } - - Value min; - // Finds the minimum coordinate - if (!needsUniv) { - for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) { - const auto lvlTp = lvlTypes[tid][lvl]; - if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp) || - isCompressedWithHiDLT(lvlTp)) { - const auto crd = coords[tid][lvl]; - if (min) { - Value cmp = CMPI(ult, coords[tid][lvl], min); - min = SELECT(cmp, coords[tid][lvl], min); - } else { - min = crd; - } - } - } - } else { - assert(!min); - // Otherwise, universal index is the minimal pos. - min = after->getArguments().back(); - } - - // Sets up the loop stack. - loopStack.emplace_back(tidLvls, ArrayRef(), whileOp, - builder.getInsertionBlock(), min, loopTag); - assert(loopStack.size() == loopSeqStack.size()); - - for (auto [tid, dstLvl] : unpackTensorLevelRange(tidLvls)) { - const auto reassoc = getCollapseReassociation(tid, dstLvl); - assert(reassoc.size() == 1 || isUniqueCOOType(tensors[tid].getType())); - // TODO: Refactors this into smaller functions. - // NOTE: For all the collapsed level (except for the last one, that is why - // the loop ends with `reassoc.size() - 1`), as each iteration is advanced - // by the segment size of the last level, which does not always invalidate - // the segment size for the previous levels, thus we need to propagate the - // segment sizes across loop iterations and only forward if needed. - // - // E.g., for a COO tensor with the following coordinates array. - // (0, 0, 1), - // (0, 0, 2), - // (1, 1, 1), - // segHi[lvl=0] = segHi[lvl=1] = 2 - // segHi[lvl=2] = 1, - // the first iteration does not invalidate segHi[0] and segHi[1] - for (unsigned i = 0, e = reassoc.size() - 1; i < e; i++) { - const Level srcLvl = reassoc[i]; - if (!isUniqueDLT(lvlTypes[tid][srcLvl])) { - const Value pos = posits[tid][srcLvl]; - const auto oldSegHi = segHi[tid][srcLvl]; - assert(oldSegHi); - Value newSegHi = builder.create( - loc, arith::CmpIPredicate::uge, pos, oldSegHi); - auto ifNewSegHi = builder.create(loc, builder.getIndexType(), - newSegHi, true); - { - OpBuilder::InsertionGuard guard(builder); - builder.setInsertionPointToStart(ifNewSegHi.thenBlock()); - YIELD(genSegmentHigh(builder, loc, tid, srcLvl, pos, - highs[tid][srcLvl])); - // Else, resues the same segment high. - builder.setInsertionPointToStart(ifNewSegHi.elseBlock()); - YIELD(oldSegHi); - } - highs[tid][srcLvl + 1] = segHi[tid][srcLvl] = ifNewSegHi.getResult(0); - } - }; - const auto srcLvl = reassoc.back(); - if (!isUniqueDLT(lvlTypes[tid][srcLvl])) { - segHi[tid][srcLvl] = genSegmentHigh( - builder, loc, tid, srcLvl, posits[tid][srcLvl], highs[tid][srcLvl]); - } - } - - // Emits extra locals - emitExtraLocalsForTensorsAtDenseLvls(builder, loc, tidLvls); - - // Updates reduction variables - assert(after->getNumArguments() == o + reduc.size() + (needsUniv ? 1 : 0)); - // In-place update on reduction variable. - for (unsigned i = 0, e = reduc.size(); i < e; i++) - reduc[i] = after->getArgument(o + i); - - return whileOp; -} - void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc, TensorId tid, Level dstLvl) { assert(isValidLevel(tid, dstLvl)); @@ -1159,20 +1265,35 @@ void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc, llvm_unreachable("Unrecognized level-type!"); } -void LoopEmitter::emitExtraLocalsForTensorsAtDenseLvls( - OpBuilder &builder, Location loc, ArrayRef tidLvls) { - // Initialize dense positions. Note that we generate dense coordinates of the - // output tensor unconditionally, since they may not appear in the lattice, - // but may be needed for linearized codegen. - for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) { - if (isSynTensor(tid)) - continue; +void LoopEmitter::enterTensorsAtDenseLvls( + OpBuilder &builder, Location loc, ArrayRef dnConds, Value iv, + SmallVectorImpl &sliceInfo) { + for (auto [dnTidLvl, denseLoopCond] : dnConds) { + auto [tid, lvl] = unpackTensorLevel(dnTidLvl); + assert(isDenseDLT(lvlTypes[tid][lvl])); - if (isDenseDLT(lvlTypes[tid][lvl])) { - // Slice-driven dense level should have be handled already. - if (!dependentLvlMap[tid][lvl].empty()) + if (isAffineIdxCond(denseLoopCond)) { + // Pushes sliced levels to build correct LoopInfo. + bool unReduc = isAffineIdxUnRedCond(denseLoopCond); + SliceInfo &info = sliceStack[tid].back(); + if (unReduc) { + // Pushes sliced dense loop info to tell LoopEmitter how to exit it. + sliceInfo.emplace_back(tid, lvl, /*fullyReduced=*/false); + // Update the slice information as we enter the new loop. + assert(*info.slicedOnLvl == lvl); + info.minCrd = info.offset = iv; + info.isNonEmpty = constantI1(builder, loc, true); + levelReducedDep[tid][lvl]++; + } else { + posits[tid][lvl] = + genAddress(builder, loc, tid, lvl, ADDI(info.offset, iv)); + } + } else { + // Skips the synthetic tensor + if (isSynTensor(tid)) continue; - + // A dense level with trivial index expression. + assert(dependentLvlMap[tid][lvl].empty()); auto enc = getSparseTensorEncoding(tensors[tid].getType()); if (enc && !isSparseOutput(tid)) { bool validPos = lvl == 0 || posits[tid][lvl - 1]; @@ -1182,8 +1303,7 @@ void LoopEmitter::emitExtraLocalsForTensorsAtDenseLvls( assert(isOutputTensor(tid)); continue; } - posits[tid][lvl] = - genAddress(builder, loc, tid, lvl, loopStack.back().iv); + posits[tid][lvl] = genAddress(builder, loc, tid, lvl, iv); // NOTE: we can also prepare for next lvl here in advance } } @@ -1270,7 +1390,7 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc, // Finished iterating a tensor, clean up // We only do the clean up on for loop as while loops do not necessarily // finish the iteration on a sparse tensor - for (auto [tid, lvl] : unpackTensorLevelRange(loopInfo.tidLvls)) { + for (auto [tid, lvl] : unpackTensorLevelRange(loopInfo.trivialTidLvls)) { // Reset to null. coords[tid][lvl] = Value(); posits[tid][lvl] = Value(); @@ -1285,6 +1405,7 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc, const LoopInfo &loopInfo = loopStack.back(); auto whileOp = llvm::cast(loopInfo.loop); Value iv = loopInfo.iv; + Value one = C_IDX(1); // Finalize the induction. Note that the induction could be performed // in the individual if-branches to avoid re-evaluating the conditions. @@ -1299,31 +1420,32 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc, assert(isCompressedDLT(lvlTypes[tid][lvl])); levelReducedDep[tid][lvl]--; if (!resolved) { + // TODO: support coiterating multiple slices + assert(loopInfo.trivialTidLvls.empty() && + loopInfo.sliceDrivenInfo.size() == 1); genSliceNextInduction(builder, loc, whileOp, tid, lvl, operands, o); continue; } - // TODO: We need to distinguish coiterate loop with slice-driven loop and - // fully reduced while op for iterating one slices. - // FIXME: since we didn't implement coiteration, this must be iteration - // just on fully resolved slice. - assert(loopInfo.sliceDrivenInfo.size() == 1 && loopInfo.tidLvls.empty()); - // The if guard to filter out out-range coordinates. - assert(llvm::isa(builder.getInsertionBlock()->getParentOp())); + + if (loopInfo.trivialTidLvls.empty() && + loopInfo.sliceDrivenInfo.size() == 1) { + // Forwards the position iterator. + operands.push_back(ADDI(posits[tid][lvl], one)); + } else { + const Value pos = posits[tid][lvl]; + const Value nxPos = ADDI(posits[tid][lvl], one); + Value cmp = CMPI(eq, coords[tid][lvl], iv); + operands.push_back(SELECT(cmp, nxPos, pos)); + } + + // The coordinate is invalid now. + coords[tid][lvl] = nullptr; + + // Update the position iterator as we exit the while loop. posits[tid][lvl] = whileOp->getResult(o++); - // FIXME: we are not using continue here since we do not support - // coiteration on slices. But it need to be treated similarly as the - // universal index. - o++; // skip continue flag. - // Since we did not push two results from whileOp. The size of the - // operands vector is smaller than the actual number of return values from - // the whileOp. - // It is because we are actually generating yield in the IfOp inside the - // whileOp to only iterates over inbound coordinates within the slices. - delta += 2; }; - Value one = C_IDX(1); - for (auto [tid, dstLvl] : unpackTensorLevelRange(loopInfo.tidLvls)) { + for (auto [tid, dstLvl] : unpackTensorLevelRange(loopInfo.trivialTidLvls)) { const auto lvlTp = lvlTypes[tid][dstLvl]; if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp) || isCompressedWithHiDLT(lvlTp)) { @@ -1357,6 +1479,7 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc, llvm::for_each(reassoc, [this, newTid, newPos](Level srcLvl) { posits[newTid][srcLvl] = newPos; }); + // The coordinate is invalid now. coords[tid][dstLvl] = nullptr; // The segment high is invalid now. @@ -1439,25 +1562,6 @@ const LoopEmitter::SliceInfo &LoopEmitter::getMostRecentSliceOnLvl(TensorId tid, llvm_unreachable("Failed to find sliceInfo"); } -static Value genSparseReducedAffineCond(OpBuilder &builder, Location loc, - Value crdBuf, Value crdHi, Value posit, - Value posHi, Value cont) { - Value inBound = CMPI(ult, posit, posHi); - auto ifOp = builder.create(loc, cont.getType(), inBound, true); - // if (inbound) - // yield coord < crdHi - builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); - Value crd = genIndexLoad(builder, loc, crdBuf, posit); - YIELD(CMPI(ult, crd, crdHi)); - // else - // yield false - builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); - YIELD(constantI1(builder, loc, false)); - - builder.setInsertionPointAfter(ifOp); - return ifOp.getResult(0); -} - // Generates a while loop to iterate over a slice sparse level as follows. // // while(coords[loopLo] < offset + size) { @@ -1466,15 +1570,13 @@ static Value genSparseReducedAffineCond(OpBuilder &builder, Location loc, // } std::pair LoopEmitter::genSliceLvlTraverseLoop( OpBuilder &builder, Location loc, Value posLo, Value posHi, Value offset, - Value size, TensorId tid, Level lvl, ValueRange userReduc, bool genYield, + Value size, TensorId tid, Level lvl, ValueRange userReduc, LoopBodyBuilder bodyBuilder) { Value c1 = C_IDX(1); Value sliceHi = ADDI(offset, sliceSizes[tid][lvl].back()); + SmallVector reduc{posLo}; // loop lower bounds + const unsigned numMetaReduc = reduc.size(); - SmallVector reduc = { - posLo, // loop lower bounds - constantI1(builder, loc, true), // continue - }; // Append user required reduction value. reduc.append(userReduc.begin(), userReduc.end()); scf::WhileOp whileOp = builder.create( @@ -1482,28 +1584,28 @@ std::pair LoopEmitter::genSliceLvlTraverseLoop( /*beforeBuilder=*/ [this, posHi, sliceHi, tid, lvl](OpBuilder &builder, Location loc, ValueRange args) { - Value cond = genSparseReducedAffineCond( - builder, loc, coordinatesBuffers[tid][lvl], sliceHi, args[0], posHi, - args[1]); + Value cond = genSparseReducedAffineCond(builder, loc, + coordinatesBuffers[tid][lvl], + sliceHi, args[0], posHi); // continue if not yet break nor out of bound. builder.create(loc, cond, args); }, /*afterBuilder=*/ - [c1, genYield, bodyBuilder](OpBuilder &builder, Location loc, - ValueRange args) { + [c1, numMetaReduc, bodyBuilder](OpBuilder &builder, Location loc, + ValueRange args) { Value iv = args[0]; - TypeRange types = args.drop_front(2).getTypes(); - // The coordinate must be in bound as guaranteed by the loop condition. - // We generate a fake if operation here only to hide the two extra loop - // induction variable maintained by us from user, and relies on later - // optimization pass to remove it. - Value cont = constantI1(builder, loc, true); - auto ifOp = builder.create(loc, types, cont, + TypeRange types = args.drop_front(numMetaReduc).getTypes(); + // The coordinate must be in bound as guaranteed by the loop + // condition. We generate a fake if operation here only to hide the + // extra loop induction variables maintained by us from users, which + // will be removed by later optimization pass. + auto ifOp = builder.create(loc, types, + constantI1(builder, loc, true), /*withElseBlock=*/!types.empty()); { // 2 reduction variable maintained by us. - SmallVector ifRet = args.drop_front(2); - assert(ifRet.size() == args.size() - 2); + SmallVector ifRet = args.drop_front(numMetaReduc); + assert(ifRet.size() == args.size() - 1); OpBuilder::InsertionGuard guard(builder); // If coord >= sliceHi. @@ -1516,10 +1618,6 @@ std::pair LoopEmitter::genSliceLvlTraverseLoop( builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); // Delegates to users' callback. bodyBuilder(builder, loc, iv, ifRet); - if (genYield) { - builder.setInsertionPointToEnd(&ifOp.getThenRegion().front()); - YIELD(ifRet); - } } // Marks this speical ifOp to avoid sparisification finalizing it. ifOp->setAttr(getLoopEmitterLoopAttrName(), @@ -1528,13 +1626,12 @@ std::pair LoopEmitter::genSliceLvlTraverseLoop( SmallVector yields; // Increase induction variable. yields.push_back(ADDI(iv, c1)); - yields.push_back(cont); yields.append(ifOp.getResults().begin(), ifOp.getResults().end()); YIELD(yields); }); builder.setInsertionPointAfter(whileOp); - return std::make_pair(whileOp, whileOp.getResults().drop_front(2)); + return std::make_pair(whileOp, whileOp.getResults().drop_front(numMetaReduc)); } // Generates a loop nest that traverse all the unresolved levels in between. @@ -1590,7 +1687,6 @@ ValueRange LoopEmitter::genUnResolvedSliceTreeTraverse( genSliceLvlTraverseLoop( builder, loc, loopLo, loopHi, offset, sliceSizes[tid][firstLvl].back(), tid, firstLvl, iterArgs, - false, [&](OpBuilder &builder, Location, Value iv, MutableArrayRef reduc) { ip = builder.saveInsertionPoint(); @@ -1710,7 +1806,8 @@ void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc, // FIXME: We need the relative offset related to the base slice. Value absOffset = offsetFromMinCoord(builder, loc, minCrd, size, isNonEmpty); - sliceStack[tid].emplace_back(minCrd, absOffset, isNonEmpty, lvl, /*depth=*/1); + sliceStack[tid].emplace_back(minCrd, absOffset, isNonEmpty, lvl, + /*depth=*/1); } // Fills in the slicePosBuffer before slice-driven loop begin. @@ -1796,8 +1893,8 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc, Value sPHi = genIndexLoad(builder, loc, positionsBuffers[tid][lvl], pHi); - // isNonEmpty = isNonEmpty || lvlNonEmpty, i.e., as long as there is one - // non-empty lvl, the slice is non-empty. + // isNonEmpty = isNonEmpty || lvlNonEmpty, i.e., as long as there is + // one non-empty lvl, the slice is non-empty. Value lvlNonEmpty = CMPI(ult, sPLo, sPHi); nonEmpty = builder.create(loc, lvlNonEmpty, nonEmpty); @@ -1884,8 +1981,8 @@ bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid, // We do not need cache for dense levels. if (slicePosBuffer[tid][lvl][0] == nullptr && !isDenseDLT(lvlType)) { OpBuilder::InsertionGuard guard(builder); - // The buffer can be reused, and the size is loop invariant: it only depends - // on the iteration graph's toposort. + // The buffer can be reused, and the size is loop invariant: it only + // depends on the iteration graph's toposort. builder.setInsertionPointAfter(localInsertPos); Value bufSize = C_IDX(1); Value c2 = C_IDX(2); @@ -1904,9 +2001,9 @@ bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid, Value sz = *(sliceSizes[tid][lvl].rbegin() + depth - 1); bufSize = MULI(bufSize, sz); } - // For a pair of [pLo, pHi]. Note that we can not compress pHi because slice - // creates segments in the index buffer so that the pHi for the current - // level is no longer the pLo for the next level. + // For a pair of [pLo, pHi]. Note that we can not compress pHi because + // slice creates segments in the index buffer so that the pHi for the + // current level is no longer the pLo for the next level. bufSize = MULI(bufSize, c2); // Additional two metadata {memSize, idx} at head. bufSize = ADDI(bufSize, c2); @@ -2117,59 +2214,6 @@ void LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc, info.offset = whileOp.getResult(retIdx++); } -Operation *LoopEmitter::emitSliceDrivenLoopOverTensorAtLvl( - OpBuilder &builder, Location loc, TensorId tid, Level lvl, - MutableArrayRef reduc) { - assert(!depFullyReduced(tid, lvl)); - SliceInfo &sliceInfo = sliceStack[tid].back(); - assert(sliceInfo.slicedOnLvl == lvl); - - // The order matters! - SmallVector operands{sliceInfo.isNonEmpty, sliceInfo.minCrd, - sliceInfo.offset}; - // number of reduction maintained by us. - size_t numMetaReduc = operands.size(); - - // Append user-required reduction values. - operands.append(reduc.begin(), reduc.end()); - assert(operands.size() == numMetaReduc + reduc.size()); - - // while (slice.nonEmpty()) { - // bodyBuilder(); - // SliceNext(); - // } - auto whileOp = builder.create( - loc, ValueRange(operands).getTypes(), operands, - /*beforeBuilder=*/ - [](OpBuilder &builder, Location loc, ValueRange args) { - builder.create(loc, /*isNonEmpty*/ args[0], args); - }, - /*afterBuilder=*/ - [this, tid, lvl, reduc, numMetaReduc, - &sliceInfo](OpBuilder &builder, Location loc, ValueRange args) { - assert(args.size() == reduc.size() + numMetaReduc); - sliceInfo.isNonEmpty = args[0]; - sliceInfo.minCrd = args[1]; - sliceInfo.offset = args[2]; - // The slice offset is used to coiterate with other tensors' - // coordinates. - Value c = sliceInfo.offset; - if (sliceInfo.depth > 1) { - // Coord is the relative offset related to its parents. - // Update c = absOffset[lvl][depth] - absOffset[lvl][depth - 1] - llvm_unreachable("TODO: not yet implement"); - } - coords[tid][lvl] = c; - - for (unsigned i = 0, e = reduc.size(); i < e; i++) - reduc[i] = args[i + numMetaReduc]; - }); - - // Set the insertion point to while loop body. - builder.setInsertionPointToEnd(&whileOp.getAfter().front()); - return whileOp; -} - #undef CMPI #undef C_IDX #undef YIELD diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h index 8fa7912..f178366 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h @@ -141,33 +141,39 @@ public: /// Exits the current loop sequence, this will reset universal index to 0. void exitCurrentLoopSeq(OpBuilder &builder, Location loc); - // TODO: Get rid of `lvls` in the argument list? Track the level we - // are currently at internally. Then it would be enterNextLvlForTensor. - // Still need a way to specify the lvl for non-annotated tensors though, - // as those can be accessed out of order. - // - /// Emits loop over tensor_tid_lvl, it assumes that loops between - /// tensor_tid_[0, lvl - 1] have already been generated. - /// The function will also perform in-place update on the `reduc` vector to - /// return the reduction variable used inside the generated loop. - Operation *enterLoopOverTensorAtLvl(OpBuilder &builder, Location loc, - ArrayRef tidLvls, - MutableArrayRef reduc = {}, - bool isParallel = false); - + /// Enters a loop that tries to locate a coordinates in a sparse level based + /// on the value evaluated by the provided affine expression. + /// DEPRECATED: affine index expression should be handled by index reduction + /// loop, filter loop-based solution is slow. Operation *enterFilterLoopOverTensorAtLvl(OpBuilder &builder, Location loc, TensorId tid, Level lvl, AffineExpr affine, MutableArrayRef reduc = {}); + /// Emits the address for a dense level based on the value evaluated by the + /// provided affine expression. + /// DEPRECATED: affine index expression should be handled by index reduction + /// loop, filter loop-based solution is slow. void genDenseAffineAddress(OpBuilder &builder, Location loc, TensorLevel tidLvl, AffineExpr lvlExpr); + // TODO: Get rid of `lvls` in the argument list? Track the level we + // are currently at internally. Then it would be enterNextLvlForTensor. + // Still need a way to specify the lvl for non-annotated tensors though, + // as those can be accessed out of order. + // /// Emits a co-iteration loop over a set of tensors. + /// Emits loop over tensor_tid_lvl, it assumes that loops between + /// tensor_tid_[0, lvl - 1] have already been generated. + /// The function will also perform in-place update on the `reduc` vector to + /// return the reduction variable used inside the generated loop. Operation *enterCoIterationOverTensorsAtLvls( OpBuilder &builder, Location loc, ArrayRef tidLvls, - bool needsUniv, MutableArrayRef reduc = {}); + MutableArrayRef reduc = {}, bool isParallel = false, + bool genDedup = false, bool needsUniv = false); + /// Generates code to exit the current loop (e.g., generates yields, forwards + /// loop induction variables, etc). void exitCurrentLoop(RewriterBase &rewriter, Location loc, MutableArrayRef reduc = {}); @@ -232,6 +238,15 @@ public: }); } + template + auto unpackTensorLevelFromCondRange(ContainerTy &&c) const { + using EltTy = decltype(*c.begin()); + static_assert(std::is_same_v, TensorLvlCond>, + "Must be unpacking a TensorLvlCond range"); + return unpackTensorLevelRange( + llvm::make_first_range(std::forward(c))); + } + /// /// Getters. /// @@ -251,6 +266,10 @@ public: } private: + /// + /// Structure definitions that hold different kinds of loops information. + /// + // A tuple that stored the slice-driven loop information. struct SliceLoopInfo final { SliceLoopInfo(TensorId tid, Level lvl, bool reduced) @@ -262,18 +281,22 @@ private: // LoopInfo stores information of a loop generated by LoopEmitter. E.g., // the set of tensors levels that the loop is iterating over. struct LoopInfo final { - LoopInfo(ArrayRef tidLvls, + LoopInfo(ArrayRef trivialTidLvls, ArrayRef sliceDrivenInfo, Operation *loop, Block *userBlock, Value iv, StringAttr loopTag) - : tidLvls(tidLvls), sliceDrivenInfo(sliceDrivenInfo), loop(loop), - userCodeBlock(userBlock), iv(iv) { + : trivialTidLvls(trivialTidLvls), sliceDrivenInfo(sliceDrivenInfo), + loop(loop), userCodeBlock(userBlock), iv(iv) { // Attached a special tag to loop emitter generated loop. if (loopTag) loop->setAttr(LoopEmitter::getLoopEmitterLoopAttrName(), loopTag); } - // The set of that the loop is operating on - const llvm::SmallVector tidLvls; - // Slice-driven loop conditions. + // The set of , with *only* trivial index expressions, that are + // used as the condition for the generated loop. Extra information is + // required for levels with non-tivial index expressions, which is + // maintained by the sliceDrivenInfo array below. + const llvm::SmallVector trivialTidLvls; + // The set of , with *only* non-trivial index expressions, that + // are used as the condition for the generated loop. const llvm::SmallVector sliceDrivenInfo; const Operation *loop; // the loop operation Block *const userCodeBlock; // the block holding users' generated code. @@ -304,9 +327,100 @@ private: unsigned depth; // the depth (relative to dependentDimMap[tid][lvl]). }; + /// + /// Enums for different kinds of loop conditions. + /// + + // The bit indicating whether the loop conditions is sparse. + static constexpr uint8_t kSparseCond = 1 << 3; + // The bit indicating whether the loop iterates over sparse tensor slices + // (i.e., with non-empty SliceDimAttr). + static constexpr uint8_t kSliceCond = 1 << 2; + // The bit indicating whether the loop iterates over tensor levels with + // non-trivial affine index reduction. + static constexpr uint8_t kAffineIdxCond = 1 << 1; + // The bit indicating whether the loop iterates over tensor levels with + // non-trivial affine index reduction, and it is not fully reduced. + static constexpr uint8_t kAffineIdxCondUnRed = 1 << 0; + + enum class LoopCondKind : uint8_t { + // Dense conditions. + DenseCond = 0, + DenseSliceCond = kSliceCond, + DenseAffineCond = kAffineIdxCond, + DenseAffineUnRedCond = kAffineIdxCond | kAffineIdxCondUnRed, + // Sparse Conditions. + SparseCond = kSparseCond, + SparseSliceCond = kSparseCond | kSliceCond, + SparseAffineCond = kSparseCond | kAffineIdxCond, + SparseAffineUnRedCond = kSparseCond | kAffineIdxCond | kAffineIdxCondUnRed, + }; + using TensorLvlCond = std::pair; + + /// Sparse or dense loop condition. + static bool isSparseCond(LoopCondKind k) { + return static_cast(k) & kSparseCond; + } + static bool isDenseCond(LoopCondKind k) { return !isSparseCond(k); } + + /// Whether loops over sparse tensor slices or sparse tensors. + static bool isSliceCond(LoopCondKind k) { + return static_cast(k) & kSliceCond; + } + + /// Affine or trivial index expression loop condition. + static bool isAffineIdxCond(LoopCondKind k) { + return static_cast(k) & kAffineIdxCond; + } + static bool isTrivalIdxCond(LoopCondKind k) { return !isAffineIdxCond(k); } + + /// Whether the affine index expression is not fully reduced. + static bool isAffineIdxUnRedCond(LoopCondKind k) { + return isAffineIdxCond(k) && static_cast(k) & kAffineIdxCondUnRed; + } + static bool isAffineIdxRedCond(LoopCondKind k) { + return isAffineIdxCond(k) && !isAffineIdxUnRedCond(k); + } + + // Whether the loop condition kind requires extra check inside the loop body. + // E.g., to iterate over sparse tensor slice, we need to check whether the + // current cooridnate is on the slice (e.g., due to stride) or not. + static bool isCondWithExtraCheck(LoopCondKind k) { + return isSparseCond(k) && isSliceCond(k); + } + + static LoopCondKind makeLoopCondKind(bool isSparse, bool isSlice, + bool isAffine, bool isUnRedu) { + assert(!isUnRedu || isAffine); + uint8_t bits = 0; + bits = isSparse ? bits | kSparseCond : bits; + bits = isSlice ? bits | kSliceCond : bits; + bits = isAffine ? bits | kAffineIdxCond : bits; + bits = isUnRedu ? bits | kAffineIdxCondUnRed : bits; + LoopCondKind kind = static_cast(bits); + + // Sanity checks. + assert(isSparse == isSparseCond(kind)); + assert(isSlice == isSliceCond(kind)); + assert(isAffine == isAffineIdxCond(kind)); + assert(isUnRedu == isAffineIdxUnRedCond(kind)); + return kind; + } + + void categorizeLoopCondition(ArrayRef tidLvls, + SmallVectorImpl &dnConds, + SmallVectorImpl &spConds); + + /// + /// LoopEmitter internal helper functions. + /// + using LoopBodyBuilder = llvm::function_ref)>; + /// Whether the list of the sparse condition should be iterated by for loop. + bool shouldIteratedByForLoop(ArrayRef spConds, bool genDedup); + /// Linearizes address for dense dimension (i.e., p = (i * d0) + j). Value genAddress(OpBuilder &builder, Location loc, TensorId tid, Level lvl, Value iv); @@ -354,31 +468,51 @@ private: void prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc, TensorId tid, Level lvl); - /// Emits extra locals, since the locals might not be in simplified lattices - /// point used to generate the loops, but are still required to generate - /// expressions. - void emitExtraLocalsForTensorsAtDenseLvls(OpBuilder &builder, Location loc, - ArrayRef tidLvls); - - /// Emits a for loop to iterate over a tensor level with the provided lower - /// bound `lo` and upper bound `hi`. - /// Apart from iterating just single tensor level, for loops can be used for - /// slice-driven loop on dense level too. - Operation *emitForLoopOverTensorAtLvl(OpBuilder &builder, Location loc, - TensorId tid, Level lvl, Value lo, - Value hi, MutableArrayRef reduc, - bool isParallel); - - /// Emits a while loop to iterate over a sparse level that has been sliced. - /// Inserts break statement when the coordinate exceeds the sliceSize; - /// The method sets the insertion point inside the generated while loop body - /// after the break statement before return (so that callers need to handle - /// only in-bound coordinates). - Operation *emitWhileLoopOverSliceAtSparseLvl(OpBuilder &builder, Location loc, - Value pLo, Value pHi, - Value offset, Value sliceSize, - TensorId tid, Level lvl, - MutableArrayRef reduc); + /// Enter dense tensor levels. Since the dense tensor condition could be + /// optimized from the loop condition, we need to compute the + /// positions/coordinates inside the loop body. + void enterTensorsAtDenseLvls(OpBuilder &builder, Location loc, + ArrayRef dnConds, Value iv, + SmallVectorImpl &sliceInfo); + + /// Emits a for loop to iterate over a tensor level with the provided + /// lower bound `lo` and upper bound `hi`. Apart from iterating just + /// single tensor level, for loops can be used for slice-driven loop on + /// dense level too. + /// Returns a pair: the loop generated and the value for the induction + /// variable. + std::pair + emitForLoopOverTensorAtLvl(OpBuilder &builder, Location loc, TensorId tid, + Level lvl, Value lo, Value hi, + MutableArrayRef reduc, bool isParallel); + + /// Emits a while loop to co-iterate over a list of sparse condition, or + /// (complex) single sparse condition that can not be handled by for loop + /// (e.g., index reduction loop). + /// Returns a pair: the loop generated and the value for the induction + /// variable (which is the minimum coordinate of all the tensor that being + /// iterated). + std::pair + emitWhileLoopOverTensorsAtLvls(OpBuilder &builder, Location loc, + ArrayRef spConds, + MutableArrayRef reduc, bool needsUniv); + + /// Generates the while loop condition for the given tensor level condition. + Value genWhileLoopConditions(OpBuilder &builder, Location loc, ValueRange ivs, + TensorLvlCond cond); + + /// Generates the while loop body for the given tensor level condition. + std::optional genWhileLoopBody(OpBuilder &builder, Location loc, + ValueRange ivs, TensorLvlCond cond); + + /// Generates the values (to forward the loop) if the extra check failes. + /// E.g., to iterate over a sparse tensor slice, we need: + /// + /// pos = onSlice(curCrd) ? pos : pos + 1 + /// + /// to skip invalid coordinate that is included in the slice. + ValueRange genCheckedValue(OpBuilder &builder, Location loc, Value pred, + ValueRange curArg, TensorLvlCond cond); /// Exits a for loop, returns the reduction results, e.g., /// For sequential for loops: @@ -488,7 +622,7 @@ private: std::pair genSliceLvlTraverseLoop(OpBuilder &builder, Location loc, Value pLo, Value pHi, Value offset, Value size, TensorId tid, - Level lvl, ValueRange userReduc, bool genYield, + Level lvl, ValueRange userReduc, LoopBodyBuilder bodyBuilder); /// Generates a nested loop that iterates over tid on all the coordinates on @@ -530,19 +664,6 @@ private: SmallVectorImpl &operands, unsigned &retIdx); - /// Generates a slice-driven while loop as follows. - /// - /// curSlice = getFirstNonEmptySlice(tensor). - /// - /// while(isNonEmpty) { - /// ..user code.. - /// isNonEmpty, curSlice = getNextNonEmptySlice(curSlice) - /// } - Operation *emitSliceDrivenLoopOverTensorAtLvl(OpBuilder &builder, - Location loc, TensorId tid, - Level lvl, - MutableArrayRef reduc); - /// A optional string attribute that should be attached to the loop /// generated by loop emitter, it might help following passes to identify /// loops that operates on sparse tensors more easily. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp index 2c5289a..bb9c15e 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -1180,7 +1180,8 @@ public: loopEmitter.enterNewLoopSeq(rewriter, loc, tidLvls); // Note that reduc will be taken care of by loop emitter and get updated // in place. - loopEmitter.enterLoopOverTensorAtLvl(rewriter, loc, tidLvls, reduc); + loopEmitter.enterCoIterationOverTensorsAtLvls(rewriter, loc, tidLvls, + reduc); } SmallVector lcvs; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp index 891befd..637c16f 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -1306,12 +1306,11 @@ static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse) { llvm_unreachable("unexpected parallelization strategy"); } -/// Generates a for-loop on a single index. -static Operation *genFor(CodegenEnv &env, OpBuilder &builder, bool isOuter, - bool isInner, LoopId ldx, - ArrayRef tidLvls) { +/// Whether or not the current loop being generated should be parallized (if +/// possible) according to the configuration. +static bool shouldTryParallize(CodegenEnv &env, LoopId ldx, bool isOuter, + ArrayRef tidLvls) { linalg::GenericOp op = env.op(); - Location loc = op.getLoc(); auto iteratorTypes = op.getIteratorTypesArray(); bool isSparse = llvm::any_of(tidLvls, [ldx, &env](TensorLevel tidLvl) { // Queries the DLT based on the tensor id and loop idx, as requested by @@ -1321,38 +1320,44 @@ static Operation *genFor(CodegenEnv &env, OpBuilder &builder, bool isOuter, return isCompressedDLT(dlt) || isSingletonDLT(dlt); }); - bool isParallel = isParallelFor(env, isOuter, isSparse); + return isParallelFor(env, isOuter, isSparse); +} +/// Generates a "filter loop" on the given tid level to locate a coordinate that +/// is of the same value as evaluated by the affine expression in its matching +/// indexing map. +static Operation *genFilterLoop(CodegenEnv &env, OpBuilder &builder, LoopId ldx, + TensorLevel tidLvl) { + linalg::GenericOp op = env.op(); + Location loc = op.getLoc(); Operation *loop = *env.genLoopBoundary([&](MutableArrayRef reduc) { - if (env.merger().isFilterLoop(ldx)) { - const auto [tid, lvl] = env.unpackTensorLevel(tidLvls.front()); - // tids/lvls must only have one value because filter loops only - // corresponding to the one and only sparse tensor level. - assert(isSparse && tidLvls.size() == 1); - OpOperand *t = &op->getOpOperand(tid); - auto enc = getSparseTensorEncoding(t->get().getType()); - // Retrieves the affine expression for the filter loop. - // FIXME: `toOrigDim` is deprecated. - AffineExpr a = - op.getMatchingIndexingMap(t).getResult(toOrigDim(enc, lvl)); - return env.emitter().enterFilterLoopOverTensorAtLvl(builder, loc, tid, - lvl, a, reduc); - } - return env.emitter().enterLoopOverTensorAtLvl(builder, loc, tidLvls, reduc, - isParallel); + assert(env.merger().isFilterLoop(ldx)); + const auto [tid, lvl] = env.unpackTensorLevel(tidLvl); + // tids/lvls must only have one value because filter loops only + // corresponding to the one and only sparse tensor level. + OpOperand *t = &op->getOpOperand(tid); + auto enc = getSparseTensorEncoding(t->get().getType()); + // Retrieves the affine expression for the filter loop. + // FIXME: `toOrigDim` is deprecated. + AffineExpr a = op.getMatchingIndexingMap(t).getResult(toOrigDim(enc, lvl)); + return env.emitter().enterFilterLoopOverTensorAtLvl(builder, loc, tid, lvl, + a, reduc); }); - assert(loop); return loop; } -/// Emit a while-loop for co-iteration over multiple indices. -static Operation *genWhile(CodegenEnv &env, OpBuilder &builder, LoopId idx, - bool needsUniv, ArrayRef tidLvls) { +/// Emit a loop to coiterate over the list of tensor levels. The generated loop +/// can either be a for loop or while loop depending on whether there is at most +/// one sparse level in the list. +static Operation *genCoIteration(CodegenEnv &env, OpBuilder &builder, + LoopId idx, ArrayRef tidLvls, + bool tryParallel, bool needsUniv) { Operation *loop = *env.genLoopBoundary([&](MutableArrayRef reduc) { // Construct the while-loop with a parameter for each // index. return env.emitter().enterCoIterationOverTensorsAtLvls( - builder, env.op().getLoc(), tidLvls, needsUniv, reduc); + builder, env.op().getLoc(), tidLvls, reduc, tryParallel, + /*genDedup=*/true, needsUniv); }); assert(loop); return loop; @@ -1361,15 +1366,15 @@ static Operation *genWhile(CodegenEnv &env, OpBuilder &builder, LoopId idx, /// Generates a for-loop or a while-loop, depending on whether it implements /// singleton iteration or co-iteration over the given conjunction. static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, LoopOrd at, - bool needsUniv, ArrayRef tidLvls, - bool isFor) { - const LoopId idx = env.topSortAt(at); - if (isFor) { - bool isOuter = at == 0; - bool isInner = at == env.topSortSize() - 1; - return genFor(env, builder, isOuter, isInner, idx, tidLvls); + bool needsUniv, ArrayRef tidLvls) { + const LoopId ldx = env.topSortAt(at); + if (env.merger().isFilterLoop(ldx)) { + assert(tidLvls.size() == 1); + return genFilterLoop(env, builder, ldx, tidLvls.front()); } - return genWhile(env, builder, idx, needsUniv, tidLvls); + + bool tryParallel = shouldTryParallize(env, ldx, at == 0, tidLvls); + return genCoIteration(env, builder, ldx, tidLvls, tryParallel, needsUniv); } /// Generates the induction structure for a while-loop. @@ -1684,7 +1689,7 @@ static std::pair startLoop(CodegenEnv &env, tidLvls, affineTidLvls); // Emit the for/while-loop control. - Operation *loop = genLoop(env, builder, at, needsUniv, tidLvls, isSingleCond); + Operation *loop = genLoop(env, builder, at, needsUniv, tidLvls); Location loc = env.op().getLoc(); for (auto [tidLvl, exp] : affineTidLvls) { env.emitter().genDenseAffineAddress(builder, loc, tidLvl, exp); diff --git a/mlir/test/Dialect/SparseTensor/sorted_coo.mlir b/mlir/test/Dialect/SparseTensor/sorted_coo.mlir index e4e65ef..0118932 100644 --- a/mlir/test/Dialect/SparseTensor/sorted_coo.mlir +++ b/mlir/test/Dialect/SparseTensor/sorted_coo.mlir @@ -185,8 +185,6 @@ func.func @matvec(%arga: tensor<32x64xf64, #SortedCOO>, // CHECK: ^bb0(%[[VAL_26:.*]]: index, %[[VAL_27:.*]]: index): // CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_26]]] : memref> // CHECK: %[[VAL_29:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_27]]] : memref> -// CHECK: %[[VAL_30:.*]] = arith.cmpi ult, %[[VAL_29]], %[[VAL_28]] : index -// CHECK: %[[VAL_31:.*]] = arith.select %[[VAL_30]], %[[VAL_29]], %[[VAL_28]] : index // CHECK: %[[VAL_32:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_26]]] : memref> // CHECK: %[[VAL_33:.*]] = scf.while (%[[VAL_34:.*]] = %[[VAL_26]]) : (index) -> index { // CHECK: %[[VAL_35:.*]] = arith.cmpi ult, %[[VAL_34]], %[[VAL_17]] : index @@ -219,6 +217,8 @@ func.func @matvec(%arga: tensor<32x64xf64, #SortedCOO>, // CHECK: %[[VAL_51:.*]] = arith.addi %[[VAL_50]], %[[VAL_6]] : index // CHECK: scf.yield %[[VAL_51]] : index // CHECK: } +// CHECK: %[[VAL_30:.*]] = arith.cmpi ult, %[[VAL_29]], %[[VAL_28]] : index +// CHECK: %[[VAL_31:.*]] = arith.select %[[VAL_30]], %[[VAL_29]], %[[VAL_28]] : index // CHECK: %[[VAL_52:.*]] = arith.cmpi eq, %[[VAL_28]], %[[VAL_31]] : index // CHECK: %[[VAL_53:.*]] = arith.cmpi eq, %[[VAL_29]], %[[VAL_31]] : index // CHECK: %[[VAL_54:.*]] = arith.andi %[[VAL_52]], %[[VAL_53]] : i1 -- 2.7.4