IVGCVSW-4487 Remove boost::filesystem
[platform/upstream/armnn.git] / tests / InferenceTest.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include "InferenceTest.hpp"
6
7 #include <armnn/utility/Assert.hpp>
8 #include <Filesystem.hpp>
9
10 #include "../src/armnn/Profiling.hpp"
11 #include <boost/numeric/conversion/cast.hpp>
12 #include <boost/format.hpp>
13 #include <boost/program_options.hpp>
14
15 #include <fstream>
16 #include <iostream>
17 #include <iomanip>
18 #include <array>
19
20 using namespace std;
21 using namespace std::chrono;
22 using namespace armnn::test;
23
24 namespace armnn
25 {
26 namespace test
27 {
28 /// Parse the command line of an ArmNN (or referencetests) inference test program.
29 /// \return false if any error occurred during options processing, otherwise true
30 bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
31     InferenceTestOptions& outParams)
32 {
33     namespace po = boost::program_options;
34
35     po::options_description desc("Options");
36
37     try
38     {
39         // Adds generic options needed for all inference tests.
40         desc.add_options()
41             ("help", "Display help messages")
42             ("iterations,i", po::value<unsigned int>(&outParams.m_IterationCount)->default_value(0),
43                 "Sets the number number of inferences to perform. If unset, a default number will be ran.")
44             ("inference-times-file", po::value<std::string>(&outParams.m_InferenceTimesFile)->default_value(""),
45                 "If non-empty, each individual inference time will be recorded and output to this file")
46             ("event-based-profiling,e", po::value<bool>(&outParams.m_EnableProfiling)->default_value(0),
47                 "Enables built in profiler. If unset, defaults to off.");
48
49         // Adds options specific to the ITestCaseProvider.
50         testCaseProvider.AddCommandLineOptions(desc);
51     }
52     catch (const std::exception& e)
53     {
54         // Coverity points out that default_value(...) can throw a bad_lexical_cast,
55         // and that desc.add_options() can throw boost::io::too_few_args.
56         // They really won't in any of these cases.
57         ARMNN_ASSERT_MSG(false, "Caught unexpected exception");
58         std::cerr << "Fatal internal error: " << e.what() << std::endl;
59         return false;
60     }
61
62     po::variables_map vm;
63
64     try
65     {
66         po::store(po::parse_command_line(argc, argv, desc), vm);
67
68         if (vm.count("help"))
69         {
70             std::cout << desc << std::endl;
71             return false;
72         }
73
74         po::notify(vm);
75     }
76     catch (po::error& e)
77     {
78         std::cerr << e.what() << std::endl << std::endl;
79         std::cerr << desc << std::endl;
80         return false;
81     }
82
83     if (!testCaseProvider.ProcessCommandLineOptions(outParams))
84     {
85         return false;
86     }
87
88     return true;
89 }
90
91 bool ValidateDirectory(std::string& dir)
92 {
93     if (dir.empty())
94     {
95         std::cerr << "No directory specified" << std::endl;
96         return false;
97     }
98
99     if (dir[dir.length() - 1] != '/')
100     {
101         dir += "/";
102     }
103
104     if (!fs::exists(dir))
105     {
106         std::cerr << "Given directory " << dir << " does not exist" << std::endl;
107         return false;
108     }
109
110     if (!fs::is_directory(dir))
111     {
112         std::cerr << "Given directory [" << dir << "] is not a directory" << std::endl;
113         return false;
114     }
115
116     return true;
117 }
118
119 bool InferenceTest(const InferenceTestOptions& params,
120     const std::vector<unsigned int>& defaultTestCaseIds,
121     IInferenceTestCaseProvider& testCaseProvider)
122 {
123 #if !defined (NDEBUG)
124     if (params.m_IterationCount > 0) // If just running a few select images then don't bother to warn.
125     {
126         ARMNN_LOG(warning) << "Performance test running in DEBUG build - results may be inaccurate.";
127     }
128 #endif
129
130     double totalTime = 0;
131     unsigned int nbProcessed = 0;
132     bool success = true;
133
134     // Opens the file to write inference times too, if needed.
135     ofstream inferenceTimesFile;
136     const bool recordInferenceTimes = !params.m_InferenceTimesFile.empty();
137     if (recordInferenceTimes)
138     {
139         inferenceTimesFile.open(params.m_InferenceTimesFile.c_str(), ios_base::trunc | ios_base::out);
140         if (!inferenceTimesFile.good())
141         {
142             ARMNN_LOG(error) << "Failed to open inference times file for writing: "
143                 << params.m_InferenceTimesFile;
144             return false;
145         }
146     }
147
148     // Create a profiler and register it for the current thread.
149     std::unique_ptr<Profiler> profiler = std::make_unique<Profiler>();
150     ProfilerManager::GetInstance().RegisterProfiler(profiler.get());
151
152     // Enable profiling if requested.
153     profiler->EnableProfiling(params.m_EnableProfiling);
154
155     // Run a single test case to 'warm-up' the model. The first one can sometimes take up to 10x longer
156     std::unique_ptr<IInferenceTestCase> warmupTestCase = testCaseProvider.GetTestCase(0);
157     if (warmupTestCase == nullptr)
158     {
159         ARMNN_LOG(error) << "Failed to load test case";
160         return false;
161     }
162
163     try
164     {
165         warmupTestCase->Run();
166     }
167     catch (const TestFrameworkException& testError)
168     {
169         ARMNN_LOG(error) << testError.what();
170         return false;
171     }
172
173     const unsigned int nbTotalToProcess = params.m_IterationCount > 0 ? params.m_IterationCount
174         : static_cast<unsigned int>(defaultTestCaseIds.size());
175
176     for (; nbProcessed < nbTotalToProcess; nbProcessed++)
177     {
178         const unsigned int testCaseId = params.m_IterationCount > 0 ? nbProcessed : defaultTestCaseIds[nbProcessed];
179         std::unique_ptr<IInferenceTestCase> testCase = testCaseProvider.GetTestCase(testCaseId);
180
181         if (testCase == nullptr)
182         {
183             ARMNN_LOG(error) << "Failed to load test case";
184             return false;
185         }
186
187         time_point<high_resolution_clock> predictStart;
188         time_point<high_resolution_clock> predictEnd;
189
190         TestCaseResult result = TestCaseResult::Ok;
191
192         try
193         {
194             predictStart = high_resolution_clock::now();
195
196             testCase->Run();
197
198             predictEnd = high_resolution_clock::now();
199
200             // duration<double> will convert the time difference into seconds as a double by default.
201             double timeTakenS = duration<double>(predictEnd - predictStart).count();
202             totalTime += timeTakenS;
203
204             // Outputss inference times, if needed.
205             if (recordInferenceTimes)
206             {
207                 inferenceTimesFile << testCaseId << " " << (timeTakenS * 1000.0) << std::endl;
208             }
209
210             result = testCase->ProcessResult(params);
211
212         }
213         catch (const TestFrameworkException& testError)
214         {
215             ARMNN_LOG(error) << testError.what();
216             result = TestCaseResult::Abort;
217         }
218
219         switch (result)
220         {
221         case TestCaseResult::Ok:
222             break;
223         case TestCaseResult::Abort:
224             return false;
225         case TestCaseResult::Failed:
226             // This test failed so we will fail the entire program eventually, but keep going for now.
227             success = false;
228             break;
229         default:
230             ARMNN_ASSERT_MSG(false, "Unexpected TestCaseResult");
231             return false;
232         }
233     }
234
235     const double averageTimePerTestCaseMs = totalTime / nbProcessed * 1000.0f;
236
237     ARMNN_LOG(info) << std::fixed << std::setprecision(3) <<
238         "Total time for " << nbProcessed << " test cases: " << totalTime << " seconds";
239     ARMNN_LOG(info) << std::fixed << std::setprecision(3) <<
240         "Average time per test case: " << averageTimePerTestCaseMs << " ms";
241
242     // if profiling is enabled print out the results
243     if (profiler && profiler->IsProfilingEnabled())
244     {
245         profiler->Print(std::cout);
246     }
247
248     if (!success)
249     {
250         ARMNN_LOG(error) << "One or more test cases failed";
251         return false;
252     }
253
254     return testCaseProvider.OnInferenceTestFinished();
255 }
256
257 } // namespace test
258
259 } // namespace armnn