mv_machine_learning: introduce get_result_count API for image classification 11/304911/4
authorInki Dae <inki.dae@samsung.com>
Thu, 25 Jan 2024 08:13:15 +0000 (17:13 +0900)
committerInki Dae <inki.dae@samsung.com>
Fri, 26 Jan 2024 04:38:15 +0000 (04:38 +0000)
[Issue type] : new feature

Introduce get_result_count API for image classification task group.

In user perspective, this API provides information on how many results exist
so that user can request each result corresponding to a user-given index.
And also, in framework perspective, it provides consistent API behavior -
get_result_count API call updates _current_result of task group by calling
getOutput function of ITask, and get_result API call returns _current_result
value by calling getOutputCache function of ITask.

And we are enough with get_result_count and get_result API so drop existing
get_label API.

Change-Id: I9c1cb9e855494474af1510bbaf94febbdc57f05e
Signed-off-by: Inki Dae <inki.dae@samsung.com>
include/mv_image_classification_internal.h
mv_machine_learning/image_classification/include/IImageClassification.h
mv_machine_learning/image_classification/include/ImageClassification.h
mv_machine_learning/image_classification/src/ImageClassification.cpp
mv_machine_learning/image_classification/src/ImageClassificationAdapter.cpp
mv_machine_learning/image_classification/src/mv_image_classification.cpp
test/testsuites/machine_learning/image_classification/test_image_classification.cpp
test/testsuites/machine_learning/image_classification/test_image_classification_async.cpp

