From 07c5cb8c48d655ba73adc2da2b88399f3ab48638 Mon Sep 17 00:00:00 2001 From: Mike Iovine Date: Fri, 27 Aug 2021 17:37:05 -0700 Subject: [PATCH] [Static Runtime] Optimize memory planner initialization (#64101) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64101 Checking `getOutOfPlaceOperation(n)` is a very expensive operation, especially in multithreaded environments, due to a lock acquisition when the NNC cache is queried. This slows down the memory planner initialization time, and by extension, the latency for the first static runtime inference. There are two optimizations in this diff: * Cache the result of `p_node->has_out_variant()` to avoid the call to `getOutOfPlaceOperation`. This speeds up calls to `canReuseInputOutputs`, which in turn speeds up `isOptimizableContainerType` * Precompute all `isOptimizableContainerType` during static runtime initialization to avoid a pass over all of each node's inputs. Test Plan: All unit tests pass: `buck test caffe2/benchmarks/static_runtime/...` Reviewed By: movefast1990 Differential Revision: D30595579 fbshipit-source-id: 70aaa7af9589c739c672788bf662f711731864f2 --- torch/csrc/jit/runtime/static/impl.cpp | 31 ++++++++++++++++++++++--------- torch/csrc/jit/runtime/static/impl.h | 11 +++++++++++ torch/csrc/jit/runtime/static/ops.cpp | 29 ++++++++++++++++++----------- torch/csrc/jit/runtime/static/ops.h | 8 ++++++-- 4 files changed, 57 insertions(+), 22 deletions(-) diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp index 643842a..ee8e903 100644 --- a/torch/csrc/jit/runtime/static/impl.cpp +++ b/torch/csrc/jit/runtime/static/impl.cpp @@ -319,7 +319,9 @@ LivenessMap GetLivenessMap( // first: Values that are candidates for memory planning // second: A deterministc order of all values std::pair, std::vector> -GetMemoryPlanningCandidates(const std::shared_ptr& graph) { +GetMemoryPlanningCandidates( + const std::shared_ptr& graph, + const FastMap& node_has_out_variant) { // for determinism FastSet seen_values; std::vector all_values; @@ -328,7 +330,8 @@ GetMemoryPlanningCandidates(const std::shared_ptr& graph) { // these need to be removed from "can_reuse" after analyzing all nodes FastSet cannot_reuse; for (auto* n : graph->nodes()) { - bool can_reuse_inputs_outputs = canReuseInputsOutputs(n); + bool can_reuse_inputs_outputs = + canReuseInputsOutputs(n, node_has_out_variant); for (const auto* v : n->inputs()) { if (!seen_values.count(v)) { all_values.emplace_back(v); @@ -628,6 +631,7 @@ StaticModule::StaticModule( // construct SSA definition for non-constant nodes int node_idx = 0; + FastMap node_has_out_variant; for (Node* node : graph_->nodes()) { if (node->kind() == prim::Constant) { continue; @@ -639,14 +643,22 @@ StaticModule::StaticModule( input_ssa_defs.emplace_back(value_to_ssa_def.at(input)); } node_inputs_ssa_def_map_[node_idx] = input_ssa_defs; - nodes_.emplace_back( - ProcessedNode(node, std::move(ivalue_inputs), opts.enable_out_variant)); + auto pnode = + ProcessedNode(node, std::move(ivalue_inputs), opts.enable_out_variant); + node_has_out_variant.emplace(node, pnode.has_out_variant()); + nodes_.emplace_back(std::move(pnode)); for (const auto i : c10::irange(node->outputs().size())) { value_to_ivalue[node->outputs()[i]] = nullptr; value_to_ssa_def[node->outputs()[i]] = std::make_pair(node_idx, i); } node_idx++; } + for (auto& pnode : nodes_) { + if (pnode.outputs().size() == 1 && + isOptimizableContainerType(pnode.node(), node_has_out_variant)) { + node_is_optimizable_container_type_.emplace(pnode.node()); + } + } for (auto output : graph_->outputs()) { output_ssa_defs_.emplace_back(value_to_ssa_def[output]); } @@ -657,7 +669,7 @@ StaticModule::StaticModule( if (opts_.optimize_memory) { auto lm = GetLivenessMap(graph_, external_values_, alias_db); - auto values = GetMemoryPlanningCandidates(graph_); + auto values = GetMemoryPlanningCandidates(graph_, node_has_out_variant); value_to_same_storage_values_ = GenerateSameStorageValues(lm, external_values_, values, alias_db); } @@ -1177,7 +1189,8 @@ void StaticRuntime::check_for_memory_leak(bool output_returned) { // check for intermediates if (!ival->isNone()) { TORCH_CHECK( - ival->isTensor() || isOptimizableContainerType(pnode.node()), + ival->isTensor() || + static_module_.is_optimizable_container_type(pnode.node()), error_msg); if (ival->isTensor()) { const auto& t = ival->toTensor(); @@ -1262,9 +1275,9 @@ MemoryPlanner::MemoryPlanner( const auto& type = out_v->type(); if (type->castRaw()) { managed_tensor_values.insert(out_v); - } else if (isOptimizableContainerType(pnode.node())) { - // We "leak" certain container types because their allocations take - // a long time + } else if (runtime->is_optimizable_container_type(pnode.node())) { + // We "leak" certain container types because their allocations + // take a long time leaked_values.insert(out_v); } } diff --git a/torch/csrc/jit/runtime/static/impl.h b/torch/csrc/jit/runtime/static/impl.h index 6cff047..d8a99f7 100644 --- a/torch/csrc/jit/runtime/static/impl.h +++ b/torch/csrc/jit/runtime/static/impl.h @@ -160,6 +160,11 @@ class TORCH_API StaticModule { return nodes_; } + bool is_optimizable_container_type(Node* n) const { + auto it = node_is_optimizable_container_type_.find(n); + return it != node_is_optimizable_container_type_.end(); + } + const c10::optional& schema() const { return schema_; } @@ -204,6 +209,8 @@ class TORCH_API StaticModule { // map a value to the set of values that may share the same storage with it FastMap> value_to_same_storage_values_; + + FastSet node_is_optimizable_container_type_; }; class TORCH_API StaticRuntime { @@ -287,6 +294,10 @@ class TORCH_API StaticRuntime { void check_for_memory_leak(bool output_returned = true); + bool is_optimizable_container_type(Node* n) const { + return static_module_.is_optimizable_container_type(n); + } + private: // helper method for copying input args/kwargs into inputs_ void set_inputs( diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp index f171d28..3b58668 100644 --- a/torch/csrc/jit/runtime/static/ops.cpp +++ b/torch/csrc/jit/runtime/static/ops.cpp @@ -25,6 +25,7 @@ #include #include #include +#include C10_DEFINE_bool( static_runtime_enable_fast_math, @@ -312,27 +313,33 @@ bool hasVarArgs(Node* n) { return false; } -// Expensive check, use sparingly. -// This is needed to make sure that we only switch to out variants for the -// supported overloads, which is checked in the `Generate` step in -// `SROperatorRegistry()->Create(op_name)->Generate(n)` -bool canReuseInputsOutputs(Node* n) { +bool canReuseInputsOutputs( + Node* n, + const FastMap& node_has_out_variant) { + auto it = node_has_out_variant.find(n); + if (it != node_has_out_variant.end()) { + return it->second; + } return getOutOfPlaceOperation(n) != nullptr; } // returns true if the producers of the inputs // to this operations are out of place. // This means the IValues will not change run to run -bool inputsCanRunOutOfPlace(Node* n) { +bool inputsCanRunOutOfPlace( + Node* n, + const FastMap& node_has_out_variant) { for (auto* input : n->inputs()) { - if (!canReuseInputsOutputs(input->node())) { + if (!canReuseInputsOutputs(input->node(), node_has_out_variant)) { return false; } } return true; } -bool isOptimizableContainerType(Node* n) { +bool isOptimizableContainerType( + Node* n, + const FastMap& node_has_out_variant) { const auto& type = n->output()->type(); bool is_supported_type = false; if (type->kind() == TypeKind::ListType) { @@ -348,7 +355,7 @@ bool isOptimizableContainerType(Node* n) { }); is_supported_type = iter != types.end(); } - return is_supported_type && inputsCanRunOutOfPlace(n); + return is_supported_type && inputsCanRunOutOfPlace(n, node_has_out_variant); } REGISTER_OPERATOR_FUNCTOR( @@ -356,7 +363,7 @@ REGISTER_OPERATOR_FUNCTOR( prim_ListConstruct, [](Node* n) -> SROperator { const auto& type = n->output()->type()->expectRef(); - bool can_optimize = isOptimizableContainerType(n); + bool can_optimize = isOptimizableContainerType(n, FastMap()); return [can_optimize, &type](ProcessedNode* p_node) { const auto& out_l = p_node->Output(0); if (!out_l.isNone() && can_optimize) { @@ -376,7 +383,7 @@ REGISTER_OPERATOR_FUNCTOR( prim::TupleConstruct, prim_TupleConstruct, [](Node* n) -> SROperator { - bool can_optimize = isOptimizableContainerType(n); + bool can_optimize = isOptimizableContainerType(n, FastMap()); return [can_optimize](ProcessedNode* p_node) { const auto& out_l = p_node->Output(0); if (!out_l.isNone() && can_optimize) { diff --git a/torch/csrc/jit/runtime/static/ops.h b/torch/csrc/jit/runtime/static/ops.h index ff5d69e..311143c 100644 --- a/torch/csrc/jit/runtime/static/ops.h +++ b/torch/csrc/jit/runtime/static/ops.h @@ -133,8 +133,12 @@ bool opIsRegistered(const c10::Symbol& op_name); // as native ops in Static Runtime bool nativeOpIsRegistered(const c10::Symbol& op_name); -bool canReuseInputsOutputs(Node* n); -bool isOptimizableContainerType(Node* n); +bool canReuseInputsOutputs( + Node* n, + const FastMap& node_has_out_variant); +bool isOptimizableContainerType( + Node* n, + const FastMap& node_has_out_variant); std::function getOutOfPlaceOperation(Node* n); std::function getNativeOperation(Node* n); -- 2.7.4