Clean up executor's execution flags (#13869)
authorIlia Cherniavskii <iliacher@fb.com>
Fri, 16 Nov 2018 00:59:14 +0000 (16:59 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 16 Nov 2018 01:11:51 +0000 (17:11 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13869

Remove unused flags and consolidate them into one struct

Reviewed By: yinghai

Differential Revision: D13032207

fbshipit-source-id: 2cef093589036238732099e3851a97e739b5fd55

caffe2/core/net_async_base.cc
caffe2/core/net_async_base.h
caffe2/core/net_async_dag_gpu.cc
caffe2/core/net_async_scheduling.cc
caffe2/core/net_async_scheduling.h

index b8e784e..6b11db5 100644 (file)
@@ -11,23 +11,11 @@ C10_DEFINE_int(
     "Number of streams per worker per GPU"
     " to use in GPU thread pool (experimental)");
 
-C10_DECLARE_bool(caffe2_dag_net_collect_stats);
-
 C10_DEFINE_bool(
     caffe2_net_async_inference_mode,
     false,
     "If set, use one single chain containing all ops");
 
-C10_DEFINE_bool(
-    caffe2_net_async_finish_chain,
-    false,
-    "Wait for chain to finish");
-
-C10_DEFINE_bool(
-    caffe2_net_async_always_schedule_child,
-    false,
-    "Always schedule child chains from parent chain");
-
 C10_DEFINE_int(
     caffe2_net_async_max_gpus,
     16,
@@ -58,6 +46,11 @@ C10_DEFINE_bool(
     false,
     "Use per net thread pools");
 
+C10_DEFINE_bool(
+    caffe2_net_async_run_root_tasks_inline,
+    false,
+    "Run root tasks in current thread instread of scheduling to threadpool");
+
 namespace caffe2 {
 
 std::vector<int>& AsyncNetBase::getStreamCounters() {
@@ -68,9 +61,7 @@ std::vector<int>& AsyncNetBase::getStreamCounters() {
 AsyncNetBase::AsyncNetBase(
     const std::shared_ptr<const NetDef>& net_def,
     Workspace* ws)
-    : NetBase(net_def, ws), counters_(net_def) {
-  computeExecutionModeFlags();
-
+    : NetBase(net_def, ws), options_(net_def), counters_(net_def) {
   operator_nodes_ = dag_utils::prepareOperatorNodes(net_def, ws);
   helper_ = caffe2::make_unique<AsyncNetExecutorHelper>(this);
   operators_.reserve(operator_nodes_.size());
@@ -96,7 +87,7 @@ AsyncNetBase::AsyncNetBase(
     const auto& last_op = operators_[chain.back()];
     events_.push_back(&last_op->event());
     // keep events for inner chain ops in case of profiling
-    if (!report_stats_) {
+    if (!options_.report_stats_) {
       for (const auto& op_id : chain) {
         if (op_id == chain.back() || op_id == chain.front()) {
           continue;
@@ -157,14 +148,17 @@ TaskThreadPoolBase* AsyncNetBase::poolGetter(
   auto pool = pools[device_id][pool_size];
   if (!pool) {
     pool = ThreadPoolRegistry()->Create(
-        DeviceTypeName(device_type), device_id, pool_size, use_per_net_pools_);
+        DeviceTypeName(device_type),
+        device_id,
+        pool_size,
+        options_.use_per_net_pools_);
     pools[device_id][pool_size] = pool;
   }
   return pool.get();
 }
 
 TaskThreadPoolBase* AsyncNetBase::pool(const DeviceOption& device_option) {
-  if (use_single_pool_) {
+  if (options_.use_single_pool_) {
     return poolGetter(cpu_pools_, PROTO_CPU, -1, num_workers_);
   }
   const auto device_type = device_option.device_type();
@@ -187,9 +181,7 @@ TaskThreadPoolBase* AsyncNetBase::pool(const DeviceOption& device_option) {
         "Invalid GPU id: " + caffe2::to_string(gpu_id));
     return poolGetter(gpu_pools_, device_type, gpu_id, num_workers_);
   } else {
-    CAFFE_THROW(
-        "Unsupported device type " +
-        caffe2::to_string(device_type));
+    CAFFE_THROW("Unsupported device type " + caffe2::to_string(device_type));
   }
 }
 
@@ -204,8 +196,9 @@ int AsyncNetBase::stream(int task_id) {
     }
     do {
       stream_id = getStreamCounters().at(gpu_id)++;
-      getStreamCounters().at(gpu_id) %= streams_per_gpu_;
-    } while (check_stream_status_ && !isStreamFree(task_id, stream_id));
+      getStreamCounters().at(gpu_id) %= options_.streams_per_gpu_;
+    } while (options_.check_stream_status_ &&
+             !isStreamFree(task_id, stream_id));
   }
   return stream_id;
 }
@@ -385,15 +378,15 @@ bool AsyncNetBase::run(int task_id, int stream_id) {
   OperatorBase* op = nullptr;
   try {
     // Optionally insert async wait ops,
-    // skip when using --caffe2_net_async_finish_chain -
+    // skip when finish_chain_ is set -
     // all parents are guaranteed to be finished
-    if (!finish_chain_) {
+    if (!options_.finish_chain_) {
       asyncWait(task_id, stream_id, parents(task_id));
     }
     for (auto& op_id : chains_[task_id]) {
       op = operators_[op_id];
       bool success = false;
-      if (!report_stats_) {
+      if (!options_.report_stats_) {
         TRACE_EVENT(
             tracing::TRACE_OP,
             op_id,
@@ -418,7 +411,7 @@ bool AsyncNetBase::run(int task_id, int stream_id) {
     }
 
     op = nullptr;
-    if (finish_chain_) {
+    if (options_.finish_chain_) {
       operators_[chains_[task_id].back()]->event().Finish();
     }
   } catch (const std::exception& e) {
@@ -465,7 +458,7 @@ ProfDAGProtos AsyncNetBase::GetPerOperatorCost() const {
 }
 
 AsyncNetBase::~AsyncNetBase() {
-  if (report_stats_) {
+  if (options_.report_stats_) {
     counters_.PrintStats();
   }
 }
@@ -490,15 +483,16 @@ C10_REGISTER_CREATOR(
     HIP,
     GetAsyncNetThreadPool<TaskThreadPool, PROTO_HIP>);
 
-void AsyncNetBase::computeExecutionModeFlags() {
+ExecutionOptions::ExecutionOptions(
+    const std::shared_ptr<const NetDef>& net_def) {
   static const std::string kDag = "dag";
   static const std::string kProfDag = "prof_dag";
   static const std::string kAsyncDag = "async_dag";
   static const std::string kSimpleNet = "simple";
 
   std::string net_type;
-  if (net_def_->has_type() && !net_def_->type().empty()) {
-    net_type = net_def_->type();
+  if (net_def->has_type() && !net_def->type().empty()) {
+    net_type = net_def->type();
   } else {
     net_type = kSimpleNet;
   }
@@ -522,8 +516,8 @@ void AsyncNetBase::computeExecutionModeFlags() {
     report_stats_ = false;
   } else {
     streams_per_gpu_ = FLAGS_caffe2_streams_per_gpu;
-    finish_chain_ = FLAGS_caffe2_net_async_finish_chain;
-    always_schedule_child_ = FLAGS_caffe2_net_async_always_schedule_child;
+    finish_chain_ = false;
+    always_schedule_child_ = false;
     check_stream_status_ = FLAGS_caffe2_net_async_check_stream_status;
     use_single_pool_ = FLAGS_caffe2_net_async_use_single_pool;
     use_per_net_pools_ = FLAGS_caffe2_net_async_use_per_net_pools;
@@ -531,14 +525,21 @@ void AsyncNetBase::computeExecutionModeFlags() {
     report_stats_ = false;
   }
 
-  for (int arg_idx = 0; arg_idx < net_def_->arg_size(); ++arg_idx) {
-    auto& arg = net_def_->arg(arg_idx);
+  use_dfs_scheduling_ = false;
+
+  for (int arg_idx = 0; arg_idx < net_def->arg_size(); ++arg_idx) {
+    auto& arg = net_def->arg(arg_idx);
     if (arg.has_name() && arg.name() == "enable_profiling") {
       CAFFE_ENFORCE(arg.has_i(), "enable_profiling should be an int");
       report_stats_ = arg.i() == 1;
-      break;
+    }
+    if (arg.has_name() && arg.name() == "deferrable_mode") {
+      CAFFE_ENFORCE(arg.has_i(), "deferrable_mode should be an int");
+      use_dfs_scheduling_ = arg.i() == 1; // corr. to DFS scheduling
     }
   }
+
+  run_root_tasks_inline_ = FLAGS_caffe2_net_async_run_root_tasks_inline;
 }
 
 } // namespace caffe2
index da5123b..e665441 100644 (file)
 #include "caffe2/utils/thread_pool.h"
 
 C10_DECLARE_int(caffe2_streams_per_gpu);
-C10_DECLARE_bool(caffe2_net_async_finish_chain);
-C10_DECLARE_bool(caffe2_net_async_always_schedule_child);
 C10_DECLARE_int(caffe2_net_async_max_gpus);
 C10_DECLARE_int(caffe2_net_async_max_numa_nodes);
 C10_DECLARE_int(caffe2_net_async_thread_pool_size);
 C10_DECLARE_bool(caffe2_net_async_check_stream_status);
 C10_DECLARE_bool(caffe2_net_async_use_single_pool);
 C10_DECLARE_bool(caffe2_net_async_use_per_net_pools);
+C10_DECLARE_bool(caffe2_net_async_run_root_tasks_inline);
 
 namespace caffe2 {
 
@@ -33,6 +32,21 @@ namespace tracing {
 class Tracer;
 }
 
+struct ExecutionOptions {
+  explicit ExecutionOptions(const std::shared_ptr<const NetDef>& net_def);
+
+  int streams_per_gpu_ = 1;
+  bool finish_chain_ = false;
+  bool always_schedule_child_ = false;
+  bool check_stream_status_ = false;
+  bool use_single_pool_ = false;
+  bool use_per_net_pools_ = false;
+  bool is_blocking_ = false;
+  bool report_stats_ = false;
+  bool use_dfs_scheduling_ = false;
+  bool run_root_tasks_inline_ = false;
+};
+
 class CAFFE2_API AsyncNetBase : public NetBase {
  public:
   AsyncNetBase(const std::shared_ptr<const NetDef>& net_def, Workspace* ws);
@@ -127,15 +141,7 @@ class CAFFE2_API AsyncNetBase : public NetBase {
   std::shared_ptr<tracing::Tracer> tracer_;
 
   // execution mode flags
-  void computeExecutionModeFlags();
-  int streams_per_gpu_;
-  bool finish_chain_;
-  bool always_schedule_child_;
-  bool check_stream_status_;
-  bool use_single_pool_;
-  bool use_per_net_pools_;
-  bool is_blocking_;
-  bool report_stats_;
+  ExecutionOptions options_;
 
   ProfDAGCounters counters_;
 
index ac97d0b..674b88e 100644 (file)
@@ -26,8 +26,6 @@ C10_DEFINE_bool(
 
 C10_DECLARE_bool(caffe2_dag_net_collect_stats);
 
-C10_DECLARE_bool(caffe2_net_async_finish_chain);
-
 C10_DECLARE_int(caffe2_streams_per_gpu);
 
 C10_DECLARE_bool(caffe2_net_async_check_stream_status);
@@ -184,9 +182,6 @@ bool AsyncDAGNet::RunAt(int chain_id, const std::vector<int>& chain) {
   }
 
   const auto& sink_idx = chain.back();
-  if (success && FLAGS_caffe2_net_async_finish_chain) {
-    operator_nodes_[sink_idx].operator_->event().Finish();
-  }
   CAFFE_ENFORCE(
       !eventRecorded_[sink_idx],
       "An event for ",
index 64e3600..a2204e9 100644 (file)
@@ -2,30 +2,12 @@
 
 #include "caffe2/core/net_async_tracing.h"
 
-C10_DEFINE_bool(
-    caffe2_net_async_optimize_polling,
-    true,
-    "Use event callbacks whenever possible instead of polling");
-C10_DEFINE_bool(
-    caffe2_net_async_run_root_tasks_inline,
-    false,
-    "Run root tasks in current thread instread of scheduling to threadpool");
-
 namespace caffe2 {
 
 AsyncSchedulingNet::AsyncSchedulingNet(
     const std::shared_ptr<const NetDef>& net_def,
     Workspace* ws)
-    : AsyncNetBase(net_def, ws), running_(false), use_dfs_scheduling_(false) {
-  for (int arg_idx = 0; arg_idx < net_def->arg_size(); ++arg_idx) {
-    auto& arg = net_def->arg(arg_idx);
-    if (arg.has_name() && arg.name() == "deferrable_mode") {
-      CAFFE_ENFORCE(arg.has_i(), "deferrable_mode should be an int");
-      use_dfs_scheduling_ = arg.i() == 1; // corr. to DFS scheduling
-      break;
-    }
-  }
-}
+    : AsyncNetBase(net_def, ws), running_(false) {}
 
 void AsyncSchedulingNet::reset() {
   AsyncNetBase::reset();
@@ -40,7 +22,7 @@ void AsyncSchedulingNet::Wait() {
 }
 
 bool AsyncSchedulingNet::isInlineTask(int parent_id, int child_id) const {
-  if (!use_dfs_scheduling_) {
+  if (!options_.use_dfs_scheduling_) {
     return false;
   }
   const auto* last_parent_op = lastTaskOp(parent_id);
@@ -57,7 +39,7 @@ void AsyncSchedulingNet::schedule(int task_id, bool run_inline) {
   auto schedule_func = [this, task_id]() {
     if (success_) {
       int stream_id = 0;
-      if (streams_per_gpu_ > 1) {
+      if (options_.streams_per_gpu_ > 1) {
         stream_id = stream(task_id);
       }
       if (!run(task_id, stream_id)) {
@@ -65,7 +47,7 @@ void AsyncSchedulingNet::schedule(int task_id, bool run_inline) {
       }
     }
 
-    if (report_stats_) {
+    if (options_.report_stats_) {
       auto last_op_id = lastTaskOpId(task_id);
       auto* last_op = lastTaskOp(task_id);
       if (last_op->device_option().device_type() == PROTO_CPU &&
@@ -80,12 +62,12 @@ void AsyncSchedulingNet::schedule(int task_id, bool run_inline) {
       if (parent_count == 0) {
         // Schedule a child if:
         // - there is failure, we skip an op execution and finish the job
-        // - forced scheduling though --caffe2_net_async_always_schedule_child
-        // - --caffe2_net_async_finish_chain is set, in this case parents are
+        // - forced scheduling though always_schedule_child_
+        // - finish_chain_ is set, in this case parents are
         //   guaranteed to be finished
         // - in all other cases, check parents with canSchedule
-        if (!success_ || always_schedule_child_ || finish_chain_ ||
-            canSchedule(child_id)) {
+        if (!success_ || options_.always_schedule_child_ ||
+            options_.finish_chain_ || canSchedule(child_id)) {
           // if DFS scheduling is enabled, run children inline,
           // ignore DFS scheduling in callbacks
           schedule(child_id, isInlineTask(task_id, child_id));
@@ -107,8 +89,7 @@ void AsyncSchedulingNet::schedule(int task_id, bool run_inline) {
               if (!canSchedule(parent_id, child_id)) {
                 // we can't schedule a child because of this parent,
                 // check if parent supports callback
-                if (FLAGS_caffe2_net_async_optimize_polling &&
-                    parent_event.SupportsCallback()) {
+                if (parent_event.SupportsCallback()) {
                   parents_with_callback.push_back(parent_id);
                 } else {
                   parent_needs_polling = true;
@@ -223,7 +204,7 @@ void AsyncSchedulingNet::finishRun() {
   std::unique_lock<std::mutex> lock(running_mutex_);
   // wait for scheduled ops and make sure all events are marked as finished
   finalizeEvents();
-  if (report_stats_) {
+  if (options_.report_stats_) {
     counters_.ReportRunEnd();
   }
   // notify observers and waiters
@@ -245,14 +226,14 @@ bool AsyncSchedulingNet::RunAsync() {
 
       StartAllObservers();
       tracing::startIter(tracer_);
-      if (report_stats_) {
+      if (options_.report_stats_) {
         counters_.ReportRunStart();
       }
     }
 
     for (auto task_id = 0; task_id < tasksNum(); ++task_id) {
       if (parents(task_id).empty()) {
-        schedule(task_id, FLAGS_caffe2_net_async_run_root_tasks_inline);
+        schedule(task_id, options_.run_root_tasks_inline_);
       }
     }
   } catch (const std::exception& e) {
@@ -265,7 +246,7 @@ bool AsyncSchedulingNet::RunAsync() {
     finishRun();
   }
 
-  if (is_blocking_) {
+  if (options_.is_blocking_) {
     Wait();
   }
 
index 69563c4..3e753e2 100644 (file)
@@ -27,7 +27,6 @@ class CAFFE2_API AsyncSchedulingNet : public AsyncNetBase {
   std::mutex running_mutex_;
   std::condition_variable running_cv_;
   std::atomic<bool> running_;
-  bool use_dfs_scheduling_;
 
   std::atomic<int> processed_tasks_num_;