From 527348a6fe934dcfeeb88352e0cb9e263c76b78d Mon Sep 17 00:00:00 2001 From: Hui Guo Date: Wed, 8 Sep 2021 15:30:59 -0700 Subject: [PATCH] [tensorexpr] Add 'is_allocated' flag for buffers and use it to insert 'Alloc/Free' stmts (#64226) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64226 Test Plan: Imported from OSS Reviewed By: ZolotukhinM Differential Revision: D30652221 Pulled By: huiguoo fbshipit-source-id: ef9bb0e3db2c444b476e5fc23956bc34ae0f0111 --- torch/csrc/jit/tensorexpr/loopnest.cpp | 21 ++++++++++++++++----- torch/csrc/jit/tensorexpr/loopnest.h | 20 +++++++++++++++++--- torch/csrc/jit/tensorexpr/tensorexpr_init.cpp | 5 ++++- 3 files changed, 37 insertions(+), 9 deletions(-) diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp index 4a70700..a957b29 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.cpp +++ b/torch/csrc/jit/tensorexpr/loopnest.cpp @@ -59,7 +59,7 @@ LoopNest::LoopNest(const std::vector& output_tensors) { verify(root_stmt_); } -const std::unordered_set LoopNest::getIntermediateBufs() const { +std::unordered_set LoopNest::getIntermediateBufs() const { std::unordered_set result; auto input_bufs = getInputBufs(); auto bufs = NodeFinder::find(root_stmt_); @@ -963,8 +963,17 @@ BlockPtr findLowestContainingBlock(const std::vector& uses) { return b; } -StmtPtr LoopNest::insertAllocFree(StmtPtr stmt) { - auto intermediate_bufs = getIntermediateBufs(); +StmtPtr LoopNest::insertAllocFree( + StmtPtr stmt, + const c10::optional>& + interm_bufs /* = c10::nullopt*/) { + std::unordered_set intermediate_bufs; + if (interm_bufs) { + intermediate_bufs = *interm_bufs; + } else { + intermediate_bufs = getIntermediateBufs(); + } + if (intermediate_bufs.size() == 0ULL) { return stmt; } @@ -1041,7 +1050,9 @@ void LoopNest::eliminateDeadStores() { root_stmt_ = root_stmt_->accept_mutator(&deleter); } -void LoopNest::prepareForCodegen() { +void LoopNest::prepareForCodegen( + const c10::optional>& + interm_bufs /*= c10::nullopt*/) { // Expand reduction ops. ReductionExpander reduceExpander; root_stmt_ = reduceExpander.expand(root_stmt_); @@ -1049,7 +1060,7 @@ void LoopNest::prepareForCodegen() { root_stmt_ = FlattenIndexes(root_stmt_); // Add allocs and frees for intermediate buffers at the global level. - root_stmt_ = insertAllocFree(root_stmt_); + root_stmt_ = insertAllocFree(root_stmt_, interm_bufs); } namespace { diff --git a/torch/csrc/jit/tensorexpr/loopnest.h b/torch/csrc/jit/tensorexpr/loopnest.h index 42f072d..f71a6e5 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.h +++ b/torch/csrc/jit/tensorexpr/loopnest.h @@ -538,19 +538,33 @@ class TORCH_API LoopNest { void vectorizeInnerLoops(); void eliminateDeadStores(); - void prepareForCodegen(); + + // Make the stmt ready for codegen. The optional argument 'interm_bufs' allows + // users to specify intermediate buffers that need runtime allocation. In + // default, we will insert 'Alloc/Free' stmts to allocate all intermediate + // buffers at runtime but users may have pre-allocated some of them at compile + // time, and in that case the user can specify what buffers to insert + // 'Alloc/Free' stmts for using 'interm_bufs'. + // TODO: refactor function 'prepareForCodegen' to remove argument + // 'interm_bufs'. + void prepareForCodegen( + const c10::optional>& interm_bufs = + c10::nullopt); const std::unordered_set getInputBufs() const; const std::unordered_set getOutputBufs() const { return output_bufs_; } + std::unordered_set getIntermediateBufs() const; private: void initialize( const std::vector& output_tensors, const std::vector& tensors_to_compute); - StmtPtr insertAllocFree(StmtPtr stmt); - const std::unordered_set getIntermediateBufs() const; + StmtPtr insertAllocFree( + StmtPtr stmt, + const c10::optional>& interm_bufs = + c10::nullopt); StmtPtr root_stmt_; diff --git a/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp b/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp index c924bde..ad8962d 100644 --- a/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp +++ b/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp @@ -388,7 +388,10 @@ void initTensorExprBindings(PyObject* module) { return std::make_unique(s, buf_nodes); })) .def("vectorize_inner_loops", &LoopNest::vectorizeInnerLoops) - .def("prepare_for_codegen", &LoopNest::prepareForCodegen) + .def( + "prepare_for_codegen", + [](LoopNest& self) { return self.prepareForCodegen(); }, + py::return_value_policy::reference) .def( "get_loop_body_for", [](const LoopNest& self, Tensor t) { return self.getLoopBodyFor(t); }, -- 2.7.4