IVGCVSW-2437 Inference test for TensorFlow Lite MobileNet SSD
[platform/upstream/armnn.git] / tests / MobileNetSsdDatabase.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6
7 #include "ObjectDetectionCommon.hpp"
8
9 #include <memory>
10 #include <string>
11 #include <vector>
12
13 #include <armnn/TypesUtils.hpp>
14
15 #include <boost/log/trivial.hpp>
16 #include <boost/numeric/conversion/cast.hpp>
17
18 #include <array>
19 #include <string>
20
21 #include "InferenceTestImage.hpp"
22
23 namespace
24 {
25
26 struct MobileNetSsdTestCaseData
27 {
28     MobileNetSsdTestCaseData(
29         std::vector<float> inputData,
30         std::vector<DetectedObject> expectedOutput)
31         : m_InputData(std::move(inputData))
32         , m_ExpectedOutput(std::move(expectedOutput))
33     {}
34
35     std::vector<float>          m_InputData;
36     std::vector<DetectedObject> m_ExpectedOutput;
37 };
38
39 class MobileNetSsdDatabase
40 {
41 public:
42     explicit MobileNetSsdDatabase(const std::string& imageDir);
43
44     std::unique_ptr<MobileNetSsdTestCaseData> GetTestCaseData(unsigned int testCaseId);
45
46 private:
47     std::string m_ImageDir;
48 };
49
50 constexpr unsigned int k_MobileNetSsdImageWidth  = 300u;
51 constexpr unsigned int k_MobileNetSsdImageHeight = k_MobileNetSsdImageWidth;
52
53 // Test cases
54 const std::array<ObjectDetectionInput, 1> g_PerTestCaseInput =
55 {
56     ObjectDetectionInput
57     {
58         "Cat.jpg",
59         DetectedObject(16, BoundingBox(0.21678525f, 0.0859828f, 0.9271242f, 0.9453231f), 0.79296875f)
60     }
61 };
62
63 MobileNetSsdDatabase::MobileNetSsdDatabase(const std::string& imageDir)
64     : m_ImageDir(imageDir)
65 {}
66
67 std::unique_ptr<MobileNetSsdTestCaseData> MobileNetSsdDatabase::GetTestCaseData(unsigned int testCaseId)
68 {
69     const unsigned int safeTestCaseId =
70         testCaseId % boost::numeric_cast<unsigned int>(g_PerTestCaseInput.size());
71     const ObjectDetectionInput& testCaseInput = g_PerTestCaseInput[safeTestCaseId];
72
73     // Load test case input
74     const std::string imagePath = m_ImageDir + testCaseInput.first;
75     std::vector<float> imageData;
76     try
77     {
78         InferenceTestImage image(imagePath.c_str());
79
80         // Resize image (if needed)
81         const unsigned int width  = image.GetWidth();
82         const unsigned int height = image.GetHeight();
83         if (width != k_MobileNetSsdImageWidth || height != k_MobileNetSsdImageHeight)
84         {
85             image.Resize(k_MobileNetSsdImageWidth, k_MobileNetSsdImageHeight, CHECK_LOCATION());
86         }
87
88         // Get image data as a vector of floats
89         imageData = GetImageDataInArmNnLayoutAsNormalizedFloats(ImageChannelLayout::Rgb, image);
90     }
91     catch (const InferenceTestImageException& e)
92     {
93         BOOST_LOG_TRIVIAL(fatal) << "Failed to load image for test case " << testCaseId << ". Error: " << e.what();
94         return nullptr;
95     }
96
97     // Prepare test case expected output
98     std::vector<DetectedObject> expectedOutput;
99     expectedOutput.reserve(1);
100     expectedOutput.push_back(testCaseInput.second);
101
102     return std::make_unique<MobileNetSsdTestCaseData>(std::move(imageData), std::move(expectedOutput));
103 }
104
105 } // anonymous namespace