IVGCVSW-4661 Add include Assert to GatordMockService.cpp
[platform/upstream/armnn.git] / tests / InferenceTest.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include "InferenceTest.hpp"
6
7 #include <armnn/utility/Assert.hpp>
8
9 #include "../src/armnn/Profiling.hpp"
10 #include <boost/algorithm/string.hpp>
11 #include <boost/numeric/conversion/cast.hpp>
12 #include <boost/filesystem/path.hpp>
13 #include <boost/format.hpp>
14 #include <boost/program_options.hpp>
15 #include <boost/filesystem/operations.hpp>
16
17 #include <fstream>
18 #include <iostream>
19 #include <iomanip>
20 #include <array>
21
22 using namespace std;
23 using namespace std::chrono;
24 using namespace armnn::test;
25
26 namespace armnn
27 {
28 namespace test
29 {
30 /// Parse the command line of an ArmNN (or referencetests) inference test program.
31 /// \return false if any error occurred during options processing, otherwise true
32 bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
33     InferenceTestOptions& outParams)
34 {
35     namespace po = boost::program_options;
36
37     po::options_description desc("Options");
38
39     try
40     {
41         // Adds generic options needed for all inference tests.
42         desc.add_options()
43             ("help", "Display help messages")
44             ("iterations,i", po::value<unsigned int>(&outParams.m_IterationCount)->default_value(0),
45                 "Sets the number number of inferences to perform. If unset, a default number will be ran.")
46             ("inference-times-file", po::value<std::string>(&outParams.m_InferenceTimesFile)->default_value(""),
47                 "If non-empty, each individual inference time will be recorded and output to this file")
48             ("event-based-profiling,e", po::value<bool>(&outParams.m_EnableProfiling)->default_value(0),
49                 "Enables built in profiler. If unset, defaults to off.");
50
51         // Adds options specific to the ITestCaseProvider.
52         testCaseProvider.AddCommandLineOptions(desc);
53     }
54     catch (const std::exception& e)
55     {
56         // Coverity points out that default_value(...) can throw a bad_lexical_cast,
57         // and that desc.add_options() can throw boost::io::too_few_args.
58         // They really won't in any of these cases.
59         ARMNN_ASSERT_MSG(false, "Caught unexpected exception");
60         std::cerr << "Fatal internal error: " << e.what() << std::endl;
61         return false;
62     }
63
64     po::variables_map vm;
65
66     try
67     {
68         po::store(po::parse_command_line(argc, argv, desc), vm);
69
70         if (vm.count("help"))
71         {
72             std::cout << desc << std::endl;
73             return false;
74         }
75
76         po::notify(vm);
77     }
78     catch (po::error& e)
79     {
80         std::cerr << e.what() << std::endl << std::endl;
81         std::cerr << desc << std::endl;
82         return false;
83     }
84
85     if (!testCaseProvider.ProcessCommandLineOptions(outParams))
86     {
87         return false;
88     }
89
90     return true;
91 }
92
93 bool ValidateDirectory(std::string& dir)
94 {
95     if (dir.empty())
96     {
97         std::cerr << "No directory specified" << std::endl;
98         return false;
99     }
100
101     if (dir[dir.length() - 1] != '/')
102     {
103         dir += "/";
104     }
105
106     if (!boost::filesystem::exists(dir))
107     {
108         std::cerr << "Given directory " << dir << " does not exist" << std::endl;
109         return false;
110     }
111
112     if (!boost::filesystem::is_directory(dir))
113     {
114         std::cerr << "Given directory [" << dir << "] is not a directory" << std::endl;
115         return false;
116     }
117
118     return true;
119 }
120
121 bool InferenceTest(const InferenceTestOptions& params,
122     const std::vector<unsigned int>& defaultTestCaseIds,
123     IInferenceTestCaseProvider& testCaseProvider)
124 {
125 #if !defined (NDEBUG)
126     if (params.m_IterationCount > 0) // If just running a few select images then don't bother to warn.
127     {
128         ARMNN_LOG(warning) << "Performance test running in DEBUG build - results may be inaccurate.";
129     }
130 #endif
131
132     double totalTime = 0;
133     unsigned int nbProcessed = 0;
134     bool success = true;
135
136     // Opens the file to write inference times too, if needed.
137     ofstream inferenceTimesFile;
138     const bool recordInferenceTimes = !params.m_InferenceTimesFile.empty();
139     if (recordInferenceTimes)
140     {
141         inferenceTimesFile.open(params.m_InferenceTimesFile.c_str(), ios_base::trunc | ios_base::out);
142         if (!inferenceTimesFile.good())
143         {
144             ARMNN_LOG(error) << "Failed to open inference times file for writing: "
145                 << params.m_InferenceTimesFile;
146             return false;
147         }
148     }
149
150     // Create a profiler and register it for the current thread.
151     std::unique_ptr<Profiler> profiler = std::make_unique<Profiler>();
152     ProfilerManager::GetInstance().RegisterProfiler(profiler.get());
153
154     // Enable profiling if requested.
155     profiler->EnableProfiling(params.m_EnableProfiling);
156
157     // Run a single test case to 'warm-up' the model. The first one can sometimes take up to 10x longer
158     std::unique_ptr<IInferenceTestCase> warmupTestCase = testCaseProvider.GetTestCase(0);
159     if (warmupTestCase == nullptr)
160     {
161         ARMNN_LOG(error) << "Failed to load test case";
162         return false;
163     }
164
165     try
166     {
167         warmupTestCase->Run();
168     }
169     catch (const TestFrameworkException& testError)
170     {
171         ARMNN_LOG(error) << testError.what();
172         return false;
173     }
174
175     const unsigned int nbTotalToProcess = params.m_IterationCount > 0 ? params.m_IterationCount
176         : static_cast<unsigned int>(defaultTestCaseIds.size());
177
178     for (; nbProcessed < nbTotalToProcess; nbProcessed++)
179     {
180         const unsigned int testCaseId = params.m_IterationCount > 0 ? nbProcessed : defaultTestCaseIds[nbProcessed];
181         std::unique_ptr<IInferenceTestCase> testCase = testCaseProvider.GetTestCase(testCaseId);
182
183         if (testCase == nullptr)
184         {
185             ARMNN_LOG(error) << "Failed to load test case";
186             return false;
187         }
188
189         time_point<high_resolution_clock> predictStart;
190         time_point<high_resolution_clock> predictEnd;
191
192         TestCaseResult result = TestCaseResult::Ok;
193
194         try
195         {
196             predictStart = high_resolution_clock::now();
197
198             testCase->Run();
199
200             predictEnd = high_resolution_clock::now();
201
202             // duration<double> will convert the time difference into seconds as a double by default.
203             double timeTakenS = duration<double>(predictEnd - predictStart).count();
204             totalTime += timeTakenS;
205
206             // Outputss inference times, if needed.
207             if (recordInferenceTimes)
208             {
209                 inferenceTimesFile << testCaseId << " " << (timeTakenS * 1000.0) << std::endl;
210             }
211
212             result = testCase->ProcessResult(params);
213
214         }
215         catch (const TestFrameworkException& testError)
216         {
217             ARMNN_LOG(error) << testError.what();
218             result = TestCaseResult::Abort;
219         }
220
221         switch (result)
222         {
223         case TestCaseResult::Ok:
224             break;
225         case TestCaseResult::Abort:
226             return false;
227         case TestCaseResult::Failed:
228             // This test failed so we will fail the entire program eventually, but keep going for now.
229             success = false;
230             break;
231         default:
232             ARMNN_ASSERT_MSG(false, "Unexpected TestCaseResult");
233             return false;
234         }
235     }
236
237     const double averageTimePerTestCaseMs = totalTime / nbProcessed * 1000.0f;
238
239     ARMNN_LOG(info) << std::fixed << std::setprecision(3) <<
240         "Total time for " << nbProcessed << " test cases: " << totalTime << " seconds";
241     ARMNN_LOG(info) << std::fixed << std::setprecision(3) <<
242         "Average time per test case: " << averageTimePerTestCaseMs << " ms";
243
244     // if profiling is enabled print out the results
245     if (profiler && profiler->IsProfilingEnabled())
246     {
247         profiler->Print(std::cout);
248     }
249
250     if (!success)
251     {
252         ARMNN_LOG(error) << "One or more test cases failed";
253         return false;
254     }
255
256     return testCaseProvider.OnInferenceTestFinished();
257 }
258
259 } // namespace test
260
261 } // namespace armnn