NNXSW-1853 Change SubgraphViewSelector algorithm
[platform/upstream/armnn.git] / tests / InferenceTest.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6
7 #include <armnn/ArmNN.hpp>
8 #include <armnn/TypesUtils.hpp>
9 #include "InferenceModel.hpp"
10
11 #include <Logging.hpp>
12
13 #include <boost/log/core/core.hpp>
14 #include <boost/program_options.hpp>
15
16
17 namespace armnn
18 {
19
20 inline std::istream& operator>>(std::istream& in, armnn::Compute& compute)
21 {
22     std::string token;
23     in >> token;
24     compute = armnn::ParseComputeDevice(token.c_str());
25     if (compute == armnn::Compute::Undefined)
26     {
27         in.setstate(std::ios_base::failbit);
28         throw boost::program_options::validation_error(boost::program_options::validation_error::invalid_option_value);
29     }
30     return in;
31 }
32
33 inline std::istream& operator>>(std::istream& in, armnn::BackendId& backend)
34 {
35     std::string token;
36     in >> token;
37     armnn::Compute compute = armnn::ParseComputeDevice(token.c_str());
38     if (compute == armnn::Compute::Undefined)
39     {
40         in.setstate(std::ios_base::failbit);
41         throw boost::program_options::validation_error(boost::program_options::validation_error::invalid_option_value);
42     }
43     backend = compute;
44     return in;
45 }
46
47 namespace test
48 {
49
50 class TestFrameworkException : public Exception
51 {
52 public:
53     using Exception::Exception;
54 };
55
56 struct InferenceTestOptions
57 {
58     unsigned int m_IterationCount;
59     std::string m_InferenceTimesFile;
60     bool m_EnableProfiling;
61     std::string m_DynamicBackendsPath;
62
63     InferenceTestOptions()
64         : m_IterationCount(0)
65         , m_EnableProfiling(0)
66         , m_DynamicBackendsPath()
67     {}
68 };
69
70 enum class TestCaseResult
71 {
72     /// The test completed without any errors.
73     Ok,
74     /// The test failed (e.g. the prediction didn't match the validation file).
75     /// This will eventually fail the whole program but the remaining test cases will still be run.
76     Failed,
77     /// The test failed with a fatal error. The remaining tests will not be run.
78     Abort
79 };
80
81 class IInferenceTestCase
82 {
83 public:
84     virtual ~IInferenceTestCase() {}
85
86     virtual void Run() = 0;
87     virtual TestCaseResult ProcessResult(const InferenceTestOptions& options) = 0;
88 };
89
90 class IInferenceTestCaseProvider
91 {
92 public:
93     virtual ~IInferenceTestCaseProvider() {}
94
95     virtual void AddCommandLineOptions(boost::program_options::options_description& options) {};
96     virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions) { return true; };
97     virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) = 0;
98     virtual bool OnInferenceTestFinished() { return true; };
99 };
100
101 template <typename TModel>
102 class InferenceModelTestCase : public IInferenceTestCase
103 {
104 public:
105     using TContainer = boost::variant<std::vector<float>, std::vector<int>, std::vector<unsigned char>>;
106
107     InferenceModelTestCase(TModel& model,
108                            unsigned int testCaseId,
109                            const std::vector<TContainer>& inputs,
110                            const std::vector<unsigned int>& outputSizes)
111         : m_Model(model)
112         , m_TestCaseId(testCaseId)
113         , m_Inputs(std::move(inputs))
114     {
115         // Initialize output vector
116         const size_t numOutputs = outputSizes.size();
117         m_Outputs.reserve(numOutputs);
118
119         for (size_t i = 0; i < numOutputs; i++)
120         {
121             m_Outputs.push_back(std::vector<typename TModel::DataType>(outputSizes[i]));
122         }
123     }
124
125     virtual void Run() override
126     {
127         m_Model.Run(m_Inputs, m_Outputs);
128     }
129
130 protected:
131     unsigned int GetTestCaseId() const { return m_TestCaseId; }
132     const std::vector<TContainer>& GetOutputs() const { return m_Outputs; }
133
134 private:
135     TModel&                 m_Model;
136     unsigned int            m_TestCaseId;
137     std::vector<TContainer> m_Inputs;
138     std::vector<TContainer> m_Outputs;
139 };
140
141 template <typename TTestCaseDatabase, typename TModel>
142 class ClassifierTestCase : public InferenceModelTestCase<TModel>
143 {
144 public:
145     ClassifierTestCase(int& numInferencesRef,
146         int& numCorrectInferencesRef,
147         const std::vector<unsigned int>& validationPredictions,
148         std::vector<unsigned int>* validationPredictionsOut,
149         TModel& model,
150         unsigned int testCaseId,
151         unsigned int label,
152         std::vector<typename TModel::DataType> modelInput);
153
154     virtual TestCaseResult ProcessResult(const InferenceTestOptions& params) override;
155
156 private:
157     unsigned int m_Label;
158     InferenceModelInternal::QuantizationParams m_QuantizationParams;
159
160     /// These fields reference the corresponding member in the ClassifierTestCaseProvider.
161     /// @{
162     int& m_NumInferencesRef;
163     int& m_NumCorrectInferencesRef;
164     const std::vector<unsigned int>& m_ValidationPredictions;
165     std::vector<unsigned int>* m_ValidationPredictionsOut;
166     /// @}
167 };
168
169 template <typename TDatabase, typename InferenceModel>
170 class ClassifierTestCaseProvider : public IInferenceTestCaseProvider
171 {
172 public:
173     template <typename TConstructDatabaseCallable, typename TConstructModelCallable>
174     ClassifierTestCaseProvider(TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel);
175
176     virtual void AddCommandLineOptions(boost::program_options::options_description& options) override;
177     virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions) override;
178     virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override;
179     virtual bool OnInferenceTestFinished() override;
180
181 private:
182     void ReadPredictions();
183
184     typename InferenceModel::CommandLineOptions m_ModelCommandLineOptions;
185     std::function<std::unique_ptr<InferenceModel>(const InferenceTestOptions& commonOptions,
186                                                   typename InferenceModel::CommandLineOptions)> m_ConstructModel;
187     std::unique_ptr<InferenceModel> m_Model;
188
189     std::string m_DataDir;
190     std::function<TDatabase(const char*, const InferenceModel&)> m_ConstructDatabase;
191     std::unique_ptr<TDatabase> m_Database;
192
193     int m_NumInferences; // Referenced by test cases.
194     int m_NumCorrectInferences; // Referenced by test cases.
195
196     std::string m_ValidationFileIn;
197     std::vector<unsigned int> m_ValidationPredictions; // Referenced by test cases.
198
199     std::string m_ValidationFileOut;
200     std::vector<unsigned int> m_ValidationPredictionsOut; // Referenced by test cases.
201 };
202
203 bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
204     InferenceTestOptions& outParams);
205
206 bool ValidateDirectory(std::string& dir);
207
208 bool InferenceTest(const InferenceTestOptions& params,
209     const std::vector<unsigned int>& defaultTestCaseIds,
210     IInferenceTestCaseProvider& testCaseProvider);
211
212 template<typename TConstructTestCaseProvider>
213 int InferenceTestMain(int argc,
214     char* argv[],
215     const std::vector<unsigned int>& defaultTestCaseIds,
216     TConstructTestCaseProvider constructTestCaseProvider);
217
218 template<typename TDatabase,
219     typename TParser,
220     typename TConstructDatabaseCallable>
221 int ClassifierInferenceTestMain(int argc, char* argv[], const char* modelFilename, bool isModelBinary,
222     const char* inputBindingName, const char* outputBindingName,
223     const std::vector<unsigned int>& defaultTestCaseIds,
224     TConstructDatabaseCallable constructDatabase,
225     const armnn::TensorShape* inputTensorShape = nullptr);
226
227 } // namespace test
228 } // namespace armnn
229
230 #include "InferenceTest.inl"