2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
7 #include "InferenceModel.hpp"
9 #include <armnn/ArmNN.hpp>
10 #include <armnn/Logging.hpp>
11 #include <armnn/TypesUtils.hpp>
12 #include <armnn/utility/IgnoreUnused.hpp>
14 #include <boost/program_options.hpp>
20 inline std::istream& operator>>(std::istream& in, armnn::Compute& compute)
24 compute = armnn::ParseComputeDevice(token.c_str());
25 if (compute == armnn::Compute::Undefined)
27 in.setstate(std::ios_base::failbit);
28 throw boost::program_options::validation_error(boost::program_options::validation_error::invalid_option_value);
33 inline std::istream& operator>>(std::istream& in, armnn::BackendId& backend)
37 armnn::Compute compute = armnn::ParseComputeDevice(token.c_str());
38 if (compute == armnn::Compute::Undefined)
40 in.setstate(std::ios_base::failbit);
41 throw boost::program_options::validation_error(boost::program_options::validation_error::invalid_option_value);
50 class TestFrameworkException : public Exception
53 using Exception::Exception;
56 struct InferenceTestOptions
58 unsigned int m_IterationCount;
59 std::string m_InferenceTimesFile;
60 bool m_EnableProfiling;
61 std::string m_DynamicBackendsPath;
63 InferenceTestOptions()
65 , m_EnableProfiling(0)
66 , m_DynamicBackendsPath()
70 enum class TestCaseResult
72 /// The test completed without any errors.
74 /// The test failed (e.g. the prediction didn't match the validation file).
75 /// This will eventually fail the whole program but the remaining test cases will still be run.
77 /// The test failed with a fatal error. The remaining tests will not be run.
81 class IInferenceTestCase
84 virtual ~IInferenceTestCase() {}
86 virtual void Run() = 0;
87 virtual TestCaseResult ProcessResult(const InferenceTestOptions& options) = 0;
90 class IInferenceTestCaseProvider
93 virtual ~IInferenceTestCaseProvider() {}
95 virtual void AddCommandLineOptions(boost::program_options::options_description& options)
97 IgnoreUnused(options);
99 virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions)
101 IgnoreUnused(commonOptions);
104 virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) = 0;
105 virtual bool OnInferenceTestFinished() { return true; };
108 template <typename TModel>
109 class InferenceModelTestCase : public IInferenceTestCase
112 using TContainer = boost::variant<std::vector<float>, std::vector<int>, std::vector<unsigned char>>;
114 InferenceModelTestCase(TModel& model,
115 unsigned int testCaseId,
116 const std::vector<TContainer>& inputs,
117 const std::vector<unsigned int>& outputSizes)
119 , m_TestCaseId(testCaseId)
120 , m_Inputs(std::move(inputs))
122 // Initialize output vector
123 const size_t numOutputs = outputSizes.size();
124 m_Outputs.reserve(numOutputs);
126 for (size_t i = 0; i < numOutputs; i++)
128 m_Outputs.push_back(std::vector<typename TModel::DataType>(outputSizes[i]));
132 virtual void Run() override
134 m_Model.Run(m_Inputs, m_Outputs);
138 unsigned int GetTestCaseId() const { return m_TestCaseId; }
139 const std::vector<TContainer>& GetOutputs() const { return m_Outputs; }
143 unsigned int m_TestCaseId;
144 std::vector<TContainer> m_Inputs;
145 std::vector<TContainer> m_Outputs;
148 template <typename TTestCaseDatabase, typename TModel>
149 class ClassifierTestCase : public InferenceModelTestCase<TModel>
152 ClassifierTestCase(int& numInferencesRef,
153 int& numCorrectInferencesRef,
154 const std::vector<unsigned int>& validationPredictions,
155 std::vector<unsigned int>* validationPredictionsOut,
157 unsigned int testCaseId,
159 std::vector<typename TModel::DataType> modelInput);
161 virtual TestCaseResult ProcessResult(const InferenceTestOptions& params) override;
164 unsigned int m_Label;
165 InferenceModelInternal::QuantizationParams m_QuantizationParams;
167 /// These fields reference the corresponding member in the ClassifierTestCaseProvider.
169 int& m_NumInferencesRef;
170 int& m_NumCorrectInferencesRef;
171 const std::vector<unsigned int>& m_ValidationPredictions;
172 std::vector<unsigned int>* m_ValidationPredictionsOut;
176 template <typename TDatabase, typename InferenceModel>
177 class ClassifierTestCaseProvider : public IInferenceTestCaseProvider
180 template <typename TConstructDatabaseCallable, typename TConstructModelCallable>
181 ClassifierTestCaseProvider(TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel);
183 virtual void AddCommandLineOptions(boost::program_options::options_description& options) override;
184 virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions) override;
185 virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override;
186 virtual bool OnInferenceTestFinished() override;
189 void ReadPredictions();
191 typename InferenceModel::CommandLineOptions m_ModelCommandLineOptions;
192 std::function<std::unique_ptr<InferenceModel>(const InferenceTestOptions& commonOptions,
193 typename InferenceModel::CommandLineOptions)> m_ConstructModel;
194 std::unique_ptr<InferenceModel> m_Model;
196 std::string m_DataDir;
197 std::function<TDatabase(const char*, const InferenceModel&)> m_ConstructDatabase;
198 std::unique_ptr<TDatabase> m_Database;
200 int m_NumInferences; // Referenced by test cases.
201 int m_NumCorrectInferences; // Referenced by test cases.
203 std::string m_ValidationFileIn;
204 std::vector<unsigned int> m_ValidationPredictions; // Referenced by test cases.
206 std::string m_ValidationFileOut;
207 std::vector<unsigned int> m_ValidationPredictionsOut; // Referenced by test cases.
210 bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
211 InferenceTestOptions& outParams);
213 bool ValidateDirectory(std::string& dir);
215 bool InferenceTest(const InferenceTestOptions& params,
216 const std::vector<unsigned int>& defaultTestCaseIds,
217 IInferenceTestCaseProvider& testCaseProvider);
219 template<typename TConstructTestCaseProvider>
220 int InferenceTestMain(int argc,
222 const std::vector<unsigned int>& defaultTestCaseIds,
223 TConstructTestCaseProvider constructTestCaseProvider);
225 template<typename TDatabase,
227 typename TConstructDatabaseCallable>
228 int ClassifierInferenceTestMain(int argc, char* argv[], const char* modelFilename, bool isModelBinary,
229 const char* inputBindingName, const char* outputBindingName,
230 const std::vector<unsigned int>& defaultTestCaseIds,
231 TConstructDatabaseCallable constructDatabase,
232 const armnn::TensorShape* inputTensorShape = nullptr);
237 #include "InferenceTest.inl"