Release 18.03
[platform/upstream/armnn.git] / tests / ImageNetDatabase.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5 #include "InferenceTestImage.hpp"
6 #include "ImageNetDatabase.hpp"
7
8 #include <boost/numeric/conversion/cast.hpp>
9 #include <boost/log/trivial.hpp>
10 #include <boost/assert.hpp>
11 #include <boost/format.hpp>
12
13 #include <iostream>
14 #include <fcntl.h>
15 #include <array>
16
17 const std::vector<ImageSet> g_DefaultImageSet =
18 {
19     {"shark.jpg", 2}
20 };
21
22 ImageNetDatabase::ImageNetDatabase(const std::string& binaryFileDirectory, unsigned int width, unsigned int height,
23                                    const std::vector<ImageSet>& imageSet)
24 :   m_BinaryDirectory(binaryFileDirectory)
25 ,   m_Height(height)
26 ,   m_Width(width)
27 ,   m_ImageSet(imageSet.empty() ? g_DefaultImageSet : imageSet)
28 {
29 }
30
31 std::unique_ptr<ImageNetDatabase::TTestCaseData> ImageNetDatabase::GetTestCaseData(unsigned int testCaseId)
32 {
33     testCaseId = testCaseId % boost::numeric_cast<unsigned int>(m_ImageSet.size());
34     const ImageSet& imageSet = m_ImageSet[testCaseId];
35     const std::string fullPath = m_BinaryDirectory + imageSet.first;
36
37     InferenceTestImage image(fullPath.c_str());
38     image.Resize(m_Width, m_Height);
39
40     // The model expects image data in BGR format
41     std::vector<float> inputImageData = GetImageDataInArmNnLayoutAsFloatsSubtractingMean(ImageChannelLayout::Bgr,
42                                                                                          image, m_MeanBgr);
43
44     // list of labels: https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a
45     const unsigned int label = imageSet.second;
46     return std::make_unique<TTestCaseData>(label, std::move(inputImageData));
47 }