2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
5 #include "InferenceTest.hpp"
7 #include <armnn/utility/Assert.hpp>
8 #include <Filesystem.hpp>
10 #include "../src/armnn/Profiling.hpp"
11 #include <boost/numeric/conversion/cast.hpp>
12 #include <boost/format.hpp>
13 #include <boost/program_options.hpp>
21 using namespace std::chrono;
22 using namespace armnn::test;
28 /// Parse the command line of an ArmNN (or referencetests) inference test program.
29 /// \return false if any error occurred during options processing, otherwise true
30 bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
31 InferenceTestOptions& outParams)
33 namespace po = boost::program_options;
35 po::options_description desc("Options");
39 // Adds generic options needed for all inference tests.
41 ("help", "Display help messages")
42 ("iterations,i", po::value<unsigned int>(&outParams.m_IterationCount)->default_value(0),
43 "Sets the number number of inferences to perform. If unset, a default number will be ran.")
44 ("inference-times-file", po::value<std::string>(&outParams.m_InferenceTimesFile)->default_value(""),
45 "If non-empty, each individual inference time will be recorded and output to this file")
46 ("event-based-profiling,e", po::value<bool>(&outParams.m_EnableProfiling)->default_value(0),
47 "Enables built in profiler. If unset, defaults to off.");
49 // Adds options specific to the ITestCaseProvider.
50 testCaseProvider.AddCommandLineOptions(desc);
52 catch (const std::exception& e)
54 // Coverity points out that default_value(...) can throw a bad_lexical_cast,
55 // and that desc.add_options() can throw boost::io::too_few_args.
56 // They really won't in any of these cases.
57 ARMNN_ASSERT_MSG(false, "Caught unexpected exception");
58 std::cerr << "Fatal internal error: " << e.what() << std::endl;
66 po::store(po::parse_command_line(argc, argv, desc), vm);
70 std::cout << desc << std::endl;
78 std::cerr << e.what() << std::endl << std::endl;
79 std::cerr << desc << std::endl;
83 if (!testCaseProvider.ProcessCommandLineOptions(outParams))
91 bool ValidateDirectory(std::string& dir)
95 std::cerr << "No directory specified" << std::endl;
99 if (dir[dir.length() - 1] != '/')
104 if (!fs::exists(dir))
106 std::cerr << "Given directory " << dir << " does not exist" << std::endl;
110 if (!fs::is_directory(dir))
112 std::cerr << "Given directory [" << dir << "] is not a directory" << std::endl;
119 bool InferenceTest(const InferenceTestOptions& params,
120 const std::vector<unsigned int>& defaultTestCaseIds,
121 IInferenceTestCaseProvider& testCaseProvider)
123 #if !defined (NDEBUG)
124 if (params.m_IterationCount > 0) // If just running a few select images then don't bother to warn.
126 ARMNN_LOG(warning) << "Performance test running in DEBUG build - results may be inaccurate.";
130 double totalTime = 0;
131 unsigned int nbProcessed = 0;
134 // Opens the file to write inference times too, if needed.
135 ofstream inferenceTimesFile;
136 const bool recordInferenceTimes = !params.m_InferenceTimesFile.empty();
137 if (recordInferenceTimes)
139 inferenceTimesFile.open(params.m_InferenceTimesFile.c_str(), ios_base::trunc | ios_base::out);
140 if (!inferenceTimesFile.good())
142 ARMNN_LOG(error) << "Failed to open inference times file for writing: "
143 << params.m_InferenceTimesFile;
148 // Create a profiler and register it for the current thread.
149 std::unique_ptr<Profiler> profiler = std::make_unique<Profiler>();
150 ProfilerManager::GetInstance().RegisterProfiler(profiler.get());
152 // Enable profiling if requested.
153 profiler->EnableProfiling(params.m_EnableProfiling);
155 // Run a single test case to 'warm-up' the model. The first one can sometimes take up to 10x longer
156 std::unique_ptr<IInferenceTestCase> warmupTestCase = testCaseProvider.GetTestCase(0);
157 if (warmupTestCase == nullptr)
159 ARMNN_LOG(error) << "Failed to load test case";
165 warmupTestCase->Run();
167 catch (const TestFrameworkException& testError)
169 ARMNN_LOG(error) << testError.what();
173 const unsigned int nbTotalToProcess = params.m_IterationCount > 0 ? params.m_IterationCount
174 : static_cast<unsigned int>(defaultTestCaseIds.size());
176 for (; nbProcessed < nbTotalToProcess; nbProcessed++)
178 const unsigned int testCaseId = params.m_IterationCount > 0 ? nbProcessed : defaultTestCaseIds[nbProcessed];
179 std::unique_ptr<IInferenceTestCase> testCase = testCaseProvider.GetTestCase(testCaseId);
181 if (testCase == nullptr)
183 ARMNN_LOG(error) << "Failed to load test case";
187 time_point<high_resolution_clock> predictStart;
188 time_point<high_resolution_clock> predictEnd;
190 TestCaseResult result = TestCaseResult::Ok;
194 predictStart = high_resolution_clock::now();
198 predictEnd = high_resolution_clock::now();
200 // duration<double> will convert the time difference into seconds as a double by default.
201 double timeTakenS = duration<double>(predictEnd - predictStart).count();
202 totalTime += timeTakenS;
204 // Outputss inference times, if needed.
205 if (recordInferenceTimes)
207 inferenceTimesFile << testCaseId << " " << (timeTakenS * 1000.0) << std::endl;
210 result = testCase->ProcessResult(params);
213 catch (const TestFrameworkException& testError)
215 ARMNN_LOG(error) << testError.what();
216 result = TestCaseResult::Abort;
221 case TestCaseResult::Ok:
223 case TestCaseResult::Abort:
225 case TestCaseResult::Failed:
226 // This test failed so we will fail the entire program eventually, but keep going for now.
230 ARMNN_ASSERT_MSG(false, "Unexpected TestCaseResult");
235 const double averageTimePerTestCaseMs = totalTime / nbProcessed * 1000.0f;
237 ARMNN_LOG(info) << std::fixed << std::setprecision(3) <<
238 "Total time for " << nbProcessed << " test cases: " << totalTime << " seconds";
239 ARMNN_LOG(info) << std::fixed << std::setprecision(3) <<
240 "Average time per test case: " << averageTimePerTestCaseMs << " ms";
242 // if profiling is enabled print out the results
243 if (profiler && profiler->IsProfilingEnabled())
245 profiler->Print(std::cout);
250 ARMNN_LOG(error) << "One or more test cases failed";
254 return testCaseProvider.OnInferenceTestFinished();