mv_machine_learning: update mobilenet v1 ssd model support
authorInki Dae <inki.dae@samsung.com>
Wed, 15 Feb 2023 08:22:18 +0000 (17:22 +0900)
committerKwanghoon Son <k.son@samsung.com>
Fri, 3 Mar 2023 08:11:58 +0000 (17:11 +0900)
[Issue type] : new feature

Updated mobilenet v1 ssd model for itask based object detection group.

With this patch, we introduce new three CAPI - using only standard data type -
for getting the object detection result and for setting user-given model, and
renamed mobilenet_ssd.h/cpp to mobilenet_v1_ssd.h/cpp because each child class
of the task group is specific to its model.

Change-Id: I5cae3436be028f9a38813883b8ed0c7836eb25f6
Signed-off-by: Inki Dae <inki.dae@samsung.com>
16 files changed:
include/mv_object_detection_internal.h
mv_machine_learning/meta/include/PostprocessParser.h
mv_machine_learning/meta/src/PostprocessParser.cpp
mv_machine_learning/object_detection/include/mobilenet_v1_ssd.h [moved from mv_machine_learning/object_detection/include/mobilenet_ssd.h with 85% similarity]
mv_machine_learning/object_detection/include/mv_object_detection_open.h
mv_machine_learning/object_detection/include/object_detection.h
mv_machine_learning/object_detection/include/object_detection_adapter.h
mv_machine_learning/object_detection/include/object_detection_type.h
mv_machine_learning/object_detection/src/ObjectDetectionParser.cpp
mv_machine_learning/object_detection/src/mobilenet_ssd.cpp [deleted file]
mv_machine_learning/object_detection/src/mobilenet_v1_ssd.cpp [new file with mode: 0644]
mv_machine_learning/object_detection/src/mv_object_detection.c
mv_machine_learning/object_detection/src/mv_object_detection_open.cpp
mv_machine_learning/object_detection/src/object_detection.cpp
mv_machine_learning/object_detection/src/object_detection_adapter.cpp
test/testsuites/machine_learning/object_detection/test_object_detection.cpp

index cd4be74..a78519d 100644 (file)
@@ -41,7 +41,7 @@ extern "C" {
  *          mv_object_detection_prepare() function to prepare a network
  *          for the inference.
  *
- * @since_tizen 7.0
+ * @since_tizen 7.5
  *
  * @remarks The @a infer should be released using mv_object_detection_destroy().
  *
@@ -61,7 +61,7 @@ int mv_object_detection_create(mv_object_detection_h *infer);
 /**
  * @brief Destroys inference handle and releases all its resources.
  *
- * @since_tizen 7.0
+ * @since_tizen 7.5
  *
  * @param[in] infer    The handle to the inference to be destroyed.
  *
@@ -77,9 +77,31 @@ int mv_object_detection_create(mv_object_detection_h *infer);
 int mv_object_detection_destroy(mv_object_detection_h infer);
 
 /**
+        * @brief Set user-given model information.
+        * @details Use this function to change the model information instead of default one after calling @ref mv_object_detection_create().
+        *
+        * @since_tizen 7.5
+        *
+        * @param[in] handle        The handle to the object detection object.
+        * @param[in] model_name    Model name.
+        * @param[in] model_file    Model file name.
+        * @param[in] meta_file     Model meta file name.
+        * @param[in] label_file    Label file name.
+        *
+        * @return @c 0 on success, otherwise a negative error value
+        * @retval #MEDIA_VISION_ERROR_NONE Successful
+        * @retval #MEDIA_VISION_ERROR_INVALID_PARAMETER Invalid parameter
+        * @retval #MEDIA_VISION_ERROR_INVALID_OPERATION Invalid operation
+        *
+        * @pre Create a object detection handle by calling @ref mv_object_detection_create()
+        */
+int mv_object_detection_set_model(mv_object_detection_h handle, const char *model_name, const char *model_file,
+                                                                 const char *meta_file, const char *label_file);
+
+/**
  * @brief Configures the backend for the object detection inference.
  *
- * @since_tizen 7.0
+ * @since_tizen 7.5
  *
  * @param [in] infer         The handle to the inference
  *
@@ -96,7 +118,7 @@ int mv_object_detection_configure(mv_object_detection_h infer);
  * @details Use this function to prepare the object detection inference based on
  *          the configured network.
  *
- * @since_tizen 7.0
+ * @since_tizen 7.5
  *
  * @param[in] infer         The handle to the inference.
  *
@@ -115,7 +137,7 @@ int mv_object_detection_prepare(mv_object_detection_h infer);
 /**
  * @brief Performs the object detection inference on the @a source.
  *
- * @since_tizen 7.0
+ * @since_tizen 7.5
  * @remarks This function is synchronous and may take considerable time to run.
  *
  * @param[in] source         The handle to the source of the media
@@ -131,13 +153,64 @@ int mv_object_detection_prepare(mv_object_detection_h infer);
  *
  * @pre Create a source handle by calling mv_create_source()
  * @pre Create an inference handle by calling mv_object_detect_create()
+ * @pre Prepare an inference by calling mv_object_detect_configure()
  * @pre Prepare an inference by calling mv_object_detect_prepare()
- * @post
- *
- * @see mv_object_detect_result_s structure
  */
 int mv_object_detection_inference(mv_object_detection_h infer, mv_source_h source);
 
+/**
+ * @brief Gets the object detection inference result on the @a source.
+ *
+ * @since_tizen 7.5
+ *
+ * @param[in] infer               The handle to the inference
+ * @param[out] number_of_objects  A number of objectes detected.
+ * @param[out] indices            Label indices to detected objects.
+ * @param[out] confidences        Probability to detected objects.
+ * @param[out] left               An left position array to bound boxs.
+ * @param[out] top                An top position array to bound boxs.
+ * @param[out] right              An right position array to bound boxs.
+ * @param[out] bottom             An bottom position array to bound boxs.
+ *
+ * @return @c 0 on success, otherwise a negative error value
+ * @retval #MEDIA_VISION_ERROR_NONE Successful
+ * @retval #MEDIA_VISION_ERROR_NOT_SUPPORTED Not supported
+ * @retval #MEDIA_VISION_ERROR_INVALID_PARAMETER Invalid parameter
+ * @retval #MEDIA_VISION_ERROR_INTERNAL          Internal error
+ *
+ * @pre Create a source handle by calling mv_create_source()
+ * @pre Create an inference handle by calling mv_object_detect_create()
+ * @pre Prepare an inference by calling mv_object_detect_configure()
+ * @pre Prepare an inference by calling mv_object_detect_prepare()
+ * @pre Prepare an inference by calling mv_object_detect_inference()
+ */
+int mv_object_detection_get_result(mv_object_detection_h infer, unsigned int *number_of_objects,
+                                                                  const unsigned int **indices, const float **confidences, const int **left,
+                                                                  const int **top, const int **right, const int **bottom);
+
+/**
+ * @brief Gets the label string to a given index.
+ *
+ * @since_tizen 7.5
+ *
+ * @param[in] infer       The handle to the inference
+ * @param[in] index       Label index to get the label string.
+ * @param[out] out_label  Label string to a given index.
+ *
+ * @return @c 0 on success, otherwise a negative error value
+ * @retval #MEDIA_VISION_ERROR_NONE Successful
+ * @retval #MEDIA_VISION_ERROR_NOT_SUPPORTED Not supported
+ * @retval #MEDIA_VISION_ERROR_INVALID_PARAMETER Invalid parameter
+ * @retval #MEDIA_VISION_ERROR_INTERNAL          Internal error
+ *
+ * @pre Create a source handle by calling mv_create_source()
+ * @pre Create an inference handle by calling mv_object_detect_create()
+ * @pre Prepare an inference by calling mv_object_detect_configure()
+ * @pre Prepare an inference by calling mv_object_detect_prepare()
+ * @pre Prepare an inference by calling mv_object_detect_inference()
+ */
+int mv_object_detection_get_label(mv_object_detection_h infer, const unsigned int index, const char **out_label);
+
 #ifdef __cplusplus
 }
 #endif /* __cplusplus */
