From d82667f7e2cd812d98b9cc4f40df46b37a9ef653 Mon Sep 17 00:00:00 2001 From: Raghavan Raman Date: Thu, 19 Aug 2021 22:50:32 -0700 Subject: [PATCH] [nnc] Updated sliceTail to do inplace mutation (#63532) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63532 Test Plan: Imported from OSS Reviewed By: ZolotukhinM Differential Revision: D30412184 Pulled By: navahgar fbshipit-source-id: e7669d3b9d24e14501f3feb6505c88d1d42030c6 --- test/cpp/tensorexpr/test_loopnest.cpp | 2 +- torch/csrc/jit/tensorexpr/loopnest.cpp | 11 +++-------- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/test/cpp/tensorexpr/test_loopnest.cpp b/test/cpp/tensorexpr/test_loopnest.cpp index b550f48..898ee52 100644 --- a/test/cpp/tensorexpr/test_loopnest.cpp +++ b/test/cpp/tensorexpr/test_loopnest.cpp @@ -380,7 +380,7 @@ TEST(LoopNest, ExprSliceTail) { LoopNest::sliceTail(loops[0], 4, &head, &tail); ASSERT_NE(head, nullptr); - ASSERT_NE(head, loops[0]); + ASSERT_EQ(head, loops[0]); ASSERT_NE(tail, nullptr); ASSERT_NE(tail, loops[0]); diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp index 3c39dcd..a296d8c 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.cpp +++ b/torch/csrc/jit/tensorexpr/loopnest.cpp @@ -1345,16 +1345,11 @@ void LoopNest::sliceTail(ForPtr f, int factor, ForPtr* head, ForPtr* tail) { ExprPtr tail_start = alloc( f->start(), alloc(f->stop(), alloc(factor)), true); - *head = alloc( - f->var(), - f->start(), - tail_start, - Stmt::clone(f->body()), - f->loop_options()); *tail = alloc(f->var(), tail_start, f->stop(), Stmt::clone(f->body())); + p->insert_stmt_after(*tail, f); - p->replace_stmt(f, *head); - p->insert_stmt_after(*tail, *head); + f->set_stop(tail_start); + *head = f; if (f->loop_options().is_gpu_block_index() || f->loop_options().is_gpu_thread_index()) { -- 2.7.4