From 7c7c10a0233fc8060aab4082094a189803cbe5ac Mon Sep 17 00:00:00 2001 From: wren romano <2998727+wrengr@users.noreply.github.com> Date: Fri, 24 Mar 2023 14:17:11 -0700 Subject: [PATCH] [mlir][sparse] Updating the `Merger::{exp,lat,set}` methods to return const This helps the `Merger` maintain invariants, as well as clarifying the immutability of the underlying objects (with the one exception of `TensorExp::val`). Depends On: D146559 Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D146083 --- .../mlir/Dialect/SparseTensor/Utils/Merger.h | 60 +++++++++++++++++++--- .../Dialect/SparseTensor/Transforms/CodegenEnv.cpp | 13 ++++- .../Dialect/SparseTensor/Transforms/CodegenEnv.h | 6 +-- .../SparseTensor/Transforms/Sparsification.cpp | 36 +++++++------ 4 files changed, 88 insertions(+), 27 deletions(-) diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h index 8b1e91a..7e83dfb 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -498,12 +498,60 @@ public: } /// Convenience getters to immediately access the stored nodes. - /// Typically it is inadvisible to keep the reference around, as in - /// `TensorExpr &te = merger.exp(e)`, since insertions into the merger - /// may cause data movement and invalidate the underlying memory address. - TensorExp &exp(ExprId e) { return tensorExps[e]; } - LatPoint &lat(LatPointId p) { return latPoints[p]; } - SmallVector &set(LatSetId s) { return latSets[s]; } + /// These methods return `const&` because the underlying objects must + /// not be mutated by client code. The only exception is for mutating + /// the value associated with an expression, for which there are + /// dedicated methods below. + /// + /// NOTE: It is inadvisable to keep the reference alive for a long + /// time (e.g., as in `TensorExpr &te = merger.exp(e)`), since insertions + /// into the merger can cause data movement which will invalidate the + /// underlying memory address. This isn't just a problem with the `&` + /// references, but also applies to the `ArrayRef`. In particular, + /// using `for (LatPointId p : merger.set(s))` will run into the same + /// dangling-reference problems if the loop body inserts new sets. + const TensorExp &exp(ExprId e) const { return tensorExps[e]; } + const LatPoint &lat(LatPointId p) const { return latPoints[p]; } + ArrayRef set(LatSetId s) const { return latSets[s]; } + + /// Checks whether the given expression has an associated value. + bool hasExprValue(ExprId e) const { + return static_cast(tensorExps[e].val); + } + + /// Sets the expression to have the associated value. Asserts that + /// the new value is defined, and that the expression does not already + /// have a value. If you want to overwrite a previous associated value, + /// use `updateExprValue` instead. + void setExprValue(ExprId e, Value v) { + assert(v && "Got an undefined value"); + auto &val = tensorExps[e].val; + assert(!val && "Expression already has an associated value"); + val = v; + } + + /// Clears the value associated with the expression. Asserts that the + /// expression does indeed have an associated value before clearing it. + /// If you don't want to check for a previous associated value first, + /// then use `updateExprValue` instead. + void clearExprValue(ExprId e) { + auto &val = tensorExps[e].val; + assert(val && "Expression does not have an associated value to clear"); + val = Value(); + } + + /// Unilaterally updates the expression to have the associated value. + /// That is, unlike `setExprValue` and `clearExprValue`, this method + /// does not perform any checks on whether the expression had a + /// previously associated value nor whether the new value is defined. + // + // TODO: The unilateral update semantics are required by the + // current implementation of `CodegenEnv::genLoopBoundary`; however, + // that implementation seems a bit dubious. We would much rather have + // the semantics `{ clearExprValue(e); setExprValue(e, v); }` or + // `{ clearExprValue(e); if (v) setExprValue(e, v); }` since those + // provide better invariants. + void updateExprValue(ExprId e, Value v) { tensorExps[e].val = v; } #ifndef NDEBUG /// Print methods (for debugging). diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp index 974c86d..5d9c347 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp @@ -130,6 +130,9 @@ std::optional CodegenEnv::genLoopBoundary( auto r = callback(params); // may update parameters unsigned i = 0; if (isReduc()) { + // FIXME: This requires `updateExprValue` to perform updates without + // checking for a previous value; but it's not clear whether that's + // by design or might be a potential source for bugs. updateReduc(params[i++]); if (redValidLexInsert) setValidLexInsert(params[i++]); @@ -281,12 +284,18 @@ void CodegenEnv::startReduc(ExprId exp, Value val) { void CodegenEnv::updateReduc(Value val) { assert(isReduc()); - redVal = exp(redExp).val = val; + redVal = val; + // NOTE: `genLoopBoundary` requires that this performs a unilateral + // update without checking for a previous value first. (It's not + // clear whether any other callsites also require that.) + latticeMerger.updateExprValue(redExp, val); } Value CodegenEnv::endReduc() { + assert(isReduc()); Value val = redVal; - updateReduc(Value()); + redVal = val; + latticeMerger.clearExprValue(redExp); redExp = kInvalidId; return val; } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h index 0041ad0..e11e242 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h @@ -66,9 +66,9 @@ public: // Merger delegates. // - TensorExp &exp(ExprId e) { return latticeMerger.exp(e); } - LatPoint &lat(LatPointId l) { return latticeMerger.lat(l); } - SmallVector &set(LatSetId s) { return latticeMerger.set(s); } + const TensorExp &exp(ExprId e) const { return latticeMerger.exp(e); } + const LatPoint &lat(LatPointId l) const { return latticeMerger.lat(l); } + ArrayRef set(LatSetId s) const { return latticeMerger.set(s); } DimLevelType dlt(TensorId t, LoopId i) const { return latticeMerger.getDimLevelType(t, i); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp index f760244..3343a51 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -1057,7 +1057,7 @@ static void genTensorStore(CodegenEnv &env, OpBuilder &builder, ExprId exp, assert(env.exp(exp).val); Value v0 = env.exp(exp).val; genInsertionStore(env, builder, t, v0); - env.exp(exp).val = Value(); + env.merger().clearExprValue(exp); // Yield modified insertion chain along true branch. Value mchain = env.getInsertionChain(); builder.create(op.getLoc(), mchain); @@ -1137,10 +1137,8 @@ static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e, if (kind == TensorExp::Kind::kReduce) env.endCustomReduc(); // exit custom - if (kind == TensorExp::Kind::kSelect) { - assert(!exp.val); - env.exp(e).val = v0; // Preserve value for later use. - } + if (kind == TensorExp::Kind::kSelect) + env.merger().setExprValue(e, v0); // Preserve value for later use. return ee; } @@ -1192,7 +1190,10 @@ static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp, } } else { // Start or end loop invariant hoisting of a tensor load. - env.exp(exp).val = atStart ? genTensorLoad(env, builder, exp) : Value(); + if (atStart) + env.merger().setExprValue(exp, genTensorLoad(env, builder, exp)); + else + env.merger().clearExprValue(exp); } } else if (env.exp(exp).kind != TensorExp::Kind::kInvariant && env.exp(exp).kind != TensorExp::Kind::kLoopVar) { @@ -1346,8 +1347,7 @@ static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, LoopOrd at, /// Generates the induction structure for a while-loop. static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, LoopId idx, - bool needsUniv, BitVector &induction, - scf::WhileOp whileOp) { + bool needsUniv, scf::WhileOp whileOp) { Location loc = env.op().getLoc(); // Finalize each else branch of all if statements. if (env.isReduc() || env.isExpand() || env.getInsertionChain()) { @@ -1386,7 +1386,7 @@ static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, LoopId idx, /// Generates a single if-statement within a while-loop. static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId ldx, - BitVector &conditions) { + const BitVector &conditions) { Location loc = env.op().getLoc(); SmallVector types; Value cond; @@ -1486,13 +1486,10 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp, // Maintain the universal index only if it is actually // consumed by a subsequent lattice point. if (needsUniv) { - unsigned lsize = env.set(lts).size(); - for (unsigned i = 1; i < lsize; i++) { - const LatPointId li = env.set(lts)[i]; + for (const LatPointId li : env.set(lts).drop_front()) if (!env.merger().hasAnySparse(env.lat(li).simple) && !env.merger().hasSparseIdxReduction(env.lat(li).simple)) return true; - } } return false; } @@ -1675,7 +1672,7 @@ static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop, LoopId idx, LatPointId li, bool needsUniv) { // End a while-loop. if (auto whileOp = dyn_cast(loop)) { - finalizeWhileOp(env, rewriter, idx, needsUniv, env.lat(li).bits, whileOp); + finalizeWhileOp(env, rewriter, idx, needsUniv, whileOp); } else if (auto forOp = dyn_cast(loop)) { // Any iteration of a reduction for-loop creates a valid lex insert. if (env.isReduc() && env.getValidLexInsert()) @@ -1726,10 +1723,14 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp, bool needsUniv = startLoopSeq(env, rewriter, exp, at, idx, ldx, lts); // Emit a loop for every lattice point L0 >= Li in this loop sequence. - unsigned lsize = env.set(lts).size(); + // + // NOTE: We cannot change this to `for (const LatPointId li : env.set(lts))` + // because the loop body causes data-movement which invalidates + // the iterator. + const unsigned lsize = env.set(lts).size(); for (unsigned i = 0; i < lsize; i++) { - // Start a loop. const LatPointId li = env.set(lts)[i]; + // Start a loop. auto [loop, isSingleCond] = startLoop(env, rewriter, at, li, needsUniv); // Visit all lattices points with Li >= Lj to generate the @@ -1737,6 +1738,9 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp, Value redInput = env.getReduc(); Value cntInput = env.getExpandCount(); Value insInput = env.getInsertionChain(); + // NOTE: We cannot change this to `for (const LatPointId lj : env.set(lts))` + // because the loop body causes data-movement which invalidates the + // iterator. for (unsigned j = 0; j < lsize; j++) { const LatPointId lj = env.set(lts)[j]; const ExprId ej = env.lat(lj).exp; -- 2.7.4