Release 18.08
[platform/upstream/armnn.git] / tests / InferenceTest.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5 #include "InferenceTest.hpp"
6
7 #include "../src/armnn/Profiling.hpp"
8 #include <boost/algorithm/string.hpp>
9 #include <boost/numeric/conversion/cast.hpp>
10 #include <boost/log/trivial.hpp>
11 #include <boost/filesystem/path.hpp>
12 #include <boost/assert.hpp>
13 #include <boost/format.hpp>
14 #include <boost/program_options.hpp>
15 #include <boost/filesystem/operations.hpp>
16
17 #include <fstream>
18 #include <iostream>
19 #include <iomanip>
20 #include <array>
21
22 using namespace std;
23 using namespace std::chrono;
24 using namespace armnn::test;
25
26 namespace armnn
27 {
28 namespace test
29 {
30 /// Parse the command line of an ArmNN (or referencetests) inference test program.
31 /// \return false if any error occurred during options processing, otherwise true
32 bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
33     InferenceTestOptions& outParams)
34 {
35     namespace po = boost::program_options;
36
37     std::string computeDeviceStr;
38
39     po::options_description desc("Options");
40
41     try
42     {
43         // Adds generic options needed for all inference tests.
44         desc.add_options()
45             ("help", "Display help messages")
46             ("iterations,i", po::value<unsigned int>(&outParams.m_IterationCount)->default_value(0),
47                 "Sets the number number of inferences to perform. If unset, a default number will be ran.")
48             ("inference-times-file", po::value<std::string>(&outParams.m_InferenceTimesFile)->default_value(""),
49                 "If non-empty, each individual inference time will be recorded and output to this file")
50             ("event-based-profiling,e", po::value<bool>(&outParams.m_EnableProfiling)->default_value(0),
51                 "Enables built in profiler. If unset, defaults to off.");
52
53         // Adds options specific to the ITestCaseProvider.
54         testCaseProvider.AddCommandLineOptions(desc);
55     }
56     catch (const std::exception& e)
57     {
58         // Coverity points out that default_value(...) can throw a bad_lexical_cast,
59         // and that desc.add_options() can throw boost::io::too_few_args.
60         // They really won't in any of these cases.
61         BOOST_ASSERT_MSG(false, "Caught unexpected exception");
62         std::cerr << "Fatal internal error: " << e.what() << std::endl;
63         return false;
64     }
65
66     po::variables_map vm;
67
68     try
69     {
70         po::store(po::parse_command_line(argc, argv, desc), vm);
71
72         if (vm.count("help"))
73         {
74             std::cout << desc << std::endl;
75             return false;
76         }
77
78         po::notify(vm);
79     }
80     catch (po::error& e)
81     {
82         std::cerr << e.what() << std::endl << std::endl;
83         std::cerr << desc << std::endl;
84         return false;
85     }
86
87     if (!testCaseProvider.ProcessCommandLineOptions())
88     {
89         return false;
90     }
91
92     return true;
93 }
94
95 bool ValidateDirectory(std::string& dir)
96 {
97     if (dir[dir.length() - 1] != '/')
98     {
99         dir += "/";
100     }
101
102     if (!boost::filesystem::exists(dir))
103     {
104         std::cerr << "Given directory " << dir << " does not exist" << std::endl;
105         return false;
106     }
107
108     return true;
109 }
110
111 bool InferenceTest(const InferenceTestOptions& params,
112     const std::vector<unsigned int>& defaultTestCaseIds,
113     IInferenceTestCaseProvider& testCaseProvider)
114 {
115 #if !defined (NDEBUG)
116     if (params.m_IterationCount > 0) // If just running a few select images then don't bother to warn.
117     {
118         BOOST_LOG_TRIVIAL(warning) << "Performance test running in DEBUG build - results may be inaccurate.";
119     }
120 #endif
121
122     double totalTime = 0;
123     unsigned int nbProcessed = 0;
124     bool success = true;
125
126     // Opens the file to write inference times too, if needed.
127     ofstream inferenceTimesFile;
128     const bool recordInferenceTimes = !params.m_InferenceTimesFile.empty();
129     if (recordInferenceTimes)
130     {
131         inferenceTimesFile.open(params.m_InferenceTimesFile.c_str(), ios_base::trunc | ios_base::out);
132         if (!inferenceTimesFile.good())
133         {
134             BOOST_LOG_TRIVIAL(error) << "Failed to open inference times file for writing: "
135                 << params.m_InferenceTimesFile;
136             return false;
137         }
138     }
139
140     // Create a profiler and register it for the current thread.
141     std::unique_ptr<Profiler> profiler = std::make_unique<Profiler>();
142     ProfilerManager::GetInstance().RegisterProfiler(profiler.get());
143
144     // Enable profiling if requested.
145     profiler->EnableProfiling(params.m_EnableProfiling);
146
147     // Run a single test case to 'warm-up' the model. The first one can sometimes take up to 10x longer
148     std::unique_ptr<IInferenceTestCase> warmupTestCase = testCaseProvider.GetTestCase(0);
149     if (warmupTestCase == nullptr)
150     {
151         BOOST_LOG_TRIVIAL(error) << "Failed to load test case";
152         return false;
153     }
154
155     try
156     {
157         warmupTestCase->Run();
158     }
159     catch (const TestFrameworkException& testError)
160     {
161         BOOST_LOG_TRIVIAL(error) << testError.what();
162         return false;
163     }
164
165     const unsigned int nbTotalToProcess = params.m_IterationCount > 0 ? params.m_IterationCount
166         : static_cast<unsigned int>(defaultTestCaseIds.size());
167
168     for (; nbProcessed < nbTotalToProcess; nbProcessed++)
169     {
170         const unsigned int testCaseId = params.m_IterationCount > 0 ? nbProcessed : defaultTestCaseIds[nbProcessed];
171         std::unique_ptr<IInferenceTestCase> testCase = testCaseProvider.GetTestCase(testCaseId);
172
173         if (testCase == nullptr)
174         {
175             BOOST_LOG_TRIVIAL(error) << "Failed to load test case";
176             return false;
177         }
178
179         time_point<high_resolution_clock> predictStart;
180         time_point<high_resolution_clock> predictEnd;
181
182         TestCaseResult result = TestCaseResult::Ok;
183
184         try
185         {
186             predictStart = high_resolution_clock::now();
187
188             testCase->Run();
189
190             predictEnd = high_resolution_clock::now();
191
192             // duration<double> will convert the time difference into seconds as a double by default.
193             double timeTakenS = duration<double>(predictEnd - predictStart).count();
194             totalTime += timeTakenS;
195
196             // Outputss inference times, if needed.
197             if (recordInferenceTimes)
198             {
199                 inferenceTimesFile << testCaseId << " " << (timeTakenS * 1000.0) << std::endl;
200             }
201
202             result = testCase->ProcessResult(params);
203
204         }
205         catch (const TestFrameworkException& testError)
206         {
207             BOOST_LOG_TRIVIAL(error) << testError.what();
208             result = TestCaseResult::Abort;
209         }
210
211         switch (result)
212         {
213         case TestCaseResult::Ok:
214             break;
215         case TestCaseResult::Abort:
216             return false;
217         case TestCaseResult::Failed:
218             // This test failed so we will fail the entire program eventually, but keep going for now.
219             success = false;
220             break;
221         default:
222             BOOST_ASSERT_MSG(false, "Unexpected TestCaseResult");
223             return false;
224         }
225     }
226
227     const double averageTimePerTestCaseMs = totalTime / nbProcessed * 1000.0f;
228
229     BOOST_LOG_TRIVIAL(info) << std::fixed << std::setprecision(3) <<
230         "Total time for " << nbProcessed << " test cases: " << totalTime << " seconds";
231     BOOST_LOG_TRIVIAL(info) << std::fixed << std::setprecision(3) <<
232         "Average time per test case: " << averageTimePerTestCaseMs << " ms";
233
234     if (!success)
235     {
236         BOOST_LOG_TRIVIAL(error) << "One or more test cases failed";
237         return false;
238     }
239
240     return testCaseProvider.OnInferenceTestFinished();
241 }
242
243 } // namespace test
244
245 } // namespace armnn