Release 18.02
[platform/upstream/armnn.git] / tests / InferenceTest.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5 #include "InferenceTest.hpp"
6
7 #include <boost/algorithm/string.hpp>
8 #include <boost/numeric/conversion/cast.hpp>
9 #include <boost/log/trivial.hpp>
10 #include <boost/filesystem/path.hpp>
11 #include <boost/assert.hpp>
12 #include <boost/format.hpp>
13 #include <boost/program_options.hpp>
14 #include <boost/filesystem/operations.hpp>
15
16 #include <fstream>
17 #include <iostream>
18 #include <iomanip>
19 #include <array>
20
21 using namespace std;
22 using namespace std::chrono;
23 using namespace armnn::test;
24
25 namespace armnn
26 {
27 namespace test
28 {
29
30 /// Parse the command line of an ArmNN (or referencetests) inference test program.
31 /// \return false if any error occurred during options processing, otherwise true
32 bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
33     InferenceTestOptions& outParams)
34 {
35     namespace po = boost::program_options;
36
37     std::string computeDeviceStr;
38
39     po::options_description desc("Options");
40
41     try
42     {
43         // Add generic options needed for all inference tests
44         desc.add_options()
45             ("help", "Display help messages")
46             ("iterations,i", po::value<unsigned int>(&outParams.m_IterationCount)->default_value(0),
47                 "Sets the number number of inferences to perform. If unset, a default number will be ran.")
48             ("inference-times-file", po::value<std::string>(&outParams.m_InferenceTimesFile)->default_value(""),
49                 "If non-empty, each individual inference time will be recorded and output to this file");
50
51         // Add options specific to the ITestCaseProvider
52         testCaseProvider.AddCommandLineOptions(desc);
53     }
54     catch (const std::exception& e)
55     {
56         // Coverity points out that default_value(...) can throw a bad_lexical_cast,
57         // and that desc.add_options() can throw boost::io::too_few_args.
58         // They really won't in any of these cases.
59         BOOST_ASSERT_MSG(false, "Caught unexpected exception");
60         std::cerr << "Fatal internal error: " << e.what() << std::endl;
61         return false;
62     }
63
64     po::variables_map vm;
65
66     try
67     {
68         po::store(po::parse_command_line(argc, argv, desc), vm);
69
70         if (vm.count("help"))
71         {
72             std::cout << desc << std::endl;
73             return false;
74         }
75
76         po::notify(vm);
77     }
78     catch (po::error& e)
79     {
80         std::cerr << e.what() << std::endl << std::endl;
81         std::cerr << desc << std::endl;
82         return false;
83     }
84
85     if (!testCaseProvider.ProcessCommandLineOptions())
86     {
87         return false;
88     }
89
90     return true;
91 }
92
93 bool ValidateDirectory(std::string& dir)
94 {
95     if (dir[dir.length() - 1] != '/')
96     {
97         dir += "/";
98     }
99
100     if (!boost::filesystem::exists(dir))
101     {
102         std::cerr << "Given directory " << dir << " does not exist" << std::endl;
103         return false;
104     }
105
106     return true;
107 }
108
109 bool InferenceTest(const InferenceTestOptions& params,
110     const std::vector<unsigned int>& defaultTestCaseIds,
111     IInferenceTestCaseProvider& testCaseProvider)
112 {
113 #if !defined (NDEBUG)
114     if (params.m_IterationCount > 0) // If just running a few select images then don't bother to warn
115     {
116         BOOST_LOG_TRIVIAL(warning) << "Performance test running in DEBUG build - results may be inaccurate.";
117     }
118 #endif
119
120     double totalTime = 0;
121     unsigned int nbProcessed = 0;
122     bool success = true;
123
124     // Open the file to write inference times to, if needed
125     ofstream inferenceTimesFile;
126     const bool recordInferenceTimes = !params.m_InferenceTimesFile.empty();
127     if (recordInferenceTimes)
128     {
129         inferenceTimesFile.open(params.m_InferenceTimesFile.c_str(), ios_base::trunc | ios_base::out);
130         if (!inferenceTimesFile.good())
131         {
132             BOOST_LOG_TRIVIAL(error) << "Failed to open inference times file for writing: "
133                 << params.m_InferenceTimesFile;
134             return false;
135         }
136     }
137
138     // Run a single test case to 'warm-up' the model. The first one can sometimes take up to 10x longer
139     std::unique_ptr<IInferenceTestCase> warmupTestCase = testCaseProvider.GetTestCase(0);
140     if (warmupTestCase == nullptr)
141     {
142         BOOST_LOG_TRIVIAL(error) << "Failed to load test case";
143         return false;
144     }
145
146     try
147     {
148         warmupTestCase->Run();
149     }
150     catch (const TestFrameworkException& testError)
151     {
152         BOOST_LOG_TRIVIAL(error) << testError.what();
153         return false;
154     }
155
156     const unsigned int nbTotalToProcess = params.m_IterationCount > 0 ? params.m_IterationCount
157         : boost::numeric_cast<unsigned int>(defaultTestCaseIds.size());
158
159     for (; nbProcessed < nbTotalToProcess; nbProcessed++)
160     {
161         const unsigned int testCaseId = params.m_IterationCount > 0 ? nbProcessed : defaultTestCaseIds[nbProcessed];
162         std::unique_ptr<IInferenceTestCase> testCase = testCaseProvider.GetTestCase(testCaseId);
163
164         if (testCase == nullptr)
165         {
166             BOOST_LOG_TRIVIAL(error) << "Failed to load test case";
167             return false;
168         }
169
170         time_point<high_resolution_clock> predictStart;
171         time_point<high_resolution_clock> predictEnd;
172
173         TestCaseResult result = TestCaseResult::Ok;
174
175         try
176         {
177             predictStart = high_resolution_clock::now();
178
179             testCase->Run();
180
181             predictEnd = high_resolution_clock::now();
182
183             // duration<double> will convert the time difference into seconds as a double by default.
184             double timeTakenS = duration<double>(predictEnd - predictStart).count();
185             totalTime += timeTakenS;
186
187             // Output inference times if needed
188             if (recordInferenceTimes)
189             {
190                 inferenceTimesFile << testCaseId << " " << (timeTakenS * 1000.0) << std::endl;
191             }
192
193             result = testCase->ProcessResult(params);
194
195         }
196         catch (const TestFrameworkException& testError)
197         {
198             BOOST_LOG_TRIVIAL(error) << testError.what();
199             result = TestCaseResult::Abort;
200         }
201
202         switch (result)
203         {
204         case TestCaseResult::Ok:
205             break;
206         case TestCaseResult::Abort:
207             return false;
208         case TestCaseResult::Failed:
209             // This test failed so we will fail the entire program eventually, but keep going for now.
210             success = false;
211             break;
212         default:
213             BOOST_ASSERT_MSG(false, "Unexpected TestCaseResult");
214             return false;
215         }
216     }
217
218     const double averageTimePerTestCaseMs = totalTime / nbProcessed * 1000.0f;
219
220     BOOST_LOG_TRIVIAL(info) << std::fixed << std::setprecision(3) <<
221         "Total time for " << nbProcessed << " test cases: " << totalTime << " seconds";
222     BOOST_LOG_TRIVIAL(info) << std::fixed << std::setprecision(3) <<
223         "Average time per test case: " << averageTimePerTestCaseMs << " ms";
224
225     if (!success)
226     {
227         BOOST_LOG_TRIVIAL(error) << "One or more test cases failed";
228         return false;
229     }
230
231     return testCaseProvider.OnInferenceTestFinished();
232 }
233
234 } // namespace test
235
236 } // namespace armnn