[tensorexpr] Allocate intermediate buffers at compile time (#64227)
authorHui Guo <huiguo@fb.com>
Wed, 8 Sep 2021 22:30:59 +0000 (15:30 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 8 Sep 2021 22:34:44 +0000 (15:34 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64227

Test Plan: Imported from OSS

Reviewed By: ZolotukhinM

Differential Revision: D30652220

Pulled By: huiguoo

fbshipit-source-id: cd75005cdfa42751318de7174b44e14a3a01634e

test/cpp/tensorexpr/test_kernel.cpp
torch/csrc/jit/tensorexpr/kernel.cpp
torch/csrc/jit/tensorexpr/kernel.h

index c4bf777..d12f142 100644 (file)
@@ -85,6 +85,44 @@ TEST_F(Kernel, InliningIntermediates) {
   }
 }
 
+TEST_F(Kernel, PreAllocIntermediateBufs) {
+  const auto graph_string = R"IR(
+graph(%a.1 : Float(8, 8, strides=[8, 1], requires_grad=0, device=cpu),
+      %b.1 : Float(8, 8, strides=[8, 1], requires_grad=0, device=cpu)):
+  %2 : int = prim::Constant[value=1]()
+  %c.2 : Float(8, 8, strides=[8, 1], requires_grad=0, device=cpu) = aten::matmul(%a.1, %b.1) # test_matmul.py:12:12
+  %3 : Float(8, 8, strides=[8, 1], requires_grad=0, device=cpu) = aten::add(%a.1, %c.2, %2) # test_matmul.py:13:15
+  return (%3))IR";
+  auto graph = std::make_shared<Graph>();
+  parseIR(graph_string, &*graph);
+
+  auto a = at::rand({8, 8}, TensorOptions(kCPU).dtype(at::kFloat));
+  auto b = at::rand({8, 8}, TensorOptions(kCPU).dtype(at::kFloat));
+  auto o = at::zeros({8, 8}, TensorOptions(kCPU).dtype(at::kFloat));
+  auto ref = at::matmul(a, b) + a;
+  TensorExprKernel k(graph, {}, true);
+
+  std::vector<at::Tensor> inputs = {a, b};
+  auto stmt = k.getCodeGenStmt();
+
+  std::ostringstream oss;
+  oss << *stmt;
+
+  // Check whether the intermediate buffer has been added to constants
+  auto constants = k.getConstantDescriptors();
+  ASSERT_EQ(constants.size(), 1);
+
+  // Check the IR we produced
+  torch::jit::testing::FileCheck().check_not("Alloc")->run(oss.str());
+  torch::jit::testing::FileCheck().check_not("Free")->run(oss.str());
+
+  // Check correctness
+  std::vector<IValue> stack = fmap<IValue>(inputs);
+  k.run(stack);
+  o = stack[0].toTensor();
+  ASSERT_TRUE(at::allclose(o, ref));
+}
+
 TEST_F(Kernel, _1) {
   const auto graph_string = R"IR(
       graph(%0 : Float(5, 3, strides=[3, 1], device=cpu),
index c5f7e99..f850d7d 100644 (file)
@@ -2763,7 +2763,14 @@ StmtPtr TensorExprKernel::transformLoops(BackendType backendType, StmtPtr st) {
     }
   }
 
-  l.prepareForCodegen();
+  if (pre_alloc_) {
+    auto interm_bufs = l.getIntermediateBufs();
+    preAllocIntermediateBufs(interm_bufs);
+    l.prepareForCodegen(interm_bufs);
+  } else {
+    l.prepareForCodegen();
+  }
+
   GRAPH_DEBUG("after prepareForCodegen", *l.root_stmt());
   l.simplify();
   GRAPH_DEBUG("after simplification", *l.root_stmt());
@@ -3080,6 +3087,46 @@ void TensorExprKernel::bindConstant(const torch::jit::Value* v) {
   bufs_[v] = buf;
 }
 
+void TensorExprKernel::preAllocIntermediateBufs(
+    std::unordered_set<BufPtr>& interm_bufs) {
+  std::vector<std::pair<BufPtr, void*>> allocated_bufs;
+  for (auto it = interm_bufs.begin(); it != interm_bufs.end();) {
+    // Check if buf shape is static and compute its size if static.
+    auto buf = *it;
+    bool is_static = true;
+    size_t size =
+        elementSize(buf->dtype().scalar_type()) * buf->dtype().lanes();
+    for (auto& d : buf->dims()) {
+      if (!d->isConstant()) {
+        is_static = false;
+        break;
+      }
+      size = size * (*intValue(d));
+    }
+    // Only allocate memory for static bufs.
+    if (!is_static) {
+      ++it;
+      continue;
+    }
+    auto bp = (void*)malloc(size);
+    if (!bp) {
+      ++it;
+      continue;
+    }
+    allocated_bufs.emplace_back(buf, bp);
+    it = interm_bufs.erase(it);
+  }
+  std::sort(
+      allocated_bufs.begin(),
+      allocated_bufs.end(),
+      [](const auto& a, const auto& b) {
+        return a.first->name_hint() > b.first->name_hint();
+      });
+  for (auto& a : allocated_bufs) {
+    constants_.push_back({a.first, a.second});
+  }
+}
+
 void TensorExprKernel::compile() {
   GRAPH_DUMP("TensorExprKernel graph:", graph_);
 
@@ -3155,13 +3202,13 @@ void TensorExprKernel::compile() {
     bufs_.erase(output);
   }
 
+  BackendType backendType = inferBackendTypeFromDevice(device_);
+  StmtPtr stmt = transformLoops(backendType, block);
+
   for (auto c : constants_) {
     bufferArgs_.emplace_back(BufHandle(c.buf));
   }
 
-  BackendType backendType = inferBackendTypeFromDevice(device_);
-  StmtPtr stmt = transformLoops(backendType, block);
-
   // Generate code.
   codegen_ = CreateCodeGen(
       getCodeGenName(backendType),
@@ -3173,10 +3220,12 @@ void TensorExprKernel::compile() {
 
 TensorExprKernel::TensorExprKernel(
     const std::shared_ptr<Graph>& subgraph,
-    std::unordered_map<c10::Symbol, NNCLoweringFunction> custom_lowerings)
+    std::unordered_map<c10::Symbol, NNCLoweringFunction> custom_lowerings,
+    bool pre_alloc /*= false*/)
     : graph_(subgraph),
       code_(subgraph, ""),
-      custom_lowerings_(std::move(custom_lowerings)) {
+      custom_lowerings_(std::move(custom_lowerings)),
+      pre_alloc_(pre_alloc) {
   allow_fallback_ = fallbackAllowed();
   if (!allow_fallback_) {
     compile();
index 4b92b02..00faa87 100644 (file)
@@ -140,7 +140,8 @@ class TORCH_API TensorExprKernel {
   explicit TensorExprKernel(
       const std::shared_ptr<Graph>& subgraph,
       std::unordered_map<c10::Symbol, NNCLoweringFunction> custom_lowerings =
-          {});
+          {},
+      bool pre_alloc = false);
 
   void run(Stack& stack);
   void runFast(
@@ -241,6 +242,12 @@ class TORCH_API TensorExprKernel {
     return custom_lowerings_;
   }
 
+  // Allocate memory for intermediate buffers at compile time.
+  // Specifically, we pre-allocate memory for intermediate buffers with static
+  // size and manage these buffers in the way we manage JIT constant tensors:
+  // push the buf args into the stack so NNC IR can access them at runtime.
+  void preAllocIntermediateBufs(std::unordered_set<BufPtr>& interm_bufs);
+
  private:
   struct UnpackedTensorOptions {
     c10::optional<c10::ScalarType> dtype;
@@ -279,6 +286,7 @@ class TORCH_API TensorExprKernel {
   std::vector<ConstantDescr> constants_;
 
   std::unordered_map<c10::Symbol, NNCLoweringFunction> custom_lowerings_;
+  bool pre_alloc_{false};
 };
 
 TORCH_API int& getTECudaPointwiseLoopLevels();