[nnc] Updated sliceTail to do inplace mutation (#63532)
authorRaghavan Raman <raghavanr@fb.com>
Fri, 20 Aug 2021 05:50:32 +0000 (22:50 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 20 Aug 2021 05:55:30 +0000 (22:55 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63532

Test Plan: Imported from OSS

Reviewed By: ZolotukhinM

Differential Revision: D30412184

Pulled By: navahgar

fbshipit-source-id: e7669d3b9d24e14501f3feb6505c88d1d42030c6

test/cpp/tensorexpr/test_loopnest.cpp
torch/csrc/jit/tensorexpr/loopnest.cpp

index b550f48..898ee52 100644 (file)
@@ -380,7 +380,7 @@ TEST(LoopNest, ExprSliceTail) {
   LoopNest::sliceTail(loops[0], 4, &head, &tail);
 
   ASSERT_NE(head, nullptr);
-  ASSERT_NE(head, loops[0]);
+  ASSERT_EQ(head, loops[0]);
   ASSERT_NE(tail, nullptr);
   ASSERT_NE(tail, loops[0]);
 
index 3c39dcd..a296d8c 100644 (file)
@@ -1345,16 +1345,11 @@ void LoopNest::sliceTail(ForPtr f, int factor, ForPtr* head, ForPtr* tail) {
 
   ExprPtr tail_start = alloc<Max>(
       f->start(), alloc<Sub>(f->stop(), alloc<IntImm>(factor)), true);
-  *head = alloc<For>(
-      f->var(),
-      f->start(),
-      tail_start,
-      Stmt::clone(f->body()),
-      f->loop_options());
   *tail = alloc<For>(f->var(), tail_start, f->stop(), Stmt::clone(f->body()));
+  p->insert_stmt_after(*tail, f);
 
-  p->replace_stmt(f, *head);
-  p->insert_stmt_after(*tail, *head);
+  f->set_stop(tail_start);
+  *head = f;
 
   if (f->loop_options().is_gpu_block_index() ||
       f->loop_options().is_gpu_thread_index()) {