MLCE-347 'REDUCE_MIN, REDUCE_MAX, REDUCE_SUM Support'
[platform/upstream/armnn.git] / tests / MnistDatabase.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include "MnistDatabase.hpp"
6
7 #include <armnn/Logging.hpp>
8
9 #include <fstream>
10 #include <vector>
11
12 constexpr int g_kMnistImageByteSize = 28 * 28;
13
14 void EndianSwap(unsigned int &x)
15 {
16     x = (x >> 24) | ((x << 8) & 0x00FF0000) | ((x >> 8) & 0x0000FF00) | (x << 24);
17 }
18
19 MnistDatabase::MnistDatabase(const std::string& binaryFileDirectory, bool scaleValues)
20     : m_BinaryDirectory(binaryFileDirectory)
21     , m_ScaleValues(scaleValues)
22 {
23 }
24
25 std::unique_ptr<MnistDatabase::TTestCaseData> MnistDatabase::GetTestCaseData(unsigned int testCaseId)
26 {
27     std::vector<unsigned char> I(g_kMnistImageByteSize);
28     unsigned int label = 0;
29
30     std::string imagePath = m_BinaryDirectory + std::string("t10k-images.idx3-ubyte");
31     std::string labelPath = m_BinaryDirectory + std::string("t10k-labels.idx1-ubyte");
32
33     std::ifstream imageStream(imagePath, std::ios::binary);
34     std::ifstream labelStream(labelPath, std::ios::binary);
35
36     if (!imageStream.is_open())
37     {
38         ARMNN_LOG(fatal) << "Failed to load " << imagePath;
39         return nullptr;
40     }
41     if (!labelStream.is_open())
42     {
43         ARMNN_LOG(fatal) << "Failed to load " << imagePath;
44         return nullptr;
45     }
46
47     unsigned int magic, num, row, col;
48
49     // Checks the files have the correct header.
50     imageStream.read(reinterpret_cast<char*>(&magic), sizeof(magic));
51     if (magic != 0x03080000)
52     {
53         ARMNN_LOG(fatal) << "Failed to read " << imagePath;
54         return nullptr;
55     }
56     labelStream.read(reinterpret_cast<char*>(&magic), sizeof(magic));
57     if (magic != 0x01080000)
58     {
59         ARMNN_LOG(fatal) << "Failed to read " << labelPath;
60         return nullptr;
61     }
62
63     // Endian swaps the image and label file - all the integers in the files are stored in MSB first(high endian)
64     // format, hence it needs to flip the bytes of the header if using it on Intel processors or low-endian machines
65     labelStream.read(reinterpret_cast<char*>(&num), sizeof(num));
66     imageStream.read(reinterpret_cast<char*>(&num), sizeof(num));
67     EndianSwap(num);
68     imageStream.read(reinterpret_cast<char*>(&row), sizeof(row));
69     EndianSwap(row);
70     imageStream.read(reinterpret_cast<char*>(&col), sizeof(col));
71     EndianSwap(col);
72
73     // Reads image and label into memory.
74     imageStream.seekg(testCaseId * g_kMnistImageByteSize, std::ios_base::cur);
75     imageStream.read(reinterpret_cast<char*>(&I[0]), g_kMnistImageByteSize);
76     labelStream.seekg(testCaseId, std::ios_base::cur);
77     labelStream.read(reinterpret_cast<char*>(&label), 1);
78
79     if (!imageStream.good())
80     {
81         ARMNN_LOG(fatal) << "Failed to read " << imagePath;
82         return nullptr;
83     }
84     if (!labelStream.good())
85     {
86         ARMNN_LOG(fatal) << "Failed to read " << labelPath;
87         return nullptr;
88     }
89
90     std::vector<float> inputImageData;
91     inputImageData.resize(g_kMnistImageByteSize);
92
93     for (unsigned int i = 0; i < col * row; ++i)
94     {
95         // Static_cast of unsigned char is safe with float
96         inputImageData[i] = static_cast<float>(I[i]);
97
98         if(m_ScaleValues)
99         {
100             inputImageData[i] /= 255.0f;
101         }
102     }
103
104     return std::make_unique<TTestCaseData>(label, std::move(inputImageData));
105 }