mv_machine_learning: move async relevant code into behavior class
authorInki Dae <inki.dae@samsung.com>
Wed, 10 May 2023 02:25:07 +0000 (11:25 +0900)
committerKwanghoon Son <k.son@samsung.com>
Tue, 4 Jul 2023 05:04:45 +0000 (14:04 +0900)
[Issue type] : code cleanup

Move async relevant code into behavior class - ObjectDetection - from
adapter class - ObjectDetectionAdapter. By doing this,
we can reduce code duplication because we don't have to implement
behavior logic respectively according to data type.

Change-Id: Ifc193152859736ae146eb47e3a46d3ec39ee98d8
Signed-off-by: Inki Dae <inki.dae@samsung.com>
mv_machine_learning/object_detection/include/object_detection.h
mv_machine_learning/object_detection/include/object_detection_adapter.h
mv_machine_learning/object_detection/src/object_detection.cpp
mv_machine_learning/object_detection/src/object_detection_adapter.cpp

index 0ce3040eb3390d66df710cd88adb17b6257e5561..78cb090462be4b501248605c0021213cda090475 100644 (file)
@@ -18,6 +18,8 @@
 #define __OBJECT_DETECTION_H__
 
 #include <queue>
+#include <thread>
+#include <mutex>
 #include <mv_common.h>
 #include <mv_inference_type.h>
 #include "mv_private.h"
@@ -39,13 +41,23 @@ private:
        void loadLabel();
        void getEngineList();
        void getDeviceList(const char *engine_type);
+       void updateResult(ObjectDetectionResult &result);
+       template<typename T>
+       void preprocess(mv_source_h &mv_src, std::shared_ptr<MetaInfo> metaInfo, std::vector<T> &inputVector);
+       template<typename T> void pushToInput(ObjectDetectionQueue<T> &input);
+       ObjectDetectionResult popFromOutput();
+       bool isOutputQueueEmpty();
 
        ObjectDetectionTaskType _task_type;
        template<typename T> std::queue<ObjectDetectionQueue<T> > static _incoming_queue;
        std::queue<ObjectDetectionResult> _outgoing_queue;
        std::mutex _incoming_queue_mutex;
        std::mutex _outgoing_queue_mutex;
-       int _input_data_type;
+       int _input_data_type {};
+       std::unique_ptr<std::thread> _thread_handle;
+       bool _exit_thread {};
+       ObjectDetectionResult _current_result {};
+       unsigned long _input_index {};
 
 protected:
        std::unique_ptr<mediavision::inference::Inference> _inference;
@@ -68,6 +80,7 @@ protected:
 public:
        ObjectDetection(ObjectDetectionTaskType task_type);
        virtual ~ObjectDetection() = default;
+       void preDestroy();
        ObjectDetectionTaskType getTaskType();
        void setUserModel(std::string model_file, std::string meta_file, std::string label_file);
        void setEngineInfo(std::string engine_type, std::string device_type);
@@ -76,19 +89,17 @@ public:
        void getNumberOfDevices(const char *engine_type, unsigned int *number_of_devices);
        void getDeviceType(const char *engine_type, const unsigned int device_index, char **device_type);
        std::shared_ptr<MetaInfo> getInputMetaInfo();
+       bool exitThread();
        void parseMetaFile(const char *meta_file_name);
        void configure();
        void prepare();
-       template<typename T>
-       void preprocess(mv_source_h &mv_src, std::shared_ptr<MetaInfo> metaInfo, std::vector<T> &inputVector);
-       template<typename T> void inference(std::vector<std::vector<T> > &inputVectors);
+       template<typename V> V &getOutput();
        template<typename T> void perform(mv_source_h &mv_src, std::shared_ptr<MetaInfo> metaInfo);
-       template<typename T> void pushToInput(ObjectDetectionQueue<T> &input);
+       template<typename T> void performAsync(ObjectDetectionInput &input, std::shared_ptr<MetaInfo> metaInfo);
+       template<typename T> void inference(std::vector<std::vector<T> > &inputVectors);
        template<typename T> ObjectDetectionQueue<T> popFromInput();
        template<typename T> bool isInputQueueEmpty();
        void pushToOutput(ObjectDetectionResult &output);
-       ObjectDetectionResult popFromOutput();
-       bool isOutputQueueEmpty();
        virtual ObjectDetectionResult &result() = 0;
 };
 
index d3c063da9e367e82ce5fe11120e9ad4bb101db53..3db8e9e105046fe71e8702d5f680b31612520dfd 100644 (file)
@@ -17,9 +17,6 @@
 #ifndef __OBJECT_DETECTION_ADAPTER_H__
 #define __OBJECT_DETECTION_ADAPTER_H__
 
-#include <queue>
-#include <mutex>
-#include <thread>
 #include <dlog.h>
 
 #include "EngineConfig.h"
@@ -40,10 +37,6 @@ private:
        std::string _model_file;
        std::string _meta_file;
        std::string _label_file;
-       std::unique_ptr<std::thread> _thread_handle;
-       ObjectDetectionResult _current_result {};
-       bool _exit_thread {};
-       unsigned long _input_index {};
 
        void updateResult(ObjectDetectionResult &result);
 