index f81bc54027c8427df55180fa5ef43a9d1222767e..94f702d2596290a5c02ae3b2e99a60384fb5609e 100644 (file)
@@ -147,24 +147,54 @@ int mv_image_classification_inference(mv_image_classification_h handle, mv_sourc
 int mv_image_classification_inference_async(mv_image_classification_h handle, mv_source_h source);
 
 /**
-        * @brief Gets the label value as a image classification inference result.
-        * @details Use this function to get the label value after calling @ref mv_image_classification_inference().
-        *
-        * @since_tizen 9.0
-        *
-        * @remarks The @a result must NOT be released using free()
-        *
-        * @param[in] handle        The handle to the image classification object.
-        * @param[out] out_label    A pointer to the label string.
-        *
-        * @return @c 0 on success, otherwise a negative error value
-        * @retval #MEDIA_VISION_ERROR_NONE Successful
-        * @retval #MEDIA_VISION_ERROR_INVALID_PARAMETER Invalid parameter
-        * @retval #MEDIA_VISION_ERROR_INVALID_OPERATION Invalid operation
-        *
-        * @pre Request an inference by calling @ref mv_image_classification_inference()
-        */
-int mv_image_classification_get_label(mv_image_classification_h handle, const char **out_label);
+ * @internal
+ * @brief Gets the image classification inference result count on the @a source.
+ *
+ * @since_tizen 9.0
+ *
+ * @param[in] handle       The handle to the inference
+ * @param[out] result_cnt  A number of results.
+ *
+ * @return @c 0 on success, otherwise a negative error value
+ * @retval #MEDIA_VISION_ERROR_NONE Successful
+ * @retval #MEDIA_VISION_ERROR_NOT_SUPPORTED Not supported
+ * @retval #MEDIA_VISION_ERROR_INVALID_PARAMETER Invalid parameter
+ * @retval #MEDIA_VISION_ERROR_INTERNAL          Internal error
+ *
+ * @pre Create a source handle by calling mv_create_source()
+ * @pre Create an inference handle by calling mv_image_classification_create()
+ * @pre Prepare an inference by calling mv_image_classification_configure()
+ * @pre Prepare an inference by calling mv_image_classification_prepare()
+ * @pre Request an inference by calling mv_image_classification_inference()
+ */
+int mv_image_classification_get_result_count(mv_image_classification_h handle, unsigned int *result_cnt);
+
+/**
+ * @internal
+ * @brief Gets the image classification inference result to a given index.
+ *
+ * @since_tizen 9.0
+ *
+ * @param[in] handle              The handle to the inference
+ * @param[in] index               A result index.
+ * @param[out] frame_number       A frame number inferenced.
+ * @param[out] label              A label name to a detected object.
+ *
+ * @return @c 0 on success, otherwise a negative error value
+ * @retval #MEDIA_VISION_ERROR_NONE Successful
+ * @retval #MEDIA_VISION_ERROR_NOT_SUPPORTED Not supported
+ * @retval #MEDIA_VISION_ERROR_INVALID_PARAMETER Invalid parameter
+ * @retval #MEDIA_VISION_ERROR_INTERNAL          Internal error
+ *
+ * @pre Create a source handle by calling mv_create_source()
+ * @pre Create an inference handle by calling mv_image_classification_create()
+ * @pre Prepare an inference by calling mv_image_classification_configure()
+ * @pre Prepare an inference by calling mv_image_classification_prepare()
+ * @pre Request an inference by calling mv_image_classification_inference()
+ * @pre Get result count by calling mv_image_classification_get_result_count()
+ */
+int mv_image_classification_get_result(mv_image_classification_h handle, unsigned int index,
+                                                                          unsigned long *frame_number, const char **label);
 
 /**
         * @brief Set user-given model information.
index ed3412ab28fde0766c5f7315a6a4af0c1ea8c6a5..b53f6528af93857a6283bcc45a5e33949db890f9 100644 (file)
@@ -41,6 +41,7 @@ public:
        virtual void perform(mv_source_h &mv_src) = 0;
        virtual void performAsync(ImageClassificationInput &input) = 0;
        virtual ImageClassificationResult &getOutput() = 0;
+       virtual ImageClassificationResult &getOutputCache() = 0;
 };
 
 } // machine_learning
index 7e4dfca735ffc58174dfd92a1305da6d89b4477d..956aa6a9c3a7d53f0825b84ab2d6182c17048458 100644 (file)
@@ -65,17 +65,18 @@ public:
        explicit ImageClassification(std::shared_ptr<Config> config);
        virtual ~ImageClassification() = default;
 
-       void preDestroy();
-       void setEngineInfo(std::string engine_type_name, std::string device_type_name);
-       unsigned int getNumberOfEngines();
-       const std::string &getEngineType(unsigned int engine_index);
-       unsigned int getNumberOfDevices(const std::string &engine_type);
-       const std::string &getDeviceType(const std::string &engine_type, unsigned int device_index);
-       void configure();
-       void prepare();
-       void perform(mv_source_h &mv_src);
-       void performAsync(ImageClassificationInput &input);
-       ImageClassificationResult &getOutput();
+       void preDestroy() override;
+       void setEngineInfo(std::string engine_type_name, std::string device_type_name) override;
+       unsigned int getNumberOfEngines() override;
+       const std::string &getEngineType(unsigned int engine_index) override;
+       unsigned int getNumberOfDevices(const std::string &engine_type) override;
+       const std::string &getDeviceType(const std::string &engine_type, unsigned int device_index) override;
+       void configure() override;
+       void prepare() override;
+       void perform(mv_source_h &mv_src) override;
+       void performAsync(ImageClassificationInput &input) override;
+       ImageClassificationResult &getOutput() override;
+       ImageClassificationResult &getOutputCache() override;
 };
 
 } // machine_learning
index 0a5533026460ab7f7aae5effc0bf758ff6924631..d5c36fc27e0384dbdf9ae0e3a2d1e3caa616311a 100644 (file)
@@ -279,6 +279,11 @@ template<typename T> ImageClassificationResult &ImageClassification<T>::getOutpu
        return _current_result;
 }
 
+template<typename T> ImageClassificationResult &ImageClassification<T>::getOutputCache()
+{
+       return _current_result;
+}
+
 template<typename T> void ImageClassification<T>::performAsync(ImageClassificationInput &input)
 {
        if (!_async_manager) {
index b4b78b7e3a8f23aede2eac840ee85af7673b928c..9a445ab0fb4281f9071763993063f44fa6310d08 100644 (file)
@@ -123,7 +123,7 @@ OutputBaseType &ImageClassificationAdapter::getOutput()
 
 OutputBaseType &ImageClassificationAdapter::getOutputCache()
 {
-       throw InvalidOperation("Not support yet.");
+       return _image_classification->getOutputCache();
 }
 
 }
index 89296e426e033a7edd40da9387c12145e53f31e3..9012c15383892e57f9a1987181b1eb54fccb9e46 100644 (file)
@@ -162,19 +162,51 @@ int mv_image_classification_inference_async(mv_image_classification_h handle, mv
        return MEDIA_VISION_ERROR_NONE;
 }
 
-int mv_image_classification_get_label(mv_image_classification_h handle, const char **out_label)
+int mv_image_classification_get_result_count(mv_image_classification_h handle, unsigned int *result_cnt)
 {
-       MEDIA_VISION_SUPPORT_CHECK(mv_check_feature_key(feature_keys, num_keys, false));
+       MEDIA_VISION_SUPPORT_CHECK(mv_check_feature_key(feature_keys, num_keys, true));
        MEDIA_VISION_INSTANCE_CHECK(handle);
-       MEDIA_VISION_NULL_ARG_CHECK(out_label);
+       MEDIA_VISION_INSTANCE_CHECK(result_cnt);
 
        MEDIA_VISION_FUNCTION_ENTER();
 
        try {
-               const auto &result =
-                               static_cast<ImageClassificationResult &>(machine_learning_native_get_result(handle, TASK_NAME));
+               auto &result = static_cast<ImageClassificationResult &>(machine_learning_native_get_result(handle, TASK_NAME));
+               // As of now, the result count of image classification task is always 1.
+               // TODO. consider for Multilabel image classification model later.
+               *result_cnt = static_cast<unsigned int>(!result.label.empty());
+       } catch (const BaseException &e) {
+               LOGE("%s", e.what());
+               return e.getError();
+       }
+
+       MEDIA_VISION_FUNCTION_LEAVE();
+
+       return MEDIA_VISION_ERROR_NONE;
+}
+
+int mv_image_classification_get_result(mv_image_classification_h handle, unsigned int index,
+                                                                          unsigned long *frame_number, const char **label)
+{
+       MEDIA_VISION_SUPPORT_CHECK(mv_check_feature_key(feature_keys, num_keys, true));
+       MEDIA_VISION_INSTANCE_CHECK(handle);
+       MEDIA_VISION_INSTANCE_CHECK(frame_number);
+       MEDIA_VISION_INSTANCE_CHECK(label);
+
+       MEDIA_VISION_FUNCTION_ENTER();
 
-               *out_label = result.label.c_str();
+       try {
+               auto &result =
+                               static_cast<ImageClassificationResult &>(machine_learning_native_get_result_cache(handle, TASK_NAME));
+               // As of now, the result count of image classification task is always 1.
+               unsigned int result_cnt = static_cast<unsigned int>(!result.label.empty());
+               if (index >= result_cnt) {
+                       LOGE("Invalid index(index = %u, result count = %u).", index, result_cnt);
+                       return MEDIA_VISION_ERROR_INVALID_PARAMETER;
+               }
+
+               *frame_number = result.frame_number;
+               *label = result.label.c_str();
        } catch (const BaseException &e) {
                LOGE("%s", e.what());
                return e.getError();
index 97282232e7cf4b353008e4a2594c166a0e2dbec3..05f15eec8fa85041894acdb120f9ec6a76a3a2cf 100644 (file)
@@ -130,12 +130,20 @@ TEST(ImageClassificationTest, InferenceShouldBeOk)
                ret = mv_image_classification_inference(handle, mv_source);
                ASSERT_EQ(ret, 0);
 
-               const char *label = NULL;
+               unsigned int cnt;
 
-               ret = mv_image_classification_get_label(handle, &label);
+               ret = mv_image_classification_get_result_count(handle, &cnt);
                ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
 
-               cout << label << endl;
+               for (unsigned long idx = 0; idx < cnt; ++idx) {
+                       unsigned long frame_number;
+                       const char *label = NULL;
+
+                       ret = mv_image_classification_get_result(handle, idx, &frame_number, &label);
+                       ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+                       cout << "frame number = " << frame_number << " label = " << label << endl;
+               }
 
                ret = mv_image_classification_destroy(handle);
                ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
index d99c934a6a566b01d13e2422ee4f4d24fc69784b..ad9e98fc586c971916479aded9097ac4483d14ff 100644 (file)
@@ -42,25 +42,36 @@ struct model_info {
 void image_classification_callback(void *user_data)
 {
        mv_image_classification_h handle = static_cast<mv_image_classification_h>(user_data);
-       unsigned int frame_number = 0;
+       bool is_loop_exit = false;
 
-       while (frame_number++ < MAX_INFERENCE_ITERATION - 10) {
-               const char *label = NULL;
+       while (!is_loop_exit) {
+               unsigned int cnt;
 
-               int ret = mv_image_classification_get_label(handle, &label);
+               int ret = mv_image_classification_get_result_count(handle, &cnt);
                if (ret == MEDIA_VISION_ERROR_INVALID_OPERATION)
                        break;
 
                ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
 
-               cout << "Expected label = BANANA"
-                        << " Actual label = " << label << endl;
+               for (unsigned long idx = 0; idx < cnt; ++idx) {
+                       const char *label = NULL;
+                       unsigned long frame_number;
 
-               string label_str(label);
+                       ret = mv_image_classification_get_result(handle, idx, &frame_number, &label);
+                       ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE);
+
+                       if (frame_number > MAX_INFERENCE_ITERATION - 10)
+                               is_loop_exit = true;
+
+                       cout << "Expected label = BANANA"
+                                << " Actual label = " << label << endl;
 
-               transform(label_str.begin(), label_str.end(), label_str.begin(), ::toupper);
+                       string label_str(label);
 
-               ASSERT_EQ(label_str, "BANANA");
+                       transform(label_str.begin(), label_str.end(), label_str.begin(), ::toupper);
+
+                       ASSERT_EQ(label_str, "BANANA");
+               }
        }
 }