From 491a9e8e12bd82df919c622b5a11528d2bccf42a Mon Sep 17 00:00:00 2001 From: Vibhav Aggarwal Date: Fri, 10 Nov 2023 20:27:23 +0900 Subject: [PATCH] mv_machine_learning: convert ObjectDetection class into template class [Issue type] code refactoring Change-Id: Ib84f8211a919c82e389659bdce1b652f4419e2fb Signed-off-by: Vibhav Aggarwal --- .../object_detection/include/mobilenet_v1_ssd.h | 6 +- .../object_detection/include/mobilenet_v2_ssd.h | 6 +- .../object_detection/include/object_detection.h | 7 +- .../src/face_detection_adapter.cpp | 23 +++-- .../object_detection/src/mobilenet_v1_ssd.cpp | 22 +++-- .../object_detection/src/mobilenet_v2_ssd.cpp | 31 ++++--- .../object_detection/src/object_detection.cpp | 100 +++++++-------------- .../src/object_detection_adapter.cpp | 41 ++++++--- 8 files changed, 119 insertions(+), 117 deletions(-) diff --git a/mv_machine_learning/object_detection/include/mobilenet_v1_ssd.h b/mv_machine_learning/object_detection/include/mobilenet_v1_ssd.h index 3d78f1e..bb309bf 100644 --- a/mv_machine_learning/object_detection/include/mobilenet_v1_ssd.h +++ b/mv_machine_learning/object_detection/include/mobilenet_v1_ssd.h @@ -29,8 +29,12 @@ namespace mediavision { namespace machine_learning { -class MobilenetV1Ssd : public ObjectDetection +template class MobilenetV1Ssd : public ObjectDetection { + using ObjectDetection::_config; + using ObjectDetection::_preprocess; + using ObjectDetection::_labels; + private: ObjectDetectionResult _result; diff --git a/mv_machine_learning/object_detection/include/mobilenet_v2_ssd.h b/mv_machine_learning/object_detection/include/mobilenet_v2_ssd.h index 232512d..c598c43 100644 --- a/mv_machine_learning/object_detection/include/mobilenet_v2_ssd.h +++ b/mv_machine_learning/object_detection/include/mobilenet_v2_ssd.h @@ -29,8 +29,12 @@ namespace mediavision { namespace machine_learning { -class MobilenetV2Ssd : public ObjectDetection +template class MobilenetV2Ssd : public ObjectDetection { + using ObjectDetection::_config; + using ObjectDetection::_preprocess; + using ObjectDetection::_labels; + private: ObjectDetectionResult _result; diff --git a/mv_machine_learning/object_detection/include/object_detection.h b/mv_machine_learning/object_detection/include/object_detection.h index 1468cc3..4669b89 100644 --- a/mv_machine_learning/object_detection/include/object_detection.h +++ b/mv_machine_learning/object_detection/include/object_detection.h @@ -41,7 +41,7 @@ namespace mediavision { namespace machine_learning { -class ObjectDetection : public IObjectDetection +template class ObjectDetection : public IObjectDetection { private: ObjectDetectionTaskType _task_type { ObjectDetectionTaskType::OBJECT_DETECTION_TASK_NONE }; @@ -51,11 +51,8 @@ private: void loadLabel(); void getEngineList(); void getDeviceList(const char *engine_type); - template void preprocess(mv_source_h &mv_src, std::shared_ptr metaInfo, std::vector &inputVector); std::shared_ptr getInputMetaInfo(); - template void perform(mv_source_h &mv_src, std::shared_ptr metaInfo); - template void performAsync(ObjectDetectionInput &input, std::shared_ptr metaInfo); protected: std::unique_ptr _inference; @@ -67,7 +64,7 @@ protected: void getOutputNames(std::vector &names); void getOutputTensor(std::string target_name, std::vector &tensor); - template void inference(std::vector > &inputVectors); + void inference(std::vector > &inputVectors); virtual ObjectDetectionResult &result() = 0; public: diff --git a/mv_machine_learning/object_detection/src/face_detection_adapter.cpp b/mv_machine_learning/object_detection/src/face_detection_adapter.cpp index 4c72a4a..526d805 100644 --- a/mv_machine_learning/object_detection/src/face_detection_adapter.cpp +++ b/mv_machine_learning/object_detection/src/face_detection_adapter.cpp @@ -43,16 +43,21 @@ template FaceDetectionAdapter::~FaceDetectionAdapt template void FaceDetectionAdapter::create(ObjectDetectionTaskType task_type) { - // If a concrete class object created already exists, reset the object - // so that other concrete class object can be created again according to a given task_type. - if (_object_detection) { - // If default task type is same as a given one then skip. - if (_object_detection->getTaskType() == task_type) - return; + _config->loadMetaFile(make_unique(static_cast(task_type))); + mv_inference_data_type_e dataType = _config->getInputMetaMap().begin()->second->dataType; + + switch (task_type) { + case ObjectDetectionTaskType::FD_MOBILENET_V1_SSD: + if (dataType == MV_INFERENCE_DATA_UINT8) + _object_detection = make_unique >(task_type, _config); + else if (dataType == MV_INFERENCE_DATA_FLOAT32) + _object_detection = make_unique >(task_type, _config); + else + throw InvalidOperation("Invalid model data type."); + break; + default: + throw InvalidOperation("Invalid face detection task type."); } - - if (task_type == ObjectDetectionTaskType::FD_MOBILENET_V1_SSD) - _object_detection = make_unique(task_type, _config); // TODO. } diff --git a/mv_machine_learning/object_detection/src/mobilenet_v1_ssd.cpp b/mv_machine_learning/object_detection/src/mobilenet_v1_ssd.cpp index cf70ac9..d08d239 100644 --- a/mv_machine_learning/object_detection/src/mobilenet_v1_ssd.cpp +++ b/mv_machine_learning/object_detection/src/mobilenet_v1_ssd.cpp @@ -31,14 +31,15 @@ namespace mediavision { namespace machine_learning { -MobilenetV1Ssd::MobilenetV1Ssd(ObjectDetectionTaskType task_type, std::shared_ptr config) - : ObjectDetection(task_type, config), _result() +template +MobilenetV1Ssd::MobilenetV1Ssd(ObjectDetectionTaskType task_type, std::shared_ptr config) + : ObjectDetection(task_type, config), _result() {} -MobilenetV1Ssd::~MobilenetV1Ssd() +template MobilenetV1Ssd::~MobilenetV1Ssd() {} -ObjectDetectionResult &MobilenetV1Ssd::result() +template ObjectDetectionResult &MobilenetV1Ssd::result() { // Clear _result object because result() function can be called every time user wants // so make sure to clear existing result data before getting the data again. @@ -46,17 +47,17 @@ ObjectDetectionResult &MobilenetV1Ssd::result() vector names; - ObjectDetection::getOutputNames(names); + ObjectDetection::getOutputNames(names); vector number_tensor; // TFLite_Detection_PostProcess:3 - ObjectDetection::getOutputTensor(names[3], number_tensor); + ObjectDetection::getOutputTensor(names[3], number_tensor); vector label_tensor; // TFLite_Detection_PostProcess:1 - ObjectDetection::getOutputTensor(names[1], label_tensor); + ObjectDetection::getOutputTensor(names[1], label_tensor); vector score_tensor; map > sorted_score; @@ -65,7 +66,7 @@ ObjectDetectionResult &MobilenetV1Ssd::result() auto decodingScore = static_pointer_cast(scoreMetaInfo->decodingTypeMap[DecodingType::SCORE]); // TFLite_Detection_PostProcess:2 - ObjectDetection::getOutputTensor(names[2], score_tensor); + ObjectDetection::getOutputTensor(names[2], score_tensor); for (size_t idx = 0; idx < score_tensor.size(); ++idx) { if (decodingScore->threshold > score_tensor[idx]) continue; @@ -77,7 +78,7 @@ ObjectDetectionResult &MobilenetV1Ssd::result() auto decodingBox = static_pointer_cast(boxMetaInfo->decodingTypeMap[DecodingType::BOX]); vector box_tensor; - ObjectDetection::getOutputTensor(names[0], box_tensor); + ObjectDetection::getOutputTensor(names[0], box_tensor); for (auto &score : sorted_score) { _result.number_of_objects++; @@ -109,5 +110,8 @@ ObjectDetectionResult &MobilenetV1Ssd::result() return _result; } +template class MobilenetV1Ssd; +template class MobilenetV1Ssd; + } } diff --git a/mv_machine_learning/object_detection/src/mobilenet_v2_ssd.cpp b/mv_machine_learning/object_detection/src/mobilenet_v2_ssd.cpp index 3032909..3af71b7 100644 --- a/mv_machine_learning/object_detection/src/mobilenet_v2_ssd.cpp +++ b/mv_machine_learning/object_detection/src/mobilenet_v2_ssd.cpp @@ -32,11 +32,12 @@ namespace mediavision { namespace machine_learning { -MobilenetV2Ssd::MobilenetV2Ssd(ObjectDetectionTaskType task_type, std::shared_ptr config) - : ObjectDetection(task_type, config), _result() +template +MobilenetV2Ssd::MobilenetV2Ssd(ObjectDetectionTaskType task_type, std::shared_ptr config) + : ObjectDetection(task_type, config), _result() {} -MobilenetV2Ssd::~MobilenetV2Ssd() +template MobilenetV2Ssd::~MobilenetV2Ssd() {} static bool compareScore(Box box0, Box box1) @@ -80,8 +81,9 @@ static float calcIntersectionOverUnion(Box box0, Box box1) return intersectArea / (area0 + area1 - intersectArea); } -void MobilenetV2Ssd::ApplyNms(vector > &box_lists, BoxNmsMode mode, float threshold, - vector &box_vector) +template +void MobilenetV2Ssd::ApplyNms(vector > &box_lists, BoxNmsMode mode, float threshold, + vector &box_vector) { LOGI("ENTER"); @@ -129,8 +131,9 @@ void MobilenetV2Ssd::ApplyNms(vector > &box_lists, BoxNmsMode mode, LOGI("LEAVE"); } -Box MobilenetV2Ssd::decodeBox(const DecodingBox *decodingBox, vector &bb_tensor, int idx, float score, int label, - int box_offset) +template +Box MobilenetV2Ssd::decodeBox(const DecodingBox *decodingBox, vector &bb_tensor, int idx, float score, + int label, int box_offset) { // assume type is (cx,cy,w,h) // left or cx @@ -149,7 +152,8 @@ Box MobilenetV2Ssd::decodeBox(const DecodingBox *decodingBox, vector &bb_ return box; } -Box MobilenetV2Ssd::decodeBoxWithAnchor(const BoxAnchorParam *boxAnchorParam, Box &box, cv::Rect2f &anchor) +template +Box MobilenetV2Ssd::decodeBoxWithAnchor(const BoxAnchorParam *boxAnchorParam, Box &box, cv::Rect2f &anchor) { if (boxAnchorParam->isFixedAnchorSize) { box.location.x += anchor.x; @@ -170,7 +174,7 @@ Box MobilenetV2Ssd::decodeBoxWithAnchor(const BoxAnchorParam *boxAnchorParam, Bo return box; } -ObjectDetectionResult &MobilenetV2Ssd::result() +template ObjectDetectionResult &MobilenetV2Ssd::result() { // Clear _result object because result() function can be called every time user wants // so make sure to clear existing result data before getting the data again. @@ -178,12 +182,12 @@ ObjectDetectionResult &MobilenetV2Ssd::result() vector names; - ObjectDetection::getOutputNames(names); + ObjectDetection::getOutputNames(names); vector score_tensor; // raw_outputs/class_predictions - ObjectDetection::getOutputTensor(names[1], score_tensor); + ObjectDetection::getOutputTensor(names[1], score_tensor); auto scoreMetaInfo = _config->getOutputMetaMap().at(names[1]); auto decodingScore = static_pointer_cast(scoreMetaInfo->decodingTypeMap[DecodingType::SCORE]); @@ -196,7 +200,7 @@ ObjectDetectionResult &MobilenetV2Ssd::result() vector bb_tensor; // raw_outputs/box_encodings - ObjectDetection::getOutputTensor(names[0], bb_tensor); + ObjectDetection::getOutputTensor(names[0], bb_tensor); vector box_vec; vector > box_list_vec; @@ -258,5 +262,8 @@ ObjectDetectionResult &MobilenetV2Ssd::result() return _result; } +template class MobilenetV2Ssd; +template class MobilenetV2Ssd; + } } diff --git a/mv_machine_learning/object_detection/src/object_detection.cpp b/mv_machine_learning/object_detection/src/object_detection.cpp index 8216b30..5802070 100644 --- a/mv_machine_learning/object_detection/src/object_detection.cpp +++ b/mv_machine_learning/object_detection/src/object_detection.cpp @@ -34,13 +34,14 @@ namespace mediavision { namespace machine_learning { -ObjectDetection::ObjectDetection(ObjectDetectionTaskType task_type, shared_ptr config) +template +ObjectDetection::ObjectDetection(ObjectDetectionTaskType task_type, shared_ptr config) : _task_type(task_type), _config(config) { _inference = make_unique(); } -void ObjectDetection::preDestroy() +template void ObjectDetection::preDestroy() { if (!_async_manager) return; @@ -48,12 +49,12 @@ void ObjectDetection::preDestroy() _async_manager->stop(); } -ObjectDetectionTaskType ObjectDetection::getTaskType() +template ObjectDetectionTaskType ObjectDetection::getTaskType() { return _task_type; } -void ObjectDetection::getEngineList() +template void ObjectDetection::getEngineList() { for (auto idx = MV_INFERENCE_BACKEND_NONE + 1; idx < MV_INFERENCE_BACKEND_MAX; ++idx) { auto backend = _inference->getSupportedInferenceBackend(idx); @@ -65,7 +66,7 @@ void ObjectDetection::getEngineList() } } -void ObjectDetection::getDeviceList(const char *engine_type) +template void ObjectDetection::getDeviceList(const char *engine_type) { // TODO. add device types available for a given engine type later. // In default, cpu and gpu only. @@ -73,7 +74,7 @@ void ObjectDetection::getDeviceList(const char *engine_type) _valid_devices.push_back("gpu"); } -void ObjectDetection::setEngineInfo(std::string engine_type_name, std::string device_type_name) +template void ObjectDetection::setEngineInfo(std::string engine_type_name, std::string device_type_name) { if (engine_type_name.empty() || device_type_name.empty()) throw InvalidParameter("Invalid engine info."); @@ -94,7 +95,7 @@ void ObjectDetection::setEngineInfo(std::string engine_type_name, std::string de device_type_name.c_str(), device_type); } -void ObjectDetection::getNumberOfEngines(unsigned int *number_of_engines) +template void ObjectDetection::getNumberOfEngines(unsigned int *number_of_engines) { if (!_valid_backends.empty()) { *number_of_engines = _valid_backends.size(); @@ -105,7 +106,7 @@ void ObjectDetection::getNumberOfEngines(unsigned int *number_of_engines) *number_of_engines = _valid_backends.size(); } -void ObjectDetection::getEngineType(unsigned int engine_index, char **engine_type) +template void ObjectDetection::getEngineType(unsigned int engine_index, char **engine_type) { if (!_valid_backends.empty()) { if (_valid_backends.size() <= engine_index) @@ -123,7 +124,8 @@ void ObjectDetection::getEngineType(unsigned int engine_index, char **engine_typ *engine_type = const_cast(_valid_backends[engine_index].data()); } -void ObjectDetection::getNumberOfDevices(const char *engine_type, unsigned int *number_of_devices) +template +void ObjectDetection::getNumberOfDevices(const char *engine_type, unsigned int *number_of_devices) { if (!_valid_devices.empty()) { *number_of_devices = _valid_devices.size(); @@ -134,7 +136,8 @@ void ObjectDetection::getNumberOfDevices(const char *engine_type, unsigned int * *number_of_devices = _valid_devices.size(); } -void ObjectDetection::getDeviceType(const char *engine_type, const unsigned int device_index, char **device_type) +template +void ObjectDetection::getDeviceType(const char *engine_type, const unsigned int device_index, char **device_type) { if (!_valid_devices.empty()) { if (_valid_devices.size() <= device_index) @@ -152,7 +155,7 @@ void ObjectDetection::getDeviceType(const char *engine_type, const unsigned int *device_type = const_cast(_valid_devices[device_index].data()); } -void ObjectDetection::loadLabel() +template void ObjectDetection::loadLabel() { if (_config->getLabelFilePath().empty()) return; @@ -173,9 +176,8 @@ void ObjectDetection::loadLabel() readFile.close(); } -void ObjectDetection::configure() +template void ObjectDetection::configure() { - _config->loadMetaFile(make_unique(static_cast(_task_type))); loadLabel(); int ret = _inference->bind(_config->getBackendType(), _config->getTargetDeviceType()); @@ -183,7 +185,7 @@ void ObjectDetection::configure() throw InvalidOperation("Fail to bind a backend engine."); } -void ObjectDetection::prepare() +template void ObjectDetection::prepare() { int ret = _inference->configureInputMetaInfo(_config->getInputMetaMap()); if (ret != MEDIA_VISION_ERROR_NONE) @@ -201,7 +203,7 @@ void ObjectDetection::prepare() throw InvalidOperation("Fail to load model files."); } -shared_ptr ObjectDetection::getInputMetaInfo() +template shared_ptr ObjectDetection::getInputMetaInfo() { TensorBuffer &tensor_buffer = _inference->getInputTensorBuffer(); IETensorBuffer &tensor_info_map = tensor_buffer.getIETensorBuffer(); @@ -217,7 +219,7 @@ shared_ptr ObjectDetection::getInputMetaInfo() } template -void ObjectDetection::preprocess(mv_source_h &mv_src, shared_ptr metaInfo, vector &inputVector) +void ObjectDetection::preprocess(mv_source_h &mv_src, shared_ptr metaInfo, vector &inputVector) { LOGI("ENTER"); @@ -249,7 +251,7 @@ void ObjectDetection::preprocess(mv_source_h &mv_src, shared_ptr metaI LOGI("LEAVE"); } -template void ObjectDetection::inference(vector > &inputVectors) +template void ObjectDetection::inference(vector > &inputVectors) { LOGI("ENTER"); @@ -260,37 +262,24 @@ template void ObjectDetection::inference(vector > &inputVe LOGI("LEAVE"); } -template void ObjectDetection::perform(mv_source_h &mv_src, shared_ptr metaInfo) +template void ObjectDetection::perform(mv_source_h &mv_src) { + shared_ptr metaInfo = getInputMetaInfo(); vector inputVector; - preprocess(mv_src, metaInfo, inputVector); + preprocess(mv_src, metaInfo, inputVector); vector > inputVectors = { inputVector }; - - inference(inputVectors); - - // TODO. Update operation status here. -} - -void ObjectDetection::perform(mv_source_h &mv_src) -{ - shared_ptr metaInfo = getInputMetaInfo(); - if (metaInfo->dataType == MV_INFERENCE_DATA_UINT8) - perform(mv_src, metaInfo); - else if (metaInfo->dataType == MV_INFERENCE_DATA_FLOAT32) - perform(mv_src, metaInfo); - else - throw InvalidOperation("Invalid model data type."); + inference(inputVectors); } -template void ObjectDetection::performAsync(ObjectDetectionInput &input, shared_ptr metaInfo) +template void ObjectDetection::performAsync(ObjectDetectionInput &input) { if (!_async_manager) { _async_manager = make_unique >([this]() { AsyncInputQueue inputQueue = _async_manager->popFromInput(); - inference(inputQueue.inputs); + inference(inputQueue.inputs); ObjectDetectionResult &resultQueue = result(); @@ -299,30 +288,16 @@ template void ObjectDetection::performAsync(ObjectDetectionInput &in }); } + shared_ptr metaInfo = getInputMetaInfo(); vector inputVector; - preprocess(input.inference_src, metaInfo, inputVector); + preprocess(input.inference_src, metaInfo, inputVector); vector > inputVectors = { inputVector }; - _async_manager->push(inputVectors); } -void ObjectDetection::performAsync(ObjectDetectionInput &input) -{ - shared_ptr metaInfo = getInputMetaInfo(); - - if (metaInfo->dataType == MV_INFERENCE_DATA_UINT8) { - performAsync(input, metaInfo); - } else if (metaInfo->dataType == MV_INFERENCE_DATA_FLOAT32) { - performAsync(input, metaInfo); - // TODO - } else { - throw InvalidOperation("Invalid model data type."); - } -} - -ObjectDetectionResult &ObjectDetection::getOutput() +template ObjectDetectionResult &ObjectDetection::getOutput() { if (_async_manager) { if (!_async_manager->isWorking()) @@ -338,12 +313,12 @@ ObjectDetectionResult &ObjectDetection::getOutput() return _current_result; } -ObjectDetectionResult &ObjectDetection::getOutputCache() +template ObjectDetectionResult &ObjectDetection::getOutputCache() { return _current_result; } -void ObjectDetection::getOutputNames(vector &names) +template void ObjectDetection::getOutputNames(vector &names) { TensorBuffer &tensor_buffer_obj = _inference->getOutputTensorBuffer(); IETensorBuffer &ie_tensor_buffer = tensor_buffer_obj.getIETensorBuffer(); @@ -352,7 +327,7 @@ void ObjectDetection::getOutputNames(vector &names) names.push_back(it->first); } -void ObjectDetection::getOutputTensor(string target_name, vector &tensor) +template void ObjectDetection::getOutputTensor(string target_name, vector &tensor) { TensorBuffer &tensor_buffer_obj = _inference->getOutputTensorBuffer(); @@ -365,17 +340,8 @@ void ObjectDetection::getOutputTensor(string target_name, vector &tensor) copy(&raw_buffer[0], &raw_buffer[tensor_buffer->size / sizeof(float)], back_inserter(tensor)); } -template void ObjectDetection::preprocess(mv_source_h &mv_src, shared_ptr metaInfo, - vector &inputVector); -template void ObjectDetection::inference(vector > &inputVectors); -template void ObjectDetection::perform(mv_source_h &mv_src, shared_ptr metaInfo); -template void ObjectDetection::performAsync(ObjectDetectionInput &input, shared_ptr metaInfo); - -template void ObjectDetection::preprocess(mv_source_h &mv_src, shared_ptr metaInfo, - vector &inputVector); -template void ObjectDetection::inference(vector > &inputVectors); -template void ObjectDetection::perform(mv_source_h &mv_src, shared_ptr metaInfo); -template void ObjectDetection::performAsync(ObjectDetectionInput &input, shared_ptr metaInfo); +template class ObjectDetection; +template class ObjectDetection; } } diff --git a/mv_machine_learning/object_detection/src/object_detection_adapter.cpp b/mv_machine_learning/object_detection/src/object_detection_adapter.cpp index 3bdd82a..c8431e2 100644 --- a/mv_machine_learning/object_detection/src/object_detection_adapter.cpp +++ b/mv_machine_learning/object_detection/src/object_detection_adapter.cpp @@ -44,20 +44,35 @@ template ObjectDetectionAdapter::~ObjectDetectionA template void ObjectDetectionAdapter::create(ObjectDetectionTaskType task_type) { - // If a concrete class object created already exists, reset the object - // so that other concrete class object can be created again according to a given task_type. - if (_object_detection) { - // If default task type is same as a given one then skip. - if (_object_detection->getTaskType() == task_type) - return; - } - - if (task_type == ObjectDetectionTaskType::MOBILENET_V1_SSD) - _object_detection = make_unique(task_type, _config); - else if (task_type == ObjectDetectionTaskType::MOBILENET_V2_SSD) - _object_detection = make_unique(task_type, _config); - else if (task_type == ObjectDetectionTaskType::OD_PLUGIN || task_type == ObjectDetectionTaskType::FD_PLUGIN) + _config->loadMetaFile(make_unique(static_cast(task_type))); + mv_inference_data_type_e dataType = _config->getInputMetaMap().begin()->second->dataType; + + switch (task_type) { + case ObjectDetectionTaskType::MOBILENET_V1_SSD: + if (dataType == MV_INFERENCE_DATA_UINT8) + _object_detection = make_unique >(task_type, _config); + else if (dataType == MV_INFERENCE_DATA_FLOAT32) + _object_detection = make_unique >(task_type, _config); + else + throw InvalidOperation("Invalid model data type."); + break; + case ObjectDetectionTaskType::MOBILENET_V2_SSD: + if (dataType == MV_INFERENCE_DATA_UINT8) + _object_detection = make_unique >(task_type, _config); + else if (dataType == MV_INFERENCE_DATA_FLOAT32) + _object_detection = make_unique >(task_type, _config); + else + throw InvalidOperation("Invalid model data type."); + break; + case ObjectDetectionTaskType::OD_PLUGIN: _object_detection = make_unique(task_type); + break; + case ObjectDetectionTaskType::FD_PLUGIN: + _object_detection = make_unique(task_type); + break; + default: + throw InvalidOperation("Invalid object detection task type."); + } // TODO. } -- 2.7.4