Release 18.08
[platform/upstream/armnn.git] / tests / InferenceTest.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5 #pragma once
6
7 #include "armnn/ArmNN.hpp"
8 #include "armnn/TypesUtils.hpp"
9 #include "InferenceModel.hpp"
10
11 #include <Logging.hpp>
12
13 #include <boost/log/core/core.hpp>
14 #include <boost/program_options.hpp>
15
16
17 namespace armnn
18 {
19
20 inline std::istream& operator>>(std::istream& in, armnn::Compute& compute)
21 {
22     std::string token;
23     in >> token;
24     compute = armnn::ParseComputeDevice(token.c_str());
25     if (compute == armnn::Compute::Undefined)
26     {
27         in.setstate(std::ios_base::failbit);
28         throw boost::program_options::validation_error(boost::program_options::validation_error::invalid_option_value);
29     }
30     return in;
31 }
32
33 namespace test
34 {
35
36 class TestFrameworkException : public Exception
37 {
38 public:
39     using Exception::Exception;
40 };
41
42 struct InferenceTestOptions
43 {
44     unsigned int m_IterationCount;
45     std::string m_InferenceTimesFile;
46     bool m_EnableProfiling;
47
48     InferenceTestOptions()
49         : m_IterationCount(0),
50           m_EnableProfiling(0)
51     {}
52 };
53
54 enum class TestCaseResult
55 {
56     /// The test completed without any errors.
57     Ok,
58     /// The test failed (e.g. the prediction didn't match the validation file).
59     /// This will eventually fail the whole program but the remaining test cases will still be run.
60     Failed,
61     /// The test failed with a fatal error. The remaining tests will not be run.
62     Abort
63 };
64
65 class IInferenceTestCase
66 {
67 public:
68     virtual ~IInferenceTestCase() {}
69
70     virtual void Run() = 0;
71     virtual TestCaseResult ProcessResult(const InferenceTestOptions& options) = 0;
72 };
73
74 class IInferenceTestCaseProvider
75 {
76 public:
77     virtual ~IInferenceTestCaseProvider() {}
78
79     virtual void AddCommandLineOptions(boost::program_options::options_description& options) {};
80     virtual bool ProcessCommandLineOptions() { return true; };
81     virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) = 0;
82     virtual bool OnInferenceTestFinished() { return true; };
83 };
84
85 template <typename TModel>
86 class InferenceModelTestCase : public IInferenceTestCase
87 {
88 public:
89     InferenceModelTestCase(TModel& model,
90         unsigned int testCaseId,
91         std::vector<typename TModel::DataType> modelInput,
92         unsigned int outputSize)
93         : m_Model(model)
94         , m_TestCaseId(testCaseId)
95         , m_Input(std::move(modelInput))
96     {
97         m_Output.resize(outputSize);
98     }
99
100     virtual void Run() override
101     {
102         m_Model.Run(m_Input, m_Output);
103     }
104
105 protected:
106     unsigned int GetTestCaseId() const { return m_TestCaseId; }
107     const std::vector<typename TModel::DataType>& GetOutput() const { return m_Output; }
108
109 private:
110     TModel& m_Model;
111     unsigned int m_TestCaseId;
112     std::vector<typename TModel::DataType> m_Input;
113     std::vector<typename TModel::DataType> m_Output;
114 };
115
116 template <typename TDataType>
117 struct ToFloat { }; // nothing defined for the generic case
118
119 template <>
120 struct ToFloat<float>
121 {
122     static inline float Convert(float value, const InferenceModelInternal::QuantizationParams &)
123     {
124         // assuming that float models are not quantized
125         return value;
126     }
127 };
128
129 template <>
130 struct ToFloat<uint8_t>
131 {
132     static inline float Convert(uint8_t value,
133                                 const InferenceModelInternal::QuantizationParams & quantizationParams)
134     {
135         return armnn::Dequantize<uint8_t>(value,
136                                           quantizationParams.first,
137                                           quantizationParams.second);
138     }
139 };
140
141 template <typename TTestCaseDatabase, typename TModel>
142 class ClassifierTestCase : public InferenceModelTestCase<TModel>
143 {
144 public:
145     ClassifierTestCase(int& numInferencesRef,
146         int& numCorrectInferencesRef,
147         const std::vector<unsigned int>& validationPredictions,
148         std::vector<unsigned int>* validationPredictionsOut,
149         TModel& model,
150         unsigned int testCaseId,
151         unsigned int label,
152         std::vector<typename TModel::DataType> modelInput);
153
154     virtual TestCaseResult ProcessResult(const InferenceTestOptions& params) override;
155
156 private:
157     unsigned int m_Label;
158     InferenceModelInternal::QuantizationParams m_QuantizationParams;
159
160     /// These fields reference the corresponding member in the ClassifierTestCaseProvider.
161     /// @{
162     int& m_NumInferencesRef;
163     int& m_NumCorrectInferencesRef;
164     const std::vector<unsigned int>& m_ValidationPredictions;
165     std::vector<unsigned int>* m_ValidationPredictionsOut;
166     /// @}
167 };
168
169 template <typename TDatabase, typename InferenceModel>
170 class ClassifierTestCaseProvider : public IInferenceTestCaseProvider
171 {
172 public:
173     template <typename TConstructDatabaseCallable, typename TConstructModelCallable>
174     ClassifierTestCaseProvider(TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel);
175
176     virtual void AddCommandLineOptions(boost::program_options::options_description& options) override;
177     virtual bool ProcessCommandLineOptions() override;
178     virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override;
179     virtual bool OnInferenceTestFinished() override;
180
181 private:
182     void ReadPredictions();
183
184     typename InferenceModel::CommandLineOptions m_ModelCommandLineOptions;
185     std::function<std::unique_ptr<InferenceModel>(typename InferenceModel::CommandLineOptions)> m_ConstructModel;
186     std::unique_ptr<InferenceModel> m_Model;
187
188     std::string m_DataDir;
189     std::function<TDatabase(const char*, const InferenceModel&)> m_ConstructDatabase;
190     std::unique_ptr<TDatabase> m_Database;
191
192     int m_NumInferences; // Referenced by test cases.
193     int m_NumCorrectInferences; // Referenced by test cases.
194
195     std::string m_ValidationFileIn;
196     std::vector<unsigned int> m_ValidationPredictions; // Referenced by test cases.
197
198     std::string m_ValidationFileOut;
199     std::vector<unsigned int> m_ValidationPredictionsOut; // Referenced by test cases.
200 };
201
202 bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
203     InferenceTestOptions& outParams);
204
205 bool ValidateDirectory(std::string& dir);
206
207 bool InferenceTest(const InferenceTestOptions& params,
208     const std::vector<unsigned int>& defaultTestCaseIds,
209     IInferenceTestCaseProvider& testCaseProvider);
210
211 template<typename TConstructTestCaseProvider>
212 int InferenceTestMain(int argc,
213     char* argv[],
214     const std::vector<unsigned int>& defaultTestCaseIds,
215     TConstructTestCaseProvider constructTestCaseProvider);
216
217 template<typename TDatabase,
218     typename TParser,
219     typename TConstructDatabaseCallable>
220 int ClassifierInferenceTestMain(int argc, char* argv[], const char* modelFilename, bool isModelBinary,
221     const char* inputBindingName, const char* outputBindingName,
222     const std::vector<unsigned int>& defaultTestCaseIds,
223     TConstructDatabaseCallable constructDatabase,
224     const armnn::TensorShape* inputTensorShape = nullptr);
225
226 } // namespace test
227 } // namespace armnn
228
229 #include "InferenceTest.inl"