@@ -66,15 +59,6 @@ public:
        void perform() override;
        void performAsync(T &t) override;
        V &getOutput() override;
-
-       ObjectDetection *getObjectDetection()
-       {
-               return _object_detection.get();
-       }
-       bool exitThread()
-       {
-               return _exit_thread;
-       }
 };
 
 } // machine_learning
index f5c3f3739f54864b988144ead64d64e3a16c5446..f1fcabf0627e70039f7b2c1493809b6efae9c556 100644 (file)
@@ -42,6 +42,19 @@ ObjectDetection::ObjectDetection(ObjectDetectionTaskType task_type)
        _parser = make_unique<ObjectDetectionParser>();
 }
 
+void ObjectDetection::preDestroy()
+{
+       if (_thread_handle) {
+               _exit_thread = true;
+               _thread_handle->join();
+       }
+}
+
+bool ObjectDetection::exitThread()
+{
+       return _exit_thread;
+}
+
 ObjectDetectionTaskType ObjectDetection::getTaskType()
 {
        return _task_type;
@@ -304,6 +317,77 @@ template<typename T> void ObjectDetection::perform(mv_source_h &mv_src, shared_p
        inference<T>(inputVectors);
 }
 
+template<typename T> void inferenceThreadLoop(ObjectDetection *object)
+{
+       // If user called destroy API then this thread loop will be terminated.
+       while (!object->exitThread() || !object->isInputQueueEmpty<T>()) {
+               // If input queue is empty then skip inference request.
+               if (object->isInputQueueEmpty<T>())
+                       continue;
+
+               ObjectDetectionQueue<T> input = object->popFromInput<T>();
+
+               LOGD("Popped : input index = %lu", input.index);
+
+               object->inference<T>(input.inputs);
+
+               ObjectDetectionResult &result = object->result();
+               result.is_valid = false;
+
+               object->pushToOutput(result);
+
+               input.completion_cb(input.handle, input.user_data);
+       }
+}
+
+template<typename T> void ObjectDetection::performAsync(ObjectDetectionInput &input, shared_ptr<MetaInfo> metaInfo)
+{
+       _input_index++;
+
+       if (!isInputQueueEmpty<T>())
+               return;
+
+       vector<T> inputVector;
+
+       preprocess<T>(input.inference_src, metaInfo, inputVector);
+
+       vector<vector<T> > inputVectors = { inputVector };
+       ObjectDetectionQueue<T> in_queue = { _input_index,                input.handle, input.inference_src,
+                                                                                input.completion_cb, inputVectors, input.user_data };
+
+       pushToInput<T>(in_queue);
+       LOGD("Pushed : input index = %lu", in_queue.index);
+
+       if (!_thread_handle)
+               _thread_handle = make_unique<thread>(&inferenceThreadLoop<T>, this);
+}
+
+void ObjectDetection::updateResult(ObjectDetectionResult &result)
+{
+       _current_result = result;
+       _current_result.is_valid = true;
+}
+
+template<typename V> V &ObjectDetection::getOutput()
+{
+       if (_thread_handle) {
+               // There may be two or more Native APIs which utilize getOutput() function.
+               // Therefore, the current result should be kept until a new inference request is made for result consistency.
+               if (_current_result.is_valid)
+                       return _current_result;
+
+               if (isOutputQueueEmpty())
+                       throw InvalidOperation("Output queue is empty.");
+
+               V result = popFromOutput();
+               updateResult(result);
+
+               return _current_result;
+       }
+
+       return result();
+}
+
 void ObjectDetection::getOutputNames(vector<string> &names)
 {
        TensorBuffer &tensor_buffer_obj = _inference->getOutputTensorBuffer();
@@ -377,6 +461,7 @@ template void ObjectDetection::perform<float>(mv_source_h &mv_src, shared_ptr<Me
 template void ObjectDetection::pushToInput<float>(ObjectDetectionQueue<float> &input);
 template ObjectDetectionQueue<float> ObjectDetection::popFromInput();
 template bool ObjectDetection::isInputQueueEmpty<float>();
+template void ObjectDetection::performAsync<float>(ObjectDetectionInput &input, shared_ptr<MetaInfo> metaInfo);
 
 template void ObjectDetection::preprocess<unsigned char>(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo,
                                                                                                                 vector<unsigned char> &inputVector);
@@ -386,5 +471,9 @@ template void ObjectDetection::pushToInput<unsigned char>(ObjectDetectionQueue<u
 template ObjectDetectionQueue<unsigned char> ObjectDetection::popFromInput();
 template bool ObjectDetection::isInputQueueEmpty<unsigned char>();
 
+template void ObjectDetection::performAsync<unsigned char>(ObjectDetectionInput &input, shared_ptr<MetaInfo> metaInfo);
+
+template ObjectDetectionResult &ObjectDetection::getOutput();
+
 }
 }
\ No newline at end of file
index e2b60caf3a481753ed862eddeaaec3929bf41f30..c4f7fffbc79da4cec61d0967bc2a6738254d199b 100644 (file)
@@ -26,8 +26,7 @@ namespace mediavision
 {
 namespace machine_learning
 {
-template<typename T, typename V>
-ObjectDetectionAdapter<T, V>::ObjectDetectionAdapter() : _source(), _exit_thread(), _input_index()
+template<typename T, typename V> ObjectDetectionAdapter<T, V>::ObjectDetectionAdapter() : _source()
 {
        // In default, Mobilenet v1 ssd model will be used.
        // If other model is set by user then strategy pattern will be used
@@ -37,10 +36,7 @@ ObjectDetectionAdapter<T, V>::ObjectDetectionAdapter() : _source(), _exit_thread
 
 template<typename T, typename V> ObjectDetectionAdapter<T, V>::~ObjectDetectionAdapter()
 {
-       if (_thread_handle) {
-               _exit_thread = true;
-               _thread_handle->join();
-       }
+       _object_detection->preDestroy();
 }
 
 template<typename T, typename V> void ObjectDetectionAdapter<T, V>::create(int type)
@@ -150,97 +146,20 @@ template<typename T, typename V> void ObjectDetectionAdapter<T, V>::perform()
                throw InvalidOperation("Invalid model data type.");
 }
 
-template<typename T, typename V> void ObjectDetectionAdapter<T, V>::updateResult(ObjectDetectionResult &result)
-{
-       _current_result = result;
-       _current_result.is_valid = true;
-}
-
 template<typename T, typename V> V &ObjectDetectionAdapter<T, V>::getOutput()
 {
-       if (_thread_handle) {
-               // There may be two or more Native APIs which utilize getOutput() function.
-               // Therefore, the current result should be kept until a new inference request is made for result consistency.
-               if (_current_result.is_valid)
-                       return _current_result;
-
-               if (_object_detection->isOutputQueueEmpty())
-                       throw InvalidOperation("Output queue is empty.");
-
-               ObjectDetectionResult result = _object_detection->popFromOutput();
-               updateResult(result);
-
-               return _current_result;
-       }
-
-       return _object_detection->result();
-}
-
-template<typename T, typename V, typename N> void inferenceThreadLoop(ObjectDetectionAdapter<T, V> *adapter)
-{
-       ObjectDetection *object = adapter->getObjectDetection();
-
-       // If user called destroy API then this thread loop will be terminated.
-       while (!adapter->exitThread() || !object->isInputQueueEmpty<N>()) {
-               // If input queue is empty then skip inference request.
-               if (object->isInputQueueEmpty<N>())
-                       continue;
-
-               ObjectDetectionQueue<N> input = object->popFromInput<N>();
-
-               LOGD("Pop : input index = %lu", input.index);
-
-               object->inference<N>(input.inputs);
-
-               ObjectDetectionResult &result = object->result();
-               result.is_valid = false;
-
-               object->pushToOutput(result);
-
-               input.completion_cb(input.handle, input.user_data);
-       }
+       return _object_detection->getOutput<V>();
 }
 
 template<typename T, typename V> void ObjectDetectionAdapter<T, V>::performAsync(T &t)
 {
        shared_ptr<MetaInfo> metaInfo = _object_detection->getInputMetaInfo();
 
-       _input_index++;
-
        if (metaInfo->dataType == MV_INFERENCE_DATA_UINT8) {
-               if (!_object_detection->isInputQueueEmpty<float>())
-                       return;
-
-               vector<unsigned char> inputVector;
-
-               _object_detection->preprocess<unsigned char>(t.inference_src, metaInfo, inputVector);
-
-               vector<vector<unsigned char> > inputs = { inputVector };
-               ObjectDetectionQueue<unsigned char> input = { _input_index,        t.handle, t.inference_src,
-                                                                                                         t.completion_cb, inputs,       t.user_data };
-
-               _object_detection->pushToInput<unsigned char>(input);
-               LOGD("Push : input index = %lu", input.index);
-
-               if (!_thread_handle)
-                       _thread_handle = make_unique<thread>(&inferenceThreadLoop<T, V, float>, this);
+               _object_detection->performAsync<unsigned char>(t, metaInfo);
        } else if (metaInfo->dataType == MV_INFERENCE_DATA_FLOAT32) {
-               if (!_object_detection->isInputQueueEmpty<float>())
-                       return;
-
-               vector<float> inputVector;
-
-               _object_detection->preprocess<float>(t.inference_src, metaInfo, inputVector);
-
-               vector<vector<float> > inputs = { inputVector };
-               ObjectDetectionQueue<float> input = { _input_index,        t.handle, t.inference_src,
-                                                                                         t.completion_cb, inputs,       t.user_data };
-
-               _object_detection->pushToInput<float>(input);
-               LOGD("Push : input index = %lu", input.index);
-
-               if (!_thread_handle)
-                       _thread_handle = make_unique<thread>(&inferenceThreadLoop<T, V, float>, this);
+               _object_detection->performAsync<float>(t, metaInfo);
+               // TODO
        } else {
                throw InvalidOperation("Invalid model data type.");
        }