From 6e00b31b15ba9a09b6aa71b0da1ba200be482011 Mon Sep 17 00:00:00 2001 From: Mikhail Zolotukhin Date: Wed, 18 Aug 2021 22:56:47 -0700 Subject: [PATCH] [TensorExpr] Make CacheReplacer and IndexFlattener mutate stmts/exprs inplace. (#63527) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63527 Test Plan: Imported from OSS Reviewed By: navahgar Differential Revision: D30411411 Pulled By: ZolotukhinM fbshipit-source-id: efb14ee57b36537fa4fefa89bdd6bafe7151c012 --- test/cpp/tensorexpr/test_loopnest.cpp | 2 +- torch/csrc/jit/tensorexpr/loopnest.cpp | 66 ++++++++++++++++++------------ torch/csrc/jit/tensorexpr/registerizer.cpp | 6 ++- 3 files changed, 44 insertions(+), 30 deletions(-) diff --git a/test/cpp/tensorexpr/test_loopnest.cpp b/test/cpp/tensorexpr/test_loopnest.cpp index f2ae208..4a2a1d0 100644 --- a/test/cpp/tensorexpr/test_loopnest.cpp +++ b/test/cpp/tensorexpr/test_loopnest.cpp @@ -4017,7 +4017,7 @@ TEST(LoopNest, DeadStoreEliminationWithIntermediates) { // Will eliminate the write to g, but not f since it used by the producer of // h. - LoopNest loop(stmt, {h.node()}); + LoopNest loop(Stmt::clone(stmt), {h.node()}); loop.eliminateDeadStores(); checkIR(loop.root_stmt(), R"IR( diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp index e9bc76c..2256369 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.cpp +++ b/torch/csrc/jit/tensorexpr/loopnest.cpp @@ -109,12 +109,13 @@ class IndexFlattener : public IRMutator { ExprPtr value = v->value(); ExprPtr new_value = value->accept_mutator(this); if (v->indices().size() == 1 && value == new_value) { - return (StmtPtr)v; + return v; } - return alloc( - v->buf(), - std::vector({flatten_index(v->buf()->dims(), v->indices())}), - new_value); + std::vector indices = { + flatten_index(v->buf()->dims(), v->indices())}; + v->set_indices(indices); + v->set_value(new_value); + return v; } }; @@ -2575,8 +2576,9 @@ class CacheReplacer : public IRMutator { ExprPtr sub = IRSimplifier::simplify(alloc(index, offset)); newIndices.push_back(sub); } - - return alloc(cache_, newIndices); + v->set_buf(cache_); + v->set_indices(newIndices); + return v; } StmtPtr mutate(StorePtr v) override { @@ -2596,8 +2598,10 @@ class CacheReplacer : public IRMutator { ExprPtr sub = IRSimplifier::simplify(alloc(index, offset)); newIndices.push_back(sub); } - - return alloc(cache_, newIndices, newValue); + v->set_buf(cache_); + v->set_indices(newIndices); + v->set_value(newValue); + return v; } BufPtr buf_; @@ -2669,21 +2673,13 @@ LoopNest::AccessResult LoopNest::cacheAccesses( // Replace acceses to the producer in the consumer with the cache. CacheReplacer replacer(producer, tmp_buf, info.start); - // TODO: Can we reuse 'consumer' below without cloning? - StmtPtr new_consumer = - IRSimplifier::simplify(Stmt::clone(consumer)->accept_mutator(&replacer)); + consumer->accept_mutator(&replacer); // replace the old consumer with the replaced consumer. - BlockPtr consumer_block = nullptr; + BlockPtr consumer_block = to(consumer); + BlockPtr parent_block = to(consumer->get_parent()); // if the consumer is a block, we should mutate it in place. - if ((consumer_block = to(consumer))) { - consumer_block->clear(); - consumer_block->append_stmt(new_consumer); - } else { - consumer_block = to(consumer->get_parent()); - assert(consumer_block); - consumer_block->replace_stmt(consumer, new_consumer); - } + bool is_block = consumer_block != nullptr; // If there's a reduction and we are operating on the reduce axis, we need to // initialize the cache with 0s. Also, we can't just write the result straight @@ -2715,7 +2711,11 @@ LoopNest::AccessResult LoopNest::cacheAccesses( alloc(new_loop_vars[i], alloc(0), tmp_dims[i], tmp_init); } - consumer_block->insert_stmt_before(tmp_init, new_consumer); + if (is_block) { + consumer_block->prepend_stmt(tmp_init); + } else { + parent_block->insert_stmt_before(tmp_init, consumer); + } // Reduce back to the original buffer: StmtPtr tmp_store = alloc( @@ -2732,9 +2732,13 @@ LoopNest::AccessResult LoopNest::cacheAccesses( new_loop_vars[i], alloc(0), tmp_dims[i], tmp_store); } - consumer_block->insert_stmt_after(tmp_store, new_consumer); + if (is_block) { + consumer_block->append_stmt(tmp_store); + } else { + parent_block->insert_stmt_after(tmp_store, consumer); + } - return std::make_pair(tmp_buf, new_consumer); + return std::make_pair(tmp_buf, consumer); } if (hasReads) { @@ -2747,7 +2751,11 @@ LoopNest::AccessResult LoopNest::cacheAccesses( new_loop_vars[i], alloc(0), tmp_dims[i], tmp_store); } - consumer_block->insert_stmt_before(tmp_store, new_consumer); + if (is_block) { + consumer_block->prepend_stmt(tmp_store); + } else { + parent_block->insert_stmt_before(tmp_store, consumer); + } } if (hasWrites) { @@ -2760,10 +2768,14 @@ LoopNest::AccessResult LoopNest::cacheAccesses( new_loop_vars[i], alloc(0), tmp_dims[i], tmp_store); } - consumer_block->insert_stmt_after(tmp_store, new_consumer); + if (is_block) { + consumer_block->append_stmt(tmp_store); + } else { + parent_block->insert_stmt_after(tmp_store, consumer); + } } - return std::make_pair(tmp_buf, new_consumer); + return std::make_pair(tmp_buf, consumer); } /* diff --git a/torch/csrc/jit/tensorexpr/registerizer.cpp b/torch/csrc/jit/tensorexpr/registerizer.cpp index 07aee20..bc26581 100644 --- a/torch/csrc/jit/tensorexpr/registerizer.cpp +++ b/torch/csrc/jit/tensorexpr/registerizer.cpp @@ -668,8 +668,10 @@ StmtPtr RegisterizerReplacer::mutate(StorePtr v) { ExprPtr new_val = v->value()->accept_mutator(this); - return alloc( - info->replacement().var_wrapper, std::vector({}), new_val); + v->set_value(new_val); + v->set_buf(info->replacement().var_wrapper); + v->set_indices({}); + return v; } StmtPtr RegisterizerReplacer::mutate(BlockPtr v) { -- 2.7.4