index b95923a..18cfd2f 100644 (file)
@@ -41,6 +41,7 @@ public:
 
        void parseBox(std::shared_ptr<MetaInfo> metaInfo, JsonObject *root);
        void parseScore(std::shared_ptr<MetaInfo> metaInfo, JsonObject *root);
+       void parseNumber(std::shared_ptr<MetaInfo> metaInfo, JsonObject *root);
 
        /**
         * Add new parsing functions.
index 2c3ad70..e5deff6 100644 (file)
@@ -74,6 +74,18 @@ void PostprocessParser::parseBox(shared_ptr<MetaInfo> metaInfo, JsonObject *root
                        decodingBox->type =
                                        GetSupportedType<BoxType, map<string, BoxType> >(object, "box_type", gSupportedBoxTypes);
 
+               if (json_object_has_member(object, "box_order")) {
+                       JsonArray *array = json_object_get_array_member(object, "box_order");
+                       unsigned int elements = json_array_get_length(array);
+                       LOGI("box order should have 4 elements and it has [%u]", elements);
+
+                       for (unsigned int elem_idx = 0; elem_idx < elements; ++elem_idx) {
+                               auto val = static_cast<unsigned int>(json_array_get_int_element(array, elem_idx));
+                               decodingBox->order.push_back(val);
+                               LOGI("%d", val);
+                       }
+               }
+
                if (json_object_has_member(object, "box_coordinate"))
                        decodingBox->coordinateType = GetSupportedType<BoxCoordinateType, map<string, BoxCoordinateType> >(
                                        object, "box_coordinate", gSupportedBoxCoordinateTypes);
@@ -86,6 +98,8 @@ void PostprocessParser::parseBox(shared_ptr<MetaInfo> metaInfo, JsonObject *root
                throw InvalidOperation("Invalid box meta information.");
        }
 
+       metaInfo->decodingTypeMap[DecodingType::BOX] = decodingBox;
+
        // In case of bypss, we don't need to parse decoding_info.
        if (decodingBox->decodingType == BoxDecodingType::BYPASS)
                return;
@@ -106,8 +120,6 @@ void PostprocessParser::parseBox(shared_ptr<MetaInfo> metaInfo, JsonObject *root
                }
        }
 
-       metaInfo->decodingTypeMap[DecodingType::BOX] = decodingBox;
-
        LOGI("LEAVE");
 }
 
@@ -126,7 +138,7 @@ void PostprocessParser::parseScore(shared_ptr<MetaInfo> metaInfo, JsonObject *ro
                        decodingScore->topNumber = static_cast<unsigned int>(json_object_get_int_member(object, "top_number"));
 
                if (json_object_has_member(object, "threshold"))
-                       decodingScore->topNumber = static_cast<float>(json_object_get_double_member(object, "threshold"));
+                       decodingScore->threshold = static_cast<float>(json_object_get_double_member(object, "threshold"));
 
                if (json_object_has_member(object, "score_type"))
                        decodingScore->type =
@@ -141,5 +153,21 @@ void PostprocessParser::parseScore(shared_ptr<MetaInfo> metaInfo, JsonObject *ro
        LOGI("LEAVE");
 }
 
+void PostprocessParser::parseNumber(shared_ptr<MetaInfo> metaInfo, JsonObject *root)
+{
+       LOGI("ENTER");
+
+       if (!json_object_has_member(root, "number"))
+               throw InvalidOperation("member number not exists");
+
+       shared_ptr<unsigned int> decodingNumber = make_shared<unsigned int>();
+
+       // TODO.
+
+       metaInfo->decodingTypeMap[DecodingType::NUMBER] = decodingNumber;
+
+       LOGI("LEAVE");
+}
+
 } /* machine_learning */
 } /* mediavision */
