14 #include <boost/program_options.hpp> 27 in.setstate(std::ios_base::failbit);
28 throw boost::program_options::validation_error(boost::program_options::validation_error::invalid_option_value);
40 in.setstate(std::ios_base::failbit);
41 throw boost::program_options::validation_error(boost::program_options::validation_error::invalid_option_value);
65 , m_EnableProfiling(0)
66 , m_DynamicBackendsPath()
86 virtual void Run() = 0;
104 virtual std::unique_ptr<IInferenceTestCase> GetTestCase(
unsigned int testCaseId) = 0;
108 template <
typename TModel>
112 using TContainer = boost::variant<std::vector<float>, std::vector<int>, std::vector<unsigned char>>;
115 unsigned int testCaseId,
116 const std::vector<TContainer>& inputs,
117 const std::vector<unsigned int>& outputSizes)
119 , m_TestCaseId(testCaseId)
120 , m_Inputs(
std::move(inputs))
123 const size_t numOutputs = outputSizes.size();
124 m_Outputs.reserve(numOutputs);
126 for (
size_t i = 0; i < numOutputs; i++)
128 m_Outputs.push_back(std::vector<typename TModel::DataType>(outputSizes[i]));
132 virtual void Run()
override 134 m_Model.Run(m_Inputs, m_Outputs);
139 const std::vector<TContainer>&
GetOutputs()
const {
return m_Outputs; }
143 unsigned int m_TestCaseId;
144 std::vector<TContainer> m_Inputs;
145 std::vector<TContainer> m_Outputs;
148 template <
typename TTestCaseDatabase,
typename TModel>
153 int& numCorrectInferencesRef,
154 const std::vector<unsigned int>& validationPredictions,
155 std::vector<unsigned int>* validationPredictionsOut,
157 unsigned int testCaseId,
159 std::vector<typename TModel::DataType> modelInput);
164 unsigned int m_Label;
169 int& m_NumInferencesRef;
170 int& m_NumCorrectInferencesRef;
171 const std::vector<unsigned int>& m_ValidationPredictions;
172 std::vector<unsigned int>* m_ValidationPredictionsOut;
176 template <
typename TDatabase,
typename InferenceModel>
180 template <
typename TConstructDatabaseCallable,
typename TConstructModelCallable>
183 virtual void AddCommandLineOptions(boost::program_options::options_description&
options)
override;
185 virtual std::unique_ptr<IInferenceTestCase> GetTestCase(
unsigned int testCaseId)
override;
186 virtual bool OnInferenceTestFinished()
override;
189 void ReadPredictions();
194 std::unique_ptr<InferenceModel> m_Model;
196 std::string m_DataDir;
197 std::function<TDatabase(const char*, const InferenceModel&)> m_ConstructDatabase;
198 std::unique_ptr<TDatabase> m_Database;
201 int m_NumCorrectInferences;
203 std::string m_ValidationFileIn;
204 std::vector<unsigned int> m_ValidationPredictions;
206 std::string m_ValidationFileOut;
207 std::vector<unsigned int> m_ValidationPredictionsOut;
216 const std::vector<unsigned int>& defaultTestCaseIds,
219 template<
typename TConstructTestCaseProv
ider>
222 const std::vector<unsigned int>& defaultTestCaseIds,
223 TConstructTestCaseProvider constructTestCaseProvider);
225 template<
typename TDatabase,
227 typename TConstructDatabaseCallable>
229 const char* inputBindingName,
const char* outputBindingName,
230 const std::vector<unsigned int>& defaultTestCaseIds,
231 TConstructDatabaseCallable constructDatabase,
bool ParseCommandLine(int argc, char **argv, IInferenceTestCaseProvider &testCaseProvider, InferenceTestOptions &outParams)
Parse the command line of an ArmNN (or referencetests) inference test program.
virtual ~IInferenceTestCaseProvider()
std::istream & operator>>(std::istream &in, armnn::Compute &compute)
std::string m_InferenceTimesFile
virtual void Run() override
virtual bool OnInferenceTestFinished()
const std::vector< TContainer > & GetOutputs() const
Exception(const std::string &message)
virtual void AddCommandLineOptions(boost::program_options::options_description &options)
Copyright (c) 2020 ARM Limited.
void IgnoreUnused(Ts &&...)
Compute
The Compute enum is now deprecated and it is now being replaced by BackendId.
virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions)
constexpr armnn::Compute ParseComputeDevice(const char *str)
Deprecated function that will be removed together with the Compute enum.
unsigned int GetTestCaseId() const
InferenceModelTestCase(TModel &model, unsigned int testCaseId, const std::vector< TContainer > &inputs, const std::vector< unsigned int > &outputSizes)
int ClassifierInferenceTestMain(int argc, char *argv[], const char *modelFilename, bool isModelBinary, const char *inputBindingName, const char *outputBindingName, const std::vector< unsigned int > &defaultTestCaseIds, TConstructDatabaseCallable constructDatabase, const armnn::TensorShape *inputTensorShape=nullptr)
std::pair< float, int32_t > QuantizationParams
virtual ~IInferenceTestCase()
bool InferenceTest(const InferenceTestOptions ¶ms, const std::vector< unsigned int > &defaultTestCaseIds, IInferenceTestCaseProvider &testCaseProvider)
The test failed with a fatal error. The remaining tests will not be run.
Base class for all ArmNN exceptions so that users can filter to just those.
unsigned int m_IterationCount
bool ValidateDirectory(std::string &dir)
boost::variant< std::vector< float >, std::vector< int >, std::vector< unsigned char > > TContainer
armnn::Runtime::CreationOptions::ExternalProfilingOptions options
int InferenceTestMain(int argc, char *argv[], const std::vector< unsigned int > &defaultTestCaseIds, TConstructTestCaseProvider constructTestCaseProvider)
std::string m_DynamicBackendsPath
The test completed without any errors.