From 0337494c6a498f5c208a402d01a2a46020f9e6bd Mon Sep 17 00:00:00 2001 From: Ilia Cherniavskii Date: Wed, 20 Feb 2019 16:22:01 -0800 Subject: [PATCH] Reinforce scheduling invariants (#17132) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17132 schedule() function is not supposed to throw exception and is supposed to succeed in scheduling the full graph of tasks, potential errors (e.g. errors from underlying thread pool, out of memory exceptions etc) are considered not recoverable. The invariant - the graph of tasks is either not executed or executed in full before the call to finishRun() Reviewed By: andrewwdye Differential Revision: D14092457 fbshipit-source-id: a3e5d65dfee5ff5e5e71ec72bb9e576180019698 --- caffe2/core/net_async_base.cc | 4 +- caffe2/core/net_async_base.h | 4 +- caffe2/core/net_async_scheduling.cc | 261 ++++++++++++++++++++---------------- caffe2/core/net_async_scheduling.h | 2 +- 4 files changed, 147 insertions(+), 124 deletions(-) diff --git a/caffe2/core/net_async_base.cc b/caffe2/core/net_async_base.cc index 6c50dbb..54cd0f4 100644 --- a/caffe2/core/net_async_base.cc +++ b/caffe2/core/net_async_base.cc @@ -369,7 +369,7 @@ void AsyncNetBase::handleChainError( int task_id, OperatorBase* op, const char* err_str, - bool save_exception) { + bool save_exception) noexcept { std::string err_msg = err_str; if (op) { err_msg += ", op " + (op->has_debug_def() ? op->type() : " unknown"); @@ -385,7 +385,7 @@ void AsyncNetBase::handleChainError( } } -bool AsyncNetBase::run(int task_id, int stream_id) { +bool AsyncNetBase::run(int task_id, int stream_id) noexcept { OperatorBase* op = nullptr; try { // Optionally insert async wait ops, diff --git a/caffe2/core/net_async_base.h b/caffe2/core/net_async_base.h index e63c1aa..85fb217 100644 --- a/caffe2/core/net_async_base.h +++ b/caffe2/core/net_async_base.h @@ -106,7 +106,7 @@ class CAFFE2_API AsyncNetBase : public NetBase { int task_id, int stream_id, const std::vector& wait_task_ids) const; - bool run(int task_id, int stream_id); + bool run(int task_id, int stream_id) noexcept; int stream(int task_id); TaskThreadPoolBase* pool(const DeviceOption& device_option); TaskThreadPoolBase* pool(); @@ -144,7 +144,7 @@ class CAFFE2_API AsyncNetBase : public NetBase { int task_id, OperatorBase* op, const char* err_msg, - bool save_exception = false); + bool save_exception = false) noexcept; std::atomic success_; // Tracing diff --git a/caffe2/core/net_async_scheduling.cc b/caffe2/core/net_async_scheduling.cc index a2204e9..3e0f366 100644 --- a/caffe2/core/net_async_scheduling.cc +++ b/caffe2/core/net_async_scheduling.cc @@ -32,130 +32,153 @@ bool AsyncSchedulingNet::isInlineTask(int parent_id, int child_id) const { last_parent_op->device_option(), first_child_op->device_option()); } -void AsyncSchedulingNet::schedule(int task_id, bool run_inline) { +// schedule() is not supposed to throw, all exceptions in the ops are caught +// and reported in the end of the graph's execution, the full graph of tasks +// is expected to be scheduled +void AsyncSchedulingNet::schedule(int task_id, bool run_inline) noexcept { if (!testAndSetScheduled(task_id)) { return; } auto schedule_func = [this, task_id]() { - if (success_) { - int stream_id = 0; - if (options_.streams_per_gpu_ > 1) { - stream_id = stream(task_id); - } - if (!run(task_id, stream_id)) { - success_ = false; + try { + if (success_) { + int stream_id = 0; + if (options_.streams_per_gpu_ > 1) { + try { + stream_id = stream(task_id); + } catch (const std::exception& e) { + C10_LOG_EVERY_MS(ERROR, 1000) + << "Failed to select a stream: " << e.what(); + } + } + if (!run(task_id, stream_id)) { + success_ = false; + } } - } - if (options_.report_stats_) { - auto last_op_id = lastTaskOpId(task_id); - auto* last_op = lastTaskOp(task_id); - if (last_op->device_option().device_type() == PROTO_CPU && - last_op->HasAsyncPart()) { - last_op->event().SetCallback( - [this, last_op_id] { counters_.AddPerOpAsyncEndTime(last_op_id); }); + if (options_.report_stats_) { + try { + auto last_op_id = lastTaskOpId(task_id); + auto* last_op = lastTaskOp(task_id); + if (last_op->device_option().device_type() == PROTO_CPU && + last_op->HasAsyncPart()) { + last_op->event().SetCallback([this, last_op_id] { + counters_.AddPerOpAsyncEndTime(last_op_id); + }); + } + } catch (const std::exception& e) { + C10_LOG_EVERY_MS(ERROR, 1000) + << "Failed to report operator stats: " << e.what(); + } } - } - for (auto child_id : children(task_id)) { - int parent_count = updateParentCount(child_id); - if (parent_count == 0) { - // Schedule a child if: - // - there is failure, we skip an op execution and finish the job - // - forced scheduling though always_schedule_child_ - // - finish_chain_ is set, in this case parents are - // guaranteed to be finished - // - in all other cases, check parents with canSchedule - if (!success_ || options_.always_schedule_child_ || - options_.finish_chain_ || canSchedule(child_id)) { - // if DFS scheduling is enabled, run children inline, - // ignore DFS scheduling in callbacks - schedule(child_id, isInlineTask(task_id, child_id)); - } else { - bool parent_failed = false; - bool parent_needs_polling = false; - std::vector parents_with_callback; + for (auto child_id : children(task_id)) { + int parent_count = updateParentCount(child_id); + if (parent_count == 0) { + // Schedule a child if: + // - there is failure, we skip an op execution and finish the job + // - forced scheduling though always_schedule_child_ + // - finish_chain_ is set, in this case parents are + // guaranteed to be finished + // - in all other cases, check parents with canSchedule + if (!success_ || options_.always_schedule_child_ || + options_.finish_chain_ || canSchedule(child_id)) { + // if DFS scheduling is enabled, run children inline, + // ignore DFS scheduling in callbacks + schedule(child_id, isInlineTask(task_id, child_id)); + } else { + bool parent_failed = false; + bool parent_needs_polling = false; + std::vector parents_with_callback; - for (auto parent_id : parents(child_id)) { - auto& parent_event = event(parent_id); - auto parent_status = parent_event.Query(); + for (auto parent_id : parents(child_id)) { + auto& parent_event = event(parent_id); + auto parent_status = parent_event.Query(); - if (parent_status == EventStatus::EVENT_FAILED) { - parent_failed = true; - break; - } else if (parent_status == EventStatus::EVENT_SCHEDULED) { - // parent is not finished yet, check if this is blocking us - // from scheduling a child - if (!canSchedule(parent_id, child_id)) { - // we can't schedule a child because of this parent, - // check if parent supports callback - if (parent_event.SupportsCallback()) { - parents_with_callback.push_back(parent_id); - } else { - parent_needs_polling = true; - break; + if (parent_status == EventStatus::EVENT_FAILED) { + parent_failed = true; + break; + } else if (parent_status == EventStatus::EVENT_SCHEDULED) { + // parent is not finished yet, check if this is blocking us + // from scheduling a child + if (!canSchedule(parent_id, child_id)) { + // we can't schedule a child because of this parent, + // check if parent supports callback + if (parent_event.SupportsCallback()) { + parents_with_callback.push_back(parent_id); + } else { + parent_needs_polling = true; + break; + } } + } else if (parent_status != EventStatus::EVENT_SUCCESS) { + VLOG(1) << "Unexpected parent task state: " << parent_status + << ", task id: " << child_id + << ", parent task id: " << parent_id; + parent_failed = true; + break; } - } else if (parent_status != EventStatus::EVENT_SUCCESS) { - VLOG(1) << "Unexpected parent task state: " << parent_status - << ", task id: " << child_id - << ", parent task id: " << parent_id; - parent_failed = true; - break; } - } - if (parent_failed) { - // one of parents failed, set failure flag and wrap up execution - success_ = false; - schedule(child_id, isInlineTask(task_id, child_id)); - } else if (parent_needs_polling) { - // some parents are blocking us from scheduling a child and don't - // support callbacks, using polling - const auto& child_device_option = event(child_id).GetDeviceOption(); - pool(child_device_option) - ->run(std::bind( - &AsyncSchedulingNet::pollAndSchedule, this, child_id)); - } else if (!parents_with_callback.empty()) { - // some parents are blocking us from scheduling a child and they - // support callbacks - for (auto parent_id : parents_with_callback) { - event(parent_id).SetCallback(std::bind( - &AsyncSchedulingNet::parentCallback, this, parent_id)); + if (parent_failed) { + // one of parents failed, set failure flag and wrap up execution + success_ = false; + schedule(child_id, isInlineTask(task_id, child_id)); + } else if (parent_needs_polling) { + // some parents are blocking us from scheduling a child and don't + // support callbacks, using polling + const auto& child_device_option = + event(child_id).GetDeviceOption(); + pool(child_device_option) + ->run(std::bind( + &AsyncSchedulingNet::pollAndSchedule, this, child_id)); + } else if (!parents_with_callback.empty()) { + // some parents are blocking us from scheduling a child and they + // support callbacks + for (auto parent_id : parents_with_callback) { + event(parent_id).SetCallback(std::bind( + &AsyncSchedulingNet::parentCallback, this, parent_id)); + } + } else { + // we're ready to schedule a child + schedule(child_id, isInlineTask(task_id, child_id)); } - } else { - // we're ready to schedule a child - schedule(child_id, isInlineTask(task_id, child_id)); } } } - } - // In case of net's failure, make sure all pending tasks are finished - if (!success_) { - // Simple logic to capture all pending tasks - check all tasks - // at the end of each task in case of net's failure - for (auto tid = 0; tid < tasksNum(); ++tid) { - if (event(tid).Query() == EventStatus::EVENT_SCHEDULED) { - // SetFinished may throw, e.g. when we call it on already finished - // event, and in some other cases (CUDA) - try { - event(tid).SetFinished("Cancelled"); - } catch (const EnforceNotMet&) { - // ignore + // In case of net's failure, make sure all pending tasks are finished + if (!success_) { + // Simple logic to capture all pending tasks - check all tasks + // at the end of each task in case of net's failure + for (auto tid = 0; tid < tasksNum(); ++tid) { + if (event(tid).Query() == EventStatus::EVENT_SCHEDULED) { + // SetFinished may throw, e.g. when we call it on already finished + // event, and in some other cases (CUDA) + try { + event(tid).SetFinished("Cancelled"); + } catch (const EnforceNotMet&) { + // ignore + } } } } - } - // finishRun may cause waiters to wake up and destroy the net, - // before we call finishRun we need to make sure all other (finishing) - // tasks are done; - // Bumping and checking the counter after the task's job is done - auto tasks_num = tasksNum(); - auto cur_processed_tasks = ++processed_tasks_num_; - if (cur_processed_tasks == tasks_num) { - finishRun(); + // finishRun may cause waiters to wake up and destroy the net, + // before we call finishRun we need to make sure all other (finishing) + // tasks are done; + // Bumping and checking the counter after the task's job is done + auto tasks_num = tasksNum(); + auto cur_processed_tasks = ++processed_tasks_num_; + if (cur_processed_tasks == tasks_num) { + finishRun(); + } + } catch (const std::exception& e) { + // error of core scheduling and/or logic, will call terminate + LOG(FATAL) << "Unexpected error during graph scheduling run: " + << e.what(); + } catch (...) { + LOG(FATAL) << "Unknown error during graph scheduling run"; } }; @@ -215,26 +238,18 @@ void AsyncSchedulingNet::finishRun() { bool AsyncSchedulingNet::RunAsync() { try { - { - std::unique_lock lock(running_mutex_); - if (running_) { - LOG(ERROR) << "Detected concurrent runs"; - return false; - } - running_ = true; - reset(); - - StartAllObservers(); - tracing::startIter(tracer_); - if (options_.report_stats_) { - counters_.ReportRunStart(); - } + std::unique_lock lock(running_mutex_); + if (running_) { + LOG(ERROR) << "Detected concurrent runs"; + return false; } + running_ = true; + reset(); - for (auto task_id = 0; task_id < tasksNum(); ++task_id) { - if (parents(task_id).empty()) { - schedule(task_id, options_.run_root_tasks_inline_); - } + StartAllObservers(); + tracing::startIter(tracer_); + if (options_.report_stats_) { + counters_.ReportRunStart(); } } catch (const std::exception& e) { LOG(ERROR) << "Exception while starting an async run: " << e.what(); @@ -242,6 +257,14 @@ bool AsyncSchedulingNet::RunAsync() { return false; } + // schedule() is not expected to throw, at this moment all the initial tasks + // will be scheduled and the full graph of tasks will be executed + for (auto task_id = 0; task_id < tasksNum(); ++task_id) { + if (parents(task_id).empty()) { + schedule(task_id, options_.run_root_tasks_inline_); + } + } + if (tasksNum() == 0) { finishRun(); } diff --git a/caffe2/core/net_async_scheduling.h b/caffe2/core/net_async_scheduling.h index 3e753e2..03e72dd 100644 --- a/caffe2/core/net_async_scheduling.h +++ b/caffe2/core/net_async_scheduling.h @@ -18,7 +18,7 @@ class CAFFE2_API AsyncSchedulingNet : public AsyncNetBase { bool RunAsync() override; void pollAndSchedule(int task_id); - void schedule(int task_id, bool run_inline = false); + void schedule(int task_id, bool run_inline = false) noexcept; void reset() override; virtual void finishRun(); void parentCallback(int parent_id); -- 2.7.4