2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
7 #include "InferenceModel.hpp"
9 #include <armnn/ArmNN.hpp>
10 #include <armnn/Logging.hpp>
11 #include <armnn/TypesUtils.hpp>
12 #include <armnn/utility/IgnoreUnused.hpp>
14 #include <cxxopts/cxxopts.hpp>
15 #include <fmt/format.h>
21 inline std::istream& operator>>(std::istream& in, armnn::Compute& compute)
25 compute = armnn::ParseComputeDevice(token.c_str());
26 if (compute == armnn::Compute::Undefined)
28 in.setstate(std::ios_base::failbit);
29 throw cxxopts::OptionException(fmt::format("Unrecognised compute device: {}", token));
34 inline std::istream& operator>>(std::istream& in, armnn::BackendId& backend)
38 armnn::Compute compute = armnn::ParseComputeDevice(token.c_str());
39 if (compute == armnn::Compute::Undefined)
41 in.setstate(std::ios_base::failbit);
42 throw cxxopts::OptionException(fmt::format("Unrecognised compute device: {}", token));
51 class TestFrameworkException : public Exception
54 using Exception::Exception;
57 struct InferenceTestOptions
59 unsigned int m_IterationCount;
60 std::string m_InferenceTimesFile;
61 bool m_EnableProfiling;
62 std::string m_DynamicBackendsPath;
64 InferenceTestOptions()
66 , m_EnableProfiling(0)
67 , m_DynamicBackendsPath()
71 enum class TestCaseResult
73 /// The test completed without any errors.
75 /// The test failed (e.g. the prediction didn't match the validation file).
76 /// This will eventually fail the whole program but the remaining test cases will still be run.
78 /// The test failed with a fatal error. The remaining tests will not be run.
82 class IInferenceTestCase
85 virtual ~IInferenceTestCase() {}
87 virtual void Run() = 0;
88 virtual TestCaseResult ProcessResult(const InferenceTestOptions& options) = 0;
91 class IInferenceTestCaseProvider
94 virtual ~IInferenceTestCaseProvider() {}
96 virtual void AddCommandLineOptions(cxxopts::Options& options, std::vector<std::string>& required)
98 IgnoreUnused(options, required);
100 virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions)
102 IgnoreUnused(commonOptions);
105 virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) = 0;
106 virtual bool OnInferenceTestFinished() { return true; };
109 template <typename TModel>
110 class InferenceModelTestCase : public IInferenceTestCase
113 using TContainer = mapbox::util::variant<std::vector<float>, std::vector<int>, std::vector<unsigned char>>;
115 InferenceModelTestCase(TModel& model,
116 unsigned int testCaseId,
117 const std::vector<TContainer>& inputs,
118 const std::vector<unsigned int>& outputSizes)
120 , m_TestCaseId(testCaseId)
121 , m_Inputs(std::move(inputs))
123 // Initialize output vector
124 const size_t numOutputs = outputSizes.size();
125 m_Outputs.reserve(numOutputs);
127 for (size_t i = 0; i < numOutputs; i++)
129 m_Outputs.push_back(std::vector<typename TModel::DataType>(outputSizes[i]));
133 virtual void Run() override
135 m_Model.Run(m_Inputs, m_Outputs);
139 unsigned int GetTestCaseId() const { return m_TestCaseId; }
140 const std::vector<TContainer>& GetOutputs() const { return m_Outputs; }
144 unsigned int m_TestCaseId;
145 std::vector<TContainer> m_Inputs;
146 std::vector<TContainer> m_Outputs;
149 template <typename TTestCaseDatabase, typename TModel>
150 class ClassifierTestCase : public InferenceModelTestCase<TModel>
153 ClassifierTestCase(int& numInferencesRef,
154 int& numCorrectInferencesRef,
155 const std::vector<unsigned int>& validationPredictions,
156 std::vector<unsigned int>* validationPredictionsOut,
158 unsigned int testCaseId,
160 std::vector<typename TModel::DataType> modelInput);
162 virtual TestCaseResult ProcessResult(const InferenceTestOptions& params) override;
165 unsigned int m_Label;
166 InferenceModelInternal::QuantizationParams m_QuantizationParams;
168 /// These fields reference the corresponding member in the ClassifierTestCaseProvider.
170 int& m_NumInferencesRef;
171 int& m_NumCorrectInferencesRef;
172 const std::vector<unsigned int>& m_ValidationPredictions;
173 std::vector<unsigned int>* m_ValidationPredictionsOut;
177 template <typename TDatabase, typename InferenceModel>
178 class ClassifierTestCaseProvider : public IInferenceTestCaseProvider
181 template <typename TConstructDatabaseCallable, typename TConstructModelCallable>
182 ClassifierTestCaseProvider(TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel);
184 virtual void AddCommandLineOptions(cxxopts::Options& options, std::vector<std::string>& required) override;
185 virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions) override;
186 virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override;
187 virtual bool OnInferenceTestFinished() override;
190 void ReadPredictions();
192 typename InferenceModel::CommandLineOptions m_ModelCommandLineOptions;
193 std::function<std::unique_ptr<InferenceModel>(const InferenceTestOptions& commonOptions,
194 typename InferenceModel::CommandLineOptions)> m_ConstructModel;
195 std::unique_ptr<InferenceModel> m_Model;
197 std::string m_DataDir;
198 std::function<TDatabase(const char*, const InferenceModel&)> m_ConstructDatabase;
199 std::unique_ptr<TDatabase> m_Database;
201 int m_NumInferences; // Referenced by test cases.
202 int m_NumCorrectInferences; // Referenced by test cases.
204 std::string m_ValidationFileIn;
205 std::vector<unsigned int> m_ValidationPredictions; // Referenced by test cases.
207 std::string m_ValidationFileOut;
208 std::vector<unsigned int> m_ValidationPredictionsOut; // Referenced by test cases.
211 bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
212 InferenceTestOptions& outParams);
214 bool ValidateDirectory(std::string& dir);
216 bool InferenceTest(const InferenceTestOptions& params,
217 const std::vector<unsigned int>& defaultTestCaseIds,
218 IInferenceTestCaseProvider& testCaseProvider);
220 template<typename TConstructTestCaseProvider>
221 int InferenceTestMain(int argc,
223 const std::vector<unsigned int>& defaultTestCaseIds,
224 TConstructTestCaseProvider constructTestCaseProvider);
226 template<typename TDatabase,
228 typename TConstructDatabaseCallable>
229 int ClassifierInferenceTestMain(int argc, char* argv[], const char* modelFilename, bool isModelBinary,
230 const char* inputBindingName, const char* outputBindingName,
231 const std::vector<unsigned int>& defaultTestCaseIds,
232 TConstructDatabaseCallable constructDatabase,
233 const armnn::TensorShape* inputTensorShape = nullptr);
238 #include "InferenceTest.inl"