[PyPer] Skip printing out per node time when do_profile is on (#63256)
authorHao Lu <hlu@fb.com>
Mon, 16 Aug 2021 23:30:53 +0000 (16:30 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 16 Aug 2021 23:32:19 +0000 (16:32 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63256

This suppresses printing out the per node time which is very long when the net has too many ops. It can be easily turned on by setting `--pt_sr_print_per_node_time=1`.

Reviewed By: ajyu, mikeiovine

Differential Revision: D30298331

fbshipit-source-id: 32b3f93b3fe19d335654168311fda93331a1e706

torch/csrc/jit/runtime/static/impl.cpp
torch/csrc/jit/runtime/static/impl.h

index 2d8b6c4..f51c4e0 100644 (file)
@@ -845,7 +845,8 @@ void StaticRuntime::benchmark(
     const std::vector<c10::IValue>& args,
     const std::unordered_map<std::string, c10::IValue>& kwargs,
     const int warmup_runs,
-    const int main_runs) {
+    const int main_runs,
+    bool print_per_node_time) {
   float time_per_iter = benchmark_model(args, kwargs, warmup_runs, main_runs);
   std::cout << "Static runtime ms per iter: " << time_per_iter
             << ". Iters per second: " << 1000.0 / time_per_iter << std::endl;
@@ -853,11 +854,13 @@ void StaticRuntime::benchmark(
   IndividualMetrics results =
       benchmark_individual_ops(args, kwargs, warmup_runs, main_runs);
 
-  for (const auto i : c10::irange(nodes_.size())) {
-    const Node* node = nodes_[i].node();
-    std::cout << "Node #" << i << ": " << results.time_per_node[i]
-              << " ms/iter, ";
-    node->print(std::cout, 0, nullptr, false);
+  if (print_per_node_time) {
+    for (const auto i : c10::irange(nodes_.size())) {
+      const Node* node = nodes_[i].node();
+      std::cout << "Node #" << i << ": " << results.time_per_node[i]
+                << " ms/iter, ";
+      node->print(std::cout, 0, nullptr, false);
+    }
   }
 
   std::vector<std::pair<std::string, double>> time_per_node_type_vec{
index bf28dfc..cc36df0 100644 (file)
@@ -209,7 +209,8 @@ class TORCH_API StaticRuntime {
       const std::vector<c10::IValue>& args,
       const std::unordered_map<std::string, c10::IValue>& kwargs,
       const int warmup_runs,
-      const int main_runs);
+      const int main_runs,
+      bool print_per_node_time = false);
 
   float benchmark_model(
       const std::vector<c10::IValue>& args,