[Static Runtime] Benchmark reports native nodes (#63346)
authorMike Iovine <mikeiovine@fb.com>
Wed, 18 Aug 2021 21:56:51 +0000 (14:56 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 18 Aug 2021 22:05:08 +0000 (15:05 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63346

We have seen that we can get significant perf wins essentially for free by implementing native ops for ops that we cannot write out variants for (e.g. TupleUnpack D30306955 (https://github.com/pytorch/pytorch/commit/078b8004a62a51f75e1fbd8d08eea359af6bb1d7), append D30326461 (https://github.com/pytorch/pytorch/commit/9d9e7a8d7294834ddad957ddb1f4cd5a0e741e55)). Therefore, whether or not SR is using a native implementation is valuable information. By capturing this in the benchmarking suite, we can hopefully avoid wasting time profiling/manually inspecting `native_ops.cpp`

Reviewed By: hlu1

Differential Revision: D30346752

fbshipit-source-id: 205b090513b6a5a6ce4cb92f75ab0395b15d08f9

torch/csrc/jit/runtime/static/impl.cpp
torch/csrc/jit/runtime/static/impl.h

index a0c3bac..1ee69a6 100644 (file)
@@ -897,10 +897,12 @@ void StaticRuntime::benchmark(
     std::cout << std::setw(15) << ms << " ms. " << std::setw(10)
               << results.percent_per_node_type[kind] << "%. " << kind << " ("
               << results.instances_per_node_type[kind] << " nodes";
-    if (results.out_nodes.count(kind) == 0) {
-      std::cout << ")" << std::endl;
-    } else {
+    if (results.out_nodes.count(kind)) {
       std::cout << ", out variant)" << std::endl;
+    } else if (results.native_nodes.count(kind)) {
+      std::cout << ", native)" << std::endl;
+    } else {
+      std::cout << ")" << std::endl;
     }
   }
   std::cout << std::setw(15) << results.total_time << " ms. in Total"
@@ -1136,6 +1138,8 @@ StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops(
     if (nodes_[i].has_out_variant()) {
       results.out_nodes.insert(kind);
       results.out_nodes_count++;
+    } else if (nodes_[i].has_native()) {
+      results.native_nodes.insert(kind);
     }
     results.total_time += results.time_per_node[i];
   }
index cc36df0..b16cfef 100644 (file)
@@ -231,6 +231,7 @@ class TORCH_API StaticRuntime {
     std::unordered_map<std::string, float> percent_per_node_type;
     std::unordered_map<std::string, int> instances_per_node_type;
     std::unordered_set<std::string> out_nodes;
+    std::unordered_set<std::string> native_nodes;
   };
 
   IndividualMetrics benchmark_individual_ops(
@@ -410,6 +411,10 @@ class TORCH_API ProcessedNode {
     return static_cast<bool>(fn_);
   }
 
+  bool has_native() const {
+    return static_cast<bool>(native_fn_);
+  }
+
   bool verify_outputs_not_overlapping_with_immutable_inputs() const;
 
  private: