2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
7 #include <armnn/ArmNN.hpp>
8 #include <armnn/TypesUtils.hpp>
9 #include "InferenceModel.hpp"
11 #include <Logging.hpp>
13 #include <boost/log/core/core.hpp>
14 #include <boost/program_options.hpp>
20 inline std::istream& operator>>(std::istream& in, armnn::Compute& compute)
24 compute = armnn::ParseComputeDevice(token.c_str());
25 if (compute == armnn::Compute::Undefined)
27 in.setstate(std::ios_base::failbit);
28 throw boost::program_options::validation_error(boost::program_options::validation_error::invalid_option_value);
33 inline std::istream& operator>>(std::istream& in, armnn::BackendId& backend)
37 armnn::Compute compute = armnn::ParseComputeDevice(token.c_str());
38 if (compute == armnn::Compute::Undefined)
40 in.setstate(std::ios_base::failbit);
41 throw boost::program_options::validation_error(boost::program_options::validation_error::invalid_option_value);
50 class TestFrameworkException : public Exception
53 using Exception::Exception;
56 struct InferenceTestOptions
58 unsigned int m_IterationCount;
59 std::string m_InferenceTimesFile;
60 bool m_EnableProfiling;
61 std::string m_DynamicBackendsPath;
63 InferenceTestOptions()
65 , m_EnableProfiling(0)
66 , m_DynamicBackendsPath()
70 enum class TestCaseResult
72 /// The test completed without any errors.
74 /// The test failed (e.g. the prediction didn't match the validation file).
75 /// This will eventually fail the whole program but the remaining test cases will still be run.
77 /// The test failed with a fatal error. The remaining tests will not be run.
81 class IInferenceTestCase
84 virtual ~IInferenceTestCase() {}
86 virtual void Run() = 0;
87 virtual TestCaseResult ProcessResult(const InferenceTestOptions& options) = 0;
90 class IInferenceTestCaseProvider
93 virtual ~IInferenceTestCaseProvider() {}
95 virtual void AddCommandLineOptions(boost::program_options::options_description& options) {};
96 virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions) { return true; };
97 virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) = 0;
98 virtual bool OnInferenceTestFinished() { return true; };
101 template <typename TModel>
102 class InferenceModelTestCase : public IInferenceTestCase
105 using TContainer = boost::variant<std::vector<float>, std::vector<int>, std::vector<unsigned char>>;
107 InferenceModelTestCase(TModel& model,
108 unsigned int testCaseId,
109 const std::vector<TContainer>& inputs,
110 const std::vector<unsigned int>& outputSizes)
112 , m_TestCaseId(testCaseId)
113 , m_Inputs(std::move(inputs))
115 // Initialize output vector
116 const size_t numOutputs = outputSizes.size();
117 m_Outputs.reserve(numOutputs);
119 for (size_t i = 0; i < numOutputs; i++)
121 m_Outputs.push_back(std::vector<typename TModel::DataType>(outputSizes[i]));
125 virtual void Run() override
127 m_Model.Run(m_Inputs, m_Outputs);
131 unsigned int GetTestCaseId() const { return m_TestCaseId; }
132 const std::vector<TContainer>& GetOutputs() const { return m_Outputs; }
136 unsigned int m_TestCaseId;
137 std::vector<TContainer> m_Inputs;
138 std::vector<TContainer> m_Outputs;
141 template <typename TTestCaseDatabase, typename TModel>
142 class ClassifierTestCase : public InferenceModelTestCase<TModel>
145 ClassifierTestCase(int& numInferencesRef,
146 int& numCorrectInferencesRef,
147 const std::vector<unsigned int>& validationPredictions,
148 std::vector<unsigned int>* validationPredictionsOut,
150 unsigned int testCaseId,
152 std::vector<typename TModel::DataType> modelInput);
154 virtual TestCaseResult ProcessResult(const InferenceTestOptions& params) override;
157 unsigned int m_Label;
158 InferenceModelInternal::QuantizationParams m_QuantizationParams;
160 /// These fields reference the corresponding member in the ClassifierTestCaseProvider.
162 int& m_NumInferencesRef;
163 int& m_NumCorrectInferencesRef;
164 const std::vector<unsigned int>& m_ValidationPredictions;
165 std::vector<unsigned int>* m_ValidationPredictionsOut;
169 template <typename TDatabase, typename InferenceModel>
170 class ClassifierTestCaseProvider : public IInferenceTestCaseProvider
173 template <typename TConstructDatabaseCallable, typename TConstructModelCallable>
174 ClassifierTestCaseProvider(TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel);
176 virtual void AddCommandLineOptions(boost::program_options::options_description& options) override;
177 virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions) override;
178 virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override;
179 virtual bool OnInferenceTestFinished() override;
182 void ReadPredictions();
184 typename InferenceModel::CommandLineOptions m_ModelCommandLineOptions;
185 std::function<std::unique_ptr<InferenceModel>(const InferenceTestOptions& commonOptions,
186 typename InferenceModel::CommandLineOptions)> m_ConstructModel;
187 std::unique_ptr<InferenceModel> m_Model;
189 std::string m_DataDir;
190 std::function<TDatabase(const char*, const InferenceModel&)> m_ConstructDatabase;
191 std::unique_ptr<TDatabase> m_Database;
193 int m_NumInferences; // Referenced by test cases.
194 int m_NumCorrectInferences; // Referenced by test cases.
196 std::string m_ValidationFileIn;
197 std::vector<unsigned int> m_ValidationPredictions; // Referenced by test cases.
199 std::string m_ValidationFileOut;
200 std::vector<unsigned int> m_ValidationPredictionsOut; // Referenced by test cases.
203 bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
204 InferenceTestOptions& outParams);
206 bool ValidateDirectory(std::string& dir);
208 bool InferenceTest(const InferenceTestOptions& params,
209 const std::vector<unsigned int>& defaultTestCaseIds,
210 IInferenceTestCaseProvider& testCaseProvider);
212 template<typename TConstructTestCaseProvider>
213 int InferenceTestMain(int argc,
215 const std::vector<unsigned int>& defaultTestCaseIds,
216 TConstructTestCaseProvider constructTestCaseProvider);
218 template<typename TDatabase,
220 typename TConstructDatabaseCallable>
221 int ClassifierInferenceTestMain(int argc, char* argv[], const char* modelFilename, bool isModelBinary,
222 const char* inputBindingName, const char* outputBindingName,
223 const std::vector<unsigned int>& defaultTestCaseIds,
224 TConstructDatabaseCallable constructDatabase,
225 const armnn::TensorShape* inputTensorShape = nullptr);
230 #include "InferenceTest.inl"