#include <gtest/gtest.h>
+#include <string>
#include <ImageHelper.h>
#include "test_inference_helper.hpp"
#define IMG_BANANA \
MV_CONFIG_PATH \
"/res/inference/images/banana.jpg"
+#define IC_TFLITE_WEIGHT_MOBILENET_V2_224_PATH \
+ MV_CONFIG_PATH \
+ "/models/IC/tflite/ic_mobilenet_v2_224x224.tflite"
+#define IC_TFLITE_WEIGHT_DENSENET_224_PATH \
+ MV_CONFIG_PATH \
+ "/models/IC/tflite/ic_densenet_224x224.tflite"
+#define IC_TFLITE_WEIGHT_NASNET_224_PATH \
+ MV_CONFIG_PATH \
+ "/models/IC/tflite/ic_nasnet_224x224.tflite"
+#define IC_TFLITE_WEIGHT_MNASNET_224_PATH \
+ MV_CONFIG_PATH \
+ "/models/IC/tflite/ic_mnasnet_224x224.tflite"
+#define IC_TFLITE_WEIGHT_SQUEEZENET_224_PATH \
+ MV_CONFIG_PATH \
+ "/models/IC/tflite/ic_squeezenet_224x224.tflite"
+#define IC_TFLITE_WEIGHT_QUANT_MOBILENET_V1_224_PATH \
+ MV_CONFIG_PATH \
+ "/models/IC/tflite/quant_mobilenet_v1_224x224.tflite"
void _image_classified_cb(mv_source_h source, const int number_of_classes,
const int *indices, const char **names,
const float *confidences, void *user_data)
{
- ASSERT_EQ(number_of_classes, 1);
- EXPECT_STREQ(names[0], "banana");
+ const std::string answer = "banana";
+ auto answer_found = false;
+ for (int i = 0; i < number_of_classes; i++) {
+ if (answer == names[i]) {
+ answer_found = true;
+ }
+ }
+ EXPECT_TRUE(answer_found);
}
class TestImageClassification : public ::testing::Test
EXPECT_EQ(mv_destroy_engine_config(engine_cfg),
MEDIA_VISION_ERROR_NONE);
}
+ void inference_banana()
+ {
+ ASSERT_EQ(mv_inference_configure(infer, engine_cfg),
+ MEDIA_VISION_ERROR_NONE);
+ ASSERT_EQ(mv_inference_prepare(infer), MEDIA_VISION_ERROR_NONE);
+ ASSERT_EQ(MediaVision::Common::ImageHelper::loadImageToSource(
+ IMG_BANANA, mv_source),
+ MEDIA_VISION_ERROR_NONE);
+ ASSERT_EQ(mv_inference_image_classify(mv_source, infer, NULL,
+ _image_classified_cb, NULL),
+ MEDIA_VISION_ERROR_NONE);
+ }
mv_engine_config_h engine_cfg;
mv_inference_h infer;
mv_source_h mv_source;
engine_config_hosted_cpu_tflite_user_model(
engine_cfg, IC_TFLITE_WEIGHT_MOBILENET_V1_224_PATH,
IC_LABEL_MOBILENET_V1_224_PATH);
- EXPECT_EQ(mv_inference_configure(infer, engine_cfg),
- MEDIA_VISION_ERROR_NONE);
- EXPECT_EQ(mv_inference_prepare(infer), MEDIA_VISION_ERROR_NONE);
- EXPECT_EQ(MediaVision::Common::ImageHelper::loadImageToSource(IMG_BANANA,
- mv_source),
- MEDIA_VISION_ERROR_NONE);
- EXPECT_EQ(mv_inference_image_classify(mv_source, infer, NULL,
- _image_classified_cb, NULL),
- MEDIA_VISION_ERROR_NONE);
- EXPECT_EQ(mv_source_clear(mv_source), MEDIA_VISION_ERROR_NONE);
+ inference_banana();
+}
+
+TEST_F(TestImageClassification, CPU_TFLITE_MobilenetV2)
+{
+ engine_config_hosted_cpu_tflite_user_model(
+ engine_cfg, IC_TFLITE_WEIGHT_MOBILENET_V2_224_PATH,
+ IC_LABEL_MOBILENET_V1_224_PATH);
+ inference_banana();
+}
+
+TEST_F(TestImageClassification, CPU_TFLITE_Densenet)
+{
+ engine_config_hosted_cpu_tflite_user_model(
+ engine_cfg, IC_TFLITE_WEIGHT_DENSENET_224_PATH,
+ IC_LABEL_MOBILENET_V1_224_PATH);
+ inference_banana();
}
+
+TEST_F(TestImageClassification, CPU_TFLITE_Nasnet)
+{
+ engine_config_hosted_cpu_tflite_user_model(engine_cfg,
+ IC_TFLITE_WEIGHT_NASNET_224_PATH,
+ IC_LABEL_MOBILENET_V1_224_PATH);
+ inference_banana();
+}
+
+TEST_F(TestImageClassification, CPU_TFLITE_MNasnet)
+{
+ engine_config_hosted_cpu_tflite_user_model(
+ engine_cfg, IC_TFLITE_WEIGHT_MNASNET_224_PATH,
+ IC_LABEL_MOBILENET_V1_224_PATH);
+ inference_banana();
+}
+
+TEST_F(TestImageClassification, CPU_TFLITE_Squeezenet)
+{
+ engine_config_hosted_cpu_tflite_user_model(
+ engine_cfg, IC_TFLITE_WEIGHT_SQUEEZENET_224_PATH,
+ IC_LABEL_MOBILENET_V1_224_PATH);
+ inference_banana();
+}
+
+TEST_F(TestImageClassification, CPU_TFLITE_QUANT_MobilenetV1)
+{
+ engine_config_hosted_cpu_tflite_user_model(
+ engine_cfg, IC_TFLITE_WEIGHT_QUANT_MOBILENET_V1_224_PATH,
+ IC_LABEL_MOBILENET_V1_224_PATH);
+ inference_banana();
+}
\ No newline at end of file