--- /dev/null
+#include "caffe2/core/net_async_task.h"
+
+#include "caffe2/core/net_async_task_graph.h"
+
+namespace caffe2 {
+
+AsyncTask::AsyncTask(const std::vector<OperatorBase*>& ops) : ops_(ops) {
+ CAFFE_ENFORCE(!ops_.empty());
+ device_option_ = ops_.front()->device_option();
+ for (auto& op : ops_) {
+ CAFFE_ENFORCE(IsSameDevice(device_option_, op->device_option()));
+ }
+ Reset();
+}
+
+void AsyncTask::handleChainError(
+ OperatorBase* op,
+ const char* err_str,
+ bool save_exception) {
+ std::string err_msg = err_str;
+ if (op) {
+ err_msg += ", op " + (op->has_debug_def() ? op->type() : " unknown");
+ }
+ LOG(ERROR) << err_msg;
+
+ // save error message and exception in chain's Event
+ auto last_op = ops_.back();
+ if (save_exception) {
+ last_op->event().SetFinishedWithException(err_msg.c_str());
+ } else {
+ last_op->event().SetFinished(err_msg.c_str());
+ }
+
+ // set future as completed with an error
+ // TODO: exceptions in future
+ future_.SetCompleted(err_msg.c_str());
+}
+
+bool AsyncTask::Run(const ExecutionOptions& options) {
+ // TODO: insert CUDA's async stream waits; tracing and counters
+ OperatorBase* op = nullptr;
+ try {
+ for (auto op_idx = 0; op_idx < ops_.size(); ++op_idx) {
+ op = ops_[op_idx];
+ int stream_id = 0; // TODO: thread local stream id
+ if (!op->RunAsync(stream_id)) {
+ handleChainError(op, "Failed to execute an op");
+ return false;
+ }
+ }
+
+ if (options.finish_chain_) {
+ op = ops_.back();
+ op->Finish();
+ }
+
+ // set the future as successfully completed or, in case of async CPU,
+ // use op's callback
+ if (IsCPUDeviceType(device_option_.device_type()) &&
+ ops_.back()->HasAsyncPart()) {
+ auto& event = ops_.back()->event();
+ event.SetCallback([this, &event]() {
+ CAFFE_ENFORCE(event.IsFinished());
+ if (event.Query() == EventStatus::EVENT_SUCCESS) {
+ future_.SetCompleted();
+ } else {
+ // TODO: support for exceptions
+ future_.SetCompleted(event.ErrorMessage().c_str());
+ }
+ });
+ } else {
+ future_.SetCompleted();
+ }
+ } catch (const std::exception& e) {
+ handleChainError(op, e.what(), /* save_exception */ true);
+ return false;
+ } catch (...) {
+ handleChainError(
+ op,
+ "Failed to execute task: unknown error",
+ /* save_exception */ true);
+ return false;
+ }
+
+ return true;
+}
+
+void AsyncTask::Reset() {
+ for (auto& op : ops_) {
+ op->ResetEvent();
+ }
+ future_.ResetState();
+}
+
+DeviceOption AsyncTask::GetDeviceOption() const {
+ return device_option_;
+}
+
+AsyncTaskFuture& AsyncTask::GetFuture() {
+ return future_;
+}
+
+const AsyncTaskFuture& AsyncTask::GetFuture() const {
+ return future_;
+}
+
+}; // namespace caffe2
--- /dev/null
+#ifndef CAFFE2_NET_ASYNC_TASK_H
+#define CAFFE2_NET_ASYNC_TASK_H
+
+#include "caffe2/core/net_async_base.h"
+#include "caffe2/core/net_async_task_future.h"
+#include "caffe2/core/operator.h"
+
+#include <vector>
+
+namespace caffe2 {
+
+// AsyncTask represents an asynchronous execution of a chain of ops.
+class AsyncTask {
+ public:
+ AsyncTask(const std::vector<OperatorBase*>& ops);
+
+ bool Run(const ExecutionOptions& options);
+
+ void Reset();
+
+ DeviceOption GetDeviceOption() const;
+
+ AsyncTaskFuture& GetFuture();
+ const AsyncTaskFuture& GetFuture() const;
+
+ private:
+ void handleChainError(
+ OperatorBase* op,
+ const char* err_msg,
+ bool save_exception = false);
+
+ std::vector<OperatorBase*> ops_;
+ DeviceOption device_option_;
+ AsyncTaskFuture future_;
+};
+
+} // namespace caffe2
+
+#endif // CAFFE2_NET_ASYNC_TASK_H
--- /dev/null
+#include "caffe2/core/net_async_task_future.h"
+
+#include "c10/util/Logging.h"
+#include "caffe2/core/common.h"
+
+namespace caffe2 {
+
+AsyncTaskFuture::AsyncTaskFuture() : completed_(false), failed_(false) {}
+
+AsyncTaskFuture::AsyncTaskFuture(const std::vector<AsyncTaskFuture*>& futures)
+ : completed_(false), failed_(false) {
+ if (futures.size() > 1) {
+ parent_counter_ = caffe2::make_unique<ParentCounter>(futures.size());
+ for (auto future : futures) {
+ future->SetCallback([this](const AsyncTaskFuture* f) {
+ if (f->IsFailed()) {
+ std::unique_lock<std::mutex> lock(parent_counter_->err_mutex);
+ if (parent_counter_->parent_failed) {
+ parent_counter_->err_msg += ", " + f->ErrorMessage();
+ } else {
+ parent_counter_->parent_failed = true;
+ parent_counter_->err_msg = f->ErrorMessage();
+ }
+ }
+ int count = --parent_counter_->parent_count;
+ if (count == 0) {
+ // thread safe to use parent_counter here
+ if (!parent_counter_->parent_failed) {
+ SetCompleted();
+ } else {
+ SetCompleted(parent_counter_->err_msg.c_str());
+ }
+ }
+ });
+ }
+ } else {
+ CAFFE_ENFORCE_EQ(futures.size(), 1);
+ auto future = futures.back();
+ future->SetCallback([this](const AsyncTaskFuture* f) {
+ if (!f->IsFailed()) {
+ SetCompleted();
+ } else {
+ SetCompleted(f->ErrorMessage().c_str());
+ }
+ });
+ }
+}
+
+bool AsyncTaskFuture::IsCompleted() const {
+ return completed_;
+}
+
+bool AsyncTaskFuture::IsFailed() const {
+ return failed_;
+}
+
+std::string AsyncTaskFuture::ErrorMessage() const {
+ return err_msg_;
+}
+
+void AsyncTaskFuture::Wait() const {
+ std::unique_lock<std::mutex> lock(mutex_);
+ while (!completed_) {
+ cv_completed_.wait(lock);
+ }
+}
+
+void AsyncTaskFuture::SetCallback(
+ std::function<void(const AsyncTaskFuture*)> callback) {
+ std::unique_lock<std::mutex> lock(mutex_);
+
+ callbacks_.push_back(callback);
+ if (completed_) {
+ callback(this);
+ }
+}
+
+void AsyncTaskFuture::SetCompleted(const char* err_msg) {
+ std::unique_lock<std::mutex> lock(mutex_);
+
+ CAFFE_ENFORCE(!completed_, "Calling SetCompleted on a completed future");
+ completed_ = true;
+
+ if (err_msg) {
+ failed_ = true;
+ err_msg_ = err_msg;
+ }
+
+ for (auto& callback : callbacks_) {
+ callback(this);
+ }
+
+ cv_completed_.notify_all();
+}
+
+// ResetState is called on a completed future,
+// does not reset callbacks to keep task graph structure
+void AsyncTaskFuture::ResetState() {
+ std::unique_lock<std::mutex> lock(mutex_);
+ if (parent_counter_) {
+ parent_counter_->Reset();
+ }
+ completed_ = false;
+ failed_ = false;
+ err_msg_ = "";
+}
+
+AsyncTaskFuture::~AsyncTaskFuture() {}
+
+} // namespace caffe2
--- /dev/null
+#ifndef CAFFE2_NET_ASYNC_TASK_FUTURE_H
+#define CAFFE2_NET_ASYNC_TASK_FUTURE_H
+
+#include <atomic>
+#include <condition_variable>
+#include <functional>
+#include <memory>
+#include <mutex>
+#include <string>
+#include <vector>
+
+namespace caffe2 {
+
+// Represents the state of AsyncTask execution, that can be queried with
+// IsCompleted/IsFailed. Callbacks are supported through SetCallback and
+// are called upon future's completion.
+
+class AsyncTaskFuture {
+ public:
+ AsyncTaskFuture();
+ // Creates a future completed when all given futures are completed
+ explicit AsyncTaskFuture(const std::vector<AsyncTaskFuture*>& futures);
+ ~AsyncTaskFuture();
+
+ AsyncTaskFuture(const AsyncTaskFuture&) = delete;
+
+ AsyncTaskFuture& operator=(const AsyncTaskFuture&) = delete;
+
+ bool IsCompleted() const;
+
+ bool IsFailed() const;
+
+ std::string ErrorMessage() const;
+
+ void Wait() const;
+
+ void SetCallback(std::function<void(const AsyncTaskFuture*)> callback);
+
+ void SetCompleted(const char* err_msg = nullptr);
+
+ void ResetState();
+
+ private:
+ mutable std::mutex mutex_;
+ mutable std::condition_variable cv_completed_;
+ std::atomic<bool> completed_;
+ std::atomic<bool> failed_;
+ std::string err_msg_;
+ std::vector<std::function<void(const AsyncTaskFuture*)>> callbacks_;
+
+ struct ParentCounter {
+ explicit ParentCounter(int init_parent_count)
+ : init_parent_count_(init_parent_count),
+ parent_count(init_parent_count),
+ parent_failed(false) {}
+
+ void Reset() {
+ std::unique_lock<std::mutex> lock(err_mutex);
+ parent_count = init_parent_count_;
+ parent_failed = false;
+ err_msg = "";
+ }
+
+ const int init_parent_count_;
+ std::atomic<int> parent_count;
+ std::mutex err_mutex;
+ std::atomic<bool> parent_failed;
+ std::string err_msg;
+ };
+
+ std::unique_ptr<ParentCounter> parent_counter_;
+};
+
+} // namespace caffe2
+
+#endif // CAFFE2_NET_ASYNC_TASK_FUTURE_H
--- /dev/null
+#include "caffe2/core/net_async_task_graph.h"
+
+#include "caffe2/core/net_parallel.h"
+
+namespace caffe2 {
+
+AsyncTaskGraph::AsyncTaskGraph(
+ ExecutorHelper* helper,
+ const ExecutionOptions& options)
+ : helper_(helper), options_(options), frozen_(false) {}
+
+bool AsyncTaskGraph::CreateNode(
+ int node_id,
+ const std::vector<OperatorBase*>& ops) {
+ CAFFE_ENFORCE(!frozen_);
+ if (!nodes_.count(node_id)) {
+ nodes_[node_id] = caffe2::make_unique<AsyncTask>(ops);
+ return true;
+ } else {
+ return false;
+ }
+}
+
+bool AsyncTaskGraph::AddDependency(
+ int child_node_id,
+ const std::vector<int>& parent_node_ids) {
+ CAFFE_ENFORCE(!frozen_);
+ CAFFE_ENFORCE(!parent_node_ids.empty());
+ CAFFE_ENFORCE(nodes_.count(child_node_id));
+ for (auto node_id : parent_node_ids) {
+ CAFFE_ENFORCE(nodes_.count(node_id));
+ }
+ CAFFE_ENFORCE(!parents_.count(child_node_id));
+
+ auto* child_task = nodes_[child_node_id].get();
+ auto child_device = child_task->GetDeviceOption();
+
+ std::vector<AsyncTaskFuture*> parent_futures;
+ for (auto node_id : parent_node_ids) {
+ parents_[child_node_id].insert(node_id);
+ children_[node_id].insert(child_node_id);
+ parent_futures.push_back(&nodes_[node_id]->GetFuture());
+ }
+
+ AsyncTaskFuture* parents_future = nullptr;
+ if (parent_futures.size() > 1) {
+ edge_futures_.push_back(
+ caffe2::make_unique<AsyncTaskFuture>(parent_futures));
+ parents_future = edge_futures_.back().get();
+ } else {
+ CAFFE_ENFORCE_EQ(parent_futures.size(), 1);
+ parents_future = parent_futures.back();
+ }
+
+ // TODO: CUDA polling
+ parents_future->SetCallback(
+ [this, child_task, child_device](const AsyncTaskFuture* f) {
+ CAFFE_ENFORCE(f->IsCompleted());
+ if (!f->IsFailed()) {
+ // if we're in the correct thread pool and DFS scheduling is enabled,
+ // immediately call task inline, otherwise send task into thread pool
+ auto* pool = helper_->GetPool(child_device);
+ if (pool->inThreadPool() && options_.use_dfs_scheduling_) {
+ child_task->Run(options_);
+ } else {
+ pool->run([this, child_task]() { child_task->Run(options_); });
+ }
+ } else {
+ // skip task execution and propagate error further
+ child_task->GetFuture().SetCompleted(f->ErrorMessage().c_str());
+ }
+ });
+
+ return true;
+}
+
+void AsyncTaskGraph::FreezeGraph() {
+ if (frozen_) {
+ return;
+ }
+
+ CAFFE_ENFORCE(!run_future_);
+ CAFFE_ENFORCE(root_tasks_.empty());
+
+ std::vector<AsyncTaskFuture*> final_futures;
+ for (auto& kv : nodes_) {
+ auto task_id = kv.first;
+ auto* task = kv.second.get();
+
+ if (parents_[task_id].empty()) {
+ root_tasks_.push_back(task);
+ }
+
+ if (children_[task_id].empty()) {
+ auto& future = task->GetFuture();
+ final_futures.push_back(&future);
+ }
+ }
+
+ CAFFE_ENFORCE(!root_tasks_.empty());
+ CAFFE_ENFORCE(!final_futures.empty());
+
+ run_future_ = caffe2::make_unique<AsyncTaskFuture>(final_futures);
+
+ frozen_ = true;
+}
+
+AsyncTaskFuture* AsyncTaskGraph::ExecuteGraph() {
+ CAFFE_ENFORCE(frozen_);
+ CAFFE_ENFORCE(run_future_ && !run_future_->IsCompleted());
+
+ // TODO: run root tasks inline in inference mode
+ for (auto* task : root_tasks_) {
+ auto task_device = task->GetDeviceOption();
+ helper_->GetPool(task_device)->run([this, task]() { task->Run(options_); });
+ }
+
+ return run_future_.get();
+}
+
+AsyncTaskFuture* AsyncTaskGraph::GetFuture() {
+ CAFFE_ENFORCE(frozen_);
+ return run_future_.get();
+}
+
+void AsyncTaskGraph::Reset() {
+ CAFFE_ENFORCE(frozen_);
+ for (auto& kv : nodes_) {
+ kv.second->Reset();
+ }
+ for (auto& future : edge_futures_) {
+ future->ResetState();
+ }
+ if (run_future_) {
+ run_future_->ResetState();
+ }
+}
+
+}; // namespace caffe2
--- /dev/null
+#ifndef CAFFE2_NET_ASYNC_TASK_GRAPH_H
+#define CAFFE2_NET_ASYNC_TASK_GRAPH_H
+
+#include "caffe2/core/net_async_base.h"
+#include "caffe2/core/net_async_task.h"
+#include "caffe2/core/net_async_task_future.h"
+#include "caffe2/core/operator.h"
+
+namespace caffe2 {
+
+// AsyncTaskGraph represents an execution of a net, it owns the tasks and
+// associated futures, sets up future callbacks and propagates errors.
+// Usage steps:
+// - Adding graph nodes and edges through CreateNode/AddDependency;
+// - Freezing the graph (FreezeGraph), after the freezing a future
+// can be obtained using GetFuture;
+// - Execution of the graph is scheduled through ExecuteGraph, after each
+// execution Reset must be called to prepare the graph for the next run
+
+class AsyncTaskGraphBase {
+ public:
+ virtual bool CreateNode(
+ int node_id,
+ const std::vector<OperatorBase*>& ops) = 0;
+
+ virtual bool AddDependency(
+ int child_node_id,
+ const std::vector<int>& parent_node_ids) = 0;
+
+ virtual void FreezeGraph() = 0;
+
+ virtual AsyncTaskFuture* ExecuteGraph() = 0;
+
+ virtual AsyncTaskFuture* GetFuture() = 0;
+
+ virtual void Reset() = 0;
+
+ virtual ~AsyncTaskGraphBase() noexcept {}
+};
+
+class AsyncTaskGraph : public AsyncTaskGraphBase {
+ public:
+ AsyncTaskGraph(ExecutorHelper* helper, const ExecutionOptions& options);
+
+ bool CreateNode(int node_id, const std::vector<OperatorBase*>& ops) override;
+
+ bool AddDependency(int child_node_id, const std::vector<int>& parent_node_ids)
+ override;
+
+ void FreezeGraph() override;
+
+ AsyncTaskFuture* ExecuteGraph() override;
+
+ AsyncTaskFuture* GetFuture() override;
+
+ void Reset() override;
+
+ private:
+ // used to, e.g., get access to executor's thread pools
+ // TODO: pass tracer and counters through ExecutorHelper
+ ExecutorHelper* helper_;
+ ExecutionOptions options_;
+
+ bool frozen_;
+
+ std::unordered_map<int, std::unique_ptr<AsyncTask>> nodes_;
+ std::unordered_map<int, std::unordered_set<int>> parents_;
+ std::unordered_map<int, std::unordered_set<int>> children_;
+ std::vector<std::unique_ptr<AsyncTaskFuture>> edge_futures_;
+
+ std::vector<AsyncTask*> root_tasks_;
+
+ std::unique_ptr<AsyncTaskFuture> run_future_;
+};
+
+} // namespace caffe2
+
+#endif // CAFFE2_NET_ASYNC_TASK_GRAPH_H
--- /dev/null
+#include "caffe2/core/net_parallel.h"
+
+#include "caffe2/core/operator.h"
+
+#include <sstream>
+
+C10_DEFINE_string(
+ caffe2_task_graph_engine,
+ "futures",
+ "Task graph engine type used by net executor");
+
+namespace caffe2 {
+
+ParallelNet::ParallelNet(
+ const std::shared_ptr<const NetDef>& net_def,
+ Workspace* ws)
+ : NetBase(net_def, ws), options_(net_def), run_future_(nullptr) {
+ num_workers_ = net_def->num_workers();
+ CAFFE_ENFORCE_GT(
+ num_workers_, 0, "Expected positive number of worker threads");
+
+ helper_ = caffe2::make_unique<ParallelNetExecutorHelper>(this);
+ task_graph_ = TaskGraphRegistry()->Create(
+ FLAGS_caffe2_task_graph_engine, helper_.get(), options_);
+
+ // initialize operators
+ operator_nodes_ = dag_utils::prepareOperatorNodes(net_def, ws);
+ operators_.reserve(operator_nodes_.size());
+ for (const auto& node : operator_nodes_) {
+ auto op = node.operator_.get();
+ op->SetExecutorHelper(helper_.get());
+ operators_.push_back(op);
+ }
+
+ // compute chains
+ // TODO: inference mode for chaining
+ auto execution_chains = dag_utils::computeChains(operator_nodes_);
+ std::vector<std::vector<int>> chains;
+ chains.reserve(execution_chains.size());
+ for (const auto& kv : execution_chains) {
+ chains.push_back(kv.second);
+ }
+ auto chain_nodes = dag_utils::prepareChainGraphNodes(operator_nodes_, chains);
+ CAFFE_ENFORCE_EQ(chains.size(), chain_nodes.size());
+
+ // disable unused events
+ for (const auto& chain : chains) {
+ for (const auto& op_id : chain) {
+ if (op_id == chain.back() || op_id == chain.front()) {
+ continue;
+ }
+ auto op = operators_[op_id];
+ if (IsCPUDeviceType(op->device_option().device_type()) &&
+ op->HasAsyncPart()) {
+ continue;
+ }
+ op->DisableEvent();
+ }
+ }
+
+ // initialize task graph
+ for (auto chain_id = 0; chain_id < chains.size(); ++chain_id) {
+ std::vector<OperatorBase*> ops;
+ ops.reserve(chains[chain_id].size());
+ for (auto op_id : chains[chain_id]) {
+ ops.push_back(operators_[op_id]);
+ }
+ CAFFE_ENFORCE(task_graph_->CreateNode(chain_id, ops));
+ }
+ for (auto chain_id = 0; chain_id < chain_nodes.size(); ++chain_id) {
+ if (!chain_nodes[chain_id].parents_.empty()) {
+ CAFFE_ENFORCE(
+ task_graph_->AddDependency(chain_id, chain_nodes[chain_id].parents_));
+ }
+ }
+
+ // Freeze graph and initialize graph execution future
+ task_graph_->FreezeGraph();
+ run_future_ = task_graph_->GetFuture();
+ run_future_->SetCallback([this](const AsyncTaskFuture* /* unused */) {
+ StopAllObservers();
+ finishRun();
+ });
+
+ LOG(INFO) << "Initialized parallel net: '" << Name()
+ << "', #ops: " << net_def->op_size()
+ << ", #chains: " << chains.size() << ", #workers: " << num_workers_
+ << ", dfs scheduling: " << options_.use_dfs_scheduling_
+ << ", task graph engine: " << FLAGS_caffe2_task_graph_engine;
+}
+
+bool ParallelNet::RunAsync() {
+ reset();
+ StartAllObservers();
+
+ try {
+ task_graph_->ExecuteGraph();
+ } catch (const std::exception&) {
+ StopAllObservers();
+ return false;
+ }
+
+ return true;
+}
+
+void ParallelNet::Wait() {
+ CAFFE_ENFORCE(run_future_);
+ run_future_->Wait();
+}
+
+void ParallelNet::reset() {
+ task_graph_->Reset();
+}
+
+bool ParallelNet::handleRunError() {
+ CAFFE_ENFORCE(run_future_ && run_future_->IsCompleted());
+ // TODO: throw saved exceptions
+ if (run_future_->IsFailed()) {
+ LOG(ERROR) << "Failed parallel run (" << Name()
+ << "): " << run_future_->ErrorMessage();
+ }
+ return !run_future_->IsFailed();
+}
+
+TaskThreadPoolBase* ParallelNet::poolGetter(
+ PoolsMap& pools,
+ int device_type,
+ int device_id,
+ int pool_size) {
+ std::unique_lock<std::mutex> pools_lock(pools_mutex_);
+ auto pool = pools[device_id][pool_size];
+ if (!pool) {
+ pool = ThreadPoolRegistry()->Create(
+ DeviceTypeName(device_type),
+ device_id,
+ pool_size,
+ options_.use_per_net_pools_);
+ pools[device_id][pool_size] = pool;
+ }
+ return pool.get();
+}
+
+TaskThreadPoolBase* ParallelNet::Pool(const DeviceOption& device_option) {
+ if (options_.use_single_pool_) {
+ return poolGetter(cpu_pools_, PROTO_CPU, -1, num_workers_);
+ }
+ const auto device_type = device_option.device_type();
+ if (IsCPUDeviceType(device_type)) {
+ auto numa_node_id = -1;
+ if (device_option.has_numa_node_id()) {
+ numa_node_id = device_option.numa_node_id();
+ CAFFE_ENFORCE_GE(numa_node_id, 0, "Invalid NUMA node id: ", numa_node_id);
+ }
+ CAFFE_ENFORCE_LT(
+ numa_node_id,
+ FLAGS_caffe2_net_async_max_numa_nodes,
+ "Invalid NUMA node id: ",
+ numa_node_id);
+ return poolGetter(cpu_pools_, device_type, numa_node_id, num_workers_);
+ } else if (IsGPUDeviceType(device_type)) {
+ auto gpu_id = device_option.device_id();
+ CAFFE_ENFORCE(
+ gpu_id >= 0 && gpu_id < FLAGS_caffe2_net_async_max_gpus,
+ "Invalid GPU id: " + caffe2::to_string(gpu_id));
+ return poolGetter(gpu_pools_, device_type, gpu_id, num_workers_);
+ } else {
+ CAFFE_THROW("Unsupported device type " + caffe2::to_string(device_type));
+ }
+}
+
+bool ParallelNet::SupportsAsync() {
+ return true;
+}
+
+void ParallelNet::finishRun() {}
+
+std::vector<OperatorBase*> ParallelNet::GetOperators() const {
+ return operators_;
+}
+
+std::shared_ptr<AsyncTaskGraphBase> GetAsyncTaskGraph(
+ ExecutorHelper* helper,
+ const ExecutionOptions& options) {
+ return std::make_shared<AsyncTaskGraph>(helper, options);
+}
+
+C10_DEFINE_SHARED_REGISTRY(
+ TaskGraphRegistry,
+ AsyncTaskGraphBase,
+ ExecutorHelper*,
+ const ExecutionOptions&);
+
+C10_REGISTER_CREATOR(TaskGraphRegistry, futures, GetAsyncTaskGraph);
+
+REGISTER_NET(parallel, ParallelNet);
+
+} // namespace caffe2
--- /dev/null
+#ifndef CAFFE2_CORE_NET_PARALLEL_H
+#define CAFFE2_CORE_NET_PARALLEL_H
+
+#include "caffe2/core/net_async_base.h"
+#include "caffe2/core/net_async_task_graph.h"
+
+C10_DECLARE_string(caffe2_task_graph_engine);
+
+namespace caffe2 {
+
+class ParallelNetExecutorHelper;
+
+class CAFFE2_API ParallelNet : public NetBase {
+ public:
+ ParallelNet(const std::shared_ptr<const NetDef>& net_def, Workspace* ws);
+
+ bool RunAsync() override;
+ void Wait() override;
+
+ bool SupportsAsync() override;
+ std::vector<OperatorBase*> GetOperators() const override;
+
+ TaskThreadPoolBase* Pool(const DeviceOption& device_option);
+
+ protected:
+ bool handleRunError() override;
+ virtual void finishRun();
+ virtual void reset();
+
+ ExecutionOptions options_;
+ int num_workers_;
+
+ std::unique_ptr<ParallelNetExecutorHelper> helper_;
+ std::shared_ptr<AsyncTaskGraphBase> task_graph_;
+ AsyncTaskFuture* run_future_;
+
+ std::vector<dag_utils::OperatorNode> operator_nodes_;
+ std::vector<OperatorBase*> operators_;
+
+ std::mutex pools_mutex_;
+ typedef std::unordered_map<
+ int,
+ std::unordered_map<int, std::shared_ptr<TaskThreadPoolBase>>>
+ PoolsMap;
+ PoolsMap cpu_pools_;
+ PoolsMap gpu_pools_;
+ TaskThreadPoolBase*
+ poolGetter(PoolsMap& pools, int device_type, int device_id, int pool_size);
+
+ friend class ParallelNetExecutorHelper;
+ C10_DISABLE_COPY_AND_ASSIGN(ParallelNet);
+};
+
+C10_DECLARE_SHARED_REGISTRY(
+ TaskGraphRegistry,
+ AsyncTaskGraphBase,
+ ExecutorHelper*,
+ const ExecutionOptions&);
+
+std::shared_ptr<AsyncTaskGraphBase> GetAsyncTaskGraph(
+ ExecutorHelper* helper,
+ const ExecutionOptions& options);
+
+class ParallelNetExecutorHelper : public ExecutorHelper {
+ public:
+ explicit ParallelNetExecutorHelper(ParallelNet* net) : net_(net) {}
+ TaskThreadPoolBase* GetPool(const DeviceOption& option) const override {
+ return net_->Pool(option);
+ }
+
+ private:
+ ParallelNet* net_;
+};
+
+} // namespace caffe2
+
+#endif // CAFFE2_CORE_NET_PARALLEL_H
import unittest
-EXECUTORS = ["async_scheduling", "dag", "async_dag"]
+EXECUTORS = ["parallel", "async_scheduling"]
ITERATIONS = 1