Release 18.08
[platform/upstream/armnn.git] / tests / InferenceTest.inl
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5 #include "InferenceTest.hpp"
6
7 #include <boost/algorithm/string.hpp>
8 #include <boost/numeric/conversion/cast.hpp>
9 #include <boost/log/trivial.hpp>
10 #include <boost/filesystem/path.hpp>
11 #include <boost/assert.hpp>
12 #include <boost/format.hpp>
13 #include <boost/program_options.hpp>
14 #include <boost/filesystem/operations.hpp>
15
16 #include <fstream>
17 #include <iostream>
18 #include <iomanip>
19 #include <array>
20 #include <chrono>
21
22 using namespace std;
23 using namespace std::chrono;
24 using namespace armnn::test;
25
26 namespace armnn
27 {
28 namespace test
29 {
30
31
32 template <typename TTestCaseDatabase, typename TModel>
33 ClassifierTestCase<TTestCaseDatabase, TModel>::ClassifierTestCase(
34     int& numInferencesRef,
35     int& numCorrectInferencesRef,
36     const std::vector<unsigned int>& validationPredictions,
37     std::vector<unsigned int>* validationPredictionsOut,
38     TModel& model,
39     unsigned int testCaseId,
40     unsigned int label,
41     std::vector<typename TModel::DataType> modelInput)
42     : InferenceModelTestCase<TModel>(model, testCaseId, std::move(modelInput), model.GetOutputSize())
43     , m_Label(label)
44     , m_QuantizationParams(model.GetQuantizationParams())
45     , m_NumInferencesRef(numInferencesRef)
46     , m_NumCorrectInferencesRef(numCorrectInferencesRef)
47     , m_ValidationPredictions(validationPredictions)
48     , m_ValidationPredictionsOut(validationPredictionsOut)
49 {
50 }
51
52 template <typename TTestCaseDatabase, typename TModel>
53 TestCaseResult ClassifierTestCase<TTestCaseDatabase, TModel>::ProcessResult(const InferenceTestOptions& params)
54 {
55     auto& output = this->GetOutput();
56     const auto testCaseId = this->GetTestCaseId();
57
58     std::map<float,int> resultMap;
59     {
60         int index = 0;
61         for (const auto & o : output)
62         {
63             resultMap[ToFloat<typename TModel::DataType>::Convert(o, m_QuantizationParams)] = index++;
64         }
65     }
66
67     {
68         BOOST_LOG_TRIVIAL(info) << "= Prediction values for test #" << testCaseId;
69         auto it = resultMap.rbegin();
70         for (int i=0; i<5 && it != resultMap.rend(); ++i)
71         {
72             BOOST_LOG_TRIVIAL(info) << "Top(" << (i+1) << ") prediction is " << it->second <<
73               " with confidence: " << 100.0*(it->first) << "%";
74             ++it;
75         }
76     }
77
78     const unsigned int prediction = boost::numeric_cast<unsigned int>(
79         std::distance(output.begin(), std::max_element(output.begin(), output.end())));
80
81     // If we're just running the defaultTestCaseIds, each one must be classified correctly.
82     if (params.m_IterationCount == 0 && prediction != m_Label)
83     {
84         BOOST_LOG_TRIVIAL(error) << "Prediction for test case " << testCaseId << " (" << prediction << ")" <<
85             " is incorrect (should be " << m_Label << ")";
86         return TestCaseResult::Failed;
87     }
88
89     // If a validation file was provided as input, it checks that the prediction matches.
90     if (!m_ValidationPredictions.empty() && prediction != m_ValidationPredictions[testCaseId])
91     {
92         BOOST_LOG_TRIVIAL(error) << "Prediction for test case " << testCaseId << " (" << prediction << ")" <<
93             " doesn't match the prediction in the validation file (" << m_ValidationPredictions[testCaseId] << ")";
94         return TestCaseResult::Failed;
95     }
96
97     // If a validation file was requested as output, it stores the predictions.
98     if (m_ValidationPredictionsOut)
99     {
100         m_ValidationPredictionsOut->push_back(prediction);
101     }
102
103     // Updates accuracy stats.
104     m_NumInferencesRef++;
105     if (prediction == m_Label)
106     {
107         m_NumCorrectInferencesRef++;
108     }
109
110     return TestCaseResult::Ok;
111 }
112
113 template <typename TDatabase, typename InferenceModel>
114 template <typename TConstructDatabaseCallable, typename TConstructModelCallable>
115 ClassifierTestCaseProvider<TDatabase, InferenceModel>::ClassifierTestCaseProvider(
116     TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel)
117     : m_ConstructModel(constructModel)
118     , m_ConstructDatabase(constructDatabase)
119     , m_NumInferences(0)
120     , m_NumCorrectInferences(0)
121 {
122 }
123
124 template <typename TDatabase, typename InferenceModel>
125 void ClassifierTestCaseProvider<TDatabase, InferenceModel>::AddCommandLineOptions(
126     boost::program_options::options_description& options)
127 {
128     namespace po = boost::program_options;
129
130     options.add_options()
131         ("validation-file-in", po::value<std::string>(&m_ValidationFileIn)->default_value(""),
132             "Reads expected predictions from the given file and confirms they match the actual predictions.")
133         ("validation-file-out", po::value<std::string>(&m_ValidationFileOut)->default_value(""),
134             "Predictions are saved to the given file for later use via --validation-file-in.")
135         ("data-dir,d", po::value<std::string>(&m_DataDir)->required(),
136             "Path to directory containing test data");
137
138     InferenceModel::AddCommandLineOptions(options, m_ModelCommandLineOptions);
139 }
140
141 template <typename TDatabase, typename InferenceModel>
142 bool ClassifierTestCaseProvider<TDatabase, InferenceModel>::ProcessCommandLineOptions()
143 {
144     if (!ValidateDirectory(m_DataDir))
145     {
146         return false;
147     }
148
149     ReadPredictions();
150
151     m_Model = m_ConstructModel(m_ModelCommandLineOptions);
152     if (!m_Model)
153     {
154         return false;
155     }
156
157     m_Database = std::make_unique<TDatabase>(m_ConstructDatabase(m_DataDir.c_str(), *m_Model));
158     if (!m_Database)
159     {
160         return false;
161     }
162
163     return true;
164 }
165
166 template <typename TDatabase, typename InferenceModel>
167 std::unique_ptr<IInferenceTestCase>
168 ClassifierTestCaseProvider<TDatabase, InferenceModel>::GetTestCase(unsigned int testCaseId)
169 {
170     std::unique_ptr<typename TDatabase::TTestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId);
171     if (testCaseData == nullptr)
172     {
173         return nullptr;
174     }
175
176     return std::make_unique<ClassifierTestCase<TDatabase, InferenceModel>>(
177         m_NumInferences,
178         m_NumCorrectInferences,
179         m_ValidationPredictions,
180         m_ValidationFileOut.empty() ? nullptr : &m_ValidationPredictionsOut,
181         *m_Model,
182         testCaseId,
183         testCaseData->m_Label,
184         std::move(testCaseData->m_InputImage));
185 }
186
187 template <typename TDatabase, typename InferenceModel>
188 bool ClassifierTestCaseProvider<TDatabase, InferenceModel>::OnInferenceTestFinished()
189 {
190     const double accuracy = boost::numeric_cast<double>(m_NumCorrectInferences) /
191         boost::numeric_cast<double>(m_NumInferences);
192     BOOST_LOG_TRIVIAL(info) << std::fixed << std::setprecision(3) << "Overall accuracy: " << accuracy;
193
194     // If a validation file was requested as output, the predictions are saved to it.
195     if (!m_ValidationFileOut.empty())
196     {
197         std::ofstream validationFileOut(m_ValidationFileOut.c_str(), std::ios_base::trunc | std::ios_base::out);
198         if (validationFileOut.good())
199         {
200             for (const unsigned int prediction : m_ValidationPredictionsOut)
201             {
202                 validationFileOut << prediction << std::endl;
203             }
204         }
205         else
206         {
207             BOOST_LOG_TRIVIAL(error) << "Failed to open output validation file: " << m_ValidationFileOut;
208             return false;
209         }
210     }
211
212     return true;
213 }
214
215 template <typename TDatabase, typename InferenceModel>
216 void ClassifierTestCaseProvider<TDatabase, InferenceModel>::ReadPredictions()
217 {
218     // Reads the expected predictions from the input validation file (if provided).
219     if (!m_ValidationFileIn.empty())
220     {
221         std::ifstream validationFileIn(m_ValidationFileIn.c_str(), std::ios_base::in);
222         if (validationFileIn.good())
223         {
224             while (!validationFileIn.eof())
225             {
226                 unsigned int i;
227                 validationFileIn >> i;
228                 m_ValidationPredictions.emplace_back(i);
229             }
230         }
231         else
232         {
233             throw armnn::Exception(boost::str(boost::format("Failed to open input validation file: %1%")
234                 % m_ValidationFileIn));
235         }
236     }
237 }
238
239 template<typename TConstructTestCaseProvider>
240 int InferenceTestMain(int argc,
241     char* argv[],
242     const std::vector<unsigned int>& defaultTestCaseIds,
243     TConstructTestCaseProvider constructTestCaseProvider)
244 {
245     // Configures logging for both the ARMNN library and this test program.
246 #ifdef NDEBUG
247     armnn::LogSeverity level = armnn::LogSeverity::Info;
248 #else
249     armnn::LogSeverity level = armnn::LogSeverity::Debug;
250 #endif
251     armnn::ConfigureLogging(true, true, level);
252     armnnUtils::ConfigureLogging(boost::log::core::get().get(), true, true, level);
253
254     try
255     {
256         std::unique_ptr<IInferenceTestCaseProvider> testCaseProvider = constructTestCaseProvider();
257         if (!testCaseProvider)
258         {
259             return 1;
260         }
261
262         InferenceTestOptions inferenceTestOptions;
263         if (!ParseCommandLine(argc, argv, *testCaseProvider, inferenceTestOptions))
264         {
265             return 1;
266         }
267
268         const bool success = InferenceTest(inferenceTestOptions, defaultTestCaseIds, *testCaseProvider);
269         return success ? 0 : 1;
270     }
271     catch (armnn::Exception const& e)
272     {
273         BOOST_LOG_TRIVIAL(fatal) << "Armnn Error: " << e.what();
274         return 1;
275     }
276 }
277
278 //
279 // This function allows us to create a classifier inference test based on:
280 //  - a model file name
281 //  - which can be a binary or a text file for protobuf formats
282 //  - an input tensor name
283 //  - an output tensor name
284 //  - a set of test case ids
285 //  - a callback method which creates an object that can return images
286 //    called 'Database' in these tests
287 //  - and an input tensor shape
288 //
289 template<typename TDatabase,
290          typename TParser,
291          typename TConstructDatabaseCallable>
292 int ClassifierInferenceTestMain(int argc,
293                                 char* argv[],
294                                 const char* modelFilename,
295                                 bool isModelBinary,
296                                 const char* inputBindingName,
297                                 const char* outputBindingName,
298                                 const std::vector<unsigned int>& defaultTestCaseIds,
299                                 TConstructDatabaseCallable constructDatabase,
300                                 const armnn::TensorShape* inputTensorShape)
301 {
302     return InferenceTestMain(argc, argv, defaultTestCaseIds,
303         [=]
304         ()
305         {
306             using InferenceModel = InferenceModel<TParser, typename TDatabase::DataType>;
307             using TestCaseProvider = ClassifierTestCaseProvider<TDatabase, InferenceModel>;
308
309             return make_unique<TestCaseProvider>(constructDatabase,
310                 [&]
311                 (typename InferenceModel::CommandLineOptions modelOptions)
312                 {
313                     if (!ValidateDirectory(modelOptions.m_ModelDir))
314                     {
315                         return std::unique_ptr<InferenceModel>();
316                     }
317
318                     typename InferenceModel::Params modelParams;
319                     modelParams.m_ModelPath = modelOptions.m_ModelDir + modelFilename;
320                     modelParams.m_InputBinding = inputBindingName;
321                     modelParams.m_OutputBinding = outputBindingName;
322                     modelParams.m_InputTensorShape = inputTensorShape;
323                     modelParams.m_IsModelBinary = isModelBinary;
324                     modelParams.m_ComputeDevice = modelOptions.m_ComputeDevice;
325                     modelParams.m_VisualizePostOptimizationModel = modelOptions.m_VisualizePostOptimizationModel;
326                     modelParams.m_EnableFp16TurboMode = modelOptions.m_EnableFp16TurboMode;
327
328                     return std::make_unique<InferenceModel>(modelParams);
329             });
330         });
331 }
332
333 } // namespace test
334 } // namespace armnn