2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
5 #include "InferenceTest.hpp"
7 #include <boost/algorithm/string.hpp>
8 #include <boost/numeric/conversion/cast.hpp>
9 #include <boost/log/trivial.hpp>
10 #include <boost/filesystem/path.hpp>
11 #include <boost/assert.hpp>
12 #include <boost/format.hpp>
13 #include <boost/program_options.hpp>
14 #include <boost/filesystem/operations.hpp>
23 using namespace std::chrono;
24 using namespace armnn::test;
31 using TContainer = boost::variant<std::vector<float>, std::vector<int>, std::vector<unsigned char>>;
33 template <typename TTestCaseDatabase, typename TModel>
34 ClassifierTestCase<TTestCaseDatabase, TModel>::ClassifierTestCase(
35 int& numInferencesRef,
36 int& numCorrectInferencesRef,
37 const std::vector<unsigned int>& validationPredictions,
38 std::vector<unsigned int>* validationPredictionsOut,
40 unsigned int testCaseId,
42 std::vector<typename TModel::DataType> modelInput)
43 : InferenceModelTestCase<TModel>(
44 model, testCaseId, std::vector<TContainer>{ modelInput }, { model.GetOutputSize() })
46 , m_QuantizationParams(model.GetQuantizationParams())
47 , m_NumInferencesRef(numInferencesRef)
48 , m_NumCorrectInferencesRef(numCorrectInferencesRef)
49 , m_ValidationPredictions(validationPredictions)
50 , m_ValidationPredictionsOut(validationPredictionsOut)
54 struct ClassifierResultProcessor : public boost::static_visitor<>
56 using ResultMap = std::map<float,int>;
58 ClassifierResultProcessor(float scale, int offset)
63 void operator()(const std::vector<float>& values)
65 SortPredictions(values, [](float value)
71 void operator()(const std::vector<uint8_t>& values)
73 auto& scale = m_Scale;
74 auto& offset = m_Offset;
75 SortPredictions(values, [&scale, &offset](uint8_t value)
77 return armnn::Dequantize(value, scale, offset);
81 void operator()(const std::vector<int>& values)
83 BOOST_ASSERT_MSG(false, "Non-float predictions output not supported.");
86 ResultMap& GetResultMap() { return m_ResultMap; }
89 template<typename Container, typename Delegate>
90 void SortPredictions(const Container& c, Delegate delegate)
93 for (const auto& value : c)
95 int classification = index++;
96 // Take the first class with each probability
97 // This avoids strange results when looping over batched results produced
98 // with identical test data.
99 ResultMap::iterator lb = m_ResultMap.lower_bound(value);
101 if (lb == m_ResultMap.end() || !m_ResultMap.key_comp()(value, lb->first))
103 // If the key is not already in the map, insert it.
104 m_ResultMap.insert(lb, ResultMap::value_type(delegate(value), classification));
109 ResultMap m_ResultMap;
115 template <typename TTestCaseDatabase, typename TModel>
116 TestCaseResult ClassifierTestCase<TTestCaseDatabase, TModel>::ProcessResult(const InferenceTestOptions& params)
118 auto& output = this->GetOutputs()[0];
119 const auto testCaseId = this->GetTestCaseId();
121 ClassifierResultProcessor resultProcessor(m_QuantizationParams.first, m_QuantizationParams.second);
122 boost::apply_visitor(resultProcessor, output);
124 BOOST_LOG_TRIVIAL(info) << "= Prediction values for test #" << testCaseId;
125 auto it = resultProcessor.GetResultMap().rbegin();
126 for (int i=0; i<5 && it != resultProcessor.GetResultMap().rend(); ++i)
128 BOOST_LOG_TRIVIAL(info) << "Top(" << (i+1) << ") prediction is " << it->second <<
129 " with value: " << (it->first);
133 unsigned int prediction = 0;
134 boost::apply_visitor([&](auto&& value)
136 prediction = boost::numeric_cast<unsigned int>(
137 std::distance(value.begin(), std::max_element(value.begin(), value.end())));
141 // If we're just running the defaultTestCaseIds, each one must be classified correctly.
142 if (params.m_IterationCount == 0 && prediction != m_Label)
144 BOOST_LOG_TRIVIAL(error) << "Prediction for test case " << testCaseId << " (" << prediction << ")" <<
145 " is incorrect (should be " << m_Label << ")";
146 return TestCaseResult::Failed;
149 // If a validation file was provided as input, it checks that the prediction matches.
150 if (!m_ValidationPredictions.empty() && prediction != m_ValidationPredictions[testCaseId])
152 BOOST_LOG_TRIVIAL(error) << "Prediction for test case " << testCaseId << " (" << prediction << ")" <<
153 " doesn't match the prediction in the validation file (" << m_ValidationPredictions[testCaseId] << ")";
154 return TestCaseResult::Failed;
157 // If a validation file was requested as output, it stores the predictions.
158 if (m_ValidationPredictionsOut)
160 m_ValidationPredictionsOut->push_back(prediction);
163 // Updates accuracy stats.
164 m_NumInferencesRef++;
165 if (prediction == m_Label)
167 m_NumCorrectInferencesRef++;
170 return TestCaseResult::Ok;
173 template <typename TDatabase, typename InferenceModel>
174 template <typename TConstructDatabaseCallable, typename TConstructModelCallable>
175 ClassifierTestCaseProvider<TDatabase, InferenceModel>::ClassifierTestCaseProvider(
176 TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel)
177 : m_ConstructModel(constructModel)
178 , m_ConstructDatabase(constructDatabase)
180 , m_NumCorrectInferences(0)
184 template <typename TDatabase, typename InferenceModel>
185 void ClassifierTestCaseProvider<TDatabase, InferenceModel>::AddCommandLineOptions(
186 boost::program_options::options_description& options)
188 namespace po = boost::program_options;
190 options.add_options()
191 ("validation-file-in", po::value<std::string>(&m_ValidationFileIn)->default_value(""),
192 "Reads expected predictions from the given file and confirms they match the actual predictions.")
193 ("validation-file-out", po::value<std::string>(&m_ValidationFileOut)->default_value(""),
194 "Predictions are saved to the given file for later use via --validation-file-in.")
195 ("data-dir,d", po::value<std::string>(&m_DataDir)->required(),
196 "Path to directory containing test data");
198 InferenceModel::AddCommandLineOptions(options, m_ModelCommandLineOptions);
201 template <typename TDatabase, typename InferenceModel>
202 bool ClassifierTestCaseProvider<TDatabase, InferenceModel>::ProcessCommandLineOptions(
203 const InferenceTestOptions& commonOptions)
205 if (!ValidateDirectory(m_DataDir))
212 m_Model = m_ConstructModel(commonOptions, m_ModelCommandLineOptions);
218 m_Database = std::make_unique<TDatabase>(m_ConstructDatabase(m_DataDir.c_str(), *m_Model));
227 template <typename TDatabase, typename InferenceModel>
228 std::unique_ptr<IInferenceTestCase>
229 ClassifierTestCaseProvider<TDatabase, InferenceModel>::GetTestCase(unsigned int testCaseId)
231 std::unique_ptr<typename TDatabase::TTestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId);
232 if (testCaseData == nullptr)
237 return std::make_unique<ClassifierTestCase<TDatabase, InferenceModel>>(
239 m_NumCorrectInferences,
240 m_ValidationPredictions,
241 m_ValidationFileOut.empty() ? nullptr : &m_ValidationPredictionsOut,
244 testCaseData->m_Label,
245 std::move(testCaseData->m_InputImage));
248 template <typename TDatabase, typename InferenceModel>
249 bool ClassifierTestCaseProvider<TDatabase, InferenceModel>::OnInferenceTestFinished()
251 const double accuracy = boost::numeric_cast<double>(m_NumCorrectInferences) /
252 boost::numeric_cast<double>(m_NumInferences);
253 BOOST_LOG_TRIVIAL(info) << std::fixed << std::setprecision(3) << "Overall accuracy: " << accuracy;
255 // If a validation file was requested as output, the predictions are saved to it.
256 if (!m_ValidationFileOut.empty())
258 std::ofstream validationFileOut(m_ValidationFileOut.c_str(), std::ios_base::trunc | std::ios_base::out);
259 if (validationFileOut.good())
261 for (const unsigned int prediction : m_ValidationPredictionsOut)
263 validationFileOut << prediction << std::endl;
268 BOOST_LOG_TRIVIAL(error) << "Failed to open output validation file: " << m_ValidationFileOut;
276 template <typename TDatabase, typename InferenceModel>
277 void ClassifierTestCaseProvider<TDatabase, InferenceModel>::ReadPredictions()
279 // Reads the expected predictions from the input validation file (if provided).
280 if (!m_ValidationFileIn.empty())
282 std::ifstream validationFileIn(m_ValidationFileIn.c_str(), std::ios_base::in);
283 if (validationFileIn.good())
285 while (!validationFileIn.eof())
288 validationFileIn >> i;
289 m_ValidationPredictions.emplace_back(i);
294 throw armnn::Exception(boost::str(boost::format("Failed to open input validation file: %1%")
295 % m_ValidationFileIn));
300 template<typename TConstructTestCaseProvider>
301 int InferenceTestMain(int argc,
303 const std::vector<unsigned int>& defaultTestCaseIds,
304 TConstructTestCaseProvider constructTestCaseProvider)
306 // Configures logging for both the ARMNN library and this test program.
308 armnn::LogSeverity level = armnn::LogSeverity::Info;
310 armnn::LogSeverity level = armnn::LogSeverity::Debug;
312 armnn::ConfigureLogging(true, true, level);
313 armnnUtils::ConfigureLogging(boost::log::core::get().get(), true, true, level);
317 std::unique_ptr<IInferenceTestCaseProvider> testCaseProvider = constructTestCaseProvider();
318 if (!testCaseProvider)
323 InferenceTestOptions inferenceTestOptions;
324 if (!ParseCommandLine(argc, argv, *testCaseProvider, inferenceTestOptions))
329 const bool success = InferenceTest(inferenceTestOptions, defaultTestCaseIds, *testCaseProvider);
330 return success ? 0 : 1;
332 catch (armnn::Exception const& e)
334 BOOST_LOG_TRIVIAL(fatal) << "Armnn Error: " << e.what();
340 // This function allows us to create a classifier inference test based on:
341 // - a model file name
342 // - which can be a binary or a text file for protobuf formats
343 // - an input tensor name
344 // - an output tensor name
345 // - a set of test case ids
346 // - a callback method which creates an object that can return images
347 // called 'Database' in these tests
348 // - and an input tensor shape
350 template<typename TDatabase,
352 typename TConstructDatabaseCallable>
353 int ClassifierInferenceTestMain(int argc,
355 const char* modelFilename,
357 const char* inputBindingName,
358 const char* outputBindingName,
359 const std::vector<unsigned int>& defaultTestCaseIds,
360 TConstructDatabaseCallable constructDatabase,
361 const armnn::TensorShape* inputTensorShape)
364 BOOST_ASSERT(modelFilename);
365 BOOST_ASSERT(inputBindingName);
366 BOOST_ASSERT(outputBindingName);
368 return InferenceTestMain(argc, argv, defaultTestCaseIds,
372 using InferenceModel = InferenceModel<TParser, typename TDatabase::DataType>;
373 using TestCaseProvider = ClassifierTestCaseProvider<TDatabase, InferenceModel>;
375 return make_unique<TestCaseProvider>(constructDatabase,
377 (const InferenceTestOptions &commonOptions,
378 typename InferenceModel::CommandLineOptions modelOptions)
380 if (!ValidateDirectory(modelOptions.m_ModelDir))
382 return std::unique_ptr<InferenceModel>();
385 typename InferenceModel::Params modelParams;
386 modelParams.m_ModelPath = modelOptions.m_ModelDir + modelFilename;
387 modelParams.m_InputBindings = { inputBindingName };
388 modelParams.m_OutputBindings = { outputBindingName };
390 if (inputTensorShape)
392 modelParams.m_InputShapes.push_back(*inputTensorShape);
395 modelParams.m_IsModelBinary = isModelBinary;
396 modelParams.m_ComputeDevices = modelOptions.GetComputeDevicesAsBackendIds();
397 modelParams.m_VisualizePostOptimizationModel = modelOptions.m_VisualizePostOptimizationModel;
398 modelParams.m_EnableFp16TurboMode = modelOptions.m_EnableFp16TurboMode;
400 return std::make_unique<InferenceModel>(modelParams,
401 commonOptions.m_EnableProfiling,
402 commonOptions.m_DynamicBackendsPath);