Release 18.02
[platform/upstream/armnn.git] / tests / InferenceTest.inl
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5 #include "InferenceTest.hpp"
6
7 #include "InferenceModel.hpp"
8
9 #include <boost/algorithm/string.hpp>
10 #include <boost/numeric/conversion/cast.hpp>
11 #include <boost/log/trivial.hpp>
12 #include <boost/filesystem/path.hpp>
13 #include <boost/assert.hpp>
14 #include <boost/format.hpp>
15 #include <boost/program_options.hpp>
16 #include <boost/filesystem/operations.hpp>
17
18 #include <fstream>
19 #include <iostream>
20 #include <iomanip>
21 #include <array>
22 #include <chrono>
23
24 using namespace std;
25 using namespace std::chrono;
26 using namespace armnn::test;
27
28 namespace armnn
29 {
30 namespace test
31 {
32
33 template <typename TTestCaseDatabase, typename TModel>
34 ClassifierTestCase<TTestCaseDatabase, TModel>::ClassifierTestCase(
35     int& numInferencesRef,
36     int& numCorrectInferencesRef,
37     const std::vector<unsigned int>& validationPredictions,
38     std::vector<unsigned int>* validationPredictionsOut,
39     TModel& model,
40     unsigned int testCaseId,
41     unsigned int label,
42     std::vector<typename TModel::DataType> modelInput)
43     : InferenceModelTestCase<TModel>(model, testCaseId, std::move(modelInput), model.GetOutputSize())
44     , m_Label(label)
45     , m_NumInferencesRef(numInferencesRef)
46     , m_NumCorrectInferencesRef(numCorrectInferencesRef)
47     , m_ValidationPredictions(validationPredictions)
48     , m_ValidationPredictionsOut(validationPredictionsOut)
49 {
50 }
51
52 template <typename TTestCaseDatabase, typename TModel>
53 TestCaseResult ClassifierTestCase<TTestCaseDatabase, TModel>::ProcessResult(const InferenceTestOptions& params)
54 {
55     auto& output = this->GetOutput();
56     const auto testCaseId = this->GetTestCaseId();
57
58     const unsigned int prediction = boost::numeric_cast<unsigned int>(
59         std::distance(output.begin(), std::max_element(output.begin(), output.end())));
60
61     // If we're just running the defaultTestCaseIds, each one must be classified correctly
62     if (params.m_IterationCount == 0 && prediction != m_Label)
63     {
64         BOOST_LOG_TRIVIAL(error) << "Prediction for test case " << testCaseId << " (" << prediction << ")" <<
65             " is incorrect (should be " << m_Label << ")";
66         return TestCaseResult::Failed;
67     }
68
69     // If a validation file was provided as input, check that the prediction matches
70     if (!m_ValidationPredictions.empty() && prediction != m_ValidationPredictions[testCaseId])
71     {
72         BOOST_LOG_TRIVIAL(error) << "Prediction for test case " << testCaseId << " (" << prediction << ")" <<
73             " doesn't match the prediction in the validation file (" << m_ValidationPredictions[testCaseId] << ")";
74         return TestCaseResult::Failed;
75     }
76
77     // If a validation file was requested as output, store the predictions
78     if (m_ValidationPredictionsOut)
79     {
80         m_ValidationPredictionsOut->push_back(prediction);
81     }
82
83     // Update accuracy stats
84     m_NumInferencesRef++;
85     if (prediction == m_Label)
86     {
87         m_NumCorrectInferencesRef++;
88     }
89
90     return TestCaseResult::Ok;
91 }
92
93 template <typename TDatabase, typename InferenceModel>
94 template <typename TConstructDatabaseCallable, typename TConstructModelCallable>
95 ClassifierTestCaseProvider<TDatabase, InferenceModel>::ClassifierTestCaseProvider(
96     TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel)
97     : m_ConstructModel(constructModel)
98     , m_ConstructDatabase(constructDatabase)
99     , m_NumInferences(0)
100     , m_NumCorrectInferences(0)
101 {
102 }
103
104 template <typename TDatabase, typename InferenceModel>
105 void ClassifierTestCaseProvider<TDatabase, InferenceModel>::AddCommandLineOptions(
106     boost::program_options::options_description& options)
107 {
108     namespace po = boost::program_options;
109
110     options.add_options()
111         ("validation-file-in", po::value<std::string>(&m_ValidationFileIn)->default_value(""),
112             "Reads expected predictions from the given file and confirms they match the actual predictions.")
113         ("validation-file-out", po::value<std::string>(&m_ValidationFileOut)->default_value(""),
114             "Predictions are saved to the given file for later use via --validation-file-in.")
115         ("data-dir,d", po::value<std::string>(&m_DataDir)->required(),
116             "Path to directory containing test data");
117
118     InferenceModel::AddCommandLineOptions(options, m_ModelCommandLineOptions);
119 }
120
121 template <typename TDatabase, typename InferenceModel>
122 bool ClassifierTestCaseProvider<TDatabase, InferenceModel>::ProcessCommandLineOptions()
123 {
124     if (!ValidateDirectory(m_DataDir))
125     {
126         return false;
127     }
128
129     ReadPredictions();
130
131     m_Model = m_ConstructModel(m_ModelCommandLineOptions);
132     if (!m_Model)
133     {
134         return false;
135     }
136
137     m_Database = std::make_unique<TDatabase>(m_ConstructDatabase(m_DataDir.c_str()));
138     if (!m_Database)
139     {
140         return false;
141     }
142
143     return true;
144 }
145
146 template <typename TDatabase, typename InferenceModel>
147 std::unique_ptr<IInferenceTestCase>
148 ClassifierTestCaseProvider<TDatabase, InferenceModel>::GetTestCase(unsigned int testCaseId)
149 {
150     std::unique_ptr<typename TDatabase::TTestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId);
151     if (testCaseData == nullptr)
152     {
153         return nullptr;
154     }
155
156     return std::make_unique<ClassifierTestCase<TDatabase, InferenceModel>>(
157         m_NumInferences,
158         m_NumCorrectInferences,
159         m_ValidationPredictions,
160         m_ValidationFileOut.empty() ? nullptr : &m_ValidationPredictionsOut,
161         *m_Model,
162         testCaseId,
163         testCaseData->m_Label,
164         std::move(testCaseData->m_InputImage));
165 }
166
167 template <typename TDatabase, typename InferenceModel>
168 bool ClassifierTestCaseProvider<TDatabase, InferenceModel>::OnInferenceTestFinished()
169 {
170     const double accuracy = boost::numeric_cast<double>(m_NumCorrectInferences) /
171         boost::numeric_cast<double>(m_NumInferences);
172     BOOST_LOG_TRIVIAL(info) << std::fixed << std::setprecision(3) << "Overall accuracy: " << accuracy;
173
174     // If a validation file was requested as output, save the predictions to it
175     if (!m_ValidationFileOut.empty())
176     {
177         std::ofstream validationFileOut(m_ValidationFileOut.c_str(), std::ios_base::trunc | std::ios_base::out);
178         if (validationFileOut.good())
179         {
180             for (const unsigned int prediction : m_ValidationPredictionsOut)
181             {
182                 validationFileOut << prediction << std::endl;
183             }
184         }
185         else
186         {
187             BOOST_LOG_TRIVIAL(error) << "Failed to open output validation file: " << m_ValidationFileOut;
188             return false;
189         }
190     }
191
192     return true;
193 }
194
195 template <typename TDatabase, typename InferenceModel>
196 void ClassifierTestCaseProvider<TDatabase, InferenceModel>::ReadPredictions()
197 {
198     // Read expected predictions from the input validation file (if provided)
199     if (!m_ValidationFileIn.empty())
200     {
201         std::ifstream validationFileIn(m_ValidationFileIn.c_str(), std::ios_base::in);
202         if (validationFileIn.good())
203         {
204             while (!validationFileIn.eof())
205             {
206                 unsigned int i;
207                 validationFileIn >> i;
208                 m_ValidationPredictions.emplace_back(i);
209             }
210         }
211         else
212         {
213             throw armnn::Exception(boost::str(boost::format("Failed to open input validation file: %1%")
214                 % m_ValidationFileIn));
215         }
216     }
217 }
218
219 template<typename TConstructTestCaseProvider>
220 int InferenceTestMain(int argc,
221     char* argv[],
222     const std::vector<unsigned int>& defaultTestCaseIds,
223     TConstructTestCaseProvider constructTestCaseProvider)
224 {
225     // Configure logging for both the ARMNN library and this test program
226 #ifdef NDEBUG
227     armnn::LogSeverity level = armnn::LogSeverity::Info;
228 #else
229     armnn::LogSeverity level = armnn::LogSeverity::Debug;
230 #endif
231     armnn::ConfigureLogging(true, true, level);
232     armnnUtils::ConfigureLogging(boost::log::core::get().get(), true, true, level);
233
234     try
235     {
236         std::unique_ptr<IInferenceTestCaseProvider> testCaseProvider = constructTestCaseProvider();
237         if (!testCaseProvider)
238         {
239             return 1;
240         }
241
242         InferenceTestOptions inferenceTestOptions;
243         if (!ParseCommandLine(argc, argv, *testCaseProvider, inferenceTestOptions))
244         {
245             return 1;
246         }
247
248         const bool success = InferenceTest(inferenceTestOptions, defaultTestCaseIds, *testCaseProvider);
249         return success ? 0 : 1;
250     }
251     catch (armnn::Exception const& e)
252     {
253         BOOST_LOG_TRIVIAL(fatal) << "Armnn Error: " << e.what();
254         return 1;
255     }
256 }
257
258 template<typename TDatabase,
259     typename TParser,
260     typename TConstructDatabaseCallable>
261 int ClassifierInferenceTestMain(int argc, char* argv[], const char* modelFilename, bool isModelBinary,
262     const char* inputBindingName, const char* outputBindingName,
263     const std::vector<unsigned int>& defaultTestCaseIds,
264     TConstructDatabaseCallable constructDatabase,
265     const armnn::TensorShape* inputTensorShape)
266 {
267     return InferenceTestMain(argc, argv, defaultTestCaseIds,
268         [=]
269         ()
270         {
271             using InferenceModel = InferenceModel<TParser, float>;
272             using TestCaseProvider = ClassifierTestCaseProvider<TDatabase, InferenceModel>;
273
274             return make_unique<TestCaseProvider>(constructDatabase,
275                 [&]
276                 (typename InferenceModel::CommandLineOptions modelOptions)
277                 {
278                     if (!ValidateDirectory(modelOptions.m_ModelDir))
279                     {
280                         return std::unique_ptr<InferenceModel>();
281                     }
282
283                     typename InferenceModel::Params modelParams;
284                     modelParams.m_ModelPath = modelOptions.m_ModelDir + modelFilename;
285                     modelParams.m_InputBinding = inputBindingName;
286                     modelParams.m_OutputBinding = outputBindingName;
287                     modelParams.m_InputTensorShape = inputTensorShape;
288                     modelParams.m_IsModelBinary = isModelBinary;
289                     modelParams.m_ComputeDevice = modelOptions.m_ComputeDevice;
290
291                     return std::make_unique<InferenceModel>(modelParams);
292             });
293         });
294 }
295
296 } // namespace test
297 } // namespace armnn