test: Add Image classicifation models 57/262657/3
authorKwang Son <k.son@samsung.com>
Tue, 17 Aug 2021 09:26:32 +0000 (05:26 -0400)
committerKwang Son <k.son@samsung.com>
Wed, 18 Aug 2021 01:07:42 +0000 (21:07 -0400)
Change-Id: Ie394dedaf72ffbfc76b2f69d41ddf6ee4c83cfaf
Signed-off-by: Kwang Son <k.son@samsung.com>
test/testsuites/machine_learning/inference/test_image_classification.cpp

index 96f8075..7c6bc0c 100644 (file)
@@ -1,4 +1,5 @@
 #include <gtest/gtest.h>
+#include <string>
 #include <ImageHelper.h>
 #include "test_inference_helper.hpp"
 
 #define IMG_BANANA \
        MV_CONFIG_PATH \
        "/res/inference/images/banana.jpg"
+#define IC_TFLITE_WEIGHT_MOBILENET_V2_224_PATH \
+       MV_CONFIG_PATH                             \
+       "/models/IC/tflite/ic_mobilenet_v2_224x224.tflite"
+#define IC_TFLITE_WEIGHT_DENSENET_224_PATH \
+       MV_CONFIG_PATH                         \
+       "/models/IC/tflite/ic_densenet_224x224.tflite"
+#define IC_TFLITE_WEIGHT_NASNET_224_PATH \
+       MV_CONFIG_PATH                       \
+       "/models/IC/tflite/ic_nasnet_224x224.tflite"
+#define IC_TFLITE_WEIGHT_MNASNET_224_PATH \
+       MV_CONFIG_PATH                        \
+       "/models/IC/tflite/ic_mnasnet_224x224.tflite"
+#define IC_TFLITE_WEIGHT_SQUEEZENET_224_PATH \
+       MV_CONFIG_PATH                           \
+       "/models/IC/tflite/ic_squeezenet_224x224.tflite"
+#define IC_TFLITE_WEIGHT_QUANT_MOBILENET_V1_224_PATH \
+       MV_CONFIG_PATH                                   \
+       "/models/IC/tflite/quant_mobilenet_v1_224x224.tflite"
 
 void _image_classified_cb(mv_source_h source, const int number_of_classes,
                                                  const int *indices, const char **names,
                                                  const float *confidences, void *user_data)
 {
-       ASSERT_EQ(number_of_classes, 1);
-       EXPECT_STREQ(names[0], "banana");
+       const std::string answer = "banana";
+       auto answer_found = false;
+       for (int i = 0; i < number_of_classes; i++) {
+               if (answer == names[i]) {
+                       answer_found = true;
+               }
+       }
+       EXPECT_TRUE(answer_found);
 }
 
 class TestImageClassification : public ::testing::Test
@@ -37,6 +62,18 @@ public:
                EXPECT_EQ(mv_destroy_engine_config(engine_cfg),
                                  MEDIA_VISION_ERROR_NONE);
        }
+       void inference_banana()
+       {
+               ASSERT_EQ(mv_inference_configure(infer, engine_cfg),
+                                 MEDIA_VISION_ERROR_NONE);
+               ASSERT_EQ(mv_inference_prepare(infer), MEDIA_VISION_ERROR_NONE);
+               ASSERT_EQ(MediaVision::Common::ImageHelper::loadImageToSource(
+                                                 IMG_BANANA, mv_source),
+                                 MEDIA_VISION_ERROR_NONE);
+               ASSERT_EQ(mv_inference_image_classify(mv_source, infer, NULL,
+                                                                                         _image_classified_cb, NULL),
+                                 MEDIA_VISION_ERROR_NONE);
+       }
        mv_engine_config_h engine_cfg;
        mv_inference_h infer;
        mv_source_h mv_source;
@@ -47,14 +84,53 @@ TEST_F(TestImageClassification, CPU_TFLITE_MobilenetV1)
        engine_config_hosted_cpu_tflite_user_model(
                        engine_cfg, IC_TFLITE_WEIGHT_MOBILENET_V1_224_PATH,
                        IC_LABEL_MOBILENET_V1_224_PATH);
-       EXPECT_EQ(mv_inference_configure(infer, engine_cfg),
-                         MEDIA_VISION_ERROR_NONE);
-       EXPECT_EQ(mv_inference_prepare(infer), MEDIA_VISION_ERROR_NONE);
-       EXPECT_EQ(MediaVision::Common::ImageHelper::loadImageToSource(IMG_BANANA,
-                                                                                                                                 mv_source),
-                         MEDIA_VISION_ERROR_NONE);
-       EXPECT_EQ(mv_inference_image_classify(mv_source, infer, NULL,
-                                                                                 _image_classified_cb, NULL),
-                         MEDIA_VISION_ERROR_NONE);
-       EXPECT_EQ(mv_source_clear(mv_source), MEDIA_VISION_ERROR_NONE);
+       inference_banana();
+}
+
+TEST_F(TestImageClassification, CPU_TFLITE_MobilenetV2)
+{
+       engine_config_hosted_cpu_tflite_user_model(
+                       engine_cfg, IC_TFLITE_WEIGHT_MOBILENET_V2_224_PATH,
+                       IC_LABEL_MOBILENET_V1_224_PATH);
+       inference_banana();
+}
+
+TEST_F(TestImageClassification, CPU_TFLITE_Densenet)
+{
+       engine_config_hosted_cpu_tflite_user_model(
+                       engine_cfg, IC_TFLITE_WEIGHT_DENSENET_224_PATH,
+                       IC_LABEL_MOBILENET_V1_224_PATH);
+       inference_banana();
 }
+
+TEST_F(TestImageClassification, CPU_TFLITE_Nasnet)
+{
+       engine_config_hosted_cpu_tflite_user_model(engine_cfg,
+                                                                                          IC_TFLITE_WEIGHT_NASNET_224_PATH,
+                                                                                          IC_LABEL_MOBILENET_V1_224_PATH);
+       inference_banana();
+}
+
+TEST_F(TestImageClassification, CPU_TFLITE_MNasnet)
+{
+       engine_config_hosted_cpu_tflite_user_model(
+                       engine_cfg, IC_TFLITE_WEIGHT_MNASNET_224_PATH,
+                       IC_LABEL_MOBILENET_V1_224_PATH);
+       inference_banana();
+}
+
+TEST_F(TestImageClassification, CPU_TFLITE_Squeezenet)
+{
+       engine_config_hosted_cpu_tflite_user_model(
+                       engine_cfg, IC_TFLITE_WEIGHT_SQUEEZENET_224_PATH,
+                       IC_LABEL_MOBILENET_V1_224_PATH);
+       inference_banana();
+}
+
+TEST_F(TestImageClassification, CPU_TFLITE_QUANT_MobilenetV1)
+{
+       engine_config_hosted_cpu_tflite_user_model(
+                       engine_cfg, IC_TFLITE_WEIGHT_QUANT_MOBILENET_V1_224_PATH,
+                       IC_LABEL_MOBILENET_V1_224_PATH);
+       inference_banana();
+}
\ No newline at end of file