From 45564ae69a0dbc36ee9ee55e352e859fe83d653c Mon Sep 17 00:00:00 2001 From: Inki Dae Date: Fri, 14 Jun 2024 12:11:28 +0900 Subject: [PATCH] task_manager: add RunQueueManager support Add RunQueueManager class which schedules nodes of graph in runtime. With this patch, we can drop some node dependency from task manager by making each concrete node to update next nodes which will be scheduled by task manager. Change-Id: I7c448aa5ae579b2f61ec2e5b765789f2e82966ad Signed-off-by: Inki Dae --- services/auto_zoom/src/AutoZoom.cpp | 4 +- services/common/include/AsyncManager.h | 4 +- services/task_manager/include/BridgeNode.h | 22 +++-- services/task_manager/include/CallbackNode.h | 7 +- services/task_manager/include/EndpointNode.h | 3 + services/task_manager/include/INode.h | 2 + services/task_manager/include/InferenceNode.h | 2 + .../task_manager/include/RunQueueManager.h | 88 +++++++++++++++++++ services/task_manager/include/TaskManager.h | 4 +- services/task_manager/include/TaskNode.h | 2 + services/task_manager/include/TrainingNode.h | 4 + services/task_manager/src/InferenceNode.cpp | 7 ++ services/task_manager/src/TaskManager.cpp | 45 ++++------ 13 files changed, 157 insertions(+), 37 deletions(-) create mode 100644 services/task_manager/include/RunQueueManager.h diff --git a/services/auto_zoom/src/AutoZoom.cpp b/services/auto_zoom/src/AutoZoom.cpp index b0321c8..ce35e40 100644 --- a/services/auto_zoom/src/AutoZoom.cpp +++ b/services/auto_zoom/src/AutoZoom.cpp @@ -144,8 +144,10 @@ void AutoZoom::configure(InputConfigBase &config) AutoZoom::~AutoZoom() { - if (_async_mode) + if (_async_mode) { _input_service->streamOff(); + _async_manager->clear(); + } _taskManager->clear(); } diff --git a/services/common/include/AsyncManager.h b/services/common/include/AsyncManager.h index a03a073..e45d066 100644 --- a/services/common/include/AsyncManager.h +++ b/services/common/include/AsyncManager.h @@ -118,7 +118,9 @@ public: _thread_handle = std::make_unique(&AsyncManager::invokeThread, this); } - ~AsyncManager() + ~AsyncManager() = default; + + void clear() { _exit_thread = true; _thread_handle->join(); diff --git a/services/task_manager/include/BridgeNode.h b/services/task_manager/include/BridgeNode.h index 078762a..5c8fc2f 100644 --- a/services/task_manager/include/BridgeNode.h +++ b/services/task_manager/include/BridgeNode.h @@ -57,14 +57,26 @@ public: _inputBuffer->release(); _enabled = false; - // Bridge node got the result from previous task node so enable this bridge node. - if (_outputBuffer) - _enabled = true; + if (!_outputBuffer) + return; + + _enabled = true; } - bool isEnabled() + void updateRunQueue(RunQueueManager *runQueueManager) final { - return _enabled; + // If this bridge node failed then wakeup all next nodes because + // other nodes can wait for them to finish. In this case, + // the next nodes aren't invoked. + if (!_enabled) { + for (const auto &n : _nexts) + n->wakeup(); + + return; + } + + for (const auto &n : _nexts) + runQueueManager->pushNode(n); } }; diff --git a/services/task_manager/include/CallbackNode.h b/services/task_manager/include/CallbackNode.h index e98333d..e064211 100644 --- a/services/task_manager/include/CallbackNode.h +++ b/services/task_manager/include/CallbackNode.h @@ -21,6 +21,7 @@ #include #include "INode.h" +#include "RunQueueManager.h" #include "SingleoException.h" #include "SharedBuffer.h" @@ -50,6 +51,10 @@ public: CallbackNode() = default; virtual ~CallbackNode() = default; + virtual void configure() = 0; + virtual void invoke() = 0; + virtual void updateRunQueue(RunQueueManager *runQueueManager) = 0; + NodeType getType() override; std::string &getName() override { @@ -64,8 +69,6 @@ public: std::shared_ptr &getOutputBuffer() override; void wait() override; void wakeup() override; - virtual void configure() = 0; - virtual void invoke() = 0; void clear() override; void setCb(const NodeCb &cb); std::vector > &results() override; diff --git a/services/task_manager/include/EndpointNode.h b/services/task_manager/include/EndpointNode.h index 771b1e2..7e1f6d1 100644 --- a/services/task_manager/include/EndpointNode.h +++ b/services/task_manager/include/EndpointNode.h @@ -50,6 +50,9 @@ public: _inputBuffer->release(); } + + void updateRunQueue(RunQueueManager *runQueueManager) override + {} }; } diff --git a/services/task_manager/include/INode.h b/services/task_manager/include/INode.h index e15604d..5cc04ce 100644 --- a/services/task_manager/include/INode.h +++ b/services/task_manager/include/INode.h @@ -30,6 +30,7 @@ namespace singleo namespace services { enum class NodeType { NONE, INFERENCE, TRAINING, BRIDGE, ENDPOINT }; +class RunQueueManager; class INode { @@ -52,6 +53,7 @@ public: virtual void wait() = 0; virtual void wakeup() = 0; virtual void clear() = 0; + virtual void updateRunQueue(RunQueueManager *runQueueManager) = 0; }; using NodeCb = std::function; diff --git a/services/task_manager/include/InferenceNode.h b/services/task_manager/include/InferenceNode.h index dc6f779..e810f24 100644 --- a/services/task_manager/include/InferenceNode.h +++ b/services/task_manager/include/InferenceNode.h @@ -18,6 +18,7 @@ #define __INFERENCE_NODE_H__ #include +#include #include #include "TaskNode.h" @@ -47,6 +48,7 @@ public: void configure() final; void invoke() final; std::vector > &results() final; + void updateRunQueue(RunQueueManager *runQueueManager) override; }; } diff --git a/services/task_manager/include/RunQueueManager.h b/services/task_manager/include/RunQueueManager.h new file mode 100644 index 0000000..5f90653 --- /dev/null +++ b/services/task_manager/include/RunQueueManager.h @@ -0,0 +1,88 @@ +/** + * Copyright (c) 2024 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __RUN_QUEUE_MANAGER_H__ +#define __RUN_QUEUE_MANAGER_H__ + +#include +#include +#include +#include + +#include "INode.h" +#include "SingleoLog.h" + +namespace singleo +{ +namespace services +{ +class RunQueueManager +{ +private: + std::unordered_set > _uniqueRunNodes; + std::queue > _runQueue; + std::mutex _runQueueMutex; + + bool isDuplicated(const std::shared_ptr &node) + { + return _uniqueRunNodes.find(node) != _uniqueRunNodes.end(); + } + +public: + void pushNode(const std::shared_ptr node) + { + std::lock_guard lock(_runQueueMutex); + + if (isDuplicated(node)) + return; + + _runQueue.push(node); + _uniqueRunNodes.insert(node); + } + + std::shared_ptr &popNode() + { + std::lock_guard lock(_runQueueMutex); + auto &node = _runQueue.front(); + + _runQueue.pop(); + + return node; + } + + bool isEmpty() + { + std::lock_guard lock(_runQueueMutex); + + return _runQueue.empty(); + } + + void clear() + { + std::lock_guard lock(_runQueueMutex); + + _uniqueRunNodes.clear(); + } + +public: + RunQueueManager() = default; + ~RunQueueManager() = default; +}; + +} +} + +#endif \ No newline at end of file diff --git a/services/task_manager/include/TaskManager.h b/services/task_manager/include/TaskManager.h index 16761a0..34f303e 100644 --- a/services/task_manager/include/TaskManager.h +++ b/services/task_manager/include/TaskManager.h @@ -27,6 +27,7 @@ #include "IInferenceTaskInterface.h" #include "SingleoCommonTypes.h" #include "INode.h" +#include "RunQueueManager.h" namespace singleo { @@ -39,9 +40,8 @@ private: std::vector > _nodes; std::vector > _results; std::queue > _threads; - std::unordered_set > _is_thread_created; - std::mutex _thread_mutex; + std::unique_ptr _runQueueManager; void threadCb(std::shared_ptr &node); void verifyGraph(); diff --git a/services/task_manager/include/TaskNode.h b/services/task_manager/include/TaskNode.h index c32fa21..6521829 100644 --- a/services/task_manager/include/TaskNode.h +++ b/services/task_manager/include/TaskNode.h @@ -18,6 +18,7 @@ #define __TASK_NODE_H__ #include +#include #include #include "INode.h" @@ -66,6 +67,7 @@ public: virtual void invoke() = 0; void clear() override; virtual std::vector > &results() = 0; + virtual void updateRunQueue(RunQueueManager *runQueueManager) = 0; }; } diff --git a/services/task_manager/include/TrainingNode.h b/services/task_manager/include/TrainingNode.h index 16ddaa0..e5d2ab3 100644 --- a/services/task_manager/include/TrainingNode.h +++ b/services/task_manager/include/TrainingNode.h @@ -18,6 +18,7 @@ #define __TRAINING_NODE_H__ #include +#include #include #include "TaskNode.h" @@ -55,6 +56,9 @@ public: { // TODO. implement results here. } + + void updateRunQueue(RunQueueManager *runQueueManager) override + {} }; } diff --git a/services/task_manager/src/InferenceNode.cpp b/services/task_manager/src/InferenceNode.cpp index a04176b..ef0c6bd 100644 --- a/services/task_manager/src/InferenceNode.cpp +++ b/services/task_manager/src/InferenceNode.cpp @@ -16,6 +16,7 @@ #include "SingleoLog.h" #include "InferenceNode.h" +#include "RunQueueManager.h" using namespace std; using namespace singleo::inference; @@ -58,5 +59,11 @@ std::vector > &InferenceNode::results() return _results; } +void InferenceNode::updateRunQueue(RunQueueManager *runQueueManager) +{ + for (auto &n : _nexts) + runQueueManager->pushNode(n); +} + } } \ No newline at end of file diff --git a/services/task_manager/src/TaskManager.cpp b/services/task_manager/src/TaskManager.cpp index 73b1d6a..ed9e29b 100644 --- a/services/task_manager/src/TaskManager.cpp +++ b/services/task_manager/src/TaskManager.cpp @@ -39,8 +39,6 @@ void TaskManager::threadCb(shared_ptr &node) for (auto &d : node->getDependencies()) d->wait(); - SINGLEO_LOGD("Launched node name = %s", node->getName().c_str()); - if (node->getType() == NodeType::INFERENCE) { if (_inputs[0]->_data_type != DataType::IMAGE) { SINGLEO_LOGE("Invalid input data type."); @@ -80,24 +78,15 @@ void TaskManager::threadCb(shared_ptr &node) node->invoke(); node->wakeup(); - // Spawn threads for next nodes - for (auto &n : node->getNexts()) { - if (node->getType() == NodeType::BRIDGE) { - auto b_node = dynamic_pointer_cast(node); + node->updateRunQueue(_runQueueManager.get()); - // In case of BRIDGE node, if this bridge node didn't get the result from previous task node, - // isEnabled() is false. So if isEnabled() is false, stop all sub graph pipelines connected to this node. - if (!b_node->isEnabled()) { - n->wakeup(); - continue; - } - } + while (!_runQueueManager->isEmpty()) { + auto &n = _runQueueManager->popNode(); - std::lock_guard lock(_thread_mutex); - if (_is_thread_created.find(n) == _is_thread_created.end()) { - _threads.push(make_shared(&TaskManager::threadCb, this, std::ref(n))); - _is_thread_created.insert(n); - } + std::unique_lock lock(_thread_mutex); + + _threads.push(make_shared(&TaskManager::threadCb, this, std::ref(n))); + _thread_mutex.unlock(); } } @@ -247,6 +236,8 @@ void TaskManager::run() auto inputBuffer = make_shared(); + _runQueueManager = make_unique(); + for (auto &i : _inputs) inputBuffer->addInput(i); @@ -261,27 +252,29 @@ void TaskManager::run() throw InvalidOperation("root node should be inference node."); } - // TODO. consider for multiple sources later. - n->setInputBuffer(inputBuffer); - _threads.push(make_shared(&TaskManager::threadCb, this, std::ref(n))); - _is_thread_created.insert(n); + _runQueueManager->pushNode(n); } } + while (!_runQueueManager->isEmpty()) { + auto &n = _runQueueManager->popNode(); + + n->setInputBuffer(inputBuffer); + _threads.push(make_shared(&TaskManager::threadCb, this, std::ref(n))); + } + while (true) { std::unique_lock lock(_thread_mutex); if (_threads.empty()) break; auto t = _threads.front(); - _threads.pop(); + _threads.pop(); lock.unlock(); - t->join(); } - _is_thread_created.clear(); _inputs.clear(); // the result has been returned to user so clear each node. @@ -309,7 +302,7 @@ void TaskManager::clear() _inputs.clear(); _nodes.clear(); _results.clear(); - _is_thread_created.clear(); + _runQueueManager->clear(); } } -- 2.34.1