From 5c27a580ecae83ce9aa253c6a4b0638aeb00d0f3 Mon Sep 17 00:00:00 2001 From: Hui Guo Date: Wed, 8 Sep 2021 15:30:59 -0700 Subject: [PATCH] [tensorexpr] Allocate intermediate buffers at compile time (#64227) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64227 Test Plan: Imported from OSS Reviewed By: ZolotukhinM Differential Revision: D30652220 Pulled By: huiguoo fbshipit-source-id: cd75005cdfa42751318de7174b44e14a3a01634e --- test/cpp/tensorexpr/test_kernel.cpp | 38 ++++++++++++++++++++++ torch/csrc/jit/tensorexpr/kernel.cpp | 61 ++++++++++++++++++++++++++++++++---- torch/csrc/jit/tensorexpr/kernel.h | 10 +++++- 3 files changed, 102 insertions(+), 7 deletions(-) diff --git a/test/cpp/tensorexpr/test_kernel.cpp b/test/cpp/tensorexpr/test_kernel.cpp index c4bf777..d12f142 100644 --- a/test/cpp/tensorexpr/test_kernel.cpp +++ b/test/cpp/tensorexpr/test_kernel.cpp @@ -85,6 +85,44 @@ TEST_F(Kernel, InliningIntermediates) { } } +TEST_F(Kernel, PreAllocIntermediateBufs) { + const auto graph_string = R"IR( +graph(%a.1 : Float(8, 8, strides=[8, 1], requires_grad=0, device=cpu), + %b.1 : Float(8, 8, strides=[8, 1], requires_grad=0, device=cpu)): + %2 : int = prim::Constant[value=1]() + %c.2 : Float(8, 8, strides=[8, 1], requires_grad=0, device=cpu) = aten::matmul(%a.1, %b.1) # test_matmul.py:12:12 + %3 : Float(8, 8, strides=[8, 1], requires_grad=0, device=cpu) = aten::add(%a.1, %c.2, %2) # test_matmul.py:13:15 + return (%3))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto a = at::rand({8, 8}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = at::rand({8, 8}, TensorOptions(kCPU).dtype(at::kFloat)); + auto o = at::zeros({8, 8}, TensorOptions(kCPU).dtype(at::kFloat)); + auto ref = at::matmul(a, b) + a; + TensorExprKernel k(graph, {}, true); + + std::vector inputs = {a, b}; + auto stmt = k.getCodeGenStmt(); + + std::ostringstream oss; + oss << *stmt; + + // Check whether the intermediate buffer has been added to constants + auto constants = k.getConstantDescriptors(); + ASSERT_EQ(constants.size(), 1); + + // Check the IR we produced + torch::jit::testing::FileCheck().check_not("Alloc")->run(oss.str()); + torch::jit::testing::FileCheck().check_not("Free")->run(oss.str()); + + // Check correctness + std::vector stack = fmap(inputs); + k.run(stack); + o = stack[0].toTensor(); + ASSERT_TRUE(at::allclose(o, ref)); +} + TEST_F(Kernel, _1) { const auto graph_string = R"IR( graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index c5f7e99..f850d7d 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -2763,7 +2763,14 @@ StmtPtr TensorExprKernel::transformLoops(BackendType backendType, StmtPtr st) { } } - l.prepareForCodegen(); + if (pre_alloc_) { + auto interm_bufs = l.getIntermediateBufs(); + preAllocIntermediateBufs(interm_bufs); + l.prepareForCodegen(interm_bufs); + } else { + l.prepareForCodegen(); + } + GRAPH_DEBUG("after prepareForCodegen", *l.root_stmt()); l.simplify(); GRAPH_DEBUG("after simplification", *l.root_stmt()); @@ -3080,6 +3087,46 @@ void TensorExprKernel::bindConstant(const torch::jit::Value* v) { bufs_[v] = buf; } +void TensorExprKernel::preAllocIntermediateBufs( + std::unordered_set& interm_bufs) { + std::vector> allocated_bufs; + for (auto it = interm_bufs.begin(); it != interm_bufs.end();) { + // Check if buf shape is static and compute its size if static. + auto buf = *it; + bool is_static = true; + size_t size = + elementSize(buf->dtype().scalar_type()) * buf->dtype().lanes(); + for (auto& d : buf->dims()) { + if (!d->isConstant()) { + is_static = false; + break; + } + size = size * (*intValue(d)); + } + // Only allocate memory for static bufs. + if (!is_static) { + ++it; + continue; + } + auto bp = (void*)malloc(size); + if (!bp) { + ++it; + continue; + } + allocated_bufs.emplace_back(buf, bp); + it = interm_bufs.erase(it); + } + std::sort( + allocated_bufs.begin(), + allocated_bufs.end(), + [](const auto& a, const auto& b) { + return a.first->name_hint() > b.first->name_hint(); + }); + for (auto& a : allocated_bufs) { + constants_.push_back({a.first, a.second}); + } +} + void TensorExprKernel::compile() { GRAPH_DUMP("TensorExprKernel graph:", graph_); @@ -3155,13 +3202,13 @@ void TensorExprKernel::compile() { bufs_.erase(output); } + BackendType backendType = inferBackendTypeFromDevice(device_); + StmtPtr stmt = transformLoops(backendType, block); + for (auto c : constants_) { bufferArgs_.emplace_back(BufHandle(c.buf)); } - BackendType backendType = inferBackendTypeFromDevice(device_); - StmtPtr stmt = transformLoops(backendType, block); - // Generate code. codegen_ = CreateCodeGen( getCodeGenName(backendType), @@ -3173,10 +3220,12 @@ void TensorExprKernel::compile() { TensorExprKernel::TensorExprKernel( const std::shared_ptr& subgraph, - std::unordered_map custom_lowerings) + std::unordered_map custom_lowerings, + bool pre_alloc /*= false*/) : graph_(subgraph), code_(subgraph, ""), - custom_lowerings_(std::move(custom_lowerings)) { + custom_lowerings_(std::move(custom_lowerings)), + pre_alloc_(pre_alloc) { allow_fallback_ = fallbackAllowed(); if (!allow_fallback_) { compile(); diff --git a/torch/csrc/jit/tensorexpr/kernel.h b/torch/csrc/jit/tensorexpr/kernel.h index 4b92b02..00faa87 100644 --- a/torch/csrc/jit/tensorexpr/kernel.h +++ b/torch/csrc/jit/tensorexpr/kernel.h @@ -140,7 +140,8 @@ class TORCH_API TensorExprKernel { explicit TensorExprKernel( const std::shared_ptr& subgraph, std::unordered_map custom_lowerings = - {}); + {}, + bool pre_alloc = false); void run(Stack& stack); void runFast( @@ -241,6 +242,12 @@ class TORCH_API TensorExprKernel { return custom_lowerings_; } + // Allocate memory for intermediate buffers at compile time. + // Specifically, we pre-allocate memory for intermediate buffers with static + // size and manage these buffers in the way we manage JIT constant tensors: + // push the buf args into the stack so NNC IR can access them at runtime. + void preAllocIntermediateBufs(std::unordered_set& interm_bufs); + private: struct UnpackedTensorOptions { c10::optional dtype; @@ -279,6 +286,7 @@ class TORCH_API TensorExprKernel { std::vector constants_; std::unordered_map custom_lowerings_; + bool pre_alloc_{false}; }; TORCH_API int& getTECudaPointwiseLoopLevels(); -- 2.7.4