+++ /dev/null
-#include "caffe2/core/net_async_task.h"
-
-#include "caffe2/core/net_async_task_graph.h"
-
-namespace caffe2 {
-
-AsyncTask::AsyncTask(const std::vector<OperatorBase*>& ops) : ops_(ops) {
- CAFFE_ENFORCE(!ops_.empty());
- device_option_ = ops_.front()->device_option();
- for (auto& op : ops_) {
- CAFFE_ENFORCE(IsSameDevice(device_option_, op->device_option()));
- }
- Reset();
-}
-
-void AsyncTask::handleChainError(
- OperatorBase* op,
- const char* err_str,
- bool save_exception) {
- std::string err_msg = err_str;
- if (op) {
- err_msg += ", op " + (op->has_debug_def() ? op->type() : " unknown");
- }
- LOG(ERROR) << err_msg;
-
- // save error message and exception in chain's Event
- auto last_op = ops_.back();
- if (save_exception) {
- last_op->event().SetFinishedWithException(err_msg.c_str());
- } else {
- last_op->event().SetFinished(err_msg.c_str());
- }
-
- // set future as completed with an error
- // TODO: exceptions in future
- future_.SetCompleted(err_msg.c_str());
-}
-
-bool AsyncTask::Run(const ExecutionOptions& options) {
- // TODO: insert CUDA's async stream waits; tracing and counters
- OperatorBase* op = nullptr;
- try {
- for (auto op_idx = 0; op_idx < ops_.size(); ++op_idx) {
- op = ops_[op_idx];
- int stream_id = 0; // TODO: thread local stream id
- if (!op->RunAsync(stream_id)) {
- handleChainError(op, "Failed to execute an op");
- return false;
- }
- }
-
- if (options.finish_chain_) {
- op = ops_.back();
- op->Finish();
- }
-
- // set the future as successfully completed or, in case of async CPU,
- // use op's callback
- if (IsCPUDeviceType(device_option_.device_type()) &&
- ops_.back()->HasAsyncPart()) {
- auto& event = ops_.back()->event();
- event.SetCallback([this, &event]() {
- CAFFE_ENFORCE(event.IsFinished());
- if (event.Query() == EventStatus::EVENT_SUCCESS) {
- future_.SetCompleted();
- } else {
- // TODO: support for exceptions
- future_.SetCompleted(event.ErrorMessage().c_str());
- }
- });
- } else {
- future_.SetCompleted();
- }
- } catch (const std::exception& e) {
- handleChainError(op, e.what(), /* save_exception */ true);
- return false;
- } catch (...) {
- handleChainError(
- op,
- "Failed to execute task: unknown error",
- /* save_exception */ true);
- return false;
- }
-
- return true;
-}
-
-void AsyncTask::Reset() {
- for (auto& op : ops_) {
- op->ResetEvent();
- }
- future_.ResetState();
-}
-
-DeviceOption AsyncTask::GetDeviceOption() const {
- return device_option_;
-}
-
-AsyncTaskFuture& AsyncTask::GetFuture() {
- return future_;
-}
-
-const AsyncTaskFuture& AsyncTask::GetFuture() const {
- return future_;
-}
-
-}; // namespace caffe2
+++ /dev/null
-#ifndef CAFFE2_NET_ASYNC_TASK_H
-#define CAFFE2_NET_ASYNC_TASK_H
-
-#include "caffe2/core/net_async_base.h"
-#include "caffe2/core/net_async_task_future.h"
-#include "caffe2/core/operator.h"
-
-#include <vector>
-
-namespace caffe2 {
-
-// AsyncTask represents an asynchronous execution of a chain of ops.
-class AsyncTask {
- public:
- AsyncTask(const std::vector<OperatorBase*>& ops);
-
- bool Run(const ExecutionOptions& options);
-
- void Reset();
-
- DeviceOption GetDeviceOption() const;
-
- AsyncTaskFuture& GetFuture();
- const AsyncTaskFuture& GetFuture() const;
-
- private:
- void handleChainError(
- OperatorBase* op,
- const char* err_msg,
- bool save_exception = false);
-
- std::vector<OperatorBase*> ops_;
- DeviceOption device_option_;
- AsyncTaskFuture future_;
-};
-
-} // namespace caffe2
-
-#endif // CAFFE2_NET_ASYNC_TASK_H
+++ /dev/null
-#include "caffe2/core/net_async_task_future.h"
-
-#include "c10/util/Logging.h"
-#include "caffe2/core/common.h"
-
-namespace caffe2 {
-
-AsyncTaskFuture::AsyncTaskFuture() : completed_(false), failed_(false) {}
-
-AsyncTaskFuture::AsyncTaskFuture(const std::vector<AsyncTaskFuture*>& futures)
- : completed_(false), failed_(false) {
- if (futures.size() > 1) {
- parent_counter_ = caffe2::make_unique<ParentCounter>(futures.size());
- for (auto future : futures) {
- future->SetCallback([this](const AsyncTaskFuture* f) {
- if (f->IsFailed()) {
- std::unique_lock<std::mutex> lock(parent_counter_->err_mutex);
- if (parent_counter_->parent_failed) {
- parent_counter_->err_msg += ", " + f->ErrorMessage();
- } else {
- parent_counter_->parent_failed = true;
- parent_counter_->err_msg = f->ErrorMessage();
- }
- }
- int count = --parent_counter_->parent_count;
- if (count == 0) {
- // thread safe to use parent_counter here
- if (!parent_counter_->parent_failed) {
- SetCompleted();
- } else {
- SetCompleted(parent_counter_->err_msg.c_str());
- }
- }
- });
- }
- } else {
- CAFFE_ENFORCE_EQ(futures.size(), 1);
- auto future = futures.back();
- future->SetCallback([this](const AsyncTaskFuture* f) {
- if (!f->IsFailed()) {
- SetCompleted();
- } else {
- SetCompleted(f->ErrorMessage().c_str());
- }
- });
- }
-}
-
-bool AsyncTaskFuture::IsCompleted() const {
- return completed_;
-}
-
-bool AsyncTaskFuture::IsFailed() const {
- return failed_;
-}
-
-std::string AsyncTaskFuture::ErrorMessage() const {
- return err_msg_;
-}
-
-void AsyncTaskFuture::Wait() const {
- std::unique_lock<std::mutex> lock(mutex_);
- while (!completed_) {
- cv_completed_.wait(lock);
- }
-}
-
-void AsyncTaskFuture::SetCallback(
- std::function<void(const AsyncTaskFuture*)> callback) {
- std::unique_lock<std::mutex> lock(mutex_);
-
- callbacks_.push_back(callback);
- if (completed_) {
- callback(this);
- }
-}
-
-void AsyncTaskFuture::SetCompleted(const char* err_msg) {
- std::unique_lock<std::mutex> lock(mutex_);
-
- CAFFE_ENFORCE(!completed_, "Calling SetCompleted on a completed future");
- completed_ = true;
-
- if (err_msg) {
- failed_ = true;
- err_msg_ = err_msg;
- }
-
- for (auto& callback : callbacks_) {
- callback(this);
- }
-
- cv_completed_.notify_all();
-}
-
-// ResetState is called on a completed future,
-// does not reset callbacks to keep task graph structure
-void AsyncTaskFuture::ResetState() {
- std::unique_lock<std::mutex> lock(mutex_);
- if (parent_counter_) {
- parent_counter_->Reset();
- }
- completed_ = false;
- failed_ = false;
- err_msg_ = "";
-}
-
-AsyncTaskFuture::~AsyncTaskFuture() {}
-
-} // namespace caffe2
+++ /dev/null
-#ifndef CAFFE2_NET_ASYNC_TASK_FUTURE_H
-#define CAFFE2_NET_ASYNC_TASK_FUTURE_H
-
-#include <atomic>
-#include <condition_variable>
-#include <functional>
-#include <memory>
-#include <vector>
-
-namespace caffe2 {
-
-// Represents the state of AsyncTask execution, that can be queried with
-// IsCompleted/IsFailed. Callbacks are supported through SetCallback and
-// are called upon future's completion.
-
-class AsyncTaskFuture {
- public:
- AsyncTaskFuture();
- // Creates a future completed when all given futures are completed
- explicit AsyncTaskFuture(const std::vector<AsyncTaskFuture*>& futures);
- ~AsyncTaskFuture();
-
- AsyncTaskFuture(const AsyncTaskFuture&) = delete;
-
- AsyncTaskFuture& operator=(const AsyncTaskFuture&) = delete;
-
- bool IsCompleted() const;
-
- bool IsFailed() const;
-
- std::string ErrorMessage() const;
-
- void Wait() const;
-
- void SetCallback(std::function<void(const AsyncTaskFuture*)> callback);
-
- void SetCompleted(const char* err_msg = nullptr);
-
- void ResetState();
-
- private:
- mutable std::mutex mutex_;
- mutable std::condition_variable cv_completed_;
- std::atomic<bool> completed_;
- std::atomic<bool> failed_;
- std::string err_msg_;
- std::vector<std::function<void(const AsyncTaskFuture*)>> callbacks_;
-
- struct ParentCounter {
- explicit ParentCounter(int init_parent_count)
- : init_parent_count_(init_parent_count),
- parent_count(init_parent_count),
- parent_failed(false) {}
-
- void Reset() {
- std::unique_lock<std::mutex> lock(err_mutex);
- parent_count = init_parent_count_;
- parent_failed = false;
- err_msg = "";
- }
-
- const int init_parent_count_;
- std::atomic<int> parent_count;
- std::mutex err_mutex;
- std::atomic<bool> parent_failed;
- std::string err_msg;
- };
-
- std::unique_ptr<ParentCounter> parent_counter_;
-};
-
-} // namespace caffe2
-
-#endif // CAFFE2_NET_ASYNC_TASK_FUTURE_H
+++ /dev/null
-#include "caffe2/core/net_async_task_graph.h"
-
-#include "caffe2/core/net_parallel.h"
-
-namespace caffe2 {
-
-AsyncTaskGraph::AsyncTaskGraph(
- ExecutorHelper* helper,
- const ExecutionOptions& options)
- : helper_(helper), options_(options), frozen_(false) {}
-
-bool AsyncTaskGraph::CreateNode(
- int node_id,
- const std::vector<OperatorBase*>& ops) {
- CAFFE_ENFORCE(!frozen_);
- if (!nodes_.count(node_id)) {
- nodes_[node_id] = caffe2::make_unique<AsyncTask>(ops);
- return true;
- } else {
- return false;
- }
-}
-
-bool AsyncTaskGraph::AddDependency(
- int child_node_id,
- const std::vector<int>& parent_node_ids) {
- CAFFE_ENFORCE(!frozen_);
- CAFFE_ENFORCE(!parent_node_ids.empty());
- CAFFE_ENFORCE(nodes_.count(child_node_id));
- for (auto node_id : parent_node_ids) {
- CAFFE_ENFORCE(nodes_.count(node_id));
- }
- CAFFE_ENFORCE(!parents_.count(child_node_id));
-
- auto* child_task = nodes_[child_node_id].get();
- auto child_device = child_task->GetDeviceOption();
-
- std::vector<AsyncTaskFuture*> parent_futures;
- for (auto node_id : parent_node_ids) {
- parents_[child_node_id].insert(node_id);
- children_[node_id].insert(child_node_id);
- parent_futures.push_back(&nodes_[node_id]->GetFuture());
- }
-
- AsyncTaskFuture* parents_future = nullptr;
- if (parent_futures.size() > 1) {
- edge_futures_.push_back(
- caffe2::make_unique<AsyncTaskFuture>(parent_futures));
- parents_future = edge_futures_.back().get();
- } else {
- CAFFE_ENFORCE_EQ(parent_futures.size(), 1);
- parents_future = parent_futures.back();
- }
-
- // TODO: CUDA polling
- parents_future->SetCallback(
- [this, child_task, child_device](const AsyncTaskFuture* f) {
- CAFFE_ENFORCE(f->IsCompleted());
- if (!f->IsFailed()) {
- // if we're in the correct thread pool and DFS scheduling is enabled,
- // immediately call task inline, otherwise send task into thread pool
- auto* pool = helper_->GetPool(child_device);
- if (pool->inThreadPool() && options_.use_dfs_scheduling_) {
- child_task->Run(options_);
- } else {
- pool->run([this, child_task]() { child_task->Run(options_); });
- }
- } else {
- // skip task execution and propagate error further
- child_task->GetFuture().SetCompleted(f->ErrorMessage().c_str());
- }
- });
-
- return true;
-}
-
-void AsyncTaskGraph::FreezeGraph() {
- if (frozen_) {
- return;
- }
-
- CAFFE_ENFORCE(!run_future_);
- CAFFE_ENFORCE(root_tasks_.empty());
-
- std::vector<AsyncTaskFuture*> final_futures;
- for (auto& kv : nodes_) {
- auto task_id = kv.first;
- auto* task = kv.second.get();
-
- if (parents_[task_id].empty()) {
- root_tasks_.push_back(task);
- }
-
- if (children_[task_id].empty()) {
- auto& future = task->GetFuture();
- final_futures.push_back(&future);
- }
- }
-
- CAFFE_ENFORCE(!root_tasks_.empty());
- CAFFE_ENFORCE(!final_futures.empty());
-
- run_future_ = caffe2::make_unique<AsyncTaskFuture>(final_futures);
-
- frozen_ = true;
-}
-
-AsyncTaskFuture* AsyncTaskGraph::ExecuteGraph() {
- CAFFE_ENFORCE(frozen_);
- CAFFE_ENFORCE(run_future_ && !run_future_->IsCompleted());
-
- // TODO: run root tasks inline in inference mode
- for (auto* task : root_tasks_) {
- auto task_device = task->GetDeviceOption();
- helper_->GetPool(task_device)->run([this, task]() { task->Run(options_); });
- }
-
- return run_future_.get();
-}
-
-AsyncTaskFuture* AsyncTaskGraph::GetFuture() {
- CAFFE_ENFORCE(frozen_);
- return run_future_.get();
-}
-
-void AsyncTaskGraph::Reset() {
- CAFFE_ENFORCE(frozen_);
- for (auto& kv : nodes_) {
- kv.second->Reset();
- }
- for (auto& future : edge_futures_) {
- future->ResetState();
- }
- if (run_future_) {
- run_future_->ResetState();
- }
-}
-
-}; // namespace caffe2
+++ /dev/null
-#ifndef CAFFE2_NET_ASYNC_TASK_GRAPH_H
-#define CAFFE2_NET_ASYNC_TASK_GRAPH_H
-
-#include "caffe2/core/net_async_base.h"
-#include "caffe2/core/net_async_task.h"
-#include "caffe2/core/net_async_task_future.h"
-#include "caffe2/core/operator.h"
-
-namespace caffe2 {
-
-// AsyncTaskGraph represents an execution of a net, it owns the tasks and
-// associated futures, sets up future callbacks and propagates errors.
-// Usage steps:
-// - Adding graph nodes and edges through CreateNode/AddDependency;
-// - Freezing the graph (FreezeGraph), after the freezing a future
-// can be obtained using GetFuture;
-// - Execution of the graph is scheduled through ExecuteGraph, after each
-// execution Reset must be called to prepare the graph for the next run
-
-class AsyncTaskGraphBase {
- public:
- virtual bool CreateNode(
- int node_id,
- const std::vector<OperatorBase*>& ops) = 0;
-
- virtual bool AddDependency(
- int child_node_id,
- const std::vector<int>& parent_node_ids) = 0;
-
- virtual void FreezeGraph() = 0;
-
- virtual AsyncTaskFuture* ExecuteGraph() = 0;
-
- virtual AsyncTaskFuture* GetFuture() = 0;
-
- virtual void Reset() = 0;
-
- virtual ~AsyncTaskGraphBase() noexcept {}
-};
-
-class AsyncTaskGraph : public AsyncTaskGraphBase {
- public:
- AsyncTaskGraph(ExecutorHelper* helper, const ExecutionOptions& options);
-
- bool CreateNode(int node_id, const std::vector<OperatorBase*>& ops) override;
-
- bool AddDependency(int child_node_id, const std::vector<int>& parent_node_ids)
- override;
-
- void FreezeGraph() override;
-
- AsyncTaskFuture* ExecuteGraph() override;
-
- AsyncTaskFuture* GetFuture() override;
-
- void Reset() override;
-
- private:
- // used to, e.g., get access to executor's thread pools
- // TODO: pass tracer and counters through ExecutorHelper
- ExecutorHelper* helper_;
- ExecutionOptions options_;
-
- bool frozen_;
-
- std::unordered_map<int, std::unique_ptr<AsyncTask>> nodes_;
- std::unordered_map<int, std::unordered_set<int>> parents_;
- std::unordered_map<int, std::unordered_set<int>> children_;
- std::vector<std::unique_ptr<AsyncTaskFuture>> edge_futures_;
-
- std::vector<AsyncTask*> root_tasks_;
-
- std::unique_ptr<AsyncTaskFuture> run_future_;
-};
-
-} // namespace caffe2
-
-#endif // CAFFE2_NET_ASYNC_TASK_GRAPH_H
+++ /dev/null
-#include "caffe2/core/net_parallel.h"
-
-#include "caffe2/core/operator.h"
-
-#include <sstream>
-
-C10_DEFINE_string(
- caffe2_task_graph_engine,
- "futures",
- "Task graph engine type used by net executor");
-
-namespace caffe2 {
-
-ParallelNet::ParallelNet(
- const std::shared_ptr<const NetDef>& net_def,
- Workspace* ws)
- : NetBase(net_def, ws), options_(net_def), run_future_(nullptr) {
- num_workers_ = net_def->num_workers();
- CAFFE_ENFORCE_GT(
- num_workers_, 0, "Expected positive number of worker threads");
-
- helper_ = caffe2::make_unique<ParallelNetExecutorHelper>(this);
- task_graph_ = TaskGraphRegistry()->Create(
- FLAGS_caffe2_task_graph_engine, helper_.get(), options_);
-
- // initialize operators
- operator_nodes_ = dag_utils::prepareOperatorNodes(net_def, ws);
- operators_.reserve(operator_nodes_.size());
- for (const auto& node : operator_nodes_) {
- auto op = node.operator_.get();
- op->SetExecutorHelper(helper_.get());
- operators_.push_back(op);
- }
-
- // compute chains
- // TODO: inference mode for chaining
- auto execution_chains = dag_utils::computeChains(operator_nodes_);
- std::vector<std::vector<int>> chains;
- chains.reserve(execution_chains.size());
- for (const auto& kv : execution_chains) {
- chains.push_back(kv.second);
- }
- auto chain_nodes = dag_utils::prepareChainGraphNodes(operator_nodes_, chains);
- CAFFE_ENFORCE_EQ(chains.size(), chain_nodes.size());
-
- // disable unused events
- for (const auto& chain : chains) {
- for (const auto& op_id : chain) {
- if (op_id == chain.back() || op_id == chain.front()) {
- continue;
- }
- auto op = operators_[op_id];
- if (IsCPUDeviceType(op->device_option().device_type()) &&
- op->HasAsyncPart()) {
- continue;
- }
- op->DisableEvent();
- }
- }
-
- // initialize task graph
- for (auto chain_id = 0; chain_id < chains.size(); ++chain_id) {
- std::vector<OperatorBase*> ops;
- ops.reserve(chains[chain_id].size());
- for (auto op_id : chains[chain_id]) {
- ops.push_back(operators_[op_id]);
- }
- CAFFE_ENFORCE(task_graph_->CreateNode(chain_id, ops));
- }
- for (auto chain_id = 0; chain_id < chain_nodes.size(); ++chain_id) {
- if (!chain_nodes[chain_id].parents_.empty()) {
- CAFFE_ENFORCE(
- task_graph_->AddDependency(chain_id, chain_nodes[chain_id].parents_));
- }
- }
-
- // Freeze graph and initialize graph execution future
- task_graph_->FreezeGraph();
- run_future_ = task_graph_->GetFuture();
- run_future_->SetCallback([this](const AsyncTaskFuture* /* unused */) {
- StopAllObservers();
- finishRun();
- });
-
- LOG(INFO) << "Initialized parallel net: '" << Name()
- << "', #ops: " << net_def->op_size()
- << ", #chains: " << chains.size() << ", #workers: " << num_workers_
- << ", dfs scheduling: " << options_.use_dfs_scheduling_
- << ", task graph engine: " << FLAGS_caffe2_task_graph_engine;
-}
-
-bool ParallelNet::RunAsync() {
- reset();
- StartAllObservers();
-
- try {
- task_graph_->ExecuteGraph();
- } catch (const std::exception&) {
- StopAllObservers();
- return false;
- }
-
- return true;
-}
-
-void ParallelNet::Wait() {
- CAFFE_ENFORCE(run_future_);
- run_future_->Wait();
-}
-
-void ParallelNet::reset() {
- task_graph_->Reset();
-}
-
-bool ParallelNet::handleRunError() {
- CAFFE_ENFORCE(run_future_ && run_future_->IsCompleted());
- // TODO: throw saved exceptions
- if (run_future_->IsFailed()) {
- LOG(ERROR) << "Failed parallel run (" << Name()
- << "): " << run_future_->ErrorMessage();
- }
- return !run_future_->IsFailed();
-}
-
-TaskThreadPoolBase* ParallelNet::poolGetter(
- PoolsMap& pools,
- int device_type,
- int device_id,
- int pool_size) {
- std::unique_lock<std::mutex> pools_lock(pools_mutex_);
- auto pool = pools[device_id][pool_size];
- if (!pool) {
- pool = ThreadPoolRegistry()->Create(
- DeviceTypeName(device_type),
- device_id,
- pool_size,
- options_.use_per_net_pools_);
- pools[device_id][pool_size] = pool;
- }
- return pool.get();
-}
-
-TaskThreadPoolBase* ParallelNet::Pool(const DeviceOption& device_option) {
- if (options_.use_single_pool_) {
- return poolGetter(cpu_pools_, PROTO_CPU, -1, num_workers_);
- }
- const auto device_type = device_option.device_type();
- if (IsCPUDeviceType(device_type)) {
- auto numa_node_id = -1;
- if (device_option.has_numa_node_id()) {
- numa_node_id = device_option.numa_node_id();
- CAFFE_ENFORCE_GE(numa_node_id, 0, "Invalid NUMA node id: ", numa_node_id);
- }
- CAFFE_ENFORCE_LT(
- numa_node_id,
- FLAGS_caffe2_net_async_max_numa_nodes,
- "Invalid NUMA node id: ",
- numa_node_id);
- return poolGetter(cpu_pools_, device_type, numa_node_id, num_workers_);
- } else if (IsGPUDeviceType(device_type)) {
- auto gpu_id = device_option.device_id();
- CAFFE_ENFORCE(
- gpu_id >= 0 && gpu_id < FLAGS_caffe2_net_async_max_gpus,
- "Invalid GPU id: " + caffe2::to_string(gpu_id));
- return poolGetter(gpu_pools_, device_type, gpu_id, num_workers_);
- } else {
- CAFFE_THROW("Unsupported device type " + caffe2::to_string(device_type));
- }
-}
-
-bool ParallelNet::SupportsAsync() {
- return true;
-}
-
-void ParallelNet::finishRun() {}
-
-std::vector<OperatorBase*> ParallelNet::GetOperators() const {
- return operators_;
-}
-
-std::shared_ptr<AsyncTaskGraphBase> GetAsyncTaskGraph(
- ExecutorHelper* helper,
- const ExecutionOptions& options) {
- return std::make_shared<AsyncTaskGraph>(helper, options);
-}
-
-C10_DEFINE_SHARED_REGISTRY(
- TaskGraphRegistry,
- AsyncTaskGraphBase,
- ExecutorHelper*,
- const ExecutionOptions&);
-
-C10_REGISTER_CREATOR(TaskGraphRegistry, futures, GetAsyncTaskGraph);
-
-REGISTER_NET(parallel, ParallelNet);
-
-} // namespace caffe2
+++ /dev/null
-#ifndef CAFFE2_CORE_NET_PARALLEL_H
-#define CAFFE2_CORE_NET_PARALLEL_H
-
-#include "caffe2/core/net_async_base.h"
-#include "caffe2/core/net_async_task_graph.h"
-
-C10_DECLARE_string(caffe2_task_graph_engine);
-
-namespace caffe2 {
-
-class ParallelNetExecutorHelper;
-
-class CAFFE2_API ParallelNet : public NetBase {
- public:
- ParallelNet(const std::shared_ptr<const NetDef>& net_def, Workspace* ws);
-
- bool RunAsync() override;
- void Wait() override;
-
- bool SupportsAsync() override;
- std::vector<OperatorBase*> GetOperators() const override;
-
- TaskThreadPoolBase* Pool(const DeviceOption& device_option);
-
- protected:
- bool handleRunError() override;
- virtual void finishRun();
- virtual void reset();
-
- ExecutionOptions options_;
- int num_workers_;
-
- std::unique_ptr<ParallelNetExecutorHelper> helper_;
- std::shared_ptr<AsyncTaskGraphBase> task_graph_;
- AsyncTaskFuture* run_future_;
-
- std::vector<dag_utils::OperatorNode> operator_nodes_;
- std::vector<OperatorBase*> operators_;
-
- std::mutex pools_mutex_;
- typedef std::unordered_map<
- int,
- std::unordered_map<int, std::shared_ptr<TaskThreadPoolBase>>>
- PoolsMap;
- PoolsMap cpu_pools_;
- PoolsMap gpu_pools_;
- TaskThreadPoolBase*
- poolGetter(PoolsMap& pools, int device_type, int device_id, int pool_size);
-
- friend class ParallelNetExecutorHelper;
- C10_DISABLE_COPY_AND_ASSIGN(ParallelNet);
-};
-
-C10_DECLARE_SHARED_REGISTRY(
- TaskGraphRegistry,
- AsyncTaskGraphBase,
- ExecutorHelper*,
- const ExecutionOptions&);
-
-std::shared_ptr<AsyncTaskGraphBase> GetAsyncTaskGraph(
- ExecutorHelper* helper,
- const ExecutionOptions& options);
-
-class ParallelNetExecutorHelper : public ExecutorHelper {
- public:
- explicit ParallelNetExecutorHelper(ParallelNet* net) : net_(net) {}
- TaskThreadPoolBase* GetPool(const DeviceOption& option) const override {
- return net_->Pool(option);
- }
-
- private:
- ParallelNet* net_;
-};
-
-} // namespace caffe2
-
-#endif // CAFFE2_CORE_NET_PARALLEL_H
import unittest
-EXECUTORS = ["parallel", "async_scheduling"]
+EXECUTORS = ["async_scheduling", "dag", "async_dag"]
ITERATIONS = 1