2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
7 #include "armnn/ArmNN.hpp"
8 #include "armnn/TypesUtils.hpp"
11 #include <boost/log/core/core.hpp>
12 #include <boost/program_options.hpp>
17 inline std::istream& operator>>(std::istream& in, armnn::Compute& compute)
21 compute = armnn::ParseComputeDevice(token.c_str());
22 if (compute == armnn::Compute::Undefined)
24 in.setstate(std::ios_base::failbit);
25 throw boost::program_options::validation_error(boost::program_options::validation_error::invalid_option_value);
33 class TestFrameworkException : public Exception
36 using Exception::Exception;
39 struct InferenceTestOptions
41 unsigned int m_IterationCount;
42 std::string m_InferenceTimesFile;
44 InferenceTestOptions()
49 enum class TestCaseResult
51 /// The test completed without any errors.
53 /// The test failed (e.g. the prediction didn't match the validation file).
54 /// This will eventually fail the whole program but the remaining test cases will still be run.
56 /// The test failed with a fatal error. The remaining tests will not be run.
60 class IInferenceTestCase
63 virtual ~IInferenceTestCase() {}
65 virtual void Run() = 0;
66 virtual TestCaseResult ProcessResult(const InferenceTestOptions& options) = 0;
69 class IInferenceTestCaseProvider
72 virtual ~IInferenceTestCaseProvider() {}
74 virtual void AddCommandLineOptions(boost::program_options::options_description& options) {};
75 virtual bool ProcessCommandLineOptions() { return true; };
76 virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) = 0;
77 virtual bool OnInferenceTestFinished() { return true; };
80 template <typename TModel>
81 class InferenceModelTestCase : public IInferenceTestCase
84 InferenceModelTestCase(TModel& model,
85 unsigned int testCaseId,
86 std::vector<typename TModel::DataType> modelInput,
87 unsigned int outputSize)
89 , m_TestCaseId(testCaseId)
90 , m_Input(std::move(modelInput))
92 m_Output.resize(outputSize);
95 virtual void Run() override
97 m_Model.Run(m_Input, m_Output);
101 unsigned int GetTestCaseId() const { return m_TestCaseId; }
102 const std::vector<typename TModel::DataType>& GetOutput() const { return m_Output; }
106 unsigned int m_TestCaseId;
107 std::vector<typename TModel::DataType> m_Input;
108 std::vector<typename TModel::DataType> m_Output;
111 template <typename TTestCaseDatabase, typename TModel>
112 class ClassifierTestCase : public InferenceModelTestCase<TModel>
115 ClassifierTestCase(int& numInferencesRef,
116 int& numCorrectInferencesRef,
117 const std::vector<unsigned int>& validationPredictions,
118 std::vector<unsigned int>* validationPredictionsOut,
120 unsigned int testCaseId,
122 std::vector<typename TModel::DataType> modelInput);
124 virtual TestCaseResult ProcessResult(const InferenceTestOptions& params) override;
127 unsigned int m_Label;
128 /// These fields reference the corresponding member in the ClassifierTestCaseProvider.
130 int& m_NumInferencesRef;
131 int& m_NumCorrectInferencesRef;
132 const std::vector<unsigned int>& m_ValidationPredictions;
133 std::vector<unsigned int>* m_ValidationPredictionsOut;
137 template <typename TDatabase, typename InferenceModel>
138 class ClassifierTestCaseProvider : public IInferenceTestCaseProvider
141 template <typename TConstructDatabaseCallable, typename TConstructModelCallable>
142 ClassifierTestCaseProvider(TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel);
144 virtual void AddCommandLineOptions(boost::program_options::options_description& options) override;
145 virtual bool ProcessCommandLineOptions() override;
146 virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override;
147 virtual bool OnInferenceTestFinished() override;
150 void ReadPredictions();
152 typename InferenceModel::CommandLineOptions m_ModelCommandLineOptions;
153 std::function<std::unique_ptr<InferenceModel>(typename InferenceModel::CommandLineOptions)> m_ConstructModel;
154 std::unique_ptr<InferenceModel> m_Model;
156 std::string m_DataDir;
157 std::function<TDatabase(const char*)> m_ConstructDatabase;
158 std::unique_ptr<TDatabase> m_Database;
160 int m_NumInferences; // Referenced by test cases
161 int m_NumCorrectInferences; // Referenced by test cases
163 std::string m_ValidationFileIn;
164 std::vector<unsigned int> m_ValidationPredictions; // Referenced by test cases
166 std::string m_ValidationFileOut;
167 std::vector<unsigned int> m_ValidationPredictionsOut; // Referenced by test cases
170 bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
171 InferenceTestOptions& outParams);
173 bool ValidateDirectory(std::string& dir);
175 bool InferenceTest(const InferenceTestOptions& params,
176 const std::vector<unsigned int>& defaultTestCaseIds,
177 IInferenceTestCaseProvider& testCaseProvider);
179 template<typename TConstructTestCaseProvider>
180 int InferenceTestMain(int argc,
182 const std::vector<unsigned int>& defaultTestCaseIds,
183 TConstructTestCaseProvider constructTestCaseProvider);
185 template<typename TDatabase,
187 typename TConstructDatabaseCallable>
188 int ClassifierInferenceTestMain(int argc, char* argv[], const char* modelFilename, bool isModelBinary,
189 const char* inputBindingName, const char* outputBindingName,
190 const std::vector<unsigned int>& defaultTestCaseIds,
191 TConstructDatabaseCallable constructDatabase,
192 const armnn::TensorShape* inputTensorShape = nullptr);
197 #include "InferenceTest.inl"