From aa63c0d9df54ea1fe65791143c8cee6b34cab4f2 Mon Sep 17 00:00:00 2001 From: Hao Lu Date: Mon, 16 Aug 2021 16:30:53 -0700 Subject: [PATCH] [PyPer] Skip printing out per node time when do_profile is on (#63256) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63256 This suppresses printing out the per node time which is very long when the net has too many ops. It can be easily turned on by setting `--pt_sr_print_per_node_time=1`. Reviewed By: ajyu, mikeiovine Differential Revision: D30298331 fbshipit-source-id: 32b3f93b3fe19d335654168311fda93331a1e706 --- torch/csrc/jit/runtime/static/impl.cpp | 15 +++++++++------ torch/csrc/jit/runtime/static/impl.h | 3 ++- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp index 2d8b6c4..f51c4e0 100644 --- a/torch/csrc/jit/runtime/static/impl.cpp +++ b/torch/csrc/jit/runtime/static/impl.cpp @@ -845,7 +845,8 @@ void StaticRuntime::benchmark( const std::vector& args, const std::unordered_map& kwargs, const int warmup_runs, - const int main_runs) { + const int main_runs, + bool print_per_node_time) { float time_per_iter = benchmark_model(args, kwargs, warmup_runs, main_runs); std::cout << "Static runtime ms per iter: " << time_per_iter << ". Iters per second: " << 1000.0 / time_per_iter << std::endl; @@ -853,11 +854,13 @@ void StaticRuntime::benchmark( IndividualMetrics results = benchmark_individual_ops(args, kwargs, warmup_runs, main_runs); - for (const auto i : c10::irange(nodes_.size())) { - const Node* node = nodes_[i].node(); - std::cout << "Node #" << i << ": " << results.time_per_node[i] - << " ms/iter, "; - node->print(std::cout, 0, nullptr, false); + if (print_per_node_time) { + for (const auto i : c10::irange(nodes_.size())) { + const Node* node = nodes_[i].node(); + std::cout << "Node #" << i << ": " << results.time_per_node[i] + << " ms/iter, "; + node->print(std::cout, 0, nullptr, false); + } } std::vector> time_per_node_type_vec{ diff --git a/torch/csrc/jit/runtime/static/impl.h b/torch/csrc/jit/runtime/static/impl.h index bf28dfc..cc36df0 100644 --- a/torch/csrc/jit/runtime/static/impl.h +++ b/torch/csrc/jit/runtime/static/impl.h @@ -209,7 +209,8 @@ class TORCH_API StaticRuntime { const std::vector& args, const std::unordered_map& kwargs, const int warmup_runs, - const int main_runs); + const int main_runs, + bool print_per_node_time = false); float benchmark_model( const std::vector& args, -- 2.7.4