task_manager: delegate collecting the results to each node 01/311801/6
authorInki Dae <inki.dae@samsung.com>
Tue, 28 May 2024 06:32:24 +0000 (15:32 +0900)
committerInki Dae <inki.dae@samsung.com>
Wed, 29 May 2024 03:07:08 +0000 (12:07 +0900)
Delegate collecting the results to each node. With this, we can use
only generic interface.

Change-Id: I5afd506bbc763fdcd86c18ac3980d295fd894c67
Signed-off-by: Inki Dae <inki.dae@samsung.com>
12 files changed:
common/include/SingleoCommonTypes.h
services/task_manager/include/BridgeNode.h
services/task_manager/include/CallbackNode.h
services/task_manager/include/EndpointNode.h
services/task_manager/include/INode.h
services/task_manager/include/InferenceNode.h
services/task_manager/include/TaskNode.h
services/task_manager/include/TrainingNode.h
services/task_manager/src/CallbackNode.cpp
services/task_manager/src/InferenceNode.cpp
services/task_manager/src/TaskManager.cpp
test/services/test_task_manager.cpp

index 31cabf34639cada78552543f500e79948e796b94..2109088b70853a3b14a3a23d192c8dfa07954d48 100644 (file)
@@ -17,6 +17,7 @@
 #ifndef __SINGLEO_COMMON_TYPES_H__
 #define __SINGLEO_COMMON_TYPES_H__
 
+#include <memory>
 #include <cstring>
 #include <vector>
 #include "SingleoLog.h"
@@ -103,24 +104,41 @@ struct BaseResultType {
        {}
        virtual ~BaseResultType()
        {}
+
+       virtual std::shared_ptr<BaseResultType> clone() = 0;
 };
 
 struct OdResultType : public BaseResultType {
        OdResultType() : BaseResultType(ResultType::OBJECT_DETECTION)
        {}
        std::vector<Rect> _rects;
+
+       std::shared_ptr<BaseResultType> clone() override
+       {
+               return std::make_shared<OdResultType>(*this);
+       }
 };
 
 struct FdResultType : public BaseResultType {
        FdResultType() : BaseResultType(ResultType::FACE_DETECTION)
        {}
        std::vector<Rect> _rects;
+
+       std::shared_ptr<BaseResultType> clone() override
+       {
+               return std::make_shared<FdResultType>(*this);
+       }
 };
 
 struct FldResultType : public BaseResultType {
        FldResultType() : BaseResultType(ResultType::FACE_LANDMARK)
        {}
        std::vector<Point> _points;
+
+       std::shared_ptr<BaseResultType> clone() override
+       {
+               return std::make_shared<FldResultType>(*this);
+       }
 };
 
 enum class ServiceType { NONE, AUTO_ZOOM };
index 32fdd40dbbc66740703418b249a21ad802cc1f54..fe5199e650d1c9f6682dd31e5e0e00dd4c95dc5a 100644 (file)
@@ -44,6 +44,11 @@ public:
                if (!_cb)
                        throw singleo::exception::InvalidOperation("Bridge node callback is not set");
 
+               _results.clear();
+
+               for (const auto &d : _dependencies)
+                       std::copy(d->results().begin(), d->results().end(), std::back_inserter(_results));
+
                _cb(this);
        }
 };
index ea3e5dcc18aa937ae662d9edd34f249544633279..20b7caee9d590f7f94fe4d0daeeab27d6eb4c2ee 100644 (file)
@@ -61,9 +61,7 @@ public:
        virtual void configure() = 0;
        virtual void invoke() = 0;
        void setCb(const NodeCb &cb);
-       std::vector<std::shared_ptr<BaseResultType> > &getResults();
-       void addResult(std::shared_ptr<BaseResultType> result);
-       void clearResults();
+       std::vector<std::shared_ptr<BaseResultType> > &results() override;
 };
 
 }
