[tensorexpr] Add 'is_allocated' flag for buffers and use it to insert 'Alloc/Free...
authorHui Guo <huiguo@fb.com>
Wed, 8 Sep 2021 22:30:59 +0000 (15:30 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 8 Sep 2021 22:34:42 +0000 (15:34 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64226

Test Plan: Imported from OSS

Reviewed By: ZolotukhinM

Differential Revision: D30652221

Pulled By: huiguoo

fbshipit-source-id: ef9bb0e3db2c444b476e5fc23956bc34ae0f0111

torch/csrc/jit/tensorexpr/loopnest.cpp
torch/csrc/jit/tensorexpr/loopnest.h
torch/csrc/jit/tensorexpr/tensorexpr_init.cpp

index 4a70700..a957b29 100644 (file)
@@ -59,7 +59,7 @@ LoopNest::LoopNest(const std::vector<Tensor>& output_tensors) {
   verify(root_stmt_);
 }
 
-const std::unordered_set<BufPtr> LoopNest::getIntermediateBufs() const {
+std::unordered_set<BufPtr> LoopNest::getIntermediateBufs() const {
   std::unordered_set<BufPtr> result;
   auto input_bufs = getInputBufs();
   auto bufs = NodeFinder<Buf>::find(root_stmt_);
@@ -963,8 +963,17 @@ BlockPtr findLowestContainingBlock(const std::vector<BufLoadOrStoreUse>& uses) {
   return b;
 }
 
-StmtPtr LoopNest::insertAllocFree(StmtPtr stmt) {
-  auto intermediate_bufs = getIntermediateBufs();
+StmtPtr LoopNest::insertAllocFree(
+    StmtPtr stmt,
+    const c10::optional<std::unordered_set<BufPtr>>&
+        interm_bufs /* = c10::nullopt*/) {
+  std::unordered_set<BufPtr> intermediate_bufs;
+  if (interm_bufs) {
+    intermediate_bufs = *interm_bufs;
+  } else {
+    intermediate_bufs = getIntermediateBufs();
+  }
+
   if (intermediate_bufs.size() == 0ULL) {
     return stmt;
   }
@@ -1041,7 +1050,9 @@ void LoopNest::eliminateDeadStores() {
   root_stmt_ = root_stmt_->accept_mutator(&deleter);
 }
 
-void LoopNest::prepareForCodegen() {
+void LoopNest::prepareForCodegen(
+    const c10::optional<std::unordered_set<BufPtr>>&
+        interm_bufs /*= c10::nullopt*/) {
   // Expand reduction ops.
   ReductionExpander reduceExpander;
   root_stmt_ = reduceExpander.expand(root_stmt_);
@@ -1049,7 +1060,7 @@ void LoopNest::prepareForCodegen() {
   root_stmt_ = FlattenIndexes(root_stmt_);
 
   // Add allocs and frees for intermediate buffers at the global level.
-  root_stmt_ = insertAllocFree(root_stmt_);
+  root_stmt_ = insertAllocFree(root_stmt_, interm_bufs);
 }
 
 namespace {
index 42f072d..f71a6e5 100644 (file)
@@ -538,19 +538,33 @@ class TORCH_API LoopNest {
   void vectorizeInnerLoops();
 
   void eliminateDeadStores();
-  void prepareForCodegen();
+
+  // Make the stmt ready for codegen. The optional argument 'interm_bufs' allows
+  // users to specify intermediate buffers that need runtime allocation. In
+  // default, we will insert 'Alloc/Free' stmts to allocate all intermediate
+  // buffers at runtime but users may have pre-allocated some of them at compile
+  // time, and in that case the user can specify what buffers to insert
+  // 'Alloc/Free' stmts for using 'interm_bufs'.
+  // TODO: refactor function 'prepareForCodegen' to remove argument
+  // 'interm_bufs'.
+  void prepareForCodegen(
+      const c10::optional<std::unordered_set<BufPtr>>& interm_bufs =
+          c10::nullopt);
 
   const std::unordered_set<BufPtr> getInputBufs() const;
   const std::unordered_set<BufPtr> getOutputBufs() const {
     return output_bufs_;
   }
+  std::unordered_set<BufPtr> getIntermediateBufs() const;
 
  private:
   void initialize(
       const std::vector<Tensor>& output_tensors,
       const std::vector<Tensor>& tensors_to_compute);
-  StmtPtr insertAllocFree(StmtPtr stmt);
-  const std::unordered_set<BufPtr> getIntermediateBufs() const;
+  StmtPtr insertAllocFree(
+      StmtPtr stmt,
+      const c10::optional<std::unordered_set<BufPtr>>& interm_bufs =
+          c10::nullopt);
 
   StmtPtr root_stmt_;
 
index c924bde..ad8962d 100644 (file)
@@ -388,7 +388,10 @@ void initTensorExprBindings(PyObject* module) {
         return std::make_unique<LoopNest>(s, buf_nodes);
       }))
       .def("vectorize_inner_loops", &LoopNest::vectorizeInnerLoops)
-      .def("prepare_for_codegen", &LoopNest::prepareForCodegen)
+      .def(
+          "prepare_for_codegen",
+          [](LoopNest& self) { return self.prepareForCodegen(); },
+          py::return_value_policy::reference)
       .def(
           "get_loop_body_for",
           [](const LoopNest& self, Tensor t) { return self.getLoopBodyFor(t); },