Update ACL pin to 1c2ff950071c5b4fd6e83487083d23c96637545f
[platform/upstream/armnn.git] / tests / InferenceTest.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include "InferenceTest.hpp"
6
7 #include "../src/armnn/Profiling.hpp"
8 #include <boost/algorithm/string.hpp>
9 #include <boost/numeric/conversion/cast.hpp>
10 #include <boost/filesystem/path.hpp>
11 #include <boost/assert.hpp>
12 #include <boost/format.hpp>
13 #include <boost/program_options.hpp>
14 #include <boost/filesystem/operations.hpp>
15
16 #include <fstream>
17 #include <iostream>
18 #include <iomanip>
19 #include <array>
20
21 using namespace std;
22 using namespace std::chrono;
23 using namespace armnn::test;
24
25 namespace armnn
26 {
27 namespace test
28 {
29 /// Parse the command line of an ArmNN (or referencetests) inference test program.
30 /// \return false if any error occurred during options processing, otherwise true
31 bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
32     InferenceTestOptions& outParams)
33 {
34     namespace po = boost::program_options;
35
36     po::options_description desc("Options");
37
38     try
39     {
40         // Adds generic options needed for all inference tests.
41         desc.add_options()
42             ("help", "Display help messages")
43             ("iterations,i", po::value<unsigned int>(&outParams.m_IterationCount)->default_value(0),
44                 "Sets the number number of inferences to perform. If unset, a default number will be ran.")
45             ("inference-times-file", po::value<std::string>(&outParams.m_InferenceTimesFile)->default_value(""),
46                 "If non-empty, each individual inference time will be recorded and output to this file")
47             ("event-based-profiling,e", po::value<bool>(&outParams.m_EnableProfiling)->default_value(0),
48                 "Enables built in profiler. If unset, defaults to off.");
49
50         // Adds options specific to the ITestCaseProvider.
51         testCaseProvider.AddCommandLineOptions(desc);
52     }
53     catch (const std::exception& e)
54     {
55         // Coverity points out that default_value(...) can throw a bad_lexical_cast,
56         // and that desc.add_options() can throw boost::io::too_few_args.
57         // They really won't in any of these cases.
58         BOOST_ASSERT_MSG(false, "Caught unexpected exception");
59         std::cerr << "Fatal internal error: " << e.what() << std::endl;
60         return false;
61     }
62
63     po::variables_map vm;
64
65     try
66     {
67         po::store(po::parse_command_line(argc, argv, desc), vm);
68
69         if (vm.count("help"))
70         {
71             std::cout << desc << std::endl;
72             return false;
73         }
74
75         po::notify(vm);
76     }
77     catch (po::error& e)
78     {
79         std::cerr << e.what() << std::endl << std::endl;
80         std::cerr << desc << std::endl;
81         return false;
82     }
83
84     if (!testCaseProvider.ProcessCommandLineOptions(outParams))
85     {
86         return false;
87     }
88
89     return true;
90 }
91
92 bool ValidateDirectory(std::string& dir)
93 {
94     if (dir.empty())
95     {
96         std::cerr << "No directory specified" << std::endl;
97         return false;
98     }
99
100     if (dir[dir.length() - 1] != '/')
101     {
102         dir += "/";
103     }
104
105     if (!boost::filesystem::exists(dir))
106     {
107         std::cerr << "Given directory " << dir << " does not exist" << std::endl;
108         return false;
109     }
110
111     if (!boost::filesystem::is_directory(dir))
112     {
113         std::cerr << "Given directory [" << dir << "] is not a directory" << std::endl;
114         return false;
115     }
116
117     return true;
118 }
119
120 bool InferenceTest(const InferenceTestOptions& params,
121     const std::vector<unsigned int>& defaultTestCaseIds,
122     IInferenceTestCaseProvider& testCaseProvider)
123 {
124 #if !defined (NDEBUG)
125     if (params.m_IterationCount > 0) // If just running a few select images then don't bother to warn.
126     {
127         ARMNN_LOG(warning) << "Performance test running in DEBUG build - results may be inaccurate.";
128     }
129 #endif
130
131     double totalTime = 0;
132     unsigned int nbProcessed = 0;
133     bool success = true;
134
135     // Opens the file to write inference times too, if needed.
136     ofstream inferenceTimesFile;
137     const bool recordInferenceTimes = !params.m_InferenceTimesFile.empty();
138     if (recordInferenceTimes)
139     {
140         inferenceTimesFile.open(params.m_InferenceTimesFile.c_str(), ios_base::trunc | ios_base::out);
141         if (!inferenceTimesFile.good())
142         {
143             ARMNN_LOG(error) << "Failed to open inference times file for writing: "
144                 << params.m_InferenceTimesFile;
145             return false;
146         }
147     }
148
149     // Create a profiler and register it for the current thread.
150     std::unique_ptr<Profiler> profiler = std::make_unique<Profiler>();
151     ProfilerManager::GetInstance().RegisterProfiler(profiler.get());
152
153     // Enable profiling if requested.
154     profiler->EnableProfiling(params.m_EnableProfiling);
155
156     // Run a single test case to 'warm-up' the model. The first one can sometimes take up to 10x longer
157     std::unique_ptr<IInferenceTestCase> warmupTestCase = testCaseProvider.GetTestCase(0);
158     if (warmupTestCase == nullptr)
159     {
160         ARMNN_LOG(error) << "Failed to load test case";
161         return false;
162     }
163
164     try
165     {
166         warmupTestCase->Run();
167     }
168     catch (const TestFrameworkException& testError)
169     {
170         ARMNN_LOG(error) << testError.what();
171         return false;
172     }
173
174     const unsigned int nbTotalToProcess = params.m_IterationCount > 0 ? params.m_IterationCount
175         : static_cast<unsigned int>(defaultTestCaseIds.size());
176
177     for (; nbProcessed < nbTotalToProcess; nbProcessed++)
178     {
179         const unsigned int testCaseId = params.m_IterationCount > 0 ? nbProcessed : defaultTestCaseIds[nbProcessed];
180         std::unique_ptr<IInferenceTestCase> testCase = testCaseProvider.GetTestCase(testCaseId);
181
182         if (testCase == nullptr)
183         {
184             ARMNN_LOG(error) << "Failed to load test case";
185             return false;
186         }
187
188         time_point<high_resolution_clock> predictStart;
189         time_point<high_resolution_clock> predictEnd;
190
191         TestCaseResult result = TestCaseResult::Ok;
192
193         try
194         {
195             predictStart = high_resolution_clock::now();
196
197             testCase->Run();
198
199             predictEnd = high_resolution_clock::now();
200
201             // duration<double> will convert the time difference into seconds as a double by default.
202             double timeTakenS = duration<double>(predictEnd - predictStart).count();
203             totalTime += timeTakenS;
204
205             // Outputss inference times, if needed.
206             if (recordInferenceTimes)
207             {
208                 inferenceTimesFile << testCaseId << " " << (timeTakenS * 1000.0) << std::endl;
209             }
210
211             result = testCase->ProcessResult(params);
212
213         }
214         catch (const TestFrameworkException& testError)
215         {
216             ARMNN_LOG(error) << testError.what();
217             result = TestCaseResult::Abort;
218         }
219
220         switch (result)
221         {
222         case TestCaseResult::Ok:
223             break;
224         case TestCaseResult::Abort:
225             return false;
226         case TestCaseResult::Failed:
227             // This test failed so we will fail the entire program eventually, but keep going for now.
228             success = false;
229             break;
230         default:
231             BOOST_ASSERT_MSG(false, "Unexpected TestCaseResult");
232             return false;
233         }
234     }
235
236     const double averageTimePerTestCaseMs = totalTime / nbProcessed * 1000.0f;
237
238     ARMNN_LOG(info) << std::fixed << std::setprecision(3) <<
239         "Total time for " << nbProcessed << " test cases: " << totalTime << " seconds";
240     ARMNN_LOG(info) << std::fixed << std::setprecision(3) <<
241         "Average time per test case: " << averageTimePerTestCaseMs << " ms";
242
243     // if profiling is enabled print out the results
244     if (profiler && profiler->IsProfilingEnabled())
245     {
246         profiler->Print(std::cout);
247     }
248
249     if (!success)
250     {
251         ARMNN_LOG(error) << "One or more test cases failed";
252         return false;
253     }
254
255     return testCaseProvider.OnInferenceTestFinished();
256 }
257
258 } // namespace test
259
260 } // namespace armnn