index 3070029c3deffeddfafc91b2dae2fc51692b5bb5..cf9d073f4a054b4ae7ef91735841211414e27e1d 100644 (file)
@@ -40,6 +40,11 @@ public:
 
        void invoke() final
        {
+               _results.clear();
+
+               for (auto &d : _dependencies)
+                       std::copy(d->results().begin(), d->results().end(), std::back_inserter(_results));
+
                if (_cb)
                        _cb(this);
        }
index d436e99fe6bc0a475acd02c6d303ebe944b17305..80ce83f474f2dea11c8c2d65a98835c5e7bb15f7 100644 (file)
@@ -45,6 +45,7 @@ public:
        virtual std::shared_ptr<BaseDataType> &getOutput() = 0;
        virtual void configure() = 0;
        virtual void invoke() = 0;
+       virtual std::vector<std::shared_ptr<BaseResultType> > &results() = 0;
        virtual void wait() = 0;
        virtual void wakeup() = 0;
 };
index f758e0ea1166fba34cdc308fcf17e2667eb80112..82fb707b56161db2bfb5a41d7862da368a3aac76 100644 (file)
@@ -33,6 +33,7 @@ class InferenceNode : public TaskNode
 private:
        std::unique_ptr<inference::IInferenceTaskInterface> _task;
        std::mutex _resultMutex;
+       std::vector<std::shared_ptr<BaseResultType> > _results;
 
 public:
        InferenceNode(std::string name = "inference")
@@ -43,11 +44,9 @@ public:
        virtual ~InferenceNode() = default;
 
        void setInferenceTask(std::unique_ptr<inference::IInferenceTaskInterface> &&task);
-       void lockResult();
-       void unlockResult();
        void configure() final;
        void invoke() final;
-       BaseResultType &getTaskResult() final;
+       std::vector<std::shared_ptr<BaseResultType> > &results() final;
 };
 
 }
index 21d8b1be6ec929db473c6d3eab8f2ba95c503f12..38807c9dabf639d4f2aee97271de4dfdf7157a3b 100644 (file)
@@ -58,7 +58,7 @@ public:
        void wakeup() override;
        virtual void configure() = 0;
        virtual void invoke() = 0;
-       virtual BaseResultType &getTaskResult() = 0;
+       virtual std::vector<std::shared_ptr<BaseResultType> > &results() = 0;
 };
 
 }
index 0f8fe511fc7a0a4e3128d84a2cf6f9da92515a3c..97ed21c1f03e48fce0b624b4529120b7147896ca 100644 (file)
@@ -51,9 +51,9 @@ public:
                // TODO. implement invoke here.
        }
 
-       BaseResultType &getTaskResult()
+       std::vector<std::shared_ptr<BaseResultType> > &results()
        {
-               // TODO. implement getTaskresult here.
+               // TODO. implement results here.
        }
 };
 
index 655e5592c57ebfaedb0ce17cb52e8b11473ed61e..05eef6915b89258867d584321456102bcbcf35f4 100644 (file)
@@ -66,21 +66,11 @@ std::shared_ptr<BaseDataType> &CallbackNode::getOutput()
        return _output;
 }
 
-void CallbackNode::addResult(std::shared_ptr<BaseResultType> result)
-{
-       _results.push_back(result);
-}
-
-vector<shared_ptr<BaseResultType> > &CallbackNode::getResults()
+std::vector<std::shared_ptr<BaseResultType> > &CallbackNode::results()
 {
        return _results;
 }
 
