From: Inki Dae Date: Thu, 25 Jan 2024 08:13:15 +0000 (+0900) Subject: mv_machine_learning: introduce get_result_count API for image classification X-Git-Tag: accepted/tizen/unified/20240213.171947~2^2~15 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=a72aba9b845f7740ead6b9e74e18ccda52eb1a73;p=platform%2Fcore%2Fapi%2Fmediavision.git mv_machine_learning: introduce get_result_count API for image classification [Issue type] : new feature Introduce get_result_count API for image classification task group. In user perspective, this API provides information on how many results exist so that user can request each result corresponding to a user-given index. And also, in framework perspective, it provides consistent API behavior - get_result_count API call updates _current_result of task group by calling getOutput function of ITask, and get_result API call returns _current_result value by calling getOutputCache function of ITask. And we are enough with get_result_count and get_result API so drop existing get_label API. Change-Id: I9c1cb9e855494474af1510bbaf94febbdc57f05e Signed-off-by: Inki Dae --- diff --git a/include/mv_image_classification_internal.h b/include/mv_image_classification_internal.h index f81bc540..94f702d2 100644 --- a/include/mv_image_classification_internal.h +++ b/include/mv_image_classification_internal.h @@ -147,24 +147,54 @@ int mv_image_classification_inference(mv_image_classification_h handle, mv_sourc int mv_image_classification_inference_async(mv_image_classification_h handle, mv_source_h source); /** - * @brief Gets the label value as a image classification inference result. - * @details Use this function to get the label value after calling @ref mv_image_classification_inference(). - * - * @since_tizen 9.0 - * - * @remarks The @a result must NOT be released using free() - * - * @param[in] handle The handle to the image classification object. - * @param[out] out_label A pointer to the label string. - * - * @return @c 0 on success, otherwise a negative error value - * @retval #MEDIA_VISION_ERROR_NONE Successful - * @retval #MEDIA_VISION_ERROR_INVALID_PARAMETER Invalid parameter - * @retval #MEDIA_VISION_ERROR_INVALID_OPERATION Invalid operation - * - * @pre Request an inference by calling @ref mv_image_classification_inference() - */ -int mv_image_classification_get_label(mv_image_classification_h handle, const char **out_label); + * @internal + * @brief Gets the image classification inference result count on the @a source. + * + * @since_tizen 9.0 + * + * @param[in] handle The handle to the inference + * @param[out] result_cnt A number of results. + * + * @return @c 0 on success, otherwise a negative error value + * @retval #MEDIA_VISION_ERROR_NONE Successful + * @retval #MEDIA_VISION_ERROR_NOT_SUPPORTED Not supported + * @retval #MEDIA_VISION_ERROR_INVALID_PARAMETER Invalid parameter + * @retval #MEDIA_VISION_ERROR_INTERNAL Internal error + * + * @pre Create a source handle by calling mv_create_source() + * @pre Create an inference handle by calling mv_image_classification_create() + * @pre Prepare an inference by calling mv_image_classification_configure() + * @pre Prepare an inference by calling mv_image_classification_prepare() + * @pre Request an inference by calling mv_image_classification_inference() + */ +int mv_image_classification_get_result_count(mv_image_classification_h handle, unsigned int *result_cnt); + +/** + * @internal + * @brief Gets the image classification inference result to a given index. + * + * @since_tizen 9.0 + * + * @param[in] handle The handle to the inference + * @param[in] index A result index. + * @param[out] frame_number A frame number inferenced. + * @param[out] label A label name to a detected object. + * + * @return @c 0 on success, otherwise a negative error value + * @retval #MEDIA_VISION_ERROR_NONE Successful + * @retval #MEDIA_VISION_ERROR_NOT_SUPPORTED Not supported + * @retval #MEDIA_VISION_ERROR_INVALID_PARAMETER Invalid parameter + * @retval #MEDIA_VISION_ERROR_INTERNAL Internal error + * + * @pre Create a source handle by calling mv_create_source() + * @pre Create an inference handle by calling mv_image_classification_create() + * @pre Prepare an inference by calling mv_image_classification_configure() + * @pre Prepare an inference by calling mv_image_classification_prepare() + * @pre Request an inference by calling mv_image_classification_inference() + * @pre Get result count by calling mv_image_classification_get_result_count() + */ +int mv_image_classification_get_result(mv_image_classification_h handle, unsigned int index, + unsigned long *frame_number, const char **label); /** * @brief Set user-given model information. diff --git a/mv_machine_learning/image_classification/include/IImageClassification.h b/mv_machine_learning/image_classification/include/IImageClassification.h index ed3412ab..b53f6528 100644 --- a/mv_machine_learning/image_classification/include/IImageClassification.h +++ b/mv_machine_learning/image_classification/include/IImageClassification.h @@ -41,6 +41,7 @@ public: virtual void perform(mv_source_h &mv_src) = 0; virtual void performAsync(ImageClassificationInput &input) = 0; virtual ImageClassificationResult &getOutput() = 0; + virtual ImageClassificationResult &getOutputCache() = 0; }; } // machine_learning diff --git a/mv_machine_learning/image_classification/include/ImageClassification.h b/mv_machine_learning/image_classification/include/ImageClassification.h index 7e4dfca7..956aa6a9 100644 --- a/mv_machine_learning/image_classification/include/ImageClassification.h +++ b/mv_machine_learning/image_classification/include/ImageClassification.h @@ -65,17 +65,18 @@ public: explicit ImageClassification(std::shared_ptr config); virtual ~ImageClassification() = default; - void preDestroy(); - void setEngineInfo(std::string engine_type_name, std::string device_type_name); - unsigned int getNumberOfEngines(); - const std::string &getEngineType(unsigned int engine_index); - unsigned int getNumberOfDevices(const std::string &engine_type); - const std::string &getDeviceType(const std::string &engine_type, unsigned int device_index); - void configure(); - void prepare(); - void perform(mv_source_h &mv_src); - void performAsync(ImageClassificationInput &input); - ImageClassificationResult &getOutput(); + void preDestroy() override; + void setEngineInfo(std::string engine_type_name, std::string device_type_name) override; + unsigned int getNumberOfEngines() override; + const std::string &getEngineType(unsigned int engine_index) override; + unsigned int getNumberOfDevices(const std::string &engine_type) override; + const std::string &getDeviceType(const std::string &engine_type, unsigned int device_index) override; + void configure() override; + void prepare() override; + void perform(mv_source_h &mv_src) override; + void performAsync(ImageClassificationInput &input) override; + ImageClassificationResult &getOutput() override; + ImageClassificationResult &getOutputCache() override; }; } // machine_learning diff --git a/mv_machine_learning/image_classification/src/ImageClassification.cpp b/mv_machine_learning/image_classification/src/ImageClassification.cpp index 0a553302..d5c36fc2 100644 --- a/mv_machine_learning/image_classification/src/ImageClassification.cpp +++ b/mv_machine_learning/image_classification/src/ImageClassification.cpp @@ -279,6 +279,11 @@ template ImageClassificationResult &ImageClassification::getOutpu return _current_result; } +template ImageClassificationResult &ImageClassification::getOutputCache() +{ + return _current_result; +} + template void ImageClassification::performAsync(ImageClassificationInput &input) { if (!_async_manager) { diff --git a/mv_machine_learning/image_classification/src/ImageClassificationAdapter.cpp b/mv_machine_learning/image_classification/src/ImageClassificationAdapter.cpp index b4b78b7e..9a445ab0 100644 --- a/mv_machine_learning/image_classification/src/ImageClassificationAdapter.cpp +++ b/mv_machine_learning/image_classification/src/ImageClassificationAdapter.cpp @@ -123,7 +123,7 @@ OutputBaseType &ImageClassificationAdapter::getOutput() OutputBaseType &ImageClassificationAdapter::getOutputCache() { - throw InvalidOperation("Not support yet."); + return _image_classification->getOutputCache(); } } diff --git a/mv_machine_learning/image_classification/src/mv_image_classification.cpp b/mv_machine_learning/image_classification/src/mv_image_classification.cpp index 89296e42..9012c153 100644 --- a/mv_machine_learning/image_classification/src/mv_image_classification.cpp +++ b/mv_machine_learning/image_classification/src/mv_image_classification.cpp @@ -162,19 +162,51 @@ int mv_image_classification_inference_async(mv_image_classification_h handle, mv return MEDIA_VISION_ERROR_NONE; } -int mv_image_classification_get_label(mv_image_classification_h handle, const char **out_label) +int mv_image_classification_get_result_count(mv_image_classification_h handle, unsigned int *result_cnt) { - MEDIA_VISION_SUPPORT_CHECK(mv_check_feature_key(feature_keys, num_keys, false)); + MEDIA_VISION_SUPPORT_CHECK(mv_check_feature_key(feature_keys, num_keys, true)); MEDIA_VISION_INSTANCE_CHECK(handle); - MEDIA_VISION_NULL_ARG_CHECK(out_label); + MEDIA_VISION_INSTANCE_CHECK(result_cnt); MEDIA_VISION_FUNCTION_ENTER(); try { - const auto &result = - static_cast(machine_learning_native_get_result(handle, TASK_NAME)); + auto &result = static_cast(machine_learning_native_get_result(handle, TASK_NAME)); + // As of now, the result count of image classification task is always 1. + // TODO. consider for Multilabel image classification model later. + *result_cnt = static_cast(!result.label.empty()); + } catch (const BaseException &e) { + LOGE("%s", e.what()); + return e.getError(); + } + + MEDIA_VISION_FUNCTION_LEAVE(); + + return MEDIA_VISION_ERROR_NONE; +} + +int mv_image_classification_get_result(mv_image_classification_h handle, unsigned int index, + unsigned long *frame_number, const char **label) +{ + MEDIA_VISION_SUPPORT_CHECK(mv_check_feature_key(feature_keys, num_keys, true)); + MEDIA_VISION_INSTANCE_CHECK(handle); + MEDIA_VISION_INSTANCE_CHECK(frame_number); + MEDIA_VISION_INSTANCE_CHECK(label); + + MEDIA_VISION_FUNCTION_ENTER(); - *out_label = result.label.c_str(); + try { + auto &result = + static_cast(machine_learning_native_get_result_cache(handle, TASK_NAME)); + // As of now, the result count of image classification task is always 1. + unsigned int result_cnt = static_cast(!result.label.empty()); + if (index >= result_cnt) { + LOGE("Invalid index(index = %u, result count = %u).", index, result_cnt); + return MEDIA_VISION_ERROR_INVALID_PARAMETER; + } + + *frame_number = result.frame_number; + *label = result.label.c_str(); } catch (const BaseException &e) { LOGE("%s", e.what()); return e.getError(); diff --git a/test/testsuites/machine_learning/image_classification/test_image_classification.cpp b/test/testsuites/machine_learning/image_classification/test_image_classification.cpp index 97282232..05f15eec 100644 --- a/test/testsuites/machine_learning/image_classification/test_image_classification.cpp +++ b/test/testsuites/machine_learning/image_classification/test_image_classification.cpp @@ -130,12 +130,20 @@ TEST(ImageClassificationTest, InferenceShouldBeOk) ret = mv_image_classification_inference(handle, mv_source); ASSERT_EQ(ret, 0); - const char *label = NULL; + unsigned int cnt; - ret = mv_image_classification_get_label(handle, &label); + ret = mv_image_classification_get_result_count(handle, &cnt); ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE); - cout << label << endl; + for (unsigned long idx = 0; idx < cnt; ++idx) { + unsigned long frame_number; + const char *label = NULL; + + ret = mv_image_classification_get_result(handle, idx, &frame_number, &label); + ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE); + + cout << "frame number = " << frame_number << " label = " << label << endl; + } ret = mv_image_classification_destroy(handle); ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE); diff --git a/test/testsuites/machine_learning/image_classification/test_image_classification_async.cpp b/test/testsuites/machine_learning/image_classification/test_image_classification_async.cpp index d99c934a..ad9e98fc 100644 --- a/test/testsuites/machine_learning/image_classification/test_image_classification_async.cpp +++ b/test/testsuites/machine_learning/image_classification/test_image_classification_async.cpp @@ -42,25 +42,36 @@ struct model_info { void image_classification_callback(void *user_data) { mv_image_classification_h handle = static_cast(user_data); - unsigned int frame_number = 0; + bool is_loop_exit = false; - while (frame_number++ < MAX_INFERENCE_ITERATION - 10) { - const char *label = NULL; + while (!is_loop_exit) { + unsigned int cnt; - int ret = mv_image_classification_get_label(handle, &label); + int ret = mv_image_classification_get_result_count(handle, &cnt); if (ret == MEDIA_VISION_ERROR_INVALID_OPERATION) break; ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE); - cout << "Expected label = BANANA" - << " Actual label = " << label << endl; + for (unsigned long idx = 0; idx < cnt; ++idx) { + const char *label = NULL; + unsigned long frame_number; - string label_str(label); + ret = mv_image_classification_get_result(handle, idx, &frame_number, &label); + ASSERT_EQ(ret, MEDIA_VISION_ERROR_NONE); + + if (frame_number > MAX_INFERENCE_ITERATION - 10) + is_loop_exit = true; + + cout << "Expected label = BANANA" + << " Actual label = " << label << endl; - transform(label_str.begin(), label_str.end(), label_str.begin(), ::toupper); + string label_str(label); - ASSERT_EQ(label_str, "BANANA"); + transform(label_str.begin(), label_str.end(), label_str.begin(), ::toupper); + + ASSERT_EQ(label_str, "BANANA"); + } } }