[TensorExpr] Make CacheReplacer and IndexFlattener mutate stmts/exprs inplace. (...
authorMikhail Zolotukhin <mvz@fb.com>
Thu, 19 Aug 2021 05:56:47 +0000 (22:56 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 19 Aug 2021 05:59:31 +0000 (22:59 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63527

Test Plan: Imported from OSS

Reviewed By: navahgar

Differential Revision: D30411411

Pulled By: ZolotukhinM

fbshipit-source-id: efb14ee57b36537fa4fefa89bdd6bafe7151c012

test/cpp/tensorexpr/test_loopnest.cpp
torch/csrc/jit/tensorexpr/loopnest.cpp
torch/csrc/jit/tensorexpr/registerizer.cpp

index f2ae208..4a2a1d0 100644 (file)
@@ -4017,7 +4017,7 @@ TEST(LoopNest, DeadStoreEliminationWithIntermediates) {
 
   // Will eliminate the write to g, but not f since it used by the producer of
   // h.
-  LoopNest loop(stmt, {h.node()});
+  LoopNest loop(Stmt::clone(stmt), {h.node()});
   loop.eliminateDeadStores();
 
   checkIR(loop.root_stmt(), R"IR(
index e9bc76c..2256369 100644 (file)
@@ -109,12 +109,13 @@ class IndexFlattener : public IRMutator {
     ExprPtr value = v->value();
     ExprPtr new_value = value->accept_mutator(this);
     if (v->indices().size() == 1 && value == new_value) {
-      return (StmtPtr)v;
+      return v;
     }
-    return alloc<Store>(
-        v->buf(),
-        std::vector<ExprPtr>({flatten_index(v->buf()->dims(), v->indices())}),
-        new_value);
+    std::vector<ExprPtr> indices = {
+        flatten_index(v->buf()->dims(), v->indices())};
+    v->set_indices(indices);
+    v->set_value(new_value);
+    return v;
   }
 };
 
@@ -2575,8 +2576,9 @@ class CacheReplacer : public IRMutator {
       ExprPtr sub = IRSimplifier::simplify(alloc<Sub>(index, offset));
       newIndices.push_back(sub);
     }
-
-    return alloc<Load>(cache_, newIndices);
+    v->set_buf(cache_);
+    v->set_indices(newIndices);
+    return v;
   }
 
   StmtPtr mutate(StorePtr v) override {
@@ -2596,8 +2598,10 @@ class CacheReplacer : public IRMutator {
       ExprPtr sub = IRSimplifier::simplify(alloc<Sub>(index, offset));
       newIndices.push_back(sub);
     }
-
-    return alloc<Store>(cache_, newIndices, newValue);
+    v->set_buf(cache_);
+    v->set_indices(newIndices);
+    v->set_value(newValue);
+    return v;
   }
 
   BufPtr buf_;
@@ -2669,21 +2673,13 @@ LoopNest::AccessResult LoopNest::cacheAccesses(
 
   // Replace acceses to the producer in the consumer with the cache.
   CacheReplacer replacer(producer, tmp_buf, info.start);
-  // TODO: Can we reuse 'consumer' below without cloning?
-  StmtPtr new_consumer =
-      IRSimplifier::simplify(Stmt::clone(consumer)->accept_mutator(&replacer));
+  consumer->accept_mutator(&replacer);
 
   // replace the old consumer with the replaced consumer.
-  BlockPtr consumer_block = nullptr;
+  BlockPtr consumer_block = to<Block>(consumer);
+  BlockPtr parent_block = to<Block>(consumer->get_parent());
   // if the consumer is a block, we should mutate it in place.
-  if ((consumer_block = to<Block>(consumer))) {
-    consumer_block->clear();
-    consumer_block->append_stmt(new_consumer);
-  } else {
-    consumer_block = to<Block>(consumer->get_parent());
-    assert(consumer_block);
-    consumer_block->replace_stmt(consumer, new_consumer);
-  }
+  bool is_block = consumer_block != nullptr;
 
   // If there's a reduction and we are operating on the reduce axis, we need to
   // initialize the cache with 0s. Also, we can't just write the result straight
@@ -2715,7 +2711,11 @@ LoopNest::AccessResult LoopNest::cacheAccesses(
           alloc<For>(new_loop_vars[i], alloc<IntImm>(0), tmp_dims[i], tmp_init);
     }
 
-    consumer_block->insert_stmt_before(tmp_init, new_consumer);
+    if (is_block) {
+      consumer_block->prepend_stmt(tmp_init);
+    } else {
+      parent_block->insert_stmt_before(tmp_init, consumer);
+    }
 
     // Reduce back to the original buffer:
     StmtPtr tmp_store = alloc<Store>(
@@ -2732,9 +2732,13 @@ LoopNest::AccessResult LoopNest::cacheAccesses(
           new_loop_vars[i], alloc<IntImm>(0), tmp_dims[i], tmp_store);
     }
 
-    consumer_block->insert_stmt_after(tmp_store, new_consumer);
+    if (is_block) {
+      consumer_block->append_stmt(tmp_store);
+    } else {
+      parent_block->insert_stmt_after(tmp_store, consumer);
+    }
 
-    return std::make_pair(tmp_buf, new_consumer);
+    return std::make_pair(tmp_buf, consumer);
   }
 
   if (hasReads) {
@@ -2747,7 +2751,11 @@ LoopNest::AccessResult LoopNest::cacheAccesses(
           new_loop_vars[i], alloc<IntImm>(0), tmp_dims[i], tmp_store);
     }
 
-    consumer_block->insert_stmt_before(tmp_store, new_consumer);
+    if (is_block) {
+      consumer_block->prepend_stmt(tmp_store);
+    } else {
+      parent_block->insert_stmt_before(tmp_store, consumer);
+    }
   }
 
   if (hasWrites) {
@@ -2760,10 +2768,14 @@ LoopNest::AccessResult LoopNest::cacheAccesses(
           new_loop_vars[i], alloc<IntImm>(0), tmp_dims[i], tmp_store);
     }
 
-    consumer_block->insert_stmt_after(tmp_store, new_consumer);
+    if (is_block) {
+      consumer_block->append_stmt(tmp_store);
+    } else {
+      parent_block->insert_stmt_after(tmp_store, consumer);
+    }
   }
 
-  return std::make_pair(tmp_buf, new_consumer);
+  return std::make_pair(tmp_buf, consumer);
 }
 
 /*
index 07aee20..bc26581 100644 (file)
@@ -668,8 +668,10 @@ StmtPtr RegisterizerReplacer::mutate(StorePtr v) {
 
   ExprPtr new_val = v->value()->accept_mutator(this);
 
-  return alloc<Store>(
-      info->replacement().var_wrapper, std::vector<ExprPtr>({}), new_val);
+  v->set_value(new_val);
+  v->set_buf(info->replacement().var_wrapper);
+  v->set_indices({});
+  return v;
 }
 
 StmtPtr RegisterizerReplacer::mutate(BlockPtr v) {