From: wren romano <2998727+wrengr@users.noreply.github.com> Date: Tue, 14 Mar 2023 20:00:29 +0000 (-0700) Subject: [mlir][sparse] Updating `Merger::foreachTensorLoopId` to take `LatPointId` X-Git-Tag: upstream/17.0.6~14667 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=b60de1dfcc15d9505de958fe160b45bea11286f2;p=platform%2Fupstream%2Fllvm.git [mlir][sparse] Updating `Merger::foreachTensorLoopId` to take `LatPointId` Since all callsites of `foreachTensorLoopId` would simply look up the `LatPointId` to extract its `BitVector`, it's cleaner to let the `Merger` handle that instead. This seems to better capture the intent of the `foreachTensorLoopId` method, and improves decoupling (since it removes a place that leaks the implementation detail that we use `BitVector`). Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D146082 --- diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h index 3c5d2d3..59c5b78 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -437,11 +437,11 @@ public: /// for each `TensorLoopId` and passing it the corresponding tensor /// identifier, level, and level-type. void - foreachTensorLoopId(const BitVector &bits, + foreachTensorLoopId(LatPointId p, function_ref, DimLevelType)> callback) const { - for (const TensorLoopId b : bits.set_bits()) + for (const TensorLoopId b : latPoints[p].bits.set_bits()) callback(b, tensor(b), getLvl(b), getDimLevelType(b)); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp index 9fedd5a..2779e2d 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -1273,18 +1273,18 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp, SmallVector tids; SmallVector lvls; - env.merger().foreachTensorLoopId( - env.lat(l0).bits, [&](TensorLoopId b, TensorId tid, - std::optional lvl, DimLevelType dlt) { - assert(env.merger().loop(b) == idx); - if (isDenseDLT(dlt) || isUndefDLT(dlt)) { - needsUniv = true; - } else { - // sparse/singleton levels. - tids.push_back(tid); - lvls.push_back(*lvl); - } - }); + env.merger().foreachTensorLoopId(l0, [&](TensorLoopId b, TensorId tid, + std::optional lvl, + DimLevelType dlt) { + assert(env.merger().loop(b) == idx); + if (isDenseDLT(dlt) || isUndefDLT(dlt)) { + needsUniv = true; + } else { + // sparse/singleton levels. + tids.push_back(tid); + lvls.push_back(*lvl); + } + }); env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tids, lvls); @@ -1342,7 +1342,6 @@ static bool translateBitsToTidLvlPairs( CodegenEnv &env, LatPointId li, LoopId ldx, SmallVectorImpl &tids, SmallVectorImpl &lvls, SmallVectorImpl &affineTids, SmallVectorImpl &affineLvls, SmallVectorImpl &exps) { - const BitVector &all = env.lat(li).bits; const BitVector &simple = env.lat(li).simple; const TensorId outTid = env.merger().getOutTensorID(); const std::optional outLvl = env.merger().getLvl(outTid, ldx); @@ -1350,8 +1349,8 @@ static bool translateBitsToTidLvlPairs( unsigned numloopCond = 0; bool hasNonUnique = false; env.merger().foreachTensorLoopId( - all, [&, ldx](TensorLoopId b, TensorId tid, std::optional lvl, - DimLevelType dlt) { + li, [&, ldx](TensorLoopId b, TensorId tid, std::optional lvl, + DimLevelType dlt) { if (simple.test(b)) { if (isUndefDLT(dlt)) { // An undefined dlt in the lattices, we probably mean to