2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
5 #include "InferenceTest.hpp"
7 #include "InferenceModel.hpp"
9 #include <boost/algorithm/string.hpp>
10 #include <boost/numeric/conversion/cast.hpp>
11 #include <boost/log/trivial.hpp>
12 #include <boost/filesystem/path.hpp>
13 #include <boost/assert.hpp>
14 #include <boost/format.hpp>
15 #include <boost/program_options.hpp>
16 #include <boost/filesystem/operations.hpp>
25 using namespace std::chrono;
26 using namespace armnn::test;
33 template <typename TTestCaseDatabase, typename TModel>
34 ClassifierTestCase<TTestCaseDatabase, TModel>::ClassifierTestCase(
35 int& numInferencesRef,
36 int& numCorrectInferencesRef,
37 const std::vector<unsigned int>& validationPredictions,
38 std::vector<unsigned int>* validationPredictionsOut,
40 unsigned int testCaseId,
42 std::vector<typename TModel::DataType> modelInput)
43 : InferenceModelTestCase<TModel>(model, testCaseId, std::move(modelInput), model.GetOutputSize())
45 , m_NumInferencesRef(numInferencesRef)
46 , m_NumCorrectInferencesRef(numCorrectInferencesRef)
47 , m_ValidationPredictions(validationPredictions)
48 , m_ValidationPredictionsOut(validationPredictionsOut)
52 template <typename TTestCaseDatabase, typename TModel>
53 TestCaseResult ClassifierTestCase<TTestCaseDatabase, TModel>::ProcessResult(const InferenceTestOptions& params)
55 auto& output = this->GetOutput();
56 const auto testCaseId = this->GetTestCaseId();
58 const unsigned int prediction = boost::numeric_cast<unsigned int>(
59 std::distance(output.begin(), std::max_element(output.begin(), output.end())));
61 // If we're just running the defaultTestCaseIds, each one must be classified correctly
62 if (params.m_IterationCount == 0 && prediction != m_Label)
64 BOOST_LOG_TRIVIAL(error) << "Prediction for test case " << testCaseId << " (" << prediction << ")" <<
65 " is incorrect (should be " << m_Label << ")";
66 return TestCaseResult::Failed;
69 // If a validation file was provided as input, check that the prediction matches
70 if (!m_ValidationPredictions.empty() && prediction != m_ValidationPredictions[testCaseId])
72 BOOST_LOG_TRIVIAL(error) << "Prediction for test case " << testCaseId << " (" << prediction << ")" <<
73 " doesn't match the prediction in the validation file (" << m_ValidationPredictions[testCaseId] << ")";
74 return TestCaseResult::Failed;
77 // If a validation file was requested as output, store the predictions
78 if (m_ValidationPredictionsOut)
80 m_ValidationPredictionsOut->push_back(prediction);
83 // Update accuracy stats
85 if (prediction == m_Label)
87 m_NumCorrectInferencesRef++;
90 return TestCaseResult::Ok;
93 template <typename TDatabase, typename InferenceModel>
94 template <typename TConstructDatabaseCallable, typename TConstructModelCallable>
95 ClassifierTestCaseProvider<TDatabase, InferenceModel>::ClassifierTestCaseProvider(
96 TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel)
97 : m_ConstructModel(constructModel)
98 , m_ConstructDatabase(constructDatabase)
100 , m_NumCorrectInferences(0)
104 template <typename TDatabase, typename InferenceModel>
105 void ClassifierTestCaseProvider<TDatabase, InferenceModel>::AddCommandLineOptions(
106 boost::program_options::options_description& options)
108 namespace po = boost::program_options;
110 options.add_options()
111 ("validation-file-in", po::value<std::string>(&m_ValidationFileIn)->default_value(""),
112 "Reads expected predictions from the given file and confirms they match the actual predictions.")
113 ("validation-file-out", po::value<std::string>(&m_ValidationFileOut)->default_value(""),
114 "Predictions are saved to the given file for later use via --validation-file-in.")
115 ("data-dir,d", po::value<std::string>(&m_DataDir)->required(),
116 "Path to directory containing test data");
118 InferenceModel::AddCommandLineOptions(options, m_ModelCommandLineOptions);
121 template <typename TDatabase, typename InferenceModel>
122 bool ClassifierTestCaseProvider<TDatabase, InferenceModel>::ProcessCommandLineOptions()
124 if (!ValidateDirectory(m_DataDir))
131 m_Model = m_ConstructModel(m_ModelCommandLineOptions);
137 m_Database = std::make_unique<TDatabase>(m_ConstructDatabase(m_DataDir.c_str()));
146 template <typename TDatabase, typename InferenceModel>
147 std::unique_ptr<IInferenceTestCase>
148 ClassifierTestCaseProvider<TDatabase, InferenceModel>::GetTestCase(unsigned int testCaseId)
150 std::unique_ptr<typename TDatabase::TTestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId);
151 if (testCaseData == nullptr)
156 return std::make_unique<ClassifierTestCase<TDatabase, InferenceModel>>(
158 m_NumCorrectInferences,
159 m_ValidationPredictions,
160 m_ValidationFileOut.empty() ? nullptr : &m_ValidationPredictionsOut,
163 testCaseData->m_Label,
164 std::move(testCaseData->m_InputImage));
167 template <typename TDatabase, typename InferenceModel>
168 bool ClassifierTestCaseProvider<TDatabase, InferenceModel>::OnInferenceTestFinished()
170 const double accuracy = boost::numeric_cast<double>(m_NumCorrectInferences) /
171 boost::numeric_cast<double>(m_NumInferences);
172 BOOST_LOG_TRIVIAL(info) << std::fixed << std::setprecision(3) << "Overall accuracy: " << accuracy;
174 // If a validation file was requested as output, save the predictions to it
175 if (!m_ValidationFileOut.empty())
177 std::ofstream validationFileOut(m_ValidationFileOut.c_str(), std::ios_base::trunc | std::ios_base::out);
178 if (validationFileOut.good())
180 for (const unsigned int prediction : m_ValidationPredictionsOut)
182 validationFileOut << prediction << std::endl;
187 BOOST_LOG_TRIVIAL(error) << "Failed to open output validation file: " << m_ValidationFileOut;
195 template <typename TDatabase, typename InferenceModel>
196 void ClassifierTestCaseProvider<TDatabase, InferenceModel>::ReadPredictions()
198 // Read expected predictions from the input validation file (if provided)
199 if (!m_ValidationFileIn.empty())
201 std::ifstream validationFileIn(m_ValidationFileIn.c_str(), std::ios_base::in);
202 if (validationFileIn.good())
204 while (!validationFileIn.eof())
207 validationFileIn >> i;
208 m_ValidationPredictions.emplace_back(i);
213 throw armnn::Exception(boost::str(boost::format("Failed to open input validation file: %1%")
214 % m_ValidationFileIn));
219 template<typename TConstructTestCaseProvider>
220 int InferenceTestMain(int argc,
222 const std::vector<unsigned int>& defaultTestCaseIds,
223 TConstructTestCaseProvider constructTestCaseProvider)
225 // Configure logging for both the ARMNN library and this test program
227 armnn::LogSeverity level = armnn::LogSeverity::Info;
229 armnn::LogSeverity level = armnn::LogSeverity::Debug;
231 armnn::ConfigureLogging(true, true, level);
232 armnnUtils::ConfigureLogging(boost::log::core::get().get(), true, true, level);
236 std::unique_ptr<IInferenceTestCaseProvider> testCaseProvider = constructTestCaseProvider();
237 if (!testCaseProvider)
242 InferenceTestOptions inferenceTestOptions;
243 if (!ParseCommandLine(argc, argv, *testCaseProvider, inferenceTestOptions))
248 const bool success = InferenceTest(inferenceTestOptions, defaultTestCaseIds, *testCaseProvider);
249 return success ? 0 : 1;
251 catch (armnn::Exception const& e)
253 BOOST_LOG_TRIVIAL(fatal) << "Armnn Error: " << e.what();
258 template<typename TDatabase,
260 typename TConstructDatabaseCallable>
261 int ClassifierInferenceTestMain(int argc, char* argv[], const char* modelFilename, bool isModelBinary,
262 const char* inputBindingName, const char* outputBindingName,
263 const std::vector<unsigned int>& defaultTestCaseIds,
264 TConstructDatabaseCallable constructDatabase,
265 const armnn::TensorShape* inputTensorShape)
267 return InferenceTestMain(argc, argv, defaultTestCaseIds,
271 using InferenceModel = InferenceModel<TParser, float>;
272 using TestCaseProvider = ClassifierTestCaseProvider<TDatabase, InferenceModel>;
274 return make_unique<TestCaseProvider>(constructDatabase,
276 (typename InferenceModel::CommandLineOptions modelOptions)
278 if (!ValidateDirectory(modelOptions.m_ModelDir))
280 return std::unique_ptr<InferenceModel>();
283 typename InferenceModel::Params modelParams;
284 modelParams.m_ModelPath = modelOptions.m_ModelDir + modelFilename;
285 modelParams.m_InputBinding = inputBindingName;
286 modelParams.m_OutputBinding = outputBindingName;
287 modelParams.m_InputTensorShape = inputTensorShape;
288 modelParams.m_IsModelBinary = isModelBinary;
289 modelParams.m_ComputeDevice = modelOptions.m_ComputeDevice;
291 return std::make_unique<InferenceModel>(modelParams);