#define __TASK_MANAGER_H__
#include <vector>
+#include <queue>
+#include <unordered_set>
#include <memory>
#include <thread>
+#include <mutex>
#include "IInferenceTaskInterface.h"
#include "SingleoCommonTypes.h"
std::vector<std::shared_ptr<BaseDataType> > _inputs;
std::vector<std::shared_ptr<INode> > _nodes;
std::vector<std::shared_ptr<BaseResultType> > _results;
+ std::queue<std::shared_ptr<std::thread> > _threads;
+ std::unordered_set<std::shared_ptr<INode> > _is_thread_created;
+
+ std::mutex _thread_mutex;
void threadCb(std::shared_ptr<INode> &node);
void verifyGraph();
node->invoke();
node->wakeup();
+
+ // Spawn threads for next nodes
+ for (auto &n : node->getNexts()) {
+ std::lock_guard<std::mutex> lock(_thread_mutex);
+ if (_is_thread_created.find(n) == _is_thread_created.end()) {
+ _threads.push(make_shared<thread>(&TaskManager::threadCb, this, std::ref(n)));
+ _is_thread_created.insert(n);
+ }
+ }
}
void TaskManager::addInput(BaseDataType &input)
verifyGraph();
- std::vector<std::unique_ptr<std::thread> > threads;
-
auto inputBuffer = make_shared<SharedBuffer>();
for (auto &i : _inputs)
// TODO. consider for multiple sources later.
n->setInputBuffer(inputBuffer);
+ _threads.push(make_shared<thread>(&TaskManager::threadCb, this, std::ref(n)));
+ _is_thread_created.insert(n);
}
-
- threads.push_back(make_unique<thread>(&TaskManager::threadCb, this, std::ref(n)));
}
- for (auto &t : threads)
- t->join();
+ while (true) {
+ std::unique_lock<std::mutex> lock(_thread_mutex);
+ if (_threads.empty())
+ break;
+ auto t = _threads.front();
+ _threads.pop();
+
+ lock.unlock();
+
+ t->join();
+ }
+ _is_thread_created.clear();
_inputs.clear();
}
void TaskManager::clear()
{
+ _inputs.clear();
_nodes.clear();
+ _results.clear();
+ _is_thread_created.clear();
}
}