2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
7 #include "armnn/ArmNN.hpp"
8 #include "armnn/TypesUtils.hpp"
9 #include "InferenceModel.hpp"
11 #include <Logging.hpp>
13 #include <boost/log/core/core.hpp>
14 #include <boost/program_options.hpp>
20 inline std::istream& operator>>(std::istream& in, armnn::Compute& compute)
24 compute = armnn::ParseComputeDevice(token.c_str());
25 if (compute == armnn::Compute::Undefined)
27 in.setstate(std::ios_base::failbit);
28 throw boost::program_options::validation_error(boost::program_options::validation_error::invalid_option_value);
36 class TestFrameworkException : public Exception
39 using Exception::Exception;
42 struct InferenceTestOptions
44 unsigned int m_IterationCount;
45 std::string m_InferenceTimesFile;
46 bool m_EnableProfiling;
48 InferenceTestOptions()
49 : m_IterationCount(0),
54 enum class TestCaseResult
56 /// The test completed without any errors.
58 /// The test failed (e.g. the prediction didn't match the validation file).
59 /// This will eventually fail the whole program but the remaining test cases will still be run.
61 /// The test failed with a fatal error. The remaining tests will not be run.
65 class IInferenceTestCase
68 virtual ~IInferenceTestCase() {}
70 virtual void Run() = 0;
71 virtual TestCaseResult ProcessResult(const InferenceTestOptions& options) = 0;
74 class IInferenceTestCaseProvider
77 virtual ~IInferenceTestCaseProvider() {}
79 virtual void AddCommandLineOptions(boost::program_options::options_description& options) {};
80 virtual bool ProcessCommandLineOptions() { return true; };
81 virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) = 0;
82 virtual bool OnInferenceTestFinished() { return true; };
85 template <typename TModel>
86 class InferenceModelTestCase : public IInferenceTestCase
89 InferenceModelTestCase(TModel& model,
90 unsigned int testCaseId,
91 std::vector<typename TModel::DataType> modelInput,
92 unsigned int outputSize)
94 , m_TestCaseId(testCaseId)
95 , m_Input(std::move(modelInput))
97 m_Output.resize(outputSize);
100 virtual void Run() override
102 m_Model.Run(m_Input, m_Output);
106 unsigned int GetTestCaseId() const { return m_TestCaseId; }
107 const std::vector<typename TModel::DataType>& GetOutput() const { return m_Output; }
111 unsigned int m_TestCaseId;
112 std::vector<typename TModel::DataType> m_Input;
113 std::vector<typename TModel::DataType> m_Output;
116 template <typename TDataType>
117 struct ToFloat { }; // nothing defined for the generic case
120 struct ToFloat<float>
122 static inline float Convert(float value, const InferenceModelInternal::QuantizationParams &)
124 // assuming that float models are not quantized
130 struct ToFloat<uint8_t>
132 static inline float Convert(uint8_t value,
133 const InferenceModelInternal::QuantizationParams & quantizationParams)
135 return armnn::Dequantize<uint8_t>(value,
136 quantizationParams.first,
137 quantizationParams.second);
141 template <typename TTestCaseDatabase, typename TModel>
142 class ClassifierTestCase : public InferenceModelTestCase<TModel>
145 ClassifierTestCase(int& numInferencesRef,
146 int& numCorrectInferencesRef,
147 const std::vector<unsigned int>& validationPredictions,
148 std::vector<unsigned int>* validationPredictionsOut,
150 unsigned int testCaseId,
152 std::vector<typename TModel::DataType> modelInput);
154 virtual TestCaseResult ProcessResult(const InferenceTestOptions& params) override;
157 unsigned int m_Label;
158 InferenceModelInternal::QuantizationParams m_QuantizationParams;
160 /// These fields reference the corresponding member in the ClassifierTestCaseProvider.
162 int& m_NumInferencesRef;
163 int& m_NumCorrectInferencesRef;
164 const std::vector<unsigned int>& m_ValidationPredictions;
165 std::vector<unsigned int>* m_ValidationPredictionsOut;
169 template <typename TDatabase, typename InferenceModel>
170 class ClassifierTestCaseProvider : public IInferenceTestCaseProvider
173 template <typename TConstructDatabaseCallable, typename TConstructModelCallable>
174 ClassifierTestCaseProvider(TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel);
176 virtual void AddCommandLineOptions(boost::program_options::options_description& options) override;
177 virtual bool ProcessCommandLineOptions() override;
178 virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override;
179 virtual bool OnInferenceTestFinished() override;
182 void ReadPredictions();
184 typename InferenceModel::CommandLineOptions m_ModelCommandLineOptions;
185 std::function<std::unique_ptr<InferenceModel>(typename InferenceModel::CommandLineOptions)> m_ConstructModel;
186 std::unique_ptr<InferenceModel> m_Model;
188 std::string m_DataDir;
189 std::function<TDatabase(const char*, const InferenceModel&)> m_ConstructDatabase;
190 std::unique_ptr<TDatabase> m_Database;
192 int m_NumInferences; // Referenced by test cases.
193 int m_NumCorrectInferences; // Referenced by test cases.
195 std::string m_ValidationFileIn;
196 std::vector<unsigned int> m_ValidationPredictions; // Referenced by test cases.
198 std::string m_ValidationFileOut;
199 std::vector<unsigned int> m_ValidationPredictionsOut; // Referenced by test cases.
202 bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
203 InferenceTestOptions& outParams);
205 bool ValidateDirectory(std::string& dir);
207 bool InferenceTest(const InferenceTestOptions& params,
208 const std::vector<unsigned int>& defaultTestCaseIds,
209 IInferenceTestCaseProvider& testCaseProvider);
211 template<typename TConstructTestCaseProvider>
212 int InferenceTestMain(int argc,
214 const std::vector<unsigned int>& defaultTestCaseIds,
215 TConstructTestCaseProvider constructTestCaseProvider);
217 template<typename TDatabase,
219 typename TConstructDatabaseCallable>
220 int ClassifierInferenceTestMain(int argc, char* argv[], const char* modelFilename, bool isModelBinary,
221 const char* inputBindingName, const char* outputBindingName,
222 const std::vector<unsigned int>& defaultTestCaseIds,
223 TConstructDatabaseCallable constructDatabase,
224 const armnn::TensorShape* inputTensorShape = nullptr);
229 #include "InferenceTest.inl"