Release 18.05.01
[platform/upstream/armnn.git] / tests / InferenceTest.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5 #pragma once
6
7 #include "armnn/ArmNN.hpp"
8 #include "armnn/TypesUtils.hpp"
9 #include <Logging.hpp>
10
11 #include <boost/log/core/core.hpp>
12 #include <boost/program_options.hpp>
13
14 namespace armnn
15 {
16
17 inline std::istream& operator>>(std::istream& in, armnn::Compute& compute)
18 {
19     std::string token;
20     in >> token;
21     compute = armnn::ParseComputeDevice(token.c_str());
22     if (compute == armnn::Compute::Undefined)
23     {
24         in.setstate(std::ios_base::failbit);
25         throw boost::program_options::validation_error(boost::program_options::validation_error::invalid_option_value);
26     }
27     return in;
28 }
29
30 namespace test
31 {
32
33 class TestFrameworkException : public Exception
34 {
35 public:
36     using Exception::Exception;
37 };
38
39 struct InferenceTestOptions
40 {
41     unsigned int m_IterationCount;
42     std::string m_InferenceTimesFile;
43
44     InferenceTestOptions()
45         : m_IterationCount(0)
46     {}
47 };
48
49 enum class TestCaseResult
50 {
51     /// The test completed without any errors.
52     Ok,
53     /// The test failed (e.g. the prediction didn't match the validation file).
54     /// This will eventually fail the whole program but the remaining test cases will still be run.
55     Failed,
56     /// The test failed with a fatal error. The remaining tests will not be run.
57     Abort
58 };
59
60 class IInferenceTestCase
61 {
62 public:
63     virtual ~IInferenceTestCase() {}
64
65     virtual void Run() = 0;
66     virtual TestCaseResult ProcessResult(const InferenceTestOptions& options) = 0;
67 };
68
69 class IInferenceTestCaseProvider
70 {
71 public:
72     virtual ~IInferenceTestCaseProvider() {}
73
74     virtual void AddCommandLineOptions(boost::program_options::options_description& options) {};
75     virtual bool ProcessCommandLineOptions() { return true; };
76     virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) = 0;
77     virtual bool OnInferenceTestFinished() { return true; };
78 };
79
80 template <typename TModel>
81 class InferenceModelTestCase : public IInferenceTestCase
82 {
83 public:
84     InferenceModelTestCase(TModel& model,
85         unsigned int testCaseId,
86         std::vector<typename TModel::DataType> modelInput,
87         unsigned int outputSize)
88         : m_Model(model)
89         , m_TestCaseId(testCaseId)
90         , m_Input(std::move(modelInput))
91     {
92         m_Output.resize(outputSize);
93     }
94
95     virtual void Run() override
96     {
97         m_Model.Run(m_Input, m_Output);
98     }
99
100 protected:
101     unsigned int GetTestCaseId() const { return m_TestCaseId; }
102     const std::vector<typename TModel::DataType>& GetOutput() const { return m_Output; }
103
104 private:
105     TModel& m_Model;
106     unsigned int m_TestCaseId;
107     std::vector<typename TModel::DataType> m_Input;
108     std::vector<typename TModel::DataType> m_Output;
109 };
110
111 template <typename TTestCaseDatabase, typename TModel>
112 class ClassifierTestCase : public InferenceModelTestCase<TModel>
113 {
114 public:
115     ClassifierTestCase(int& numInferencesRef,
116         int& numCorrectInferencesRef,
117         const std::vector<unsigned int>& validationPredictions,
118         std::vector<unsigned int>* validationPredictionsOut,
119         TModel& model,
120         unsigned int testCaseId,
121         unsigned int label,
122         std::vector<typename TModel::DataType> modelInput);
123
124     virtual TestCaseResult ProcessResult(const InferenceTestOptions& params) override;
125
126 private:
127     unsigned int m_Label;
128     /// These fields reference the corresponding member in the ClassifierTestCaseProvider.
129     /// @{
130     int& m_NumInferencesRef;
131     int& m_NumCorrectInferencesRef;
132     const std::vector<unsigned int>& m_ValidationPredictions;
133     std::vector<unsigned int>* m_ValidationPredictionsOut;
134     /// @}
135 };
136
137 template <typename TDatabase, typename InferenceModel>
138 class ClassifierTestCaseProvider : public IInferenceTestCaseProvider
139 {
140 public:
141     template <typename TConstructDatabaseCallable, typename TConstructModelCallable>
142     ClassifierTestCaseProvider(TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel);
143
144     virtual void AddCommandLineOptions(boost::program_options::options_description& options) override;
145     virtual bool ProcessCommandLineOptions() override;
146     virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override;
147     virtual bool OnInferenceTestFinished() override;
148
149 private:
150     void ReadPredictions();
151
152     typename InferenceModel::CommandLineOptions m_ModelCommandLineOptions;
153     std::function<std::unique_ptr<InferenceModel>(typename InferenceModel::CommandLineOptions)> m_ConstructModel;
154     std::unique_ptr<InferenceModel> m_Model;
155
156     std::string m_DataDir;
157     std::function<TDatabase(const char*)> m_ConstructDatabase;
158     std::unique_ptr<TDatabase> m_Database;
159
160     int m_NumInferences; // Referenced by test cases
161     int m_NumCorrectInferences; // Referenced by test cases
162
163     std::string m_ValidationFileIn;
164     std::vector<unsigned int> m_ValidationPredictions; // Referenced by test cases
165
166     std::string m_ValidationFileOut;
167     std::vector<unsigned int> m_ValidationPredictionsOut; // Referenced by test cases
168 };
169
170 bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
171     InferenceTestOptions& outParams);
172
173 bool ValidateDirectory(std::string& dir);
174
175 bool InferenceTest(const InferenceTestOptions& params,
176     const std::vector<unsigned int>& defaultTestCaseIds,
177     IInferenceTestCaseProvider& testCaseProvider);
178
179 template<typename TConstructTestCaseProvider>
180 int InferenceTestMain(int argc,
181     char* argv[],
182     const std::vector<unsigned int>& defaultTestCaseIds,
183     TConstructTestCaseProvider constructTestCaseProvider);
184
185 template<typename TDatabase,
186     typename TParser,
187     typename TConstructDatabaseCallable>
188 int ClassifierInferenceTestMain(int argc, char* argv[], const char* modelFilename, bool isModelBinary,
189     const char* inputBindingName, const char* outputBindingName,
190     const std::vector<unsigned int>& defaultTestCaseIds,
191     TConstructDatabaseCallable constructDatabase,
192     const armnn::TensorShape* inputTensorShape = nullptr);
193
194 } // namespace test
195 } // namespace armnn
196
197 #include "InferenceTest.inl"