From d356560be7eb3f3ff0c2acf45915d542624e4ee8 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Ivan=20Vagin/AI=20Tools=20Lab=20/SRR/Engineer/=EC=82=BC?= =?utf8?q?=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Fri, 9 Nov 2018 14:26:30 +0300 Subject: [PATCH] [nnc] 'nnmodel' option become a vector (#2157) Make 'nnmodel' option a vector option, because caffe2 importer consumes two files (init_net.pb and predict_net.pb) Signed-off-by: Ivan Vagin --- contrib/nnc/driver/Driver.cpp | 4 +- contrib/nnc/driver/Options.cpp | 21 ++++----- contrib/nnc/examples/caffe_frontend/model_dump.cpp | 2 +- .../nnc/examples/tflite_frontend/sanity_check.cpp | 2 +- contrib/nnc/include/option/Options.h | 2 +- contrib/nnc/include/support/CommandLine.h | 1 + contrib/nnc/support/CLOptionChecker.cpp | 54 +++++++++++----------- contrib/nnc/tests/import/caffe.cpp | 2 +- contrib/nnc/tests/import/tflite.cpp | 2 +- contrib/nnc/unittests/pass/CMakeLists.txt | 2 +- 10 files changed, 47 insertions(+), 45 deletions(-) diff --git a/contrib/nnc/driver/Driver.cpp b/contrib/nnc/driver/Driver.cpp index 7b168ea..511cb18 100644 --- a/contrib/nnc/driver/Driver.cpp +++ b/contrib/nnc/driver/Driver.cpp @@ -69,7 +69,7 @@ void Driver::registerFrontendPass() { if (cli::caffeFrontend) { #ifdef NNC_FRONTEND_CAFFE_ENABLED - pass = std::move(std::unique_ptr(new CaffeImporter(cli::inputFile))); + pass = std::move(std::unique_ptr(new CaffeImporter(cli::inputFiles[0]))); #endif // NNC_FRONTEND_CAFFE_ENABLED } else if ( cli::onnxFrontend ) @@ -81,7 +81,7 @@ void Driver::registerFrontendPass() { else if ( cli::tflFrontend ) { #ifdef NNC_FRONTEND_TFLITE_ENABLED - pass = std::move(std::unique_ptr(new TfliteImporter(cli::inputFile))); + pass = std::move(std::unique_ptr(new TfliteImporter(cli::inputFiles[0]))); #endif // NNC_FRONTEND_TFLITE_ENABLED } else { throw DriverException("one of the following options must be defined: '" diff --git a/contrib/nnc/driver/Options.cpp b/contrib/nnc/driver/Options.cpp index d578e85..0d6b714 100644 --- a/contrib/nnc/driver/Options.cpp +++ b/contrib/nnc/driver/Options.cpp @@ -20,10 +20,8 @@ #include "option/Options.h" #include "Definitions.h" -namespace nnc -{ -namespace cli -{ +namespace nnc { +namespace cli { /** * Options for *compiler driver* @@ -83,13 +81,14 @@ Option target(optname("--target"), /** * Options for *frontend* */ -Option inputFile(optname("--nnmodel, -m"), - overview("specify input file with NN model"), - std::string(), - optional(false), - optvalues(""), - checkInFile, - separators("=")); +Option> inputFiles(optname("--nnmodel, -m"), + overview("specify input files with serialized NN models: " + "single model file must be provided for caffe, tflite and onnx frameworks; " + "two model files must be specified for caffe2 framework (init_net and predict_net)"), + std::vector{}, + optional(false), + optvalues(""), + checkModelFiles); /** * Options for *backend* diff --git a/contrib/nnc/examples/caffe_frontend/model_dump.cpp b/contrib/nnc/examples/caffe_frontend/model_dump.cpp index cfcc057..34f0c8d 100644 --- a/contrib/nnc/examples/caffe_frontend/model_dump.cpp +++ b/contrib/nnc/examples/caffe_frontend/model_dump.cpp @@ -30,7 +30,7 @@ using namespace nnc::cli; int main(int argc, const char **argv) { cli::CommandLine::getParser()->parseCommandLine(argc, argv, false); - std::string model = cli::inputFile; + std::string model = cli::inputFiles[0]; nnc::CaffeImporter importer{model}; diff --git a/contrib/nnc/examples/tflite_frontend/sanity_check.cpp b/contrib/nnc/examples/tflite_frontend/sanity_check.cpp index 6451730..e40d560 100644 --- a/contrib/nnc/examples/tflite_frontend/sanity_check.cpp +++ b/contrib/nnc/examples/tflite_frontend/sanity_check.cpp @@ -30,7 +30,7 @@ using namespace nnc::cli; int main(int argc, const char **argv) { cli::CommandLine::getParser()->parseCommandLine(argc, argv, false); - std::string model = cli::inputFile; + std::string model = cli::inputFiles[0]; nnc::TfliteImporter importer{model}; diff --git a/contrib/nnc/include/option/Options.h b/contrib/nnc/include/option/Options.h index 8ca3452..f13ee4a 100644 --- a/contrib/nnc/include/option/Options.h +++ b/contrib/nnc/include/option/Options.h @@ -41,7 +41,7 @@ extern Option target; // kind of target for which compiler generat /** * Frontend options */ -extern Option inputFile; // file contains model of specific AI framework +extern Option> inputFiles; // files contains model of specific AI framework /** * Options for backend diff --git a/contrib/nnc/include/support/CommandLine.h b/contrib/nnc/include/support/CommandLine.h index 45a233c..a1f5525 100644 --- a/contrib/nnc/include/support/CommandLine.h +++ b/contrib/nnc/include/support/CommandLine.h @@ -520,6 +520,7 @@ Option::Option(const std::vector &optnames, // prototypes of option checker functions // void checkInFile(const Option &); +void checkModelFiles(const Option> &); void checkOutFile(const Option &); void checkOutDir(const Option &); void checkDebugFile(const Option &); diff --git a/contrib/nnc/support/CLOptionChecker.cpp b/contrib/nnc/support/CLOptionChecker.cpp index 09b8e36..5680b9f 100644 --- a/contrib/nnc/support/CLOptionChecker.cpp +++ b/contrib/nnc/support/CLOptionChecker.cpp @@ -14,58 +14,62 @@ * limitations under the License. */ -// -// Created by rrusyaev on 14.08.18. -// #include "support/CommandLine.h" +#include "option/Options.h" #include #include #include -namespace nnc -{ -namespace cli -{ -void checkInFile(const Option &in_file) -{ +namespace nnc { +namespace cli { + +void checkInFile(const Option &in_file) { if ( in_file.empty() ) - { throw BadOption("Input file name should not be empty"); - } auto f = fopen(in_file.c_str(), "rb"); - if (!f) { + if (!f) throw BadOption("Cannot open file <" + in_file + ">"); - } fclose(f); } // checkInFile -void checkOutFile(const Option &out_file) -{ +void checkModelFiles(const Option> &in_files) { + if (in_files.empty()) + throw BadOption("Model file names should not be empty"); + + if ((tflFrontend || caffeFrontend || onnxFrontend) && in_files.size() != 1) + throw BadOption("For caffe, tflite and onnx frameworks single model file must be specified"); + // else if (cli::caffe2Frontend && in_files.size() != 2) + // throw BadOption("For caffe2 framework two model files must be specified (init_net and predict_net)"); + + for (auto& f_name : in_files) { + auto f = fopen(f_name.c_str(), "rb"); + if (!f) + throw BadOption("Cannot open file <" + f_name + ">"); + fclose(f); + } +} // checkModelFiles + +void checkOutFile(const Option &out_file) { if ( out_file.empty() ) - { throw BadOption("Output file name should not be empty"); - } /// @todo: if file already exists need to check accessibility } // checkOutFile -void checkOutDir(const Option &out_dir) -{ +void checkOutDir(const Option &out_dir) { auto dir = opendir(out_dir.c_str()); - if (dir) - { + if (dir) { closedir(dir); return; } auto err = errno; - switch (err) - { + switch (err) { case ENOENT: return; case ENOTDIR: @@ -75,11 +79,9 @@ void checkOutDir(const Option &out_dir) default: throw BadOption("Can not open output directory"); } - } // checkOutDir -void checkDebugFile(const Option &in_file) -{ +void checkDebugFile(const Option &in_file) { if (access(in_file.c_str(), W_OK) != 0) { throw BadOption("Has no permission to open debug output file"); } diff --git a/contrib/nnc/tests/import/caffe.cpp b/contrib/nnc/tests/import/caffe.cpp index 939ce07..c277d0c 100644 --- a/contrib/nnc/tests/import/caffe.cpp +++ b/contrib/nnc/tests/import/caffe.cpp @@ -27,7 +27,7 @@ int main(int argc, const char **argv) { return 1; cli::CommandLine::getParser()->parseCommandLine(argc, argv); - std::string modelName = cli::inputFile; + std::string modelName = cli::inputFiles[0]; nnc::CaffeImporter importer{modelName}; diff --git a/contrib/nnc/tests/import/tflite.cpp b/contrib/nnc/tests/import/tflite.cpp index dddc9e8..1dc642d 100644 --- a/contrib/nnc/tests/import/tflite.cpp +++ b/contrib/nnc/tests/import/tflite.cpp @@ -30,7 +30,7 @@ int main(int argc, const char **argv) } cli::CommandLine::getParser()->parseCommandLine(argc, argv); - std::string modelName = cli::inputFile; + std::string modelName = cli::inputFiles[0]; nnc::TfliteImporter importer{modelName}; diff --git a/contrib/nnc/unittests/pass/CMakeLists.txt b/contrib/nnc/unittests/pass/CMakeLists.txt index d1812be..5f75d53 100644 --- a/contrib/nnc/unittests/pass/CMakeLists.txt +++ b/contrib/nnc/unittests/pass/CMakeLists.txt @@ -1,6 +1,6 @@ file(GLOB_RECURSE TEST_SOURCES "*.cpp") -add_nnc_unit_test(nnc_pass_test ${TEST_SOURCES}) +add_nnc_unit_test(nnc_pass_test ${TEST_SOURCES} ${OPTIONS_SRC}) if (TARGET nnc_pass_test) nncc_target_link_libraries(nnc_pass_test nnc_support nnc_core) endif() -- 2.7.4