@@ -29,16 +29,16 @@ namespace mediavision
 {
 namespace machine_learning
 {
-class MobilenetSsd : public ObjectDetection
+class MobilenetV1Ssd : public ObjectDetection
 {
 private:
-       object_detection_result_s _result;
+       ObjectDetectionResult _result;
 
 public:
-       MobilenetSsd();
-       ~MobilenetSsd();
+       MobilenetV1Ssd();
+       ~MobilenetV1Ssd();
 
-       object_detection_result_s &result() override;
+       ObjectDetectionResult &result() override;
 };
 
 } // machine_learning
index e7163b8..10aa90b 100644 (file)
@@ -67,6 +67,28 @@ int mv_object_detection_create_open(mv_object_detection_h *out_handle);
 int mv_object_detection_destroy_open(mv_object_detection_h handle);
 
 /**
+        * @brief Set user-given model information.
+        * @details Use this function to change the model information instead of default one after calling @ref mv_object_detection_create().
+        *
+        * @since_tizen 7.5
+        *
+        * @param[in] handle        The handle to the object detection object.
+        * @param[in] model_name    Model name.
+        * @param[in] model_file    Model file name.
+        * @param[in] meta_type     Model meta file name.
+        * @param[in] label_file    Label file name.
+        *
+        * @return @c 0 on success, otherwise a negative error value
+        * @retval #MEDIA_VISION_ERROR_NONE Successful
+        * @retval #MEDIA_VISION_ERROR_INVALID_PARAMETER Invalid parameter
+        * @retval #MEDIA_VISION_ERROR_INVALID_OPERATION Invalid operation
+        *
+        * @pre Create a object detection handle by calling @ref mv_object_detection_create()
+        */
+int mv_object_detection_set_model_open(mv_object_detection_h handle, const char *model_name, const char *model_file,
+                                                                          const char *meta_file, const char *label_file);
+
+/**
         * @brief Configure the backend to the inference handle
         *
         * @since_tizen 7.0
@@ -116,11 +138,64 @@ int mv_object_detection_prepare_open(mv_object_detection_h handle);
         *
         * @pre Create a source handle by calling @ref mv_create_source()
         * @pre Create an object detection handle by calling @ref mv_object_detection_create_open()
+        * @pre Prepare an inference by calling mv_object_detect_configure_open()
         * @pre Prepare an object detection by calling @ref mv_object_detection_prepare_open()
-        * @pre Register a new face by calling @ref mv_object_detection_register_open()
         */
 int mv_object_detection_inference_open(mv_object_detection_h handle, mv_source_h source);
 
+/**
+ * @brief Gets the object detection inference result on the @a source.
+ *
+ * @since_tizen 7.5
+ *
+ * @param[in] handle              The handle to the inference
+ * @param[out] number_of_objects  A number of objectes detected.
+ * @param[out] indices            Label indices to detected objects.
+ * @param[out] confidences        Probability to detected objects.
+ * @param[out] left               An left position array to bound boxs.
+ * @param[out] top                An top position array to bound boxs.
+ * @param[out] right              An right position array to bound boxs.
+ * @param[out] bottom             An bottom position array to bound boxs.
+ *
+ * @return @c 0 on success, otherwise a negative error value
+ * @retval #MEDIA_VISION_ERROR_NONE Successful
+ * @retval #MEDIA_VISION_ERROR_NOT_SUPPORTED Not supported
+ * @retval #MEDIA_VISION_ERROR_INVALID_PARAMETER Invalid parameter
+ * @retval #MEDIA_VISION_ERROR_INTERNAL          Internal error
+ *
+ * @pre Create a source handle by calling mv_create_source()
+ * @pre Create an inference handle by calling mv_object_detect_create_open()
+ * @pre Prepare an inference by calling mv_object_detect_configure_open()
+ * @pre Prepare an inference by calling mv_object_detect_prepare_open()
+ * @pre Prepare an inference by calling mv_object_detect_inference_open()
+ */
+int mv_object_detection_get_result_open(mv_object_detection_h handle, unsigned int *number_of_objects,
+                                                                               const unsigned int **indices, const float **confidences, const int **left,
+                                                                               const int **top, const int **right, const int **bottom);
+
+/**
+ * @brief Gets the label string to a given index.
+ *
+ * @since_tizen 7.5
+ *
+ * @param[in] infer       The handle to the inference
+ * @param[in] index       Label index to get the label string.
+ * @param[out] out_label  Label string to a given index.
+ *
+ * @return @c 0 on success, otherwise a negative error value
+ * @retval #MEDIA_VISION_ERROR_NONE Successful
+ * @retval #MEDIA_VISION_ERROR_NOT_SUPPORTED Not supported
+ * @retval #MEDIA_VISION_ERROR_INVALID_PARAMETER Invalid parameter
+ * @retval #MEDIA_VISION_ERROR_INTERNAL          Internal error
+ *
+ * @pre Create a source handle by calling mv_create_source()
+ * @pre Create an inference handle by calling mv_object_detect_create()
+ * @pre Prepare an inference by calling mv_object_detect_configure()
+ * @pre Prepare an inference by calling mv_object_detect_prepare()
+ * @pre Prepare an inference by calling mv_object_detect_inference()
+ */
+int mv_object_detection_get_label_open(mv_object_detection_h handle, const unsigned int index, const char **out_label);
+
 #ifdef __cplusplus
 }
 #endif /* __cplusplus */
index c67ba3d..806371a 100644 (file)
@@ -34,11 +34,16 @@ namespace machine_learning
 {
 class ObjectDetection
 {
+private:
+       void loadLabel();
+
 protected:
        std::unique_ptr<mediavision::inference::Inference> _inference;
        std::unique_ptr<MediaVision::Common::EngineConfig> _config;
        std::unique_ptr<MetaParser> _parser;
+       std::vector<std::string> _labels;
        Preprocess _preprocess;
+       std::string _modelName;
        std::string _modelFilePath;
        std::string _modelMetaFilePath;
        std::string _modelDefaultPath;
@@ -47,17 +52,19 @@ protected:
        int _targetDeviceType;
 
        void getOutputNames(std::vector<std::string> &names);
-       void getOutputTensor(std::string &target_name, std::vector<float> &tensor);
+       void getOutputTensor(std::string target_name, std::vector<float> &tensor);
 
 public:
        ObjectDetection();
        virtual ~ObjectDetection() = default;
+       void setUserModel(std::string &model_name, std::string &model_file, std::string &meta_file,
+                                         std::string &label_file);
        void parseMetaFile();
        void configure();
        void prepare();
        void preprocess(mv_source_h &mv_src);
        void inference(mv_source_h source);
-       virtual object_detection_result_s &result() = 0;
+       virtual ObjectDetectionResult &result() = 0;
 };
 
 } // machine_learning
