mv_machine_learning: convert ObjectDetection class into template class
authorVibhav Aggarwal <v.aggarwal@samsung.com>
Fri, 10 Nov 2023 11:27:23 +0000 (20:27 +0900)
committerKwanghoon Son <k.son@samsung.com>
Tue, 14 Nov 2023 07:52:55 +0000 (16:52 +0900)
[Issue type] code refactoring

Change-Id: Ib84f8211a919c82e389659bdce1b652f4419e2fb
Signed-off-by: Vibhav Aggarwal <v.aggarwal@samsung.com>
mv_machine_learning/object_detection/include/mobilenet_v1_ssd.h
mv_machine_learning/object_detection/include/mobilenet_v2_ssd.h
mv_machine_learning/object_detection/include/object_detection.h
mv_machine_learning/object_detection/src/face_detection_adapter.cpp
mv_machine_learning/object_detection/src/mobilenet_v1_ssd.cpp
mv_machine_learning/object_detection/src/mobilenet_v2_ssd.cpp
mv_machine_learning/object_detection/src/object_detection.cpp
mv_machine_learning/object_detection/src/object_detection_adapter.cpp

index 3d78f1e..bb309bf 100644 (file)
@@ -29,8 +29,12 @@ namespace mediavision
 {
 namespace machine_learning
 {
-class MobilenetV1Ssd : public ObjectDetection
+template<typename T> class MobilenetV1Ssd : public ObjectDetection<T>
 {
+       using ObjectDetection<T>::_config;
+       using ObjectDetection<T>::_preprocess;
+       using ObjectDetection<T>::_labels;
+
 private:
        ObjectDetectionResult _result;
 
index 232512d..c598c43 100644 (file)
@@ -29,8 +29,12 @@ namespace mediavision
 {
 namespace machine_learning
 {
-class MobilenetV2Ssd : public ObjectDetection
+template<typename T> class MobilenetV2Ssd : public ObjectDetection<T>
 {
+       using ObjectDetection<T>::_config;
+       using ObjectDetection<T>::_preprocess;
+       using ObjectDetection<T>::_labels;
+
 private:
        ObjectDetectionResult _result;
 
index 1468cc3..4669b89 100644 (file)
@@ -41,7 +41,7 @@ namespace mediavision
 {
 namespace machine_learning
 {
-class ObjectDetection : public IObjectDetection
+template<typename T> class ObjectDetection : public IObjectDetection
 {
 private:
        ObjectDetectionTaskType _task_type { ObjectDetectionTaskType::OBJECT_DETECTION_TASK_NONE };
@@ -51,11 +51,8 @@ private:
        void loadLabel();
        void getEngineList();
        void getDeviceList(const char *engine_type);
-       template<typename T>
        void preprocess(mv_source_h &mv_src, std::shared_ptr<MetaInfo> metaInfo, std::vector<T> &inputVector);
        std::shared_ptr<MetaInfo> getInputMetaInfo();
-       template<typename T> void perform(mv_source_h &mv_src, std::shared_ptr<MetaInfo> metaInfo);
-       template<typename T> void performAsync(ObjectDetectionInput &input, std::shared_ptr<MetaInfo> metaInfo);
 
 protected:
        std::unique_ptr<mediavision::inference::Inference> _inference;
@@ -67,7 +64,7 @@ protected:
 
        void getOutputNames(std::vector<std::string> &names);
        void getOutputTensor(std::string target_name, std::vector<float> &tensor);
-       template<typename T> void inference(std::vector<std::vector<T> > &inputVectors);
+       void inference(std::vector<std::vector<T> > &inputVectors);
        virtual ObjectDetectionResult &result() = 0;
 
 public:
index 4c72a4a..526d805 100644 (file)
@@ -43,16 +43,21 @@ template<typename T, typename V> FaceDetectionAdapter<T, V>::~FaceDetectionAdapt
 
 template<typename T, typename V> void FaceDetectionAdapter<T, V>::create(ObjectDetectionTaskType task_type)
 {
-       // If a concrete class object created already exists, reset the object
-       // so that other concrete class object can be created again according to a given task_type.
-       if (_object_detection) {
-               // If default task type is same as a given one then skip.
-               if (_object_detection->getTaskType() == task_type)
-                       return;
+       _config->loadMetaFile(make_unique<ObjectDetectionParser>(static_cast<int>(task_type)));
+       mv_inference_data_type_e dataType = _config->getInputMetaMap().begin()->second->dataType;
+
+       switch (task_type) {
+       case ObjectDetectionTaskType::FD_MOBILENET_V1_SSD:
+               if (dataType == MV_INFERENCE_DATA_UINT8)
+                       _object_detection = make_unique<MobilenetV1Ssd<unsigned char> >(task_type, _config);
+               else if (dataType == MV_INFERENCE_DATA_FLOAT32)
+                       _object_detection = make_unique<MobilenetV1Ssd<float> >(task_type, _config);
+               else
+                       throw InvalidOperation("Invalid model data type.");
+               break;
+       default:
+               throw InvalidOperation("Invalid face detection task type.");
        }
-
-       if (task_type == ObjectDetectionTaskType::FD_MOBILENET_V1_SSD)
-               _object_detection = make_unique<MobilenetV1Ssd>(task_type, _config);
        // TODO.
 }
 
index cf70ac9..d08d239 100644 (file)
@@ -31,14 +31,15 @@ namespace mediavision
 {
 namespace machine_learning
 {
-MobilenetV1Ssd::MobilenetV1Ssd(ObjectDetectionTaskType task_type, std::shared_ptr<MachineLearningConfig> config)
-               : ObjectDetection(task_type, config), _result()
+template<typename T>
+MobilenetV1Ssd<T>::MobilenetV1Ssd(ObjectDetectionTaskType task_type, std::shared_ptr<MachineLearningConfig> config)
+               : ObjectDetection<T>(task_type, config), _result()
 {}
 
-MobilenetV1Ssd::~MobilenetV1Ssd()
+template<typename T> MobilenetV1Ssd<T>::~MobilenetV1Ssd()
 {}
 
-ObjectDetectionResult &MobilenetV1Ssd::result()
+template<typename T> ObjectDetectionResult &MobilenetV1Ssd<T>::result()
 {
        // Clear _result object because result() function can be called every time user wants
        // so make sure to clear existing result data before getting the data again.
@@ -46,17 +47,17 @@ ObjectDetectionResult &MobilenetV1Ssd::result()
 
        vector<string> names;
 
-       ObjectDetection::getOutputNames(names);
+       ObjectDetection<T>::getOutputNames(names);
 
        vector<float> number_tensor;
 
        // TFLite_Detection_PostProcess:3
-       ObjectDetection::getOutputTensor(names[3], number_tensor);
+       ObjectDetection<T>::getOutputTensor(names[3], number_tensor);
 
        vector<float> label_tensor;
 
        // TFLite_Detection_PostProcess:1
-       ObjectDetection::getOutputTensor(names[1], label_tensor);
+       ObjectDetection<T>::getOutputTensor(names[1], label_tensor);
 
        vector<float> score_tensor;
        map<float, unsigned int, std::greater<float> > sorted_score;
@@ -65,7 +66,7 @@ ObjectDetectionResult &MobilenetV1Ssd::result()
        auto decodingScore = static_pointer_cast<DecodingScore>(scoreMetaInfo->decodingTypeMap[DecodingType::SCORE]);
 
        // TFLite_Detection_PostProcess:2
-       ObjectDetection::getOutputTensor(names[2], score_tensor);
+       ObjectDetection<T>::getOutputTensor(names[2], score_tensor);
        for (size_t idx = 0; idx < score_tensor.size(); ++idx) {
                if (decodingScore->threshold > score_tensor[idx])
                        continue;
@@ -77,7 +78,7 @@ ObjectDetectionResult &MobilenetV1Ssd::result()
        auto decodingBox = static_pointer_cast<DecodingBox>(boxMetaInfo->decodingTypeMap[DecodingType::BOX]);
        vector<float> box_tensor;
 
-       ObjectDetection::getOutputTensor(names[0], box_tensor);
+       ObjectDetection<T>::getOutputTensor(names[0], box_tensor);
 
        for (auto &score : sorted_score) {
                _result.number_of_objects++;
@@ -109,5 +110,8 @@ ObjectDetectionResult &MobilenetV1Ssd::result()
        return _result;
 }
 
+template class MobilenetV1Ssd<float>;
+template class MobilenetV1Ssd<unsigned char>;
+
 }
 }
index 3032909..3af71b7 100644 (file)
@@ -32,11 +32,12 @@ namespace mediavision
 {
 namespace machine_learning
 {
-MobilenetV2Ssd::MobilenetV2Ssd(ObjectDetectionTaskType task_type, std::shared_ptr<MachineLearningConfig> config)
-               : ObjectDetection(task_type, config), _result()
+template<typename T>
+MobilenetV2Ssd<T>::MobilenetV2Ssd(ObjectDetectionTaskType task_type, std::shared_ptr<MachineLearningConfig> config)
+               : ObjectDetection<T>(task_type, config), _result()
 {}
 
-MobilenetV2Ssd::~MobilenetV2Ssd()
+template<typename T> MobilenetV2Ssd<T>::~MobilenetV2Ssd()
 {}
 
 static bool compareScore(Box box0, Box box1)
@@ -80,8 +81,9 @@ static float calcIntersectionOverUnion(Box box0, Box box1)
        return intersectArea / (area0 + area1 - intersectArea);
 }
 
-void MobilenetV2Ssd::ApplyNms(vector<vector<Box> > &box_lists, BoxNmsMode mode, float threshold,
-                                                         vector<Box> &box_vector)
+template<typename T>
+void MobilenetV2Ssd<T>::ApplyNms(vector<vector<Box> > &box_lists, BoxNmsMode mode, float threshold,
+                                                                vector<Box> &box_vector)
 {
        LOGI("ENTER");
 
@@ -129,8 +131,9 @@ void MobilenetV2Ssd::ApplyNms(vector<vector<Box> > &box_lists, BoxNmsMode mode,
        LOGI("LEAVE");
 }
 
-Box MobilenetV2Ssd::decodeBox(const DecodingBox *decodingBox, vector<float> &bb_tensor, int idx, float score, int label,
-                                                         int box_offset)
+template<typename T>
+Box MobilenetV2Ssd<T>::decodeBox(const DecodingBox *decodingBox, vector<float> &bb_tensor, int idx, float score,
+                                                                int label, int box_offset)
 {
        // assume type is (cx,cy,w,h)
        // left or cx
@@ -149,7 +152,8 @@ Box MobilenetV2Ssd::decodeBox(const DecodingBox *decodingBox, vector<float> &bb_
        return box;
 }
 
-Box MobilenetV2Ssd::decodeBoxWithAnchor(const BoxAnchorParam *boxAnchorParam, Box &box, cv::Rect2f &anchor)
+template<typename T>
+Box MobilenetV2Ssd<T>::decodeBoxWithAnchor(const BoxAnchorParam *boxAnchorParam, Box &box, cv::Rect2f &anchor)
 {
        if (boxAnchorParam->isFixedAnchorSize) {
                box.location.x += anchor.x;
@@ -170,7 +174,7 @@ Box MobilenetV2Ssd::decodeBoxWithAnchor(const BoxAnchorParam *boxAnchorParam, Bo
        return box;
 }
 
-ObjectDetectionResult &MobilenetV2Ssd::result()
+template<typename T> ObjectDetectionResult &MobilenetV2Ssd<T>::result()
 {
        // Clear _result object because result() function can be called every time user wants
        // so make sure to clear existing result data before getting the data again.
@@ -178,12 +182,12 @@ ObjectDetectionResult &MobilenetV2Ssd::result()
 
        vector<string> names;
 
-       ObjectDetection::getOutputNames(names);
+       ObjectDetection<T>::getOutputNames(names);
 
        vector<float> score_tensor;
 
        // raw_outputs/class_predictions
-       ObjectDetection::getOutputTensor(names[1], score_tensor);
+       ObjectDetection<T>::getOutputTensor(names[1], score_tensor);
 
        auto scoreMetaInfo = _config->getOutputMetaMap().at(names[1]);
        auto decodingScore = static_pointer_cast<DecodingScore>(scoreMetaInfo->decodingTypeMap[DecodingType::SCORE]);
@@ -196,7 +200,7 @@ ObjectDetectionResult &MobilenetV2Ssd::result()
        vector<float> bb_tensor;
 
        // raw_outputs/box_encodings
-       ObjectDetection::getOutputTensor(names[0], bb_tensor);
+       ObjectDetection<T>::getOutputTensor(names[0], bb_tensor);
 
        vector<Box> box_vec;
        vector<vector<Box> > box_list_vec;
@@ -258,5 +262,8 @@ ObjectDetectionResult &MobilenetV2Ssd::result()
        return _result;
 }
 
+template class MobilenetV2Ssd<float>;
+template class MobilenetV2Ssd<unsigned char>;
+
 }
 }
index 8216b30..5802070 100644 (file)
@@ -34,13 +34,14 @@ namespace mediavision
 {
 namespace machine_learning
 {
-ObjectDetection::ObjectDetection(ObjectDetectionTaskType task_type, shared_ptr<MachineLearningConfig> config)
+template<typename T>
+ObjectDetection<T>::ObjectDetection(ObjectDetectionTaskType task_type, shared_ptr<MachineLearningConfig> config)
                : _task_type(task_type), _config(config)
 {
        _inference = make_unique<Inference>();
 }
 
-void ObjectDetection::preDestroy()
+template<typename T> void ObjectDetection<T>::preDestroy()
 {
        if (!_async_manager)
                return;
@@ -48,12 +49,12 @@ void ObjectDetection::preDestroy()
        _async_manager->stop();
 }
 
-ObjectDetectionTaskType ObjectDetection::getTaskType()
+template<typename T> ObjectDetectionTaskType ObjectDetection<T>::getTaskType()
 {
        return _task_type;
 }
 
-void ObjectDetection::getEngineList()
+template<typename T> void ObjectDetection<T>::getEngineList()
 {
        for (auto idx = MV_INFERENCE_BACKEND_NONE + 1; idx < MV_INFERENCE_BACKEND_MAX; ++idx) {
                auto backend = _inference->getSupportedInferenceBackend(idx);
@@ -65,7 +66,7 @@ void ObjectDetection::getEngineList()
        }
 }
 
-void ObjectDetection::getDeviceList(const char *engine_type)
+template<typename T> void ObjectDetection<T>::getDeviceList(const char *engine_type)
 {
        // TODO. add device types available for a given engine type later.
        //       In default, cpu and gpu only.
@@ -73,7 +74,7 @@ void ObjectDetection::getDeviceList(const char *engine_type)
        _valid_devices.push_back("gpu");
 }
 
-void ObjectDetection::setEngineInfo(std::string engine_type_name, std::string device_type_name)
+template<typename T> void ObjectDetection<T>::setEngineInfo(std::string engine_type_name, std::string device_type_name)
 {
        if (engine_type_name.empty() || device_type_name.empty())
                throw InvalidParameter("Invalid engine info.");
@@ -94,7 +95,7 @@ void ObjectDetection::setEngineInfo(std::string engine_type_name, std::string de
                 device_type_name.c_str(), device_type);
 }
 
-void ObjectDetection::getNumberOfEngines(unsigned int *number_of_engines)
+template<typename T> void ObjectDetection<T>::getNumberOfEngines(unsigned int *number_of_engines)
 {
        if (!_valid_backends.empty()) {
                *number_of_engines = _valid_backends.size();
@@ -105,7 +106,7 @@ void ObjectDetection::getNumberOfEngines(unsigned int *number_of_engines)
        *number_of_engines = _valid_backends.size();
 }
 
-void ObjectDetection::getEngineType(unsigned int engine_index, char **engine_type)
+template<typename T> void ObjectDetection<T>::getEngineType(unsigned int engine_index, char **engine_type)
 {
        if (!_valid_backends.empty()) {
                if (_valid_backends.size() <= engine_index)
@@ -123,7 +124,8 @@ void ObjectDetection::getEngineType(unsigned int engine_index, char **engine_typ
        *engine_type = const_cast<char *>(_valid_backends[engine_index].data());
 }
 
-void ObjectDetection::getNumberOfDevices(const char *engine_type, unsigned int *number_of_devices)
+template<typename T>
+void ObjectDetection<T>::getNumberOfDevices(const char *engine_type, unsigned int *number_of_devices)
 {
        if (!_valid_devices.empty()) {
                *number_of_devices = _valid_devices.size();
@@ -134,7 +136,8 @@ void ObjectDetection::getNumberOfDevices(const char *engine_type, unsigned int *
        *number_of_devices = _valid_devices.size();
 }
 
-void ObjectDetection::getDeviceType(const char *engine_type, const unsigned int device_index, char **device_type)
+template<typename T>
+void ObjectDetection<T>::getDeviceType(const char *engine_type, const unsigned int device_index, char **device_type)
 {
        if (!_valid_devices.empty()) {
                if (_valid_devices.size() <= device_index)
@@ -152,7 +155,7 @@ void ObjectDetection::getDeviceType(const char *engine_type, const unsigned int
        *device_type = const_cast<char *>(_valid_devices[device_index].data());
 }
 
-void ObjectDetection::loadLabel()
+template<typename T> void ObjectDetection<T>::loadLabel()
 {
        if (_config->getLabelFilePath().empty())
                return;
@@ -173,9 +176,8 @@ void ObjectDetection::loadLabel()
        readFile.close();
 }
 
-void ObjectDetection::configure()
+template<typename T> void ObjectDetection<T>::configure()
 {
-       _config->loadMetaFile(make_unique<ObjectDetectionParser>(static_cast<int>(_task_type)));
        loadLabel();
 
        int ret = _inference->bind(_config->getBackendType(), _config->getTargetDeviceType());
@@ -183,7 +185,7 @@ void ObjectDetection::configure()
                throw InvalidOperation("Fail to bind a backend engine.");
 }
 
-void ObjectDetection::prepare()
+template<typename T> void ObjectDetection<T>::prepare()
 {
        int ret = _inference->configureInputMetaInfo(_config->getInputMetaMap());
        if (ret != MEDIA_VISION_ERROR_NONE)
@@ -201,7 +203,7 @@ void ObjectDetection::prepare()
                throw InvalidOperation("Fail to load model files.");
 }
 
-shared_ptr<MetaInfo> ObjectDetection::getInputMetaInfo()
+template<typename T> shared_ptr<MetaInfo> ObjectDetection<T>::getInputMetaInfo()
 {
        TensorBuffer &tensor_buffer = _inference->getInputTensorBuffer();
        IETensorBuffer &tensor_info_map = tensor_buffer.getIETensorBuffer();
@@ -217,7 +219,7 @@ shared_ptr<MetaInfo> ObjectDetection::getInputMetaInfo()
 }
 
 template<typename T>
-void ObjectDetection::preprocess(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo, vector<T> &inputVector)
+void ObjectDetection<T>::preprocess(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo, vector<T> &inputVector)
 {
        LOGI("ENTER");
 
@@ -249,7 +251,7 @@ void ObjectDetection::preprocess(mv_source_h &mv_src, shared_ptr<MetaInfo> metaI
        LOGI("LEAVE");
 }
 
-template<typename T> void ObjectDetection::inference(vector<vector<T> > &inputVectors)
+template<typename T> void ObjectDetection<T>::inference(vector<vector<T> > &inputVectors)
 {
        LOGI("ENTER");
 
@@ -260,37 +262,24 @@ template<typename T> void ObjectDetection::inference(vector<vector<T> > &inputVe
        LOGI("LEAVE");
 }
 
-template<typename T> void ObjectDetection::perform(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo)
+template<typename T> void ObjectDetection<T>::perform(mv_source_h &mv_src)
 {
+       shared_ptr<MetaInfo> metaInfo = getInputMetaInfo();
        vector<T> inputVector;
 
-       preprocess<T>(mv_src, metaInfo, inputVector);
+       preprocess(mv_src, metaInfo, inputVector);
 
        vector<vector<T> > inputVectors = { inputVector };
-
-       inference<T>(inputVectors);
-
-       // TODO. Update operation status here.
-}
-
-void ObjectDetection::perform(mv_source_h &mv_src)
-{
-       shared_ptr<MetaInfo> metaInfo = getInputMetaInfo();
-       if (metaInfo->dataType == MV_INFERENCE_DATA_UINT8)
-               perform<unsigned char>(mv_src, metaInfo);
-       else if (metaInfo->dataType == MV_INFERENCE_DATA_FLOAT32)
-               perform<float>(mv_src, metaInfo);
-       else
-               throw InvalidOperation("Invalid model data type.");
+       inference(inputVectors);
 }
 
-template<typename T> void ObjectDetection::performAsync(ObjectDetectionInput &input, shared_ptr<MetaInfo> metaInfo)
+template<typename T> void ObjectDetection<T>::performAsync(ObjectDetectionInput &input)
 {
        if (!_async_manager) {
                _async_manager = make_unique<AsyncManager<ObjectDetectionResult> >([this]() {
                        AsyncInputQueue<T> inputQueue = _async_manager->popFromInput<T>();
 
-                       inference<T>(inputQueue.inputs);
+                       inference(inputQueue.inputs);
 
                        ObjectDetectionResult &resultQueue = result();
 
@@ -299,30 +288,16 @@ template<typename T> void ObjectDetection::performAsync(ObjectDetectionInput &in
                });
        }
 
+       shared_ptr<MetaInfo> metaInfo = getInputMetaInfo();
        vector<T> inputVector;
 
-       preprocess<T>(input.inference_src, metaInfo, inputVector);
+       preprocess(input.inference_src, metaInfo, inputVector);
 
        vector<vector<T> > inputVectors = { inputVector };
-
        _async_manager->push(inputVectors);
 }
 
-void ObjectDetection::performAsync(ObjectDetectionInput &input)
-{
-       shared_ptr<MetaInfo> metaInfo = getInputMetaInfo();
-
-       if (metaInfo->dataType == MV_INFERENCE_DATA_UINT8) {
-               performAsync<unsigned char>(input, metaInfo);
-       } else if (metaInfo->dataType == MV_INFERENCE_DATA_FLOAT32) {
-               performAsync<float>(input, metaInfo);
-               // TODO
-       } else {
-               throw InvalidOperation("Invalid model data type.");
-       }
-}
-
-ObjectDetectionResult &ObjectDetection::getOutput()
+template<typename T> ObjectDetectionResult &ObjectDetection<T>::getOutput()
 {
        if (_async_manager) {
                if (!_async_manager->isWorking())
@@ -338,12 +313,12 @@ ObjectDetectionResult &ObjectDetection::getOutput()
        return _current_result;
 }
 
-ObjectDetectionResult &ObjectDetection::getOutputCache()
+template<typename T> ObjectDetectionResult &ObjectDetection<T>::getOutputCache()
 {
        return _current_result;
 }
 
-void ObjectDetection::getOutputNames(vector<string> &names)
+template<typename T> void ObjectDetection<T>::getOutputNames(vector<string> &names)
 {
        TensorBuffer &tensor_buffer_obj = _inference->getOutputTensorBuffer();
        IETensorBuffer &ie_tensor_buffer = tensor_buffer_obj.getIETensorBuffer();
@@ -352,7 +327,7 @@ void ObjectDetection::getOutputNames(vector<string> &names)
                names.push_back(it->first);
 }
 
-void ObjectDetection::getOutputTensor(string target_name, vector<float> &tensor)
+template<typename T> void ObjectDetection<T>::getOutputTensor(string target_name, vector<float> &tensor)
 {
        TensorBuffer &tensor_buffer_obj = _inference->getOutputTensorBuffer();
 
@@ -365,17 +340,8 @@ void ObjectDetection::getOutputTensor(string target_name, vector<float> &tensor)
        copy(&raw_buffer[0], &raw_buffer[tensor_buffer->size / sizeof(float)], back_inserter(tensor));
 }
 
-template void ObjectDetection::preprocess<float>(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo,
-                                                                                                vector<float> &inputVector);
-template void ObjectDetection::inference<float>(vector<vector<float> > &inputVectors);
-template void ObjectDetection::perform<float>(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo);
-template void ObjectDetection::performAsync<float>(ObjectDetectionInput &input, shared_ptr<MetaInfo> metaInfo);
-
-template void ObjectDetection::preprocess<unsigned char>(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo,
-                                                                                                                vector<unsigned char> &inputVector);
-template void ObjectDetection::inference<unsigned char>(vector<vector<unsigned char> > &inputVectors);
-template void ObjectDetection::perform<unsigned char>(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo);
-template void ObjectDetection::performAsync<unsigned char>(ObjectDetectionInput &input, shared_ptr<MetaInfo> metaInfo);
+template class ObjectDetection<float>;
+template class ObjectDetection<unsigned char>;
 
 }
 }
index 3bdd82a..c8431e2 100644 (file)
@@ -44,20 +44,35 @@ template<typename T, typename V> ObjectDetectionAdapter<T, V>::~ObjectDetectionA
 
 template<typename T, typename V> void ObjectDetectionAdapter<T, V>::create(ObjectDetectionTaskType task_type)
 {
-       // If a concrete class object created already exists, reset the object
-       // so that other concrete class object can be created again according to a given task_type.
-       if (_object_detection) {
-               // If default task type is same as a given one then skip.
-               if (_object_detection->getTaskType() == task_type)
-                       return;
-       }
-
-       if (task_type == ObjectDetectionTaskType::MOBILENET_V1_SSD)
-               _object_detection = make_unique<MobilenetV1Ssd>(task_type, _config);
-       else if (task_type == ObjectDetectionTaskType::MOBILENET_V2_SSD)
-               _object_detection = make_unique<MobilenetV2Ssd>(task_type, _config);
-       else if (task_type == ObjectDetectionTaskType::OD_PLUGIN || task_type == ObjectDetectionTaskType::FD_PLUGIN)
+       _config->loadMetaFile(make_unique<ObjectDetectionParser>(static_cast<int>(task_type)));
+       mv_inference_data_type_e dataType = _config->getInputMetaMap().begin()->second->dataType;
+
+       switch (task_type) {
+       case ObjectDetectionTaskType::MOBILENET_V1_SSD:
+               if (dataType == MV_INFERENCE_DATA_UINT8)
+                       _object_detection = make_unique<MobilenetV1Ssd<unsigned char> >(task_type, _config);
+               else if (dataType == MV_INFERENCE_DATA_FLOAT32)
+                       _object_detection = make_unique<MobilenetV1Ssd<float> >(task_type, _config);
+               else
+                       throw InvalidOperation("Invalid model data type.");
+               break;
+       case ObjectDetectionTaskType::MOBILENET_V2_SSD:
+               if (dataType == MV_INFERENCE_DATA_UINT8)
+                       _object_detection = make_unique<MobilenetV2Ssd<unsigned char> >(task_type, _config);
+               else if (dataType == MV_INFERENCE_DATA_FLOAT32)
+                       _object_detection = make_unique<MobilenetV2Ssd<float> >(task_type, _config);
+               else
+                       throw InvalidOperation("Invalid model data type.");
+               break;
+       case ObjectDetectionTaskType::OD_PLUGIN:
                _object_detection = make_unique<ObjectDetectionExternal>(task_type);
+               break;
+       case ObjectDetectionTaskType::FD_PLUGIN:
+               _object_detection = make_unique<ObjectDetectionExternal>(task_type);
+               break;
+       default:
+               throw InvalidOperation("Invalid object detection task type.");
+       }
        // TODO.
 }