Restructuring prof dag counters (#13321)
authorHassan Eslami <heslami@fb.com>
Thu, 20 Dec 2018 05:35:08 +0000 (21:35 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 20 Dec 2018 05:48:30 +0000 (21:48 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13321

This diff simply refactors the `ProfDAGCounters` into two:
* `ProfDAGCounters` that gathers stats at runtime.
* `ProfDAGReport` which holds the report from the gathered stats once stats collection is done.

This refactoring allow us to implement `+=` for `ProfDAGReport`, which can be used for aggregating same-net reports on each host.

Reviewed By: donglimm

Differential Revision: D12837988

fbshipit-source-id: 0470c5fd6437f12711cab25a15a12965d79b2a91

caffe2/core/net_async_base.cc
caffe2/core/net_async_base.h
caffe2/core/prof_dag_counters.cc
caffe2/core/prof_dag_counters.h

index 680893f..8bfecab 100644 (file)
@@ -461,16 +461,20 @@ void AsyncNetBase::finalizeEvents() {
 }
 
 ProfDAGProtos AsyncNetBase::GetOperatorStats() const {
-  return counters_.GetOperatorStats();
+  return counters_.GetReport().GetOperatorStats();
 }
 
 ProfDAGProtos AsyncNetBase::GetPerOperatorCost() const {
-  return counters_.GetPerOperatorCost();
+  return counters_.GetReport().GetPerOperatorCost();
+}
+
+ProfDAGReport AsyncNetBase::GetProfReport() const {
+  return counters_.GetReport();
 }
 
 AsyncNetBase::~AsyncNetBase() {
   if (options_.report_stats_) {
-    counters_.PrintStats();
+    counters_.GetReport().PrintStats();
   }
 }
 
index 8792c04..e63c1aa 100644 (file)
@@ -76,6 +76,7 @@ class CAFFE2_API AsyncNetBase : public NetBase {
 
   ProfDAGProtos GetOperatorStats() const;
   ProfDAGProtos GetPerOperatorCost() const;
+  ProfDAGReport GetProfReport() const;
 
  protected:
   bool canSchedule(
index f256880..6b576d3 100644 (file)
@@ -5,29 +5,31 @@
 
 namespace caffe2 {
 
-ProfDAGCounters::ProfDAGCounters(const std::shared_ptr<const NetDef>& net_def)
-    : net_name_(net_def->name()), num_runs_(0) {
-  op_types_.reserve(net_def->op_size());
-  for (auto op_id = 0; op_id < net_def->op_size(); ++op_id) {
-    op_types_.push_back(net_def->op(op_id).type());
+ProfDAGCounters::ProfDAGCounters(const std::shared_ptr<const NetDef>& net_def) {
+  report_.net_name_ = net_def->name();
+  report_.num_runs_ = 0;
+  auto num_ops = net_def->op_size();
+  report_.op_types_.reserve(num_ops);
+  for (auto op_id = 0; op_id < num_ops; ++op_id) {
+    report_.op_types_.push_back(net_def->op(op_id).type());
   }
-  time_per_op_total_.resize(op_types_.size());
+  report_.time_per_op_total_.resize(num_ops);
 }
 
 void ProfDAGCounters::ReportRunStart() {
-  num_runs_ += 1;
+  report_.num_runs_ += 1;
   timer_.Start();
-
+  auto num_ops = report_.op_types_.size();
   op_start_times_run_.clear();
-  op_start_times_run_.resize(op_types_.size(), -1.0);
+  op_start_times_run_.resize(num_ops, -1.0);
   op_end_times_run_.clear();
-  op_end_times_run_.resize(op_types_.size(), -1.0);
+  op_end_times_run_.resize(num_ops, -1.0);
   op_async_end_times_run_.clear();
-  op_async_end_times_run_.resize(op_types_.size(), -1.0);
+  op_async_end_times_run_.resize(num_ops, -1.0);
 }
 
 void ProfDAGCounters::AddPerOpStartTime(size_t op_id) {
-  if (num_runs_ <= 1) {
+  if (report_.num_runs_ <= 1) {
     return;
   }
 
@@ -36,7 +38,7 @@ void ProfDAGCounters::AddPerOpStartTime(size_t op_id) {
 }
 
 void ProfDAGCounters::AddPerOpEndTime(size_t op_id) {
-  if (num_runs_ <= 1) {
+  if (report_.num_runs_ <= 1) {
     return;
   }
 
@@ -45,7 +47,7 @@ void ProfDAGCounters::AddPerOpEndTime(size_t op_id) {
 }
 
 void ProfDAGCounters::AddPerOpAsyncEndTime(size_t op_id) {
-  if (num_runs_ <= 1) {
+  if (report_.num_runs_ <= 1) {
     return;
   }
 
@@ -54,16 +56,16 @@ void ProfDAGCounters::AddPerOpAsyncEndTime(size_t op_id) {
 }
 
 void ProfDAGCounters::ReportRunEnd() {
-  if (num_runs_ <= 1) {
+  if (report_.num_runs_ <= 1) {
     return;
   }
 
   auto runtime = timer_.MilliSeconds();
-  runtime_stats_ += ProfDAGStats(runtime);
+  report_.runtime_stats_ += ProfDAGStats(runtime);
 
   CaffeMap<std::string, float> cum_per_type_time_run_;
   CaffeMap<std::string, float> cum_per_type_invocations_run_;
-  for (auto op_id = 0; op_id < op_types_.size(); ++op_id) {
+  for (auto op_id = 0; op_id < report_.op_types_.size(); ++op_id) {
     float op_time;
     CAFFE_ENFORCE(op_start_times_run_[op_id] > 0);
     if (op_async_end_times_run_[op_id] > 0) {
@@ -77,21 +79,25 @@ void ProfDAGCounters::ReportRunEnd() {
       op_time = op_sync_time;
     }
 
-    time_per_op_total_[op_id] += ProfDAGStats(op_time);
+    report_.time_per_op_total_[op_id] += ProfDAGStats(op_time);
 
-    const string& op_type = op_types_[op_id];
+    const string& op_type = report_.op_types_[op_id];
     cum_per_type_time_run_[op_type] += op_time;
     cum_per_type_invocations_run_[op_type] += 1;
   }
 
   for (const auto& kv : cum_per_type_time_run_) {
-    time_per_op_type_total_[kv.first] += ProfDAGStats(kv.second);
-    times_per_run_per_type_total_[kv.first] +=
+    report_.time_per_op_type_total_[kv.first] += ProfDAGStats(kv.second);
+    report_.times_per_run_per_type_total_[kv.first] +=
         ProfDAGStats(cum_per_type_invocations_run_[kv.first]);
   }
 }
 
-ProfDAGProto ProfDAGCounters::statsProto(
+ProfDAGReport ProfDAGCounters::GetReport() const {
+  return report_;
+}
+
+ProfDAGProto ProfDAGReport::statsProto(
     const std::string& name,
     const ProfDAGStats& stats) const {
   ProfDAGProto stats_proto;
@@ -102,30 +108,34 @@ ProfDAGProto ProfDAGCounters::statsProto(
   return stats_proto;
 }
 
-ProfDAGProtos ProfDAGCounters::GetOperatorStats() const {
-  CAFFE_ENFORCE_GT(num_runs_, 1, "Insufficient number of runs");
+ProfDAGProtos ProfDAGReport::GetOperatorStats() const {
   ProfDAGProtos prof_dag_protos;
-  for (auto& item : time_per_op_type_total_) {
-    auto buf = prof_dag_protos.add_stats();
-    buf->CopyFrom(statsProto(item.first, item.second));
+  prof_dag_protos.set_net_name(net_name_);
+  if (num_runs_ > 1) {
+    for (auto& item : time_per_op_type_total_) {
+      auto buf = prof_dag_protos.add_stats();
+      buf->CopyFrom(statsProto(item.first, item.second));
+    }
   }
   return prof_dag_protos;
 }
 
-ProfDAGProtos ProfDAGCounters::GetPerOperatorCost() const {
-  CAFFE_ENFORCE_GT(num_runs_, 1, "Insufficient number of runs");
+ProfDAGProtos ProfDAGReport::GetPerOperatorCost() const {
   ProfDAGProtos prof_dag_protos;
-  for (int op_id = 0; op_id < op_types_.size(); op_id++) {
-    const string& op_type = op_types_[op_id];
-    auto buf = prof_dag_protos.add_stats();
-    std::string op_output_name =
-        net_name_ + "___" + to_string(op_id) + "___" + op_type;
-    buf->CopyFrom(statsProto(op_output_name, time_per_op_total_[op_id]));
+  prof_dag_protos.set_net_name(net_name_);
+  if (num_runs_ > 1) {
+    for (int op_id = 0; op_id < op_types_.size(); op_id++) {
+      const string& op_type = op_types_[op_id];
+      auto buf = prof_dag_protos.add_stats();
+      std::string op_output_name =
+          net_name_ + "___" + to_string(op_id) + "___" + op_type;
+      buf->CopyFrom(statsProto(op_output_name, time_per_op_total_[op_id]));
+    }
   }
   return prof_dag_protos;
 }
 
-void ProfDAGCounters::PrintStats() {
+void ProfDAGReport::PrintStats() {
   if (num_runs_ <= 1) {
     LOG(INFO) << "Insufficient number of runs";
     return;
@@ -153,4 +163,48 @@ void ProfDAGCounters::PrintStats() {
   LOG(INFO) << debug_out.str();
 }
 
+ProfDAGReport& ProfDAGReport::operator+=(const ProfDAGReport& rhs) {
+  // Verify nets are compatible for addition
+  CAFFE_ENFORCE_EQ(
+      net_name_, rhs.net_name_, "Incompatible nets to add counters");
+  CAFFE_ENFORCE_EQ(
+      op_types_.size(),
+      rhs.op_types_.size(),
+      "Incompatible nets to add counters");
+  for (auto idx = 0; idx < op_types_.size(); ++idx) {
+    CAFFE_ENFORCE_EQ(
+        op_types_[idx],
+        rhs.op_types_[idx],
+        "Incompatible nets to add counters");
+  }
+
+  if (rhs.num_runs_ <= 1) {
+    // rhs does not have valid profiling results, do nothing
+    return *this;
+  } else if (num_runs_ <= 1) {
+    // "this" does not have valid profiling results, but rhs does. copy rhs
+    time_per_op_total_ = rhs.time_per_op_total_;
+    time_per_op_type_total_ = rhs.time_per_op_type_total_;
+    times_per_run_per_type_total_ = rhs.times_per_run_per_type_total_;
+    runtime_stats_ = rhs.runtime_stats_;
+    num_runs_ = rhs.num_runs_;
+    return *this;
+  }
+
+  // Do the addition
+  for (auto idx = 0; idx < time_per_op_total_.size(); ++idx) {
+    time_per_op_total_[idx] += rhs.time_per_op_total_.at(idx);
+  }
+  for (auto& item : time_per_op_type_total_) {
+    item.second += rhs.time_per_op_type_total_.at(item.first);
+  }
+  for (auto& item : times_per_run_per_type_total_) {
+    item.second += rhs.times_per_run_per_type_total_.at(item.first);
+  }
+  runtime_stats_ += rhs.runtime_stats_;
+  num_runs_ += rhs.num_runs_;
+
+  return *this;
+}
+
 } // namespace caffe2
index 49d7305..18c8dc8 100644 (file)
@@ -49,13 +49,9 @@ class ProfDAGStats {
   size_t cnt_;
 };
 
-/**
- * A simple wrapper around prof_dag's counters
- */
-class ProfDAGCounters {
+class ProfDAGReport {
  public:
-  explicit ProfDAGCounters(const std::shared_ptr<const NetDef>& net_def);
-
+  friend class ProfDAGCounters;
   // Collects the execution time per each operator type
   ProfDAGProtos GetOperatorStats() const;
 
@@ -63,14 +59,7 @@ class ProfDAGCounters {
   // formatted as a map: (netName__opIndex__opType, cost)
   ProfDAGProtos GetPerOperatorCost() const;
 
-  // ReportRunStart/End are called at the beginning and at the end of
-  // each net's run
-  void ReportRunStart();
-  void ReportRunEnd();
-
-  void AddPerOpStartTime(size_t op_id);
-  void AddPerOpEndTime(size_t op_id);
-  void AddPerOpAsyncEndTime(size_t op_id);
+  ProfDAGReport& operator+=(const ProfDAGReport& rhs);
 
   void PrintStats();
 
@@ -80,6 +69,9 @@ class ProfDAGCounters {
 
   std::vector<std::string> op_types_;
 
+  std::string net_name_;
+
+  int num_runs_;
   // Cumulative stats per operator instance of the net
   std::vector<ProfDAGStats> time_per_op_total_;
 
@@ -88,15 +80,33 @@ class ProfDAGCounters {
 
   CaffeMap<std::string, ProfDAGStats> times_per_run_per_type_total_;
 
-  std::string net_name_;
+  ProfDAGStats runtime_stats_;
+};
 
-  int num_runs_;
+/**
+ * A simple wrapper around prof_dag's counters
+ */
+class ProfDAGCounters {
+ public:
+  explicit ProfDAGCounters(const std::shared_ptr<const NetDef>& net_def);
+
+  // ReportRunStart/End are called at the beginning and at the end of
+  // each net's run
+  void ReportRunStart();
+  void ReportRunEnd();
+
+  void AddPerOpStartTime(size_t op_id);
+  void AddPerOpEndTime(size_t op_id);
+  void AddPerOpAsyncEndTime(size_t op_id);
+  ProfDAGReport GetReport() const;
+
+ private:
   Timer timer_;
-  ProfDAGStats runtime_stats_;
 
   std::vector<float> op_start_times_run_;
   std::vector<float> op_end_times_run_;
   std::vector<float> op_async_end_times_run_;
+  ProfDAGReport report_;
 };
 
 } // namespace caffe2