-void CallbackNode::clearResults()
-{
-       _results.clear();
-}
-
 void CallbackNode::wait()
 {
        unique_lock<mutex> lock(_mutex);
index 97f2d3f0f7481a95d4e0df7e6b6a3257812e9c67..c62a4a38337c95e6e5deba0aaf788b856f76095f 100644 (file)
@@ -14,6 +14,7 @@
  * limitations under the License.
  */
 
+#include "SingleoLog.h"
 #include "InferenceNode.h"
 
 using namespace std;
@@ -28,16 +29,6 @@ void InferenceNode::setInferenceTask(unique_ptr<inference::IInferenceTaskInterfa
        _task = move(task);
 }
 
-void InferenceNode::lockResult()
-{
-       _resultMutex.lock();
-}
-
-void InferenceNode::unlockResult()
-{
-       _resultMutex.unlock();
-}
-
 void InferenceNode::configure()
 {
        _task->configure();
@@ -57,11 +48,16 @@ void InferenceNode::invoke()
                delete input->ptr;
 
        _inputs.clear();
+       _results.clear();
+
+       _resultMutex.lock();
+       _results.push_back(_task->result().clone());
+       _resultMutex.unlock();
 }
 
-BaseResultType &InferenceNode::getTaskResult()
+std::vector<std::shared_ptr<BaseResultType> > &InferenceNode::results()
 {
-       return _task->result();
+       return _results;
 }
 
 }
index 534e21b2660b5364d235e46f54a3b9f75567f17e..c84614fd2c59b5c013ede8a295a0c93df23ee4f0 100644 (file)
@@ -59,38 +59,11 @@ void TaskManager::threadCb(shared_ptr<INode> &node)
 
                        node->addInput(callbackNode->getOutput());
                }
-
-               node->invoke();
-               node->wakeup();
-
-               return;
+       } else {
+               // TODO. consider for mulitple inputs later.
+               node->addInput(_inputs[0]);
        }
 
-       // ALl other types of callback nodes are handled here.
-       // TODO. Other callback style nodes such as StoreNode, DebugNode and so on will be added later.
-       auto callbackNode = dynamic_pointer_cast<CallbackNode>(node);
-
-       callbackNode->clearResults();
-
-       for (auto &n : node->getDependencies()) {
-               auto inferenceNode = dynamic_pointer_cast<InferenceNode>(n);
-
-               inferenceNode->lockResult();
-
-               auto &result = inferenceNode->getTaskResult();
-
-               if (result._type == ResultType::FACE_DETECTION)
-                       callbackNode->addResult(make_shared<FdResultType>(dynamic_cast<FdResultType &>(result)));
-               else if (result._type == ResultType::FACE_LANDMARK)
-                       callbackNode->addResult(make_shared<FldResultType>(dynamic_cast<FldResultType &>(result)));
-               else if (result._type == ResultType::OBJECT_DETECTION)
-                       callbackNode->addResult(make_shared<OdResultType>(dynamic_cast<OdResultType &>(result)));
-
-               inferenceNode->unlockResult();
-       }
-
-       // TODO. consider for mulitple inputs later.
-       node->addInput(_inputs[0]);
        node->invoke();
        node->wakeup();
 }
@@ -228,10 +201,8 @@ vector<shared_ptr<BaseResultType> > &TaskManager::output()
                throw InvalidOperation("Node is not set");
        }
 
-       auto lastNode = dynamic_pointer_cast<EndpointNode>(_nodes.back());
-
        _results.clear();
-       _results = lastNode->getResults();
+       _results = _nodes.back()->results();
 
        return _results;
 }
index 3418fa0b3c996588860f70cc08913f7dd8c19dd9..a1dde3d38a9c7aa11b5ef6563c0df8d6c20ccd36 100644 (file)
@@ -46,7 +46,7 @@ void BridgeNodeCallback(INode *node)
 
        cv::Mat cv_image(cv::Size(newImage.width, newImage.height), CV_MAKETYPE(CV_8U, 3), newImage.ptr);
 
-       auto &results = callbackNode->getResults();
+       auto &results = callbackNode->results();
        for (auto r : results) {
                ASSERT_EQ(r->_type, ResultType::FACE_DETECTION);
 
@@ -195,7 +195,7 @@ TEST(SingloTaskManager, MultipleNodesBasedGraphBShouldWork)
 void LastNodeCallback(INode *node)
 {
        auto callbackNode = dynamic_cast<CallbackNode *>(node);
-       auto &results = callbackNode->getResults();
+       auto &results = callbackNode->results();
 
        ASSERT_EQ(results.size(), 2);