IVGCVSW-5465 ExecuteNetworkTestsDynamicBackends Bug Fix
[platform/upstream/armnn.git] / tests / InferenceTest.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include "InferenceTest.hpp"
6
7 #include <armnn/utility/Assert.hpp>
8 #include <Filesystem.hpp>
9
10 #include "../src/armnn/Profiling.hpp"
11 #include <cxxopts/cxxopts.hpp>
12
13 #include <fstream>
14 #include <iostream>
15 #include <iomanip>
16 #include <array>
17
18 using namespace std;
19 using namespace std::chrono;
20 using namespace armnn::test;
21
22 namespace armnn
23 {
24 namespace test
25 {
26 /// Parse the command line of an ArmNN (or referencetests) inference test program.
27 /// \return false if any error occurred during options processing, otherwise true
28 bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
29     InferenceTestOptions& outParams)
30 {
31     cxxopts::Options options("InferenceTest", "Inference iteration parameters");
32
33     try
34     {
35         // Adds generic options needed for all inference tests.
36         options
37             .allow_unrecognised_options()
38             .add_options()
39                 ("h,help", "Display help messages")
40                 ("i,iterations", "Sets the number of inferences to perform. If unset, will only be run once.",
41                  cxxopts::value<unsigned int>(outParams.m_IterationCount)->default_value("0"))
42                 ("inference-times-file",
43                  "If non-empty, each individual inference time will be recorded and output to this file",
44                  cxxopts::value<std::string>(outParams.m_InferenceTimesFile)->default_value(""))
45                 ("e,event-based-profiling", "Enables built in profiler. If unset, defaults to off.",
46                  cxxopts::value<bool>(outParams.m_EnableProfiling)->default_value("0"));
47
48         std::vector<std::string> required; //to be passed as reference to derived inference tests
49
50         // Adds options specific to the ITestCaseProvider.
51         testCaseProvider.AddCommandLineOptions(options, required);
52
53         auto result = options.parse(argc, argv);
54
55         if (result.count("help"))
56         {
57             std::cout << options.help() << std::endl;
58             return false;
59         }
60
61         CheckRequiredOptions(result, required);
62
63     }
64     catch (const cxxopts::OptionException& e)
65     {
66         std::cerr << e.what() << std::endl << options.help() << std::endl;
67         return false;
68     }
69     catch (const std::exception& e)
70     {
71         // Coverity points out that default_value(...) can throw a bad_lexical_cast,
72         // and that desc.add_options() can throw boost::io::too_few_args.
73         // They really won't in any of these cases.
74         ARMNN_ASSERT_MSG(false, "Caught unexpected exception");
75         std::cerr << "Fatal internal error: " << e.what() << std::endl;
76         return false;
77     }
78
79     if (!testCaseProvider.ProcessCommandLineOptions(outParams))
80     {
81         return false;
82     }
83
84     return true;
85 }
86
87 bool ValidateDirectory(std::string& dir)
88 {
89     if (dir.empty())
90     {
91         std::cerr << "No directory specified" << std::endl;
92         return false;
93     }
94
95     if (dir[dir.length() - 1] != '/')
96     {
97         dir += "/";
98     }
99
100     if (!fs::exists(dir))
101     {
102         std::cerr << "Given directory " << dir << " does not exist" << std::endl;
103         return false;
104     }
105
106     if (!fs::is_directory(dir))
107     {
108         std::cerr << "Given directory [" << dir << "] is not a directory" << std::endl;
109         return false;
110     }
111
112     return true;
113 }
114
115 bool InferenceTest(const InferenceTestOptions& params,
116     const std::vector<unsigned int>& defaultTestCaseIds,
117     IInferenceTestCaseProvider& testCaseProvider)
118 {
119 #if !defined (NDEBUG)
120     if (params.m_IterationCount > 0) // If just running a few select images then don't bother to warn.
121     {
122         ARMNN_LOG(warning) << "Performance test running in DEBUG build - results may be inaccurate.";
123     }
124 #endif
125
126     double totalTime = 0;
127     unsigned int nbProcessed = 0;
128     bool success = true;
129
130     // Opens the file to write inference times too, if needed.
131     ofstream inferenceTimesFile;
132     const bool recordInferenceTimes = !params.m_InferenceTimesFile.empty();
133     if (recordInferenceTimes)
134     {
135         inferenceTimesFile.open(params.m_InferenceTimesFile.c_str(), ios_base::trunc | ios_base::out);
136         if (!inferenceTimesFile.good())
137         {
138             ARMNN_LOG(error) << "Failed to open inference times file for writing: "
139                 << params.m_InferenceTimesFile;
140             return false;
141         }
142     }
143
144     // Create a profiler and register it for the current thread.
145     std::unique_ptr<Profiler> profiler = std::make_unique<Profiler>();
146     ProfilerManager::GetInstance().RegisterProfiler(profiler.get());
147
148     // Enable profiling if requested.
149     profiler->EnableProfiling(params.m_EnableProfiling);
150
151     // Run a single test case to 'warm-up' the model. The first one can sometimes take up to 10x longer
152     std::unique_ptr<IInferenceTestCase> warmupTestCase = testCaseProvider.GetTestCase(0);
153     if (warmupTestCase == nullptr)
154     {
155         ARMNN_LOG(error) << "Failed to load test case";
156         return false;
157     }
158
159     try
160     {
161         warmupTestCase->Run();
162     }
163     catch (const TestFrameworkException& testError)
164     {
165         ARMNN_LOG(error) << testError.what();
166         return false;
167     }
168
169     const unsigned int nbTotalToProcess = params.m_IterationCount > 0 ? params.m_IterationCount
170         : static_cast<unsigned int>(defaultTestCaseIds.size());
171
172     for (; nbProcessed < nbTotalToProcess; nbProcessed++)
173     {
174         const unsigned int testCaseId = params.m_IterationCount > 0 ? nbProcessed : defaultTestCaseIds[nbProcessed];
175         std::unique_ptr<IInferenceTestCase> testCase = testCaseProvider.GetTestCase(testCaseId);
176
177         if (testCase == nullptr)
178         {
179             ARMNN_LOG(error) << "Failed to load test case";
180             return false;
181         }
182
183         time_point<high_resolution_clock> predictStart;
184         time_point<high_resolution_clock> predictEnd;
185
186         TestCaseResult result = TestCaseResult::Ok;
187
188         try
189         {
190             predictStart = high_resolution_clock::now();
191
192             testCase->Run();
193
194             predictEnd = high_resolution_clock::now();
195
196             // duration<double> will convert the time difference into seconds as a double by default.
197             double timeTakenS = duration<double>(predictEnd - predictStart).count();
198             totalTime += timeTakenS;
199
200             // Outputss inference times, if needed.
201             if (recordInferenceTimes)
202             {
203                 inferenceTimesFile << testCaseId << " " << (timeTakenS * 1000.0) << std::endl;
204             }
205
206             result = testCase->ProcessResult(params);
207
208         }
209         catch (const TestFrameworkException& testError)
210         {
211             ARMNN_LOG(error) << testError.what();
212             result = TestCaseResult::Abort;
213         }
214
215         switch (result)
216         {
217         case TestCaseResult::Ok:
218             break;
219         case TestCaseResult::Abort:
220             return false;
221         case TestCaseResult::Failed:
222             // This test failed so we will fail the entire program eventually, but keep going for now.
223             success = false;
224             break;
225         default:
226             ARMNN_ASSERT_MSG(false, "Unexpected TestCaseResult");
227             return false;
228         }
229     }
230
231     const double averageTimePerTestCaseMs = totalTime / nbProcessed * 1000.0f;
232
233     ARMNN_LOG(info) << std::fixed << std::setprecision(3) <<
234         "Total time for " << nbProcessed << " test cases: " << totalTime << " seconds";
235     ARMNN_LOG(info) << std::fixed << std::setprecision(3) <<
236         "Average time per test case: " << averageTimePerTestCaseMs << " ms";
237
238     // if profiling is enabled print out the results
239     if (profiler && profiler->IsProfilingEnabled())
240     {
241         profiler->Print(std::cout);
242     }
243
244     if (!success)
245     {
246         ARMNN_LOG(error) << "One or more test cases failed";
247         return false;
248     }
249
250     return testCaseProvider.OnInferenceTestFinished();
251 }
252
253 } // namespace test
254
255 } // namespace armnn