mv_machine_learning: convert ObjectDetection3d into a template class
authorVibhav Aggarwal <v.aggarwal@samsung.com>
Fri, 10 Nov 2023 10:10:07 +0000 (19:10 +0900)
committerInki Dae <inki.dae@samsung.com>
Tue, 14 Nov 2023 07:43:10 +0000 (16:43 +0900)
[Issue type] code refactoring

After the introduction of MachineLearningConfig, it is now possible
to convert each task group's object class into a template class.
This patch introduces this change to the object detection 3d task group.

Change-Id: Ib887846b4b1466f70d9131079f6f7ce62598fbb5
Signed-off-by: Vibhav Aggarwal <v.aggarwal@samsung.com>
mv_machine_learning/object_detection_3d/include/iobject_detection_3d.h [new file with mode: 0644]
mv_machine_learning/object_detection_3d/include/object_detection_3d.h
mv_machine_learning/object_detection_3d/include/object_detection_3d_adapter.h
mv_machine_learning/object_detection_3d/include/objectron.h
mv_machine_learning/object_detection_3d/src/object_detection_3d.cpp
mv_machine_learning/object_detection_3d/src/object_detection_3d_adapter.cpp
mv_machine_learning/object_detection_3d/src/objectron.cpp

diff --git a/mv_machine_learning/object_detection_3d/include/iobject_detection_3d.h b/mv_machine_learning/object_detection_3d/include/iobject_detection_3d.h
new file mode 100644 (file)
index 0000000..94efaf6
--- /dev/null
@@ -0,0 +1,49 @@
+/**
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __IOBJECT_DETECTION_H__
+#define __IOBJECT_DETECTION_H__
+
+#include <mv_common.h>
+
+#include "object_detection_3d_type.h"
+
+namespace mediavision
+{
+namespace machine_learning
+{
+class IObjectDetection3d
+{
+public:
+       virtual ~IObjectDetection3d() {};
+
+       virtual ObjectDetection3dTaskType getTaskType() = 0;
+       virtual void setEngineInfo(std::string engine_type_name, std::string device_type_name) = 0;
+       virtual void getNumberOfEngines(unsigned int *number_of_engines) = 0;
+       virtual void getEngineType(unsigned int engine_index, char **engine_type) = 0;
+       virtual void getNumberOfDevices(const char *engine_type, unsigned int *number_of_devices) = 0;
+       virtual void getDeviceType(const char *engine_type, const unsigned int device_index, char **device_type) = 0;
+       virtual std::shared_ptr<MetaInfo> getInputMetaInfo() = 0;
+       virtual void configure() = 0;
+       virtual void prepare() = 0;
+       virtual void perform(mv_source_h &mv_src, std::shared_ptr<MetaInfo> metaInfo) = 0;
+       virtual ObjectDetection3dResult &result() = 0;
+};
+
+} // machine_learning
+} // mediavision
+
+#endif
\ No newline at end of file
index 370e54a..64dd038 100644 (file)
@@ -25,6 +25,7 @@
 #include "inference_engine_common_impl.h"
 #include "Inference.h"
 #include "object_detection_3d_type.h"
+#include "iobject_detection_3d.h"
 #include "MetaParser.h"
 #include "machine_learning_config.h"
 #include "ObjectDetection3dParser.h"
@@ -34,7 +35,7 @@ namespace mediavision
 {
 namespace machine_learning
 {
-class ObjectDetection3d
+template<typename T> class ObjectDetection3d : public IObjectDetection3d
 {
 private:
        ObjectDetection3dTaskType _task_type;
@@ -53,16 +54,14 @@ protected:
 
        void getOutputNames(std::vector<std::string> &names);
        void getOutputTensor(std::string &target_name, std::vector<float> &tensor);
-       template<typename T>
        void preprocess(mv_source_h &mv_src, std::shared_ptr<MetaInfo> metaInfo, std::vector<T> &inputVector);
-       template<typename T> void inference(std::vector<std::vector<T> > &inputVectors);
+       void inference(std::vector<std::vector<T> > &inputVectors);
 
 public:
        ObjectDetection3d(ObjectDetection3dTaskType task_type, std::shared_ptr<MachineLearningConfig> config);
        virtual ~ObjectDetection3d() = default;
 
        ObjectDetection3dTaskType getTaskType();
-       void setUserModel(std::string model_file, std::string meta_file, std::string label_file);
        void setEngineInfo(std::string engine_type_name, std::string device_type_name);
        void getNumberOfEngines(unsigned int *number_of_engines);
        void getEngineType(unsigned int engine_index, char **engine_type);
@@ -71,7 +70,7 @@ public:
        std::shared_ptr<MetaInfo> getInputMetaInfo();
        void configure();
        void prepare();
-       template<typename T> void perform(mv_source_h &mv_src, std::shared_ptr<MetaInfo> metaInfo);
+       void perform(mv_source_h &mv_src, std::shared_ptr<MetaInfo> metaInfo);
        virtual ObjectDetection3dResult &result() = 0;
 };
 
index 4b0b266..da40f64 100644 (file)
@@ -23,6 +23,7 @@
 #include "itask.h"
 #include "machine_learning_config.h"
 #include "objectron.h"
+#include "iobject_detection_3d.h"
 
 namespace mediavision
 {
@@ -31,7 +32,7 @@ namespace machine_learning
 template<typename T, typename V> class ObjectDetection3dAdapter : public mediavision::common::ITask<T, V>
 {
 private:
-       std::unique_ptr<ObjectDetection3d> _object_detection_3d;
+       std::unique_ptr<IObjectDetection3d> _object_detection_3d;
        std::shared_ptr<MachineLearningConfig> _config;
        T _source;
        const std::string _config_file_name = "object_detection_3d.json";
index d86fb63..ca0cf99 100644 (file)
@@ -29,8 +29,12 @@ namespace mediavision
 {
 namespace machine_learning
 {
-class Objectron : public ObjectDetection3d
+template<typename T> class Objectron : public ObjectDetection3d<T>
 {
+       using ObjectDetection3d<T>::_preprocess;
+       using ObjectDetection3d<T>::_inference;
+       using ObjectDetection3d<T>::_config;
+
 private:
        ObjectDetection3dResult _result;
 
index cef186a..82ef21e 100644 (file)
@@ -34,18 +34,20 @@ namespace mediavision
 {
 namespace machine_learning
 {
-ObjectDetection3d::ObjectDetection3d(ObjectDetection3dTaskType task_type, std::shared_ptr<MachineLearningConfig> config)
+template<typename T>
+ObjectDetection3d<T>::ObjectDetection3d(ObjectDetection3dTaskType task_type,
+                                                                               std::shared_ptr<MachineLearningConfig> config)
                : _task_type(task_type), _config(config)
 {
        _inference = make_unique<Inference>();
 }
 
-ObjectDetection3dTaskType ObjectDetection3d::getTaskType()
+template<typename T> ObjectDetection3dTaskType ObjectDetection3d<T>::getTaskType()
 {
        return _task_type;
 }
 
-void ObjectDetection3d::getEngineList()
+template<typename T> void ObjectDetection3d<T>::getEngineList()
 {
        for (auto idx = MV_INFERENCE_BACKEND_NONE + 1; idx < MV_INFERENCE_BACKEND_MAX; ++idx) {
                auto backend = _inference->getSupportedInferenceBackend(idx);
@@ -57,7 +59,7 @@ void ObjectDetection3d::getEngineList()
        }
 }
 
-void ObjectDetection3d::getDeviceList(const char *engine_type)
+template<typename T> void ObjectDetection3d<T>::getDeviceList(const char *engine_type)
 {
        // TODO. add device types available for a given engine type later.
        //       In default, cpu and gpu only.
@@ -65,7 +67,8 @@ void ObjectDetection3d::getDeviceList(const char *engine_type)
        _valid_devices.push_back("gpu");
 }
 
-void ObjectDetection3d::setEngineInfo(std::string engine_type_name, std::string device_type_name)
+template<typename T>
+void ObjectDetection3d<T>::setEngineInfo(std::string engine_type_name, std::string device_type_name)
 {
        if (engine_type_name.empty() || device_type_name.empty())
                throw InvalidParameter("Invalid engine info.");
@@ -86,7 +89,7 @@ void ObjectDetection3d::setEngineInfo(std::string engine_type_name, std::string
                 device_type_name.c_str(), device_type);
 }
 
-void ObjectDetection3d::getNumberOfEngines(unsigned int *number_of_engines)
+template<typename T> void ObjectDetection3d<T>::getNumberOfEngines(unsigned int *number_of_engines)
 {
        if (!_valid_backends.empty()) {
                *number_of_engines = _valid_backends.size();
@@ -97,7 +100,7 @@ void ObjectDetection3d::getNumberOfEngines(unsigned int *number_of_engines)
        *number_of_engines = _valid_backends.size();
 }
 
-void ObjectDetection3d::getEngineType(unsigned int engine_index, char **engine_type)
+template<typename T> void ObjectDetection3d<T>::getEngineType(unsigned int engine_index, char **engine_type)
 {
        if (!_valid_backends.empty()) {
                if (_valid_backends.size() <= engine_index)
@@ -115,7 +118,8 @@ void ObjectDetection3d::getEngineType(unsigned int engine_index, char **engine_t
        *engine_type = const_cast<char *>(_valid_backends[engine_index].data());
 }
 
-void ObjectDetection3d::getNumberOfDevices(const char *engine_type, unsigned int *number_of_devices)
+template<typename T>
+void ObjectDetection3d<T>::getNumberOfDevices(const char *engine_type, unsigned int *number_of_devices)
 {
        if (!_valid_devices.empty()) {
                *number_of_devices = _valid_devices.size();
@@ -126,7 +130,8 @@ void ObjectDetection3d::getNumberOfDevices(const char *engine_type, unsigned int
        *number_of_devices = _valid_devices.size();
 }
 
-void ObjectDetection3d::getDeviceType(const char *engine_type, const unsigned int device_index, char **device_type)
+template<typename T>
+void ObjectDetection3d<T>::getDeviceType(const char *engine_type, const unsigned int device_index, char **device_type)
 {
        if (!_valid_devices.empty()) {
                if (_valid_devices.size() <= device_index)
@@ -144,7 +149,7 @@ void ObjectDetection3d::getDeviceType(const char *engine_type, const unsigned in
        *device_type = const_cast<char *>(_valid_devices[device_index].data());
 }
 
-void ObjectDetection3d::loadLabel()
+template<typename T> void ObjectDetection3d<T>::loadLabel()
 {
        if (_config->getLabelFilePath().empty())
                return;
@@ -162,9 +167,8 @@ void ObjectDetection3d::loadLabel()
                _labels.push_back(line);
 }
 
-void ObjectDetection3d::configure()
+template<typename T> void ObjectDetection3d<T>::configure()
 {
-       _config->loadMetaFile(make_unique<ObjectDetection3dParser>(static_cast<int>(_task_type)));
        loadLabel();
 
        int ret = _inference->bind(_config->getBackendType(), _config->getTargetDeviceType());
@@ -172,7 +176,7 @@ void ObjectDetection3d::configure()
                throw InvalidOperation("Fail to bind a backend engine.");
 }
 
-void ObjectDetection3d::prepare()
+template<typename T> void ObjectDetection3d<T>::prepare()
 {
        int ret = _inference->configureInputMetaInfo(_config->getInputMetaMap());
        if (ret != MEDIA_VISION_ERROR_NONE)
@@ -190,7 +194,7 @@ void ObjectDetection3d::prepare()
                throw InvalidOperation("Fail to load model files.");
 }
 
-shared_ptr<MetaInfo> ObjectDetection3d::getInputMetaInfo()
+template<typename T> shared_ptr<MetaInfo> ObjectDetection3d<T>::getInputMetaInfo()
 {
        TensorBuffer &tensor_buffer = _inference->getInputTensorBuffer();
        IETensorBuffer &tensor_info_map = tensor_buffer.getIETensorBuffer();
@@ -206,7 +210,7 @@ shared_ptr<MetaInfo> ObjectDetection3d::getInputMetaInfo()
 }
 
 template<typename T>
-void ObjectDetection3d::preprocess(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo, vector<T> &inputVector)
+void ObjectDetection3d<T>::preprocess(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo, vector<T> &inputVector)
 {
        LOGI("ENTER");
 
@@ -238,7 +242,7 @@ void ObjectDetection3d::preprocess(mv_source_h &mv_src, shared_ptr<MetaInfo> met
        LOGI("LEAVE");
 }
 
-template<typename T> void ObjectDetection3d::inference(vector<vector<T> > &inputVectors)
+template<typename T> void ObjectDetection3d<T>::inference(vector<vector<T> > &inputVectors)
 {
        LOGI("ENTER");
 
@@ -249,18 +253,18 @@ template<typename T> void ObjectDetection3d::inference(vector<vector<T> > &input
        LOGI("LEAVE");
 }
 
-template<typename T> void ObjectDetection3d::perform(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo)
+template<typename T> void ObjectDetection3d<T>::perform(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo)
 {
        vector<T> inputVector;
 
-       preprocess<T>(mv_src, metaInfo, inputVector);
+       preprocess(mv_src, metaInfo, inputVector);
 
        vector<vector<T> > inputVectors = { inputVector };
 
-       inference<T>(inputVectors);
+       inference(inputVectors);
 }
 
-void ObjectDetection3d::getOutputNames(vector<string> &names)
+template<typename T> void ObjectDetection3d<T>::getOutputNames(vector<string> &names)
 {
        TensorBuffer &tensor_buffer_obj = _inference->getOutputTensorBuffer();
        IETensorBuffer &ie_tensor_buffer = tensor_buffer_obj.getIETensorBuffer();
@@ -269,7 +273,7 @@ void ObjectDetection3d::getOutputNames(vector<string> &names)
                names.push_back(it->first);
 }
 
-void ObjectDetection3d::getOutputTensor(string &target_name, vector<float> &tensor)
+template<typename T> void ObjectDetection3d<T>::getOutputTensor(string &target_name, vector<float> &tensor)
 {
        LOGI("ENTER");
 
@@ -286,15 +290,8 @@ void ObjectDetection3d::getOutputTensor(string &target_name, vector<float> &tens
        LOGI("LEAVE");
 }
 
-template void ObjectDetection3d::preprocess<float>(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo,
-                                                                                                  vector<float> &inputVector);
-template void ObjectDetection3d::inference<float>(vector<vector<float> > &inputVectors);
-template void ObjectDetection3d::perform<float>(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo);
-
-template void ObjectDetection3d::preprocess<unsigned char>(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo,
-                                                                                                                  vector<unsigned char> &inputVector);
-template void ObjectDetection3d::inference<unsigned char>(vector<vector<unsigned char> > &inputVectors);
-template void ObjectDetection3d::perform<unsigned char>(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo);
+template class ObjectDetection3d<float>;
+template class ObjectDetection3d<unsigned char>;
 
 }
 }
\ No newline at end of file
index 728d537..d2fa862 100644 (file)
@@ -38,14 +38,18 @@ template<typename T, typename V> ObjectDetection3dAdapter<T, V>::~ObjectDetectio
 
 template<typename T, typename V> void ObjectDetection3dAdapter<T, V>::create(ObjectDetection3dTaskType task_type)
 {
-       if (_object_detection_3d) {
-               // If current task type is same as a given one then skip.
-               if (_object_detection_3d->getTaskType() == task_type)
-                       return;
-       }
+       _config->loadMetaFile(make_unique<ObjectDetection3dParser>(static_cast<int>(task_type)));
+       mv_inference_data_type_e dataType = _config->getInputMetaMap().begin()->second->dataType;
 
-       if (task_type == ObjectDetection3dTaskType::OBJECTRON)
-               _object_detection_3d = make_unique<Objectron>(task_type, _config);
+       if (task_type == ObjectDetection3dTaskType::OBJECTRON) {
+               if (dataType == MV_INFERENCE_DATA_UINT8)
+                       _object_detection_3d = make_unique<Objectron<unsigned char> >(task_type, _config);
+               else if (dataType == MV_INFERENCE_DATA_FLOAT32)
+                       _object_detection_3d = make_unique<Objectron<float> >(task_type, _config);
+               else
+                       throw InvalidOperation("Invalid model data type.");
+       } else
+               throw InvalidParameter("Invalid object detection 3d task type.");
 }
 
 template<typename T, typename V>
@@ -126,12 +130,7 @@ template<typename T, typename V> void ObjectDetection3dAdapter<T, V>::setInput(T
 template<typename T, typename V> void ObjectDetection3dAdapter<T, V>::perform()
 {
        shared_ptr<MetaInfo> metaInfo = _object_detection_3d->getInputMetaInfo();
-       if (metaInfo->dataType == MV_INFERENCE_DATA_UINT8)
-               _object_detection_3d->perform<unsigned char>(_source.inference_src, metaInfo);
-       else if (metaInfo->dataType == MV_INFERENCE_DATA_FLOAT32)
-               _object_detection_3d->perform<float>(_source.inference_src, metaInfo);
-       else
-               throw InvalidOperation("Invalid model data type.");
+       _object_detection_3d->perform(_source.inference_src, metaInfo);
 }
 
 template<typename T, typename V> void ObjectDetection3dAdapter<T, V>::performAsync(T &t)
index 6cb13b1..030456b 100644 (file)
@@ -30,22 +30,23 @@ namespace mediavision
 {
 namespace machine_learning
 {
-Objectron::Objectron(ObjectDetection3dTaskType task_type, std::shared_ptr<MachineLearningConfig> config)
-               : ObjectDetection3d(task_type, config), _result()
+template<typename T>
+Objectron<T>::Objectron(ObjectDetection3dTaskType task_type, std::shared_ptr<MachineLearningConfig> config)
+               : ObjectDetection3d<T>(task_type, config), _result()
 {}
 
-Objectron::~Objectron()
+template<typename T> Objectron<T>::~Objectron()
 {}
 
-ObjectDetection3dResult &Objectron::result()
+template<typename T> ObjectDetection3dResult &Objectron<T>::result()
 {
        vector<string> names;
 
-       ObjectDetection3d::getOutputNames(names);
+       ObjectDetection3d<T>::getOutputNames(names);
 
        vector<float> keypoints;
 
-       ObjectDetection3d::getOutputTensor(names[1], keypoints);
+       ObjectDetection3d<T>::getOutputTensor(names[1], keypoints);
 
        size_t output_size = keypoints.size();
 
@@ -65,7 +66,7 @@ ObjectDetection3dResult &Objectron::result()
        vector<float> probability_vec;
 
        // names[0] is "Identity"
-       ObjectDetection3d::getOutputTensor(names[0], probability_vec);
+       ObjectDetection3d<T>::getOutputTensor(names[0], probability_vec);
 
        _result.probability = static_cast<unsigned int>(probability_vec[0] * 100);
 
@@ -85,5 +86,8 @@ ObjectDetection3dResult &Objectron::result()
        return _result;
 }
 
+template class Objectron<float>;
+template class Objectron<unsigned char>;
+
 }
 }