IVGCVSW-4487 Remove boost::filesystem
[platform/upstream/armnn.git] / tests / InferenceTest.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6
7 #include "InferenceModel.hpp"
8
9 #include <armnn/ArmNN.hpp>
10 #include <armnn/Logging.hpp>
11 #include <armnn/TypesUtils.hpp>
12 #include <armnn/utility/IgnoreUnused.hpp>
13
14 #include <boost/program_options.hpp>
15
16
17 namespace armnn
18 {
19
20 inline std::istream& operator>>(std::istream& in, armnn::Compute& compute)
21 {
22     std::string token;
23     in >> token;
24     compute = armnn::ParseComputeDevice(token.c_str());
25     if (compute == armnn::Compute::Undefined)
26     {
27         in.setstate(std::ios_base::failbit);
28         throw boost::program_options::validation_error(boost::program_options::validation_error::invalid_option_value);
29     }
30     return in;
31 }
32
33 inline std::istream& operator>>(std::istream& in, armnn::BackendId& backend)
34 {
35     std::string token;
36     in >> token;
37     armnn::Compute compute = armnn::ParseComputeDevice(token.c_str());
38     if (compute == armnn::Compute::Undefined)
39     {
40         in.setstate(std::ios_base::failbit);
41         throw boost::program_options::validation_error(boost::program_options::validation_error::invalid_option_value);
42     }
43     backend = compute;
44     return in;
45 }
46
47 namespace test
48 {
49
50 class TestFrameworkException : public Exception
51 {
52 public:
53     using Exception::Exception;
54 };
55
56 struct InferenceTestOptions
57 {
58     unsigned int m_IterationCount;
59     std::string m_InferenceTimesFile;
60     bool m_EnableProfiling;
61     std::string m_DynamicBackendsPath;
62
63     InferenceTestOptions()
64         : m_IterationCount(0)
65         , m_EnableProfiling(0)
66         , m_DynamicBackendsPath()
67     {}
68 };
69
70 enum class TestCaseResult
71 {
72     /// The test completed without any errors.
73     Ok,
74     /// The test failed (e.g. the prediction didn't match the validation file).
75     /// This will eventually fail the whole program but the remaining test cases will still be run.
76     Failed,
77     /// The test failed with a fatal error. The remaining tests will not be run.
78     Abort
79 };
80
81 class IInferenceTestCase
82 {
83 public:
84     virtual ~IInferenceTestCase() {}
85
86     virtual void Run() = 0;
87     virtual TestCaseResult ProcessResult(const InferenceTestOptions& options) = 0;
88 };
89
90 class IInferenceTestCaseProvider
91 {
92 public:
93     virtual ~IInferenceTestCaseProvider() {}
94
95     virtual void AddCommandLineOptions(boost::program_options::options_description& options)
96     {
97         IgnoreUnused(options);
98     };
99     virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions)
100     {
101         IgnoreUnused(commonOptions);
102         return true;
103     };
104     virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) = 0;
105     virtual bool OnInferenceTestFinished() { return true; };
106 };
107
108 template <typename TModel>
109 class InferenceModelTestCase : public IInferenceTestCase
110 {
111 public:
112     using TContainer = boost::variant<std::vector<float>, std::vector<int>, std::vector<unsigned char>>;
113
114     InferenceModelTestCase(TModel& model,
115                            unsigned int testCaseId,
116                            const std::vector<TContainer>& inputs,
117                            const std::vector<unsigned int>& outputSizes)
118         : m_Model(model)
119         , m_TestCaseId(testCaseId)
120         , m_Inputs(std::move(inputs))
121     {
122         // Initialize output vector
123         const size_t numOutputs = outputSizes.size();
124         m_Outputs.reserve(numOutputs);
125
126         for (size_t i = 0; i < numOutputs; i++)
127         {
128             m_Outputs.push_back(std::vector<typename TModel::DataType>(outputSizes[i]));
129         }
130     }
131
132     virtual void Run() override
133     {
134         m_Model.Run(m_Inputs, m_Outputs);
135     }
136
137 protected:
138     unsigned int GetTestCaseId() const { return m_TestCaseId; }
139     const std::vector<TContainer>& GetOutputs() const { return m_Outputs; }
140
141 private:
142     TModel&                 m_Model;
143     unsigned int            m_TestCaseId;
144     std::vector<TContainer> m_Inputs;
145     std::vector<TContainer> m_Outputs;
146 };
147
148 template <typename TTestCaseDatabase, typename TModel>
149 class ClassifierTestCase : public InferenceModelTestCase<TModel>
150 {
151 public:
152     ClassifierTestCase(int& numInferencesRef,
153         int& numCorrectInferencesRef,
154         const std::vector<unsigned int>& validationPredictions,
155         std::vector<unsigned int>* validationPredictionsOut,
156         TModel& model,
157         unsigned int testCaseId,
158         unsigned int label,
159         std::vector<typename TModel::DataType> modelInput);
160
161     virtual TestCaseResult ProcessResult(const InferenceTestOptions& params) override;
162
163 private:
164     unsigned int m_Label;
165     InferenceModelInternal::QuantizationParams m_QuantizationParams;
166
167     /// These fields reference the corresponding member in the ClassifierTestCaseProvider.
168     /// @{
169     int& m_NumInferencesRef;
170     int& m_NumCorrectInferencesRef;
171     const std::vector<unsigned int>& m_ValidationPredictions;
172     std::vector<unsigned int>* m_ValidationPredictionsOut;
173     /// @}
174 };
175
176 template <typename TDatabase, typename InferenceModel>
177 class ClassifierTestCaseProvider : public IInferenceTestCaseProvider
178 {
179 public:
180     template <typename TConstructDatabaseCallable, typename TConstructModelCallable>
181     ClassifierTestCaseProvider(TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel);
182
183     virtual void AddCommandLineOptions(boost::program_options::options_description& options) override;
184     virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions) override;
185     virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override;
186     virtual bool OnInferenceTestFinished() override;
187
188 private:
189     void ReadPredictions();
190
191     typename InferenceModel::CommandLineOptions m_ModelCommandLineOptions;
192     std::function<std::unique_ptr<InferenceModel>(const InferenceTestOptions& commonOptions,
193                                                   typename InferenceModel::CommandLineOptions)> m_ConstructModel;
194     std::unique_ptr<InferenceModel> m_Model;
195
196     std::string m_DataDir;
197     std::function<TDatabase(const char*, const InferenceModel&)> m_ConstructDatabase;
198     std::unique_ptr<TDatabase> m_Database;
199
200     int m_NumInferences; // Referenced by test cases.
201     int m_NumCorrectInferences; // Referenced by test cases.
202
203     std::string m_ValidationFileIn;
204     std::vector<unsigned int> m_ValidationPredictions; // Referenced by test cases.
205
206     std::string m_ValidationFileOut;
207     std::vector<unsigned int> m_ValidationPredictionsOut; // Referenced by test cases.
208 };
209
210 bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
211     InferenceTestOptions& outParams);
212
213 bool ValidateDirectory(std::string& dir);
214
215 bool InferenceTest(const InferenceTestOptions& params,
216     const std::vector<unsigned int>& defaultTestCaseIds,
217     IInferenceTestCaseProvider& testCaseProvider);
218
219 template<typename TConstructTestCaseProvider>
220 int InferenceTestMain(int argc,
221     char* argv[],
222     const std::vector<unsigned int>& defaultTestCaseIds,
223     TConstructTestCaseProvider constructTestCaseProvider);
224
225 template<typename TDatabase,
226     typename TParser,
227     typename TConstructDatabaseCallable>
228 int ClassifierInferenceTestMain(int argc, char* argv[], const char* modelFilename, bool isModelBinary,
229     const char* inputBindingName, const char* outputBindingName,
230     const std::vector<unsigned int>& defaultTestCaseIds,
231     TConstructDatabaseCallable constructDatabase,
232     const armnn::TensorShape* inputTensorShape = nullptr);
233
234 } // namespace test
235 } // namespace armnn
236
237 #include "InferenceTest.inl"