7 #if defined(ARMNN_CAFFE_PARSER) 10 #if defined(ARMNN_ONNX_PARSER) 13 #if defined(ARMNN_SERIALIZER) 16 #if defined(ARMNN_TF_PARSER) 19 #if defined(ARMNN_TF_LITE_PARSER) 25 #include <boost/format.hpp> 26 #include <boost/algorithm/string/split.hpp> 27 #include <boost/algorithm/string/classification.hpp> 28 #include <boost/program_options.hpp> 37 namespace po = boost::program_options;
41 std::vector<unsigned int> result;
44 while (std::getline(stream, line))
46 std::vector<std::string> tokens;
50 boost::split(tokens, line, boost::algorithm::is_any_of(
","), boost::token_compress_on);
52 catch (
const std::exception& e)
54 ARMNN_LOG(error) <<
"An error occurred when splitting tokens: " << e.what();
57 for (
const std::string& token : tokens)
63 result.push_back(boost::numeric_cast<unsigned int>(std::stoi((token))));
65 catch (
const std::exception&)
67 ARMNN_LOG(error) <<
"'" << token <<
"' is not a valid number. It has been ignored.";
73 return armnn::TensorShape(boost::numeric_cast<unsigned int>(result.size()), result.data());
76 bool CheckOption(
const po::variables_map& vm,
79 if (option ==
nullptr)
85 return vm.find(option) != vm.end();
88 void CheckOptionDependency(
const po::variables_map& vm,
92 if (option ==
nullptr || required ==
nullptr)
94 throw po::error(
"Invalid option to check dependency for");
98 if (CheckOption(vm, option) && !vm[option].defaulted())
100 if (CheckOption(vm, required) == 0 || vm[required].defaulted())
102 throw po::error(std::string(
"Option '") + option +
"' requires option '" + required +
"'.");
107 void CheckOptionDependencies(
const po::variables_map& vm)
109 CheckOptionDependency(vm,
"model-path",
"model-format");
110 CheckOptionDependency(vm,
"model-path",
"input-name");
111 CheckOptionDependency(vm,
"model-path",
"output-name");
112 CheckOptionDependency(vm,
"input-tensor-shape",
"model-path");
115 int ParseCommandLineArgs(
int argc,
const char* argv[],
116 std::string& modelFormat,
117 std::string& modelPath,
118 std::vector<std::string>& inputNames,
119 std::vector<std::string>& inputTensorShapeStrs,
120 std::vector<std::string>& outputNames,
121 std::string& outputPath,
bool& isModelBinary)
123 po::options_description desc(
"Options");
126 (
"help",
"Display usage information")
127 (
"model-format,f", po::value(&modelFormat)->required(),
"Format of the model file" 128 #if defined(ARMNN_CAFFE_PARSER) 129 ", caffe-binary, caffe-text" 131 #if defined(ARMNN_ONNX_PARSER) 132 ", onnx-binary, onnx-text" 134 #if defined(ARMNN_TF_PARSER) 135 ", tensorflow-binary, tensorflow-text" 137 #if defined(ARMNN_TF_LITE_PARSER) 141 (
"model-path,m", po::value(&modelPath)->required(),
"Path to model file.")
142 (
"input-name,i", po::value<std::vector<std::string>>()->multitoken(),
143 "Identifier of the input tensors in the network, separated by whitespace.")
144 (
"input-tensor-shape,s", po::value<std::vector<std::string>>()->multitoken(),
145 "The shape of the input tensor in the network as a flat array of integers, separated by comma." 146 " Multiple shapes are separated by whitespace." 147 " This parameter is optional, depending on the network.")
148 (
"output-name,o", po::value<std::vector<std::string>>()->multitoken(),
149 "Identifier of the output tensor in the network.")
150 (
"output-path,p", po::value(&outputPath)->required(),
"Path to serialize the network to.");
152 po::variables_map vm;
155 po::store(po::parse_command_line(argc, argv, desc), vm);
157 if (CheckOption(vm,
"help") || argc <= 1)
159 std::cout <<
"Convert a neural network model from provided file to ArmNN format." << std::endl;
160 std::cout << std::endl;
161 std::cout << desc << std::endl;
166 catch (
const po::error& e)
168 std::cerr << e.what() << std::endl << std::endl;
169 std::cerr << desc << std::endl;
175 CheckOptionDependencies(vm);
177 catch (
const po::error& e)
179 std::cerr << e.what() << std::endl << std::endl;
180 std::cerr << desc << std::endl;
184 if (modelFormat.find(
"bin") != std::string::npos)
186 isModelBinary =
true;
188 else if (modelFormat.find(
"text") != std::string::npos)
190 isModelBinary =
false;
194 ARMNN_LOG(fatal) <<
"Unknown model format: '" << modelFormat <<
"'. Please include 'binary' or 'text'";
198 if (!vm[
"input-tensor-shape"].empty())
200 inputTensorShapeStrs = vm[
"input-tensor-shape"].as<std::vector<std::string>>();
203 inputNames = vm[
"input-name"].as<std::vector<std::string>>();
204 outputNames = vm[
"output-name"].as<std::vector<std::string>>();
212 typedef T parserType;
218 ArmnnConverter(
const std::string& modelPath,
219 const std::vector<std::string>& inputNames,
220 const std::vector<armnn::TensorShape>& inputShapes,
221 const std::vector<std::string>& outputNames,
222 const std::string& outputPath,
225 m_ModelPath(modelPath),
226 m_InputNames(inputNames),
227 m_InputShapes(inputShapes),
228 m_OutputNames(outputNames),
229 m_OutputPath(outputPath),
230 m_IsModelBinary(isModelBinary) {}
234 if (m_NetworkPtr.get() ==
nullptr)
243 std::ofstream file(m_OutputPath, std::ios::out | std::ios::binary);
245 bool retVal =
serializer->SaveSerializedToStream(file);
250 template <
typename IParser>
251 bool CreateNetwork ()
253 return CreateNetwork (ParserType<IParser>());
258 std::string m_ModelPath;
259 std::vector<std::string> m_InputNames;
260 std::vector<armnn::TensorShape> m_InputShapes;
261 std::vector<std::string> m_OutputNames;
262 std::string m_OutputPath;
263 bool m_IsModelBinary;
265 template <
typename IParser>
266 bool CreateNetwork (ParserType<IParser>)
269 auto parser(IParser::Create());
271 std::map<std::string, armnn::TensorShape> inputShapes;
272 if (!m_InputShapes.empty())
274 const size_t numInputShapes = m_InputShapes.size();
275 const size_t numInputBindings = m_InputNames.size();
276 if (numInputShapes < numInputBindings)
279 "Not every input has its tensor shape specified: expected=%1%, got=%2%")
280 % numInputBindings % numInputShapes));
283 for (
size_t i = 0; i < numInputShapes; i++)
285 inputShapes[m_InputNames[i]] = m_InputShapes[i];
291 m_NetworkPtr = (m_IsModelBinary ?
292 parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str(), inputShapes, m_OutputNames) :
293 parser->CreateNetworkFromTextFile(m_ModelPath.c_str(), inputShapes, m_OutputNames));
296 return m_NetworkPtr.get() !=
nullptr;
299 #if defined(ARMNN_TF_LITE_PARSER) 300 bool CreateNetwork (ParserType<armnnTfLiteParser::ITfLiteParser>)
305 if (!m_InputShapes.empty())
307 const size_t numInputShapes = m_InputShapes.size();
308 const size_t numInputBindings = m_InputNames.size();
309 if (numInputShapes < numInputBindings)
312 "Not every input has its tensor shape specified: expected=%1%, got=%2%")
313 % numInputBindings % numInputShapes));
319 m_NetworkPtr = parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str());
322 return m_NetworkPtr.get() !=
nullptr;
326 #if defined(ARMNN_ONNX_PARSER) 327 bool CreateNetwork (ParserType<armnnOnnxParser::IOnnxParser>)
332 if (!m_InputShapes.empty())
334 const size_t numInputShapes = m_InputShapes.size();
335 const size_t numInputBindings = m_InputNames.size();
336 if (numInputShapes < numInputBindings)
339 "Not every input has its tensor shape specified: expected=%1%, got=%2%")
340 % numInputBindings % numInputShapes));
346 m_NetworkPtr = (m_IsModelBinary ?
347 parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str()) :
348 parser->CreateNetworkFromTextFile(m_ModelPath.c_str()));
351 return m_NetworkPtr.get() !=
nullptr;
359 int main(
int argc,
const char* argv[])
362 #if (!defined(ARMNN_CAFFE_PARSER) \ 363 && !defined(ARMNN_ONNX_PARSER) \ 364 && !defined(ARMNN_TF_PARSER) \ 365 && !defined(ARMNN_TF_LITE_PARSER)) 366 ARMNN_LOG(fatal) <<
"Not built with any of the supported parsers, Caffe, Onnx, Tensorflow, or TfLite.";
370 #if !defined(ARMNN_SERIALIZER) 371 ARMNN_LOG(fatal) <<
"Not built with Serializer support.";
383 std::string modelFormat;
384 std::string modelPath;
386 std::vector<std::string> inputNames;
387 std::vector<std::string> inputTensorShapeStrs;
388 std::vector<armnn::TensorShape> inputTensorShapes;
390 std::vector<std::string> outputNames;
391 std::string outputPath;
393 bool isModelBinary =
true;
395 if (ParseCommandLineArgs(
396 argc, argv, modelFormat, modelPath, inputNames, inputTensorShapeStrs, outputNames, outputPath, isModelBinary)
402 for (
const std::string& shapeStr : inputTensorShapeStrs)
404 if (!shapeStr.empty())
406 std::stringstream ss(shapeStr);
411 inputTensorShapes.push_back(shape);
415 ARMNN_LOG(fatal) <<
"Cannot create tensor shape: " << e.
what();
421 ArmnnConverter converter(modelPath, inputNames, inputTensorShapes, outputNames, outputPath, isModelBinary);
423 if (modelFormat.find(
"caffe") != std::string::npos)
425 #if defined(ARMNN_CAFFE_PARSER) 428 ARMNN_LOG(fatal) <<
"Failed to load model from file";
432 ARMNN_LOG(fatal) <<
"Not built with Caffe parser support.";
436 else if (modelFormat.find(
"onnx") != std::string::npos)
438 #if defined(ARMNN_ONNX_PARSER) 441 ARMNN_LOG(fatal) <<
"Failed to load model from file";
445 ARMNN_LOG(fatal) <<
"Not built with Onnx parser support.";
449 else if (modelFormat.find(
"tensorflow") != std::string::npos)
451 #if defined(ARMNN_TF_PARSER) 454 ARMNN_LOG(fatal) <<
"Failed to load model from file";
458 ARMNN_LOG(fatal) <<
"Not built with Tensorflow parser support.";
462 else if (modelFormat.find(
"tflite") != std::string::npos)
464 #if defined(ARMNN_TF_LITE_PARSER) 467 ARMNN_LOG(fatal) <<
"Unknown model format: '" << modelFormat <<
"'. Only 'binary' format supported \ 474 ARMNN_LOG(fatal) <<
"Failed to load model from file";
478 ARMNN_LOG(fatal) <<
"Not built with TfLite parser support.";
484 ARMNN_LOG(fatal) <<
"Unknown model format: '" << modelFormat <<
"'";
488 if (!converter.Serialize())
490 ARMNN_LOG(fatal) <<
"Failed to serialize model";
void ConfigureLogging(bool printToStandardOutput, bool printToDebugOutput, LogSeverity severity)
static ITfLiteParserPtr Create(const armnn::Optional< TfLiteParserOptions > &options=armnn::EmptyOptional())
#define ARMNN_LOG(severity)
virtual const char * what() const noexcept override
#define ARMNN_SCOPED_HEAP_PROFILING(TAG)
int main(int argc, const char *argv[])
static ISerializerPtr Create()
Base class for all ArmNN exceptions so that users can filter to just those.
std::unique_ptr< INetwork, void(*)(INetwork *network)> INetworkPtr
static IOnnxParserPtr Create()
Parses a directed acyclic graph from a tensorflow protobuf file.