From: Сергей Баранников/AI Tools Lab /SRR/Engineer/삼성전자 Date: Thu, 10 Jan 2019 13:38:53 +0000 (+0300) Subject: [nnc] Check that the model has the correct format in Caffe importer (#2753) X-Git-Tag: nncc_backup~987 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=ce32750f796762eeebc56b35c7ae6d19ab59e696;p=platform%2Fcore%2Fml%2Fnnfw.git [nnc] Check that the model has the correct format in Caffe importer (#2753) Ensure that the model file consumed entirely, otherwise complain that the model cannot be loaded. Fixes #2050. Signed-off-by: Sergei Barannikov --- diff --git a/contrib/nnc/driver/main.cpp b/contrib/nnc/driver/main.cpp index 1e19bd7..2b5c239 100644 --- a/contrib/nnc/driver/main.cpp +++ b/contrib/nnc/driver/main.cpp @@ -23,12 +23,23 @@ using namespace nnc; -int main(int argc, const char *argv[]) -{ +/* + * Prints the explanatory string of an exception. If the exception is nested, recurses to print + * the explanatory string of the exception it holds. + */ +static void printException(const std::exception& e, int indent = 0) { + std::cerr << std::string(indent, ' ') << e.what() << std::endl; + try { + std::rethrow_if_nested(e); + } catch (const std::exception& e) { + printException(e, indent + 2); + } +} + +int main(int argc, const char* argv[]) { int exit_code = EXIT_FAILURE; - try - { + try { // Parse command line cli::CommandLine::getParser()->parseCommandLine(argc, argv); @@ -43,15 +54,11 @@ int main(int argc, const char *argv[]) // errors didn't happen exit_code = EXIT_SUCCESS; - } - catch ( const DriverException &e ) - { - std::cerr << e.what() << std::endl; + } catch (const DriverException& e) { + printException(e); std::cerr << "use --help for more information" << std::endl; - } - catch ( const PassException &e ) - { - std::cerr << e.what() << std::endl; + } catch (const PassException& e) { + printException(e); } return exit_code; diff --git a/contrib/nnc/include/passes/common_frontend/proto_helper.h b/contrib/nnc/include/passes/common_frontend/proto_helper.h deleted file mode 100644 index 3d402e2..0000000 --- a/contrib/nnc/include/passes/common_frontend/proto_helper.h +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef NNCC_PROTO_HELPER_H -#define NNCC_PROTO_HELPER_H - -#include -#include -#include -#include - -#include "google/protobuf/io/coded_stream.h" -#include "google/protobuf/io/zero_copy_stream_impl.h" -#include "google/protobuf/text_format.h" - -#include "passes/common_frontend/model_allocation.h" - -namespace nnc { - -const int protoBytesLimit = INT_MAX; -const int protoBytesWarningLimit = 1024 * 1024 * 512; - -template -bool readProtoFromTextFile(const char* filename, protoType* proto) { - std::unique_ptr protoMap(new ModelAllocation{filename}); - - google::protobuf::io::CodedInputStream coded_input( - (const google::protobuf::uint8*)protoMap->getDataPnt(), protoMap->getNumBytes()); - coded_input.SetTotalBytesLimit(protoBytesLimit, protoBytesWarningLimit); - - bool success = google::protobuf::TextFormat::Parse(&coded_input, proto); - - return success; -} - -template -bool readProtoFromBinaryFile(const char* filename, protoType* proto) { - std::unique_ptr protoMap(new ModelAllocation{filename}); - - google::protobuf::io::CodedInputStream coded_input( - (const google::protobuf::uint8*)protoMap->getDataPnt(), protoMap->getNumBytes()); - coded_input.SetTotalBytesLimit(protoBytesLimit, protoBytesWarningLimit); - - bool success = proto->ParseFromCodedStream(&coded_input); - - return success; -} - -} // namespace nnc - -#endif // NNCC_PROTO_HELPER_H diff --git a/contrib/nnc/include/support/ProtobufHelper.h b/contrib/nnc/include/support/ProtobufHelper.h new file mode 100644 index 0000000..8fc82ee --- /dev/null +++ b/contrib/nnc/include/support/ProtobufHelper.h @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNCC_PROTOBUF_HELPER_H +#define NNCC_PROTOBUF_HELPER_H + +#include +#include + +namespace nnc { + +void readBinaryProto(const std::string& filename, google::protobuf::Message* message); + +} // namespace nnc + +#endif // NNCC_PROTOBUF_HELPER_H diff --git a/contrib/nnc/passes/caffe2_frontend/caffe2_importer.cpp b/contrib/nnc/passes/caffe2_frontend/caffe2_importer.cpp index 5a0e613..a6c543b 100644 --- a/contrib/nnc/passes/caffe2_frontend/caffe2_importer.cpp +++ b/contrib/nnc/passes/caffe2_frontend/caffe2_importer.cpp @@ -21,7 +21,7 @@ #include "caffe2_importer.h" #include "passes/common_frontend/shape_helper.h" -#include "passes/common_frontend/proto_helper.h" +#include "support/ProtobufHelper.h" #include "caffe2_op_types.h" #include "caffe2_op_creator.h" @@ -54,17 +54,23 @@ void Caffe2Importer::cleanup() { delete _graph; } +static void loadModelFile(const std::string& filename, caffe2::NetDef* net) { + try { + readBinaryProto(filename, net); + } catch (...) { + std::throw_with_nested(PassException("Couldn't load the model file \"" + filename + "\"")); + } +} + void Caffe2Importer::import() { GOOGLE_PROTOBUF_VERIFY_VERSION; _net.reset(new NetDef()); - if (!readProtoFromBinaryFile<::caffe2::NetDef>(_predictNet.c_str(), _net.get())) - throw PassException("Could not load model: " + _predictNet + "\n"); + loadModelFile(_predictNet, _net.get()); std::unique_ptr net2; net2.reset(new NetDef()); - if (!readProtoFromBinaryFile<::caffe2::NetDef>(_initNet.c_str(), net2.get())) - throw PassException("Could not load model: " + _initNet + "\n"); + loadModelFile(_initNet, net2.get()); _net->MergeFrom(*net2); collectUnsupportedOps(); diff --git a/contrib/nnc/passes/caffe_frontend/caffe_importer.cpp b/contrib/nnc/passes/caffe_frontend/caffe_importer.cpp index 62f64ce..5258c0c 100644 --- a/contrib/nnc/passes/caffe_frontend/caffe_importer.cpp +++ b/contrib/nnc/passes/caffe_frontend/caffe_importer.cpp @@ -28,7 +28,7 @@ #include "pass/PassException.h" #include "passes/common_frontend/shape_helper.h" -#include "passes/common_frontend/proto_helper.h" +#include "support/ProtobufHelper.h" namespace nnc { @@ -43,12 +43,19 @@ CaffeImporter::CaffeImporter(std::string filename) : _modelFilename(std::move(fi CaffeImporter::~CaffeImporter() {} +static void loadModelFile(const std::string& filename, caffe::NetParameter* net) { + try { + readBinaryProto(filename, net); + } catch (...) { + std::throw_with_nested(PassException("Couldn't load the model file \"" + filename + "\"")); + } +} + void CaffeImporter::import() { GOOGLE_PROTOBUF_VERIFY_VERSION; _net.reset(new NetParameter()); - if (!readProtoFromBinaryFile<::caffe::NetParameter>(_modelFilename.c_str(), _net.get())) - throw PassException("Could not load model: " + _modelFilename + "\n"); + loadModelFile(_modelFilename, _net.get()); collectUnsupportedLayers(); } diff --git a/contrib/nnc/support/CMakeLists.txt b/contrib/nnc/support/CMakeLists.txt index 03d980e..f4b4188 100644 --- a/contrib/nnc/support/CMakeLists.txt +++ b/contrib/nnc/support/CMakeLists.txt @@ -1,9 +1,8 @@ set(SUPPORT_SOURCES CommandLine.cpp - CLOptionChecker.cpp) + CLOptionChecker.cpp + ProtobufHelper.cpp) -add_library(nnc_support SHARED ${SUPPORT_SOURCES}) +add_library(nnc_support STATIC ${SUPPORT_SOURCES}) set_target_properties(nnc_support PROPERTIES LINKER_LANGUAGE CXX) -target_link_libraries(nnc_support PRIVATE dl) -install_nnc_library(nnc_support) - +set_target_properties(nnc_support PROPERTIES POSITION_INDEPENDENT_CODE ON) diff --git a/contrib/nnc/support/ProtobufHelper.cpp b/contrib/nnc/support/ProtobufHelper.cpp new file mode 100644 index 0000000..34256a8 --- /dev/null +++ b/contrib/nnc/support/ProtobufHelper.cpp @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "support/ProtobufHelper.h" +#include +#include +#include +#include +#include +#include +#include + +namespace nnc { + +// Allow files up to 1GB, but warn about files larger than 512MB. +static constexpr int protoTotalBytesLimit = 1024 << 20; +static constexpr int protoTotalBytesWarningThreshold = 512 << 20; + +void readBinaryProto(const std::string& filename, google::protobuf::Message* message) { + int file_handle = open(filename.c_str(), O_RDONLY); + + if (file_handle == -1) + throw std::runtime_error("Couldn't open file \"" + filename + "\": " + + std::strerror(errno) + "."); + + google::protobuf::io::FileInputStream file_stream(file_handle); + file_stream.SetCloseOnDelete(true); + + google::protobuf::io::CodedInputStream coded_stream(&file_stream); + coded_stream.SetTotalBytesLimit(protoTotalBytesLimit, protoTotalBytesWarningThreshold); + + if (!message->ParseFromCodedStream(&coded_stream)) + throw std::runtime_error("Couldn't parse file \"" + filename + "\"."); + + // If the file has not been consumed entirely, assume that the file is in the wrong format. + if (!coded_stream.ConsumedEntireMessage()) + throw std::runtime_error("File \"" + filename + "\" has not been consumed entirely."); +} + +} // namespace nnc