From cd3c4a2f1c06c7fc9409d486fbd9e4edd40fb162 Mon Sep 17 00:00:00 2001 From: Dong Li Date: Fri, 28 Dec 2018 15:00:41 -0800 Subject: [PATCH] keep extra_info of each op in ProfDagStats (#15244) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15244 This DIFF keeps track of the extra_info information attached to each operator. When getPerOpStas() is called, it attaches the extra_info to the result ProfDagStats protobuf. Facebook Net transform attaches a global_op_id which is defined as a tuple of (orig_net_name, original_op_index) to each operator, The global_op_id is encoded as extra_info in each operator. Reviewed By: aazzolini Differential Revision: D13016289 fbshipit-source-id: 3e2719ec7ed0ebe47740b77581c565ff7e79b102 --- caffe2/core/prof_dag_counters.cc | 24 +++++++++++++++++++++--- caffe2/core/prof_dag_counters.h | 7 +++++-- caffe2/proto/prof_dag.proto | 4 ++++ 3 files changed, 30 insertions(+), 5 deletions(-) diff --git a/caffe2/core/prof_dag_counters.cc b/caffe2/core/prof_dag_counters.cc index 6b576d3..0187a85 100644 --- a/caffe2/core/prof_dag_counters.cc +++ b/caffe2/core/prof_dag_counters.cc @@ -1,4 +1,5 @@ #include "caffe2/core/prof_dag_counters.h" +#include "caffe2/utils/string_utils.h" #include #include @@ -10,8 +11,20 @@ ProfDAGCounters::ProfDAGCounters(const std::shared_ptr& net_def) { report_.num_runs_ = 0; auto num_ops = net_def->op_size(); report_.op_types_.reserve(num_ops); + report_.op_extra_info_.reserve(num_ops); + for (auto op_id = 0; op_id < num_ops; ++op_id) { report_.op_types_.push_back(net_def->op(op_id).type()); + vector op_extra_info; + if (net_def->op(op_id).has_device_option() && + net_def->op(op_id).device_option().extra_info_size() > 0) { + for (auto i = 0; i < net_def->op(op_id).device_option().extra_info_size(); + ++i) { + auto extra_info_str = net_def->op(op_id).device_option().extra_info(i); + op_extra_info.push_back(extra_info_str); + } + } + report_.op_extra_info_.push_back(op_extra_info); } report_.time_per_op_total_.resize(num_ops); } @@ -99,12 +112,16 @@ ProfDAGReport ProfDAGCounters::GetReport() const { ProfDAGProto ProfDAGReport::statsProto( const std::string& name, - const ProfDAGStats& stats) const { + const ProfDAGStats& stats, + const std::vector& op_extra_info) const { ProfDAGProto stats_proto; const auto& moments = stats.computeMoments(); stats_proto.set_mean(moments.first); stats_proto.set_stddev(moments.second); stats_proto.set_name(name); + for (auto& extra_info : op_extra_info) { + stats_proto.add_extra_info(extra_info); + } return stats_proto; } @@ -114,7 +131,7 @@ ProfDAGProtos ProfDAGReport::GetOperatorStats() const { if (num_runs_ > 1) { for (auto& item : time_per_op_type_total_) { auto buf = prof_dag_protos.add_stats(); - buf->CopyFrom(statsProto(item.first, item.second)); + buf->CopyFrom(statsProto(item.first, item.second, vector())); } } return prof_dag_protos; @@ -129,7 +146,8 @@ ProfDAGProtos ProfDAGReport::GetPerOperatorCost() const { auto buf = prof_dag_protos.add_stats(); std::string op_output_name = net_name_ + "___" + to_string(op_id) + "___" + op_type; - buf->CopyFrom(statsProto(op_output_name, time_per_op_total_[op_id])); + buf->CopyFrom(statsProto( + op_output_name, time_per_op_total_[op_id], op_extra_info_[op_id])); } } return prof_dag_protos; diff --git a/caffe2/core/prof_dag_counters.h b/caffe2/core/prof_dag_counters.h index 18c8dc8..a1a494e 100644 --- a/caffe2/core/prof_dag_counters.h +++ b/caffe2/core/prof_dag_counters.h @@ -64,10 +64,13 @@ class ProfDAGReport { void PrintStats(); private: - ProfDAGProto statsProto(const std::string& name, const ProfDAGStats& stats) - const; + ProfDAGProto statsProto( + const std::string& name, + const ProfDAGStats& stats, + const std::vector& op_extra_info) const; std::vector op_types_; + std::vector> op_extra_info_; std::string net_name_; diff --git a/caffe2/proto/prof_dag.proto b/caffe2/proto/prof_dag.proto index 343cff1..c6820d4 100644 --- a/caffe2/proto/prof_dag.proto +++ b/caffe2/proto/prof_dag.proto @@ -42,6 +42,10 @@ message ProfDAGProto { // Blob profiles that this node outputs. repeated BlobProfile output_profile = 5; + + // The extra_info from the operator device option. + repeated string extra_info = 7; + } // Operator profiling information. -- 2.7.4