2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
5 #include "armnn/ArmNN.hpp"
6 #if defined(ARMNN_CAFFE_PARSER)
7 #include "armnnCaffeParser/ICaffeParser.hpp"
10 #include "../InferenceTest.hpp"
12 #include <boost/program_options.hpp>
13 #include <boost/algorithm/string/split.hpp>
14 #include <boost/algorithm/string/classification.hpp>
22 template<typename T, typename TParseElementFunc>
23 std::vector<T> ParseArrayImpl(std::istream& stream, TParseElementFunc parseElementFunc)
25 std::vector<T> result;
26 // Process line-by-line
28 while (std::getline(stream, line))
30 std::vector<std::string> tokens;
31 boost::split(tokens, line, boost::algorithm::is_any_of("\t ,;:"), boost::token_compress_on);
32 for (const std::string& token : tokens)
34 if (!token.empty()) // See https://stackoverflow.com/questions/10437406/
38 result.push_back(parseElementFunc(token));
40 catch (const std::exception&)
42 BOOST_LOG_TRIVIAL(error) << "'" << token << "' is not a valid number. It has been ignored.";
54 std::vector<T> ParseArray(std::istream& stream);
57 std::vector<float> ParseArray(std::istream& stream)
59 return ParseArrayImpl<float>(stream, [](const std::string& s) { return std::stof(s); });
63 std::vector<unsigned int> ParseArray(std::istream& stream)
65 return ParseArrayImpl<unsigned int>(stream,
66 [](const std::string& s) { return boost::numeric_cast<unsigned int>(std::stoi(s)); });
69 void PrintArray(const std::vector<float>& v)
71 for (size_t i = 0; i < v.size(); i++)
78 template<typename TParser, typename TDataType>
79 int MainImpl(const char* modelPath, bool isModelBinary, armnn::Compute computeDevice,
80 const char* inputName, const armnn::TensorShape* inputTensorShape, const char* inputTensorDataFilePath,
81 const char* outputName)
84 std::vector<TDataType> input;
86 std::ifstream inputTensorFile(inputTensorDataFilePath);
87 if (!inputTensorFile.good())
89 BOOST_LOG_TRIVIAL(fatal) << "Failed to load input tensor data file from " << inputTensorDataFilePath;
92 input = ParseArray<TDataType>(inputTensorFile);
97 // Create an InferenceModel, which will parse the model and load it into an IRuntime
98 typename InferenceModel<TParser, TDataType>::Params params;
99 params.m_ModelPath = modelPath;
100 params.m_IsModelBinary = isModelBinary;
101 params.m_ComputeDevice = computeDevice;
102 params.m_InputBinding = inputName;
103 params.m_InputTensorShape = inputTensorShape;
104 params.m_OutputBinding = outputName;
105 InferenceModel<TParser, TDataType> model(params);
108 std::vector<TDataType> output(model.GetOutputSize());
109 model.Run(input, output);
111 // Print the output tensor
114 catch (armnn::Exception const& e)
116 BOOST_LOG_TRIVIAL(fatal) << "Armnn Error: " << e.what();
123 int main(int argc, char* argv[])
125 // Configure logging for both the ARMNN library and this test program
127 armnn::LogSeverity level = armnn::LogSeverity::Info;
129 armnn::LogSeverity level = armnn::LogSeverity::Debug;
131 armnn::ConfigureLogging(true, true, level);
132 armnnUtils::ConfigureLogging(boost::log::core::get().get(), true, true, level);
134 // Configure boost::program_options for command-line parsing
135 namespace po = boost::program_options;
137 std::string modelFormat;
138 std::string modelPath;
139 std::string inputName;
140 std::string inputTensorShapeStr;
141 std::string inputTensorDataFilePath;
142 std::string outputName;
143 armnn::Compute computeDevice;
145 po::options_description desc("Options");
149 ("help", "Display usage information")
150 ("model-format,f", po::value(&modelFormat)->required(),
151 "caffe-binary, caffe-text, tensorflow-binary or tensorflow-text.")
152 ("model-path,m", po::value(&modelPath)->required(), "Path to model file, e.g. .caffemodel, .prototxt")
153 ("compute,c", po::value<armnn::Compute>(&computeDevice)->required(),
154 "Which device to run layers on by default. Possible choices: CpuAcc, CpuRef, GpuAcc")
155 ("input-name,i", po::value(&inputName)->required(), "Identifier of the input tensor in the network.")
156 ("input-tensor-shape,s", po::value(&inputTensorShapeStr),
157 "The shape of the input tensor in the network as a flat array of integers separated by whitespace. "
158 "This parameter is optional, depending on the network.")
159 ("input-tensor-data,d", po::value(&inputTensorDataFilePath)->required(),
160 "Path to a file containing the input data as a flat array separated by whitespace.")
161 ("output-name,o", po::value(&outputName)->required(), "Identifier of the output tensor in the network.");
163 catch (const std::exception& e)
165 // Coverity points out that default_value(...) can throw a bad_lexical_cast,
166 // and that desc.add_options() can throw boost::io::too_few_args.
167 // They really won't in any of these cases.
168 BOOST_ASSERT_MSG(false, "Caught unexpected exception");
169 BOOST_LOG_TRIVIAL(fatal) << "Fatal internal error: " << e.what();
173 // Parse the command-line
174 po::variables_map vm;
177 po::store(po::parse_command_line(argc, argv, desc), vm);
179 if (vm.count("help") || argc <= 1)
181 std::cout << "Executes a neural network model using the provided input tensor. " << std::endl;
182 std::cout << "Prints the resulting output tensor." << std::endl;
183 std::cout << std::endl;
184 std::cout << desc << std::endl;
192 std::cerr << e.what() << std::endl << std::endl;
193 std::cerr << desc << std::endl;
197 // Parse model binary flag from the model-format string we got from the command-line
199 if (modelFormat.find("bin") != std::string::npos)
201 isModelBinary = true;
203 else if (modelFormat.find("txt") != std::string::npos || modelFormat.find("text") != std::string::npos)
205 isModelBinary = false;
209 BOOST_LOG_TRIVIAL(fatal) << "Unknown model format: '" << modelFormat << "'. Please include 'binary' or 'text'";
213 // Parse input tensor shape from the string we got from the command-line.
214 std::unique_ptr<armnn::TensorShape> inputTensorShape;
215 if (!inputTensorShapeStr.empty())
217 std::stringstream ss(inputTensorShapeStr);
218 std::vector<unsigned int> dims = ParseArray<unsigned int>(ss);
219 inputTensorShape = std::make_unique<armnn::TensorShape>(dims.size(), dims.data());
222 // Forward to implementation based on the parser type
223 if (modelFormat.find("caffe") != std::string::npos)
225 #if defined(ARMNN_CAFFE_PARSER)
226 return MainImpl<armnnCaffeParser::ICaffeParser, float>(modelPath.c_str(), isModelBinary, computeDevice,
227 inputName.c_str(), inputTensorShape.get(), inputTensorDataFilePath.c_str(), outputName.c_str());
229 BOOST_LOG_TRIVIAL(fatal) << "Not built with Caffe parser support.";
233 else if (modelFormat.find("tensorflow") != std::string::npos)
235 BOOST_LOG_TRIVIAL(fatal) << "Not built with Tensorflow parser support.";
240 BOOST_LOG_TRIVIAL(fatal) << "Unknown model format: '" << modelFormat <<
241 "'. Please include 'caffe' or 'tensorflow'";