index ead4817..e4f138f 100644 (file)
@@ -24,22 +24,32 @@ namespace mediavision
 {
 namespace machine_learning
 {
-struct object_detection_input_s {
+struct ObjectDetectionInput {
        mv_source_h inference_src;
+       std::string model_name;
+       std::string model_file;
+       std::string meta_file;
+       std::string label_file;
 };
 
 /**
  * @brief The object detection result structure.
  * @details Contains object detection result.
  */
-struct object_detection_result_s {
-       std::vector<unsigned int> x_vec;
-       std::vector<unsigned int> y_vec;
+struct ObjectDetectionResult {
+       unsigned int number_of_objects {};
+       std::vector<unsigned int> indices;
+       std::vector<std::string> names;
+       std::vector<float> confidences;
+       std::vector<int> left;
+       std::vector<int> top;
+       std::vector<int> right;
+       std::vector<int> bottom;
 };
 
-enum class object_detection_task_type_e {
+enum class ObjectDetectionTaskType {
        OBJECT_DETECTION_TASK_NONE = 0,
-       MOBILENET_SSD_V1
+       MOBILENET_V1_SSD
        // TODO
 };
 
index 0290d2d..c88b1d2 100644 (file)
@@ -43,6 +43,12 @@ void ObjectDetectionParser::parsePostprocess(shared_ptr<MetaInfo> meta_info, Jso
        if (json_object_has_member(in_obj, "box"))
                _postprocessParser.parseBox(meta_info, in_obj);
 
+       if (json_object_has_member(in_obj, "score"))
+               _postprocessParser.parseScore(meta_info, in_obj);
+
+       if (json_object_has_member(in_obj, "number"))
+               _postprocessParser.parseNumber(meta_info, in_obj);
+
        LOGI("LEAVE");
 }
 
diff --git a/mv_machine_learning/object_detection/src/mobilenet_ssd.cpp b/mv_machine_learning/object_detection/src/mobilenet_ssd.cpp
deleted file mode 100644 (file)
index a7ea300..0000000
+++ /dev/null
@@ -1,56 +0,0 @@
-/**
- * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include <string.h>
-#include <map>
-#include <algorithm>
-
-#include "machine_learning_exception.h"
-#include "mv_object_detection_config.h"
-#include "mobilenet_ssd.h"
-#include "Postprocess.h"
-
-using namespace std;
-using namespace mediavision::inference;
-using namespace mediavision::machine_learning::exception;
-
-namespace mediavision
-{
-namespace machine_learning
-{
-MobilenetSsd::MobilenetSsd() : _result()
-{}
-
-MobilenetSsd::~MobilenetSsd()
-{}
-
-object_detection_result_s &MobilenetSsd::result()
-{
-       vector<string> names;
-
-       ObjectDetection::getOutputNames(names);
-
-       vector<float> output_tensor;
-
-       ObjectDetection::getOutputTensor(names[1], output_tensor);
-
-       // TODO.
-
-       return _result;
-}
-
-}
-}
diff --git a/mv_machine_learning/object_detection/src/mobilenet_v1_ssd.cpp b/mv_machine_learning/object_detection/src/mobilenet_v1_ssd.cpp
new file mode 100644 (file)
index 0000000..60f21fc
--- /dev/null
@@ -0,0 +1,112 @@
+/**
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <string.h>
+#include <map>
+#include <algorithm>
+
+#include "machine_learning_exception.h"
+#include "mv_object_detection_config.h"
+#include "mobilenet_v1_ssd.h"
+#include "Postprocess.h"
+
+using namespace std;
+using namespace mediavision::inference;
+using namespace mediavision::machine_learning::exception;
+
+namespace mediavision
+{
+namespace machine_learning
+{
+MobilenetV1Ssd::MobilenetV1Ssd() : _result()
+{}
+
+MobilenetV1Ssd::~MobilenetV1Ssd()
+{}
+
+ObjectDetectionResult &MobilenetV1Ssd::result()
+{
+       // Clear _result object because result() function can be called every time user wants
+       // so make sure to clear existing result data before getting the data again.
+       memset(reinterpret_cast<void *>(&_result), 0, sizeof(_result));
+
+       vector<string> names;
+
+       ObjectDetection::getOutputNames(names);
+
+       vector<float> number_tensor;
+
+       // TFLite_Detection_PostProcess:3
+       ObjectDetection::getOutputTensor(names[3], number_tensor);
+
+       vector<float> label_tensor;
+
+       // TFLite_Detection_PostProcess:1
+       ObjectDetection::getOutputTensor(names[1], label_tensor);
+
+       vector<float> score_tensor;
+       map<float, unsigned int, std::greater<float> > sorted_score;
+
+       auto scoreMetaInfo = _parser->getOutputMetaMap().at(names[2]);
+       auto decodingScore = static_pointer_cast<DecodingScore>(scoreMetaInfo->decodingTypeMap[DecodingType::SCORE]);
+
+       // TFLite_Detection_PostProcess:2
+       ObjectDetection::getOutputTensor(names[2], score_tensor);
+       for (size_t idx = 0; idx < score_tensor.size(); ++idx) {
+               if (decodingScore->threshold > score_tensor[idx])
+                       continue;
+
+               sorted_score[score_tensor[idx]] = idx;
+       }
+
+       auto boxMetaInfo = _parser->getOutputMetaMap().at(names[0]);
+       auto decodingBox = static_pointer_cast<DecodingBox>(boxMetaInfo->decodingTypeMap[DecodingType::BOX]);
+       vector<float> box_tensor;
+
+       ObjectDetection::getOutputTensor(names[0], box_tensor);
+
+       for (auto &score : sorted_score) {
+               _result.number_of_objects++;
+               // second is idx
+               _result.names.push_back(_labels[label_tensor[score.second]]);
+               _result.indices.push_back(_result.number_of_objects - 1);
+               _result.confidences.push_back(score.first);
+
+               vector<unsigned int> &order = decodingBox->order;
+
+               _result.left.push_back(
+                               static_cast<int>(box_tensor[score.second * 4 + order[0]] * _preprocess.getImageWidth()[0]));
+               _result.top.push_back(
+                               static_cast<int>(box_tensor[score.second * 4 + order[1]] * _preprocess.getImageHeight()[0]));
+               _result.right.push_back(
+                               static_cast<int>(box_tensor[score.second * 4 + order[2]] * _preprocess.getImageWidth()[0]));
+               _result.bottom.push_back(
+                               static_cast<int>(box_tensor[score.second * 4 + order[3]] * _preprocess.getImageHeight()[0]));
+
+               LOGI("idx = %d, name = %s, score = %f, %dx%d, %dx%d", score.second,
+                        _result.names[_result.number_of_objects - 1].c_str(), _result.confidences[_result.number_of_objects - 1],
+                        _result.left[_result.number_of_objects - 1], _result.top[_result.number_of_objects - 1],
+                        _result.right[_result.number_of_objects - 1], _result.bottom[_result.number_of_objects - 1]);
+
+               if (decodingScore->topNumber == _result.number_of_objects)
+                       break;
+       }
+
+       return _result;
+}
+
+}
+}
index 313fcc2..04f1125 100644 (file)
@@ -30,9 +30,7 @@ int mv_object_detection_create(mv_object_detection_h *infer)
 
        MEDIA_VISION_FUNCTION_ENTER();
 
-       int ret = MEDIA_VISION_ERROR_NONE;
-
-       ret = mv_object_detection_create_open(infer);
+       int ret = mv_object_detection_create_open(infer);
 
        MEDIA_VISION_FUNCTION_LEAVE();
        return ret;
@@ -45,11 +43,29 @@ int mv_object_detection_destroy(mv_object_detection_h infer)
 
        MEDIA_VISION_FUNCTION_ENTER();
 
-       int ret = MEDIA_VISION_ERROR_NONE;
+       int ret = mv_object_detection_destroy_open(infer);
+
+       MEDIA_VISION_FUNCTION_LEAVE();
+       return ret;
+}
+
+int mv_object_detection_set_model(mv_object_detection_h handle, const char *model_name, const char *model_file,
+                                                                 const char *meta_file, const char *label_file)
+{
+       MEDIA_VISION_SUPPORT_CHECK(_mv_inference_face_check_system_info_feature_supported());
+
+       MEDIA_VISION_INSTANCE_CHECK(handle);
+       MEDIA_VISION_INSTANCE_CHECK(model_name);
+       MEDIA_VISION_NULL_ARG_CHECK(model_file);
+       MEDIA_VISION_NULL_ARG_CHECK(meta_file);
+       MEDIA_VISION_NULL_ARG_CHECK(label_file);
+
+       MEDIA_VISION_FUNCTION_ENTER();
 
-       ret = mv_object_detection_destroy_open(infer);
+       int ret = mv_object_detection_set_model_open(handle, model_name, model_file, meta_file, label_file);
 
        MEDIA_VISION_FUNCTION_LEAVE();
+
        return ret;
 }
 
@@ -60,9 +76,7 @@ int mv_object_detection_configure(mv_object_detection_h infer)
 
        MEDIA_VISION_FUNCTION_ENTER();
 
-       int ret = MEDIA_VISION_ERROR_NONE;
-
-       ret = mv_object_detection_configure_open(infer);
+       int ret = mv_object_detection_configure_open(infer);
 
        MEDIA_VISION_FUNCTION_LEAVE();
        return ret;
@@ -75,9 +89,7 @@ int mv_object_detection_prepare(mv_object_detection_h infer)
 
        MEDIA_VISION_FUNCTION_ENTER();
 
-       int ret = MEDIA_VISION_ERROR_NONE;
-
-       ret = mv_object_detection_prepare_open(infer);
+       int ret = mv_object_detection_prepare_open(infer);
 
        MEDIA_VISION_FUNCTION_LEAVE();
        return ret;
@@ -91,9 +103,46 @@ int mv_object_detection_inference(mv_object_detection_h infer, mv_source_h sourc
 
        MEDIA_VISION_FUNCTION_ENTER();
 
-       int ret = MEDIA_VISION_ERROR_NONE;
+       int ret = mv_object_detection_inference_open(infer, source);
+
+       MEDIA_VISION_FUNCTION_LEAVE();
+
+       return ret;
+}
+
+int mv_object_detection_get_result(mv_object_detection_h infer, unsigned int *number_of_objects,
+                                                                  const unsigned int **indices, const float **confidences, const int **left,
+                                                                  const int **top, const int **right, const int **bottom)
+{
+       MEDIA_VISION_SUPPORT_CHECK(_mv_inference_image_check_system_info_feature_supported());
+       MEDIA_VISION_INSTANCE_CHECK(infer);
+       MEDIA_VISION_INSTANCE_CHECK(number_of_objects);
+       MEDIA_VISION_INSTANCE_CHECK(indices);
+       MEDIA_VISION_INSTANCE_CHECK(confidences);
+       MEDIA_VISION_INSTANCE_CHECK(left);
+       MEDIA_VISION_INSTANCE_CHECK(top);
+       MEDIA_VISION_INSTANCE_CHECK(right);
+       MEDIA_VISION_INSTANCE_CHECK(bottom);
+
+       MEDIA_VISION_FUNCTION_ENTER();
+
+       int ret = mv_object_detection_get_result_open(infer, number_of_objects, indices, confidences, left, top, right,
+                                                                                                 bottom);
+
+       MEDIA_VISION_FUNCTION_LEAVE();
+
+       return ret;
+}
+
+int mv_object_detection_get_label(mv_object_detection_h infer, const unsigned int index, const char **out_label)
+{
+       MEDIA_VISION_SUPPORT_CHECK(_mv_inference_image_check_system_info_feature_supported());
+       MEDIA_VISION_INSTANCE_CHECK(infer);
+       MEDIA_VISION_INSTANCE_CHECK(out_label);
+
+       MEDIA_VISION_FUNCTION_ENTER();
 
-       ret = mv_object_detection_inference_open(infer, source);
+       int ret = mv_object_detection_get_label_open(infer, index, out_label);
 
        MEDIA_VISION_FUNCTION_LEAVE();
 
index 0ce6d8b..0c07f0b 100644 (file)
@@ -33,7 +33,7 @@ using namespace mediavision::common;
 using namespace mediavision::machine_learning;
 using namespace MediaVision::Common;
 using namespace mediavision::machine_learning::exception;
-using ObjectDetectionTask = ITask<object_detection_input_s, object_detection_result_s>;
+using ObjectDetectionTask = ITask<ObjectDetectionInput, ObjectDetectionResult>;
 
 int mv_object_detection_create_open(mv_object_detection_h *out_handle)
 {
@@ -47,9 +47,7 @@ int mv_object_detection_create_open(mv_object_detection_h *out_handle)
 
        try {
                context = new Context();
-               task = new ObjectDetectionAdapter<object_detection_input_s, object_detection_result_s>();
-               task->create(static_cast<int>(object_detection_task_type_e::MOBILENET_SSD_V1));
-
+               task = new ObjectDetectionAdapter<ObjectDetectionInput, ObjectDetectionResult>();
                context->__tasks.insert(make_pair("object_detection", task));
                *out_handle = static_cast<mv_object_detection_h>(context);
        } catch (const BaseException &e) {
@@ -58,7 +56,7 @@ int mv_object_detection_create_open(mv_object_detection_h *out_handle)
                return e.getError();
        }
 
-       LOGD("object detection 3d handle [%p] has been created", *out_handle);
+       LOGD("object detection handle [%p] has been created", *out_handle);
 
        return MEDIA_VISION_ERROR_NONE;
 }
@@ -72,12 +70,44 @@ int mv_object_detection_destroy_open(mv_object_detection_h handle)
 
        auto context = static_cast<Context *>(handle);
 
+       std::lock_guard<std::mutex> lock(context->_mutex);
+
        for (auto &m : context->__tasks)
                delete static_cast<ObjectDetectionTask *>(m.second);
 
        delete context;
 
-       LOGD("Object detection 3d handle has been destroyed.");
+       LOGD("Object detection handle has been destroyed.");
+
+       return MEDIA_VISION_ERROR_NONE;
+}
+
+int mv_object_detection_set_model_open(mv_object_detection_h handle, const char *model_name, const char *model_file,
+                                                                          const char *meta_file, const char *label_file)
+{
+       if (!handle) {
+               LOGE("Handle is NULL.");
+               return MEDIA_VISION_ERROR_INVALID_PARAMETER;
+       }
+
+       try {
+               auto context = static_cast<Context *>(handle);
+               auto task = static_cast<ObjectDetectionTask *>(context->__tasks.at("object_detection"));
+
+               ObjectDetectionInput input;
+
+               input.model_name = string(model_name);
+               input.model_file = string(model_file);
+               input.meta_file = string(meta_file);
+               input.label_file = string(label_file);
+
+               task->setInput(input);
+       } catch (const BaseException &e) {
+               LOGE("%s", e.what());
+               return e.getError();
+       }
+
+       LOGD("LEAVE");
 
        return MEDIA_VISION_ERROR_NONE;
 }
@@ -95,6 +125,9 @@ int mv_object_detection_configure_open(mv_object_detection_h handle)
                auto context = static_cast<Context *>(handle);
                auto task = static_cast<ObjectDetectionTask *>(context->__tasks.at("object_detection"));
 
+               std::lock_guard<std::mutex> lock(context->_mutex);
+
+               task->create(static_cast<int>(ObjectDetectionTaskType::MOBILENET_V1_SSD));
                task->configure();
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
@@ -119,6 +152,8 @@ int mv_object_detection_prepare_open(mv_object_detection_h handle)
                auto context = static_cast<Context *>(handle);
                auto task = static_cast<ObjectDetectionTask *>(context->__tasks.at("object_detection"));
 
+               std::lock_guard<std::mutex> lock(context->_mutex);
+
                task->prepare();
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
@@ -143,7 +178,9 @@ int mv_object_detection_inference_open(mv_object_detection_h handle, mv_source_h
                auto context = static_cast<Context *>(handle);
                auto task = static_cast<ObjectDetectionTask *>(context->__tasks.at("object_detection"));
 
-               object_detection_input_s input = { source };
+               std::lock_guard<std::mutex> lock(context->_mutex);
+
+               ObjectDetectionInput input = { source };
 
                task->setInput(input);
                task->perform();
@@ -155,4 +192,70 @@ int mv_object_detection_inference_open(mv_object_detection_h handle, mv_source_h
        LOGD("LEAVE");
 
        return MEDIA_VISION_ERROR_NONE;
+}
+
+int mv_object_detection_get_result_open(mv_object_detection_h handle, unsigned int *number_of_objects,
+                                                                               const unsigned int **indices, const float **confidences, const int **left,
+                                                                               const int **top, const int **right, const int **bottom)
+{
+       LOGD("ENTER");
+
+       if (!handle) {
+               LOGE("Handle is NULL.");
+               return MEDIA_VISION_ERROR_INVALID_PARAMETER;
+       }
+
+       try {
+               auto context = static_cast<Context *>(handle);
+               auto task = static_cast<ObjectDetectionTask *>(context->__tasks.at("object_detection"));
+
+               std::lock_guard<std::mutex> lock(context->_mutex);
+
+               ObjectDetectionResult &result = task->getOutput();
+               *number_of_objects = result.number_of_objects;
+               *indices = result.indices.data();
+               *confidences = result.confidences.data();
+               *left = result.left.data();
+               *top = result.top.data();
+               *right = result.right.data();
+               *bottom = result.bottom.data();
+       } catch (const BaseException &e) {
+               LOGE("%s", e.what());
+               return e.getError();
+       }
+
+       LOGD("LEAVE");
+
+       return MEDIA_VISION_ERROR_NONE;
+}
+
+int mv_object_detection_get_label_open(mv_object_detection_h handle, const unsigned int index, const char **out_label)
+{
+       LOGD("ENTER");
+
+       if (!handle) {
+               LOGE("Handle is NULL.");
+               return MEDIA_VISION_ERROR_INVALID_PARAMETER;
+       }
+
+       try {
+               auto context = static_cast<Context *>(handle);
+               auto task = static_cast<ObjectDetectionTask *>(context->__tasks.at("object_detection"));
+
+               std::lock_guard<std::mutex> lock(context->_mutex);
+
+               ObjectDetectionResult &result = task->getOutput();
+
+               if (result.number_of_objects <= index)
+                       throw InvalidParameter("Invalid index range.");
+
+               *out_label = result.names[index].c_str();
+       } catch (const BaseException &e) {
+               LOGE("%s", e.what());
+               return e.getError();
+       }
+
+       LOGD("LEAVE");
+
+       return MEDIA_VISION_ERROR_NONE;
 }
\ No newline at end of file
index e2fb899..9d1502d 100644 (file)
@@ -15,6 +15,7 @@
  */
 
 #include <string.h>
+#include <fstream>
 #include <map>
 #include <memory>
 #include <algorithm>
@@ -38,11 +39,37 @@ ObjectDetection::ObjectDetection() : _backendType(), _targetDeviceType()
        _parser = make_unique<ObjectDetectionParser>();
 }
 
+void ObjectDetection::setUserModel(string &model_name, string &model_file, string &meta_file, string &label_file)
+{
+       _modelName = model_name;
+       _modelFilePath = model_file;
+       _modelMetaFilePath = meta_file;
+       _modelLabelFilePath = label_file;
+}
+
 static bool IsJsonFile(const string &fileName)
 {
        return (!fileName.substr(fileName.find_last_of(".") + 1).compare("json"));
 }
 
+void ObjectDetection::loadLabel()
+{
+       ifstream readFile;
+
+       _labels.clear();
+       readFile.open(_modelLabelFilePath.c_str());
+
+       if (readFile.fail())
+               throw InvalidOperation("Fail to open " + _modelLabelFilePath + " file.");
+
+       string line;
+
+       while (getline(readFile, line))
+               _labels.push_back(line);
+
+       readFile.close();
+}
+
 void ObjectDetection::parseMetaFile()
 {
        _config = make_unique<EngineConfig>(string(MV_CONFIG_PATH) + string(MV_OBJECT_DETECTION_META_FILE_NAME));
@@ -59,31 +86,45 @@ void ObjectDetection::parseMetaFile()
        if (ret != MEDIA_VISION_ERROR_NONE)
                throw InvalidOperation("Fail to get model default path");
 
-       ret = _config->getStringAttribute(MV_OBJECT_DETECTION_MODEL_FILE_PATH, &_modelFilePath);
-       if (ret != MEDIA_VISION_ERROR_NONE)
-               throw InvalidOperation("Fail to get model file path");
+       if (_modelFilePath.empty()) {
+               ret = _config->getStringAttribute(MV_OBJECT_DETECTION_MODEL_FILE_PATH, &_modelFilePath);
+               if (ret != MEDIA_VISION_ERROR_NONE)
+                       throw InvalidOperation("Fail to get model file path");
+       }
 
        _modelFilePath = _modelDefaultPath + _modelFilePath;
+       LOGI("model file path = %s", _modelFilePath.c_str());
 
-       ret = _config->getStringAttribute(MV_OBJECT_DETECTION_MODEL_META_FILE_PATH, &_modelMetaFilePath);
-       if (ret != MEDIA_VISION_ERROR_NONE)
-               throw InvalidOperation("Fail to get model meta file path");
+       if (_modelMetaFilePath.empty()) {
+               ret = _config->getStringAttribute(MV_OBJECT_DETECTION_MODEL_META_FILE_PATH, &_modelMetaFilePath);
+               if (ret != MEDIA_VISION_ERROR_NONE)
+                       throw InvalidOperation("Fail to get model meta file path");
 
-       if (_modelMetaFilePath.empty())
-               throw InvalidOperation("Model meta file doesn't exist.");
+               if (_modelMetaFilePath.empty())
+                       throw InvalidOperation("Model meta file doesn't exist.");
 
-       if (!IsJsonFile(_modelMetaFilePath))
-               throw InvalidOperation("Model meta file should be json");
+               if (!IsJsonFile(_modelMetaFilePath))
+                       throw InvalidOperation("Model meta file should be json");
+       }
 
        _modelMetaFilePath = _modelDefaultPath + _modelMetaFilePath;
+       LOGI("meta file path = %s", _modelMetaFilePath.c_str());
 
        _parser->load(_modelMetaFilePath);
 
-       ret = _config->getStringAttribute(MV_OBJECT_DETECTION_LABEL_FILE_NAME, &_modelLabelFilePath);
-       if (ret != MEDIA_VISION_ERROR_NONE)
-               throw InvalidOperation("Fail to get label file path");
+       if (_modelLabelFilePath.empty()) {
+               ret = _config->getStringAttribute(MV_OBJECT_DETECTION_LABEL_FILE_NAME, &_modelLabelFilePath);
+               if (ret != MEDIA_VISION_ERROR_NONE)
+                       throw InvalidOperation("Fail to get label file path");
+
+               if (_modelLabelFilePath.empty())
+                       throw InvalidOperation("Model label file doesn't exist.");
+       }
 
        _modelLabelFilePath = _modelDefaultPath + _modelLabelFilePath;
+       LOGI("label file path = %s", _modelLabelFilePath.c_str());
+
+       loadLabel();
 }
 
 void ObjectDetection::configure()
@@ -110,6 +151,7 @@ void ObjectDetection::prepare()
        if (ret != MEDIA_VISION_ERROR_NONE)
                throw InvalidOperation("Fail to load model files.");
 }
+
 void ObjectDetection::preprocess(mv_source_h &mv_src)
 {
        LOGI("ENTER");
@@ -147,7 +189,7 @@ void ObjectDetection::getOutputNames(vector<string> &names)
                names.push_back(it->first);
 }
 
-void ObjectDetection::getOutputTensor(string &target_name, vector<float> &tensor)
+void ObjectDetection::getOutputTensor(string target_name, vector<float> &tensor)
 {
        LOGI("ENTER");
 
index 053d4bf..77ba2db 100644 (file)
@@ -34,9 +34,19 @@ template<typename T, typename V> ObjectDetectionAdapter<T, V>::~ObjectDetectionA
 
 template<typename T, typename V> void ObjectDetectionAdapter<T, V>::create(int type)
 {
-       switch (static_cast<object_detection_task_type_e>(type)) {
-       case object_detection_task_type_e::MOBILENET_SSD_V1:
-               _object_detection = make_unique<MobilenetSsd>();
+       if (!_source.model_name.empty()) {
+               transform(_source.model_name.begin(), _source.model_name.end(), _source.model_name.begin(), ::toupper);
+
+               if (_source.model_name == string("MOBILENET_V1_SSD"))
+                       type = static_cast<int>(ObjectDetectionTaskType::MOBILENET_V1_SSD);
+               // TODO.
+               else
+                       throw InvalidParameter("Invalid object detection model name.");
+       }
+
+       switch (static_cast<ObjectDetectionTaskType>(type)) {
+       case ObjectDetectionTaskType::MOBILENET_V1_SSD:
+               _object_detection = make_unique<MobilenetV1Ssd>();
                break;
        default:
                throw InvalidParameter("Invalid object detection task type.");
@@ -45,6 +55,10 @@ template<typename T, typename V> void ObjectDetectionAdapter<T, V>::create(int t
 
 template<typename T, typename V> void ObjectDetectionAdapter<T, V>::configure()
 {
+       if (!_source.model_name.empty() && !_source.model_file.empty() && !_source.meta_file.empty() &&
+               !_source.label_file.empty())
+               _object_detection->setUserModel(_source.model_name, _source.model_file, _source.meta_file, _source.label_file);
+
        _object_detection->parseMetaFile();
        _object_detection->configure();
 }
@@ -70,6 +84,6 @@ template<typename T, typename V> V &ObjectDetectionAdapter<T, V>::getOutput()
        return _object_detection->result();
 }
 
-template class ObjectDetectionAdapter<object_detection_input_s, object_detection_result_s>;
+template class ObjectDetectionAdapter<ObjectDetectionInput, ObjectDetectionResult>;
 }
 }
\ No newline at end of file
index 357c1b6..7ab4624 100644 (file)
@@ -15,8 +15,8 @@
  */
 
 #include <iostream>
+#include <algorithm>
 #include <string.h>
-#include <map>
 
 #include "gtest/gtest.h"
 
@@ -30,34 +30,82 @@ using namespace std;
 
 using namespace MediaVision::Common;
 
+struct model_info {
+       string model_name;
+       string model_file;
+       string meta_file;
+       string label_file;
+};
+
 TEST(ObjectDetectionTest, InferenceShouldBeOk)
 {
        mv_object_detection_h handle;
-
-       int ret = mv_object_detection_create(&handle);
-       ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
-
-       ret = mv_object_detection_configure(handle);
-       ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
-
-       ret = mv_object_detection_prepare(handle);
-       ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+       vector<model_info> test_models {
+               { "", "", "", "" }, // If empty then default model will be used.
+               { "mobilenet_v1_ssd", "od_mobilenet_v1_ssd_postop_300x300.tflite", "od_mobilenet_v1_ssd_postop_300x300.json",
+                 "od_mobilenet_v1_ssd_postop_label.txt" }
+               // TODO.
+       };
 
        const string image_path = IMAGE_PATH;
        mv_source_h mv_source = NULL;
 
-       ret = mv_create_source(&mv_source);
+       int ret = mv_create_source(&mv_source);
        ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
 
        ret = ImageHelper::loadImageToSource(image_path.c_str(), mv_source);
        ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
 
-       ret = mv_object_detection_inference(handle, mv_source);
-       ASSERT_EQ(ret, 0);
+       for (auto model : test_models) {
+               cout << "model name : " << model.model_file << endl;
 
-       ret = mv_destroy_source(mv_source);
-       ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+               ret = mv_object_detection_create(&handle);
+               ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+               mv_object_detection_set_model(handle, model.model_name.c_str(), model.model_file.c_str(),
+                                                                         model.meta_file.c_str(), model.label_file.c_str());
+
+               ret = mv_object_detection_configure(handle);
+               ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+               ret = mv_object_detection_prepare(handle);
+               ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
 
-       ret = mv_object_detection_destroy(handle);
+               ret = mv_object_detection_inference(handle, mv_source);
+               ASSERT_EQ(ret, 0);
+
+               unsigned int number_of_objects;
+               const int *left, *top, *right, *bottom;
+               const unsigned int *indices;
+               const float *confidences;
+
+               ret = mv_object_detection_get_result(handle, &number_of_objects, &indices, &confidences, &left, &top, &right,
+                                                                                        &bottom);
+               ASSERT_EQ(ret, 0);
+
+               for (unsigned int idx = 0; idx < number_of_objects; ++idx) {
+                       cout << "index = " << indices[idx] << " probability = " << confidences[idx] << " " << left[idx] << " x "
+                                << top[idx] << " ~ " << right[idx] << " x " << bottom[idx] << endl;
+               }
+
+               for (unsigned int idx = 0; idx < number_of_objects; ++idx) {
+                       const char *label;
+
+                       ret = mv_object_detection_get_label(handle, indices[idx], &label);
+                       ASSERT_EQ(ret, 0);
+                       cout << "index = " << indices[idx] << " label = " << label << endl;
+
+                       string label_str(label);
+
+                       transform(label_str.begin(), label_str.end(), label_str.begin(), ::toupper);
+
+                       ASSERT_TRUE(label_str == string("DOG"));
+               }
+
+               ret = mv_object_detection_destroy(handle);
+               ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+       }
+
+       ret = mv_destroy_source(mv_source);
        ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
 }