Ensure that the model file consumed entirely, otherwise complain that the model cannot be loaded. Fixes #2050.
Signed-off-by: Sergei Barannikov <s.barannikov@samsung.com>
using namespace nnc;
-int main(int argc, const char *argv[])
-{
+/*
+ * Prints the explanatory string of an exception. If the exception is nested, recurses to print
+ * the explanatory string of the exception it holds.
+ */
+static void printException(const std::exception& e, int indent = 0) {
+ std::cerr << std::string(indent, ' ') << e.what() << std::endl;
+ try {
+ std::rethrow_if_nested(e);
+ } catch (const std::exception& e) {
+ printException(e, indent + 2);
+ }
+}
+
+int main(int argc, const char* argv[]) {
int exit_code = EXIT_FAILURE;
- try
- {
+ try {
// Parse command line
cli::CommandLine::getParser()->parseCommandLine(argc, argv);
// errors didn't happen
exit_code = EXIT_SUCCESS;
- }
- catch ( const DriverException &e )
- {
- std::cerr << e.what() << std::endl;
+ } catch (const DriverException& e) {
+ printException(e);
std::cerr << "use --help for more information" << std::endl;
- }
- catch ( const PassException &e )
- {
- std::cerr << e.what() << std::endl;
+ } catch (const PassException& e) {
+ printException(e);
}
return exit_code;
+++ /dev/null
-/*
- * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef NNCC_PROTO_HELPER_H
-#define NNCC_PROTO_HELPER_H
-
-#include <iostream>
-#include <fcntl.h>
-#include <unistd.h>
-#include <memory>
-
-#include "google/protobuf/io/coded_stream.h"
-#include "google/protobuf/io/zero_copy_stream_impl.h"
-#include "google/protobuf/text_format.h"
-
-#include "passes/common_frontend/model_allocation.h"
-
-namespace nnc {
-
-const int protoBytesLimit = INT_MAX;
-const int protoBytesWarningLimit = 1024 * 1024 * 512;
-
-template <typename protoType>
-bool readProtoFromTextFile(const char* filename, protoType* proto) {
- std::unique_ptr<ModelAllocation> protoMap(new ModelAllocation{filename});
-
- google::protobuf::io::CodedInputStream coded_input(
- (const google::protobuf::uint8*)protoMap->getDataPnt(), protoMap->getNumBytes());
- coded_input.SetTotalBytesLimit(protoBytesLimit, protoBytesWarningLimit);
-
- bool success = google::protobuf::TextFormat::Parse(&coded_input, proto);
-
- return success;
-}
-
-template <typename protoType>
-bool readProtoFromBinaryFile(const char* filename, protoType* proto) {
- std::unique_ptr<ModelAllocation> protoMap(new ModelAllocation{filename});
-
- google::protobuf::io::CodedInputStream coded_input(
- (const google::protobuf::uint8*)protoMap->getDataPnt(), protoMap->getNumBytes());
- coded_input.SetTotalBytesLimit(protoBytesLimit, protoBytesWarningLimit);
-
- bool success = proto->ParseFromCodedStream(&coded_input);
-
- return success;
-}
-
-} // namespace nnc
-
-#endif // NNCC_PROTO_HELPER_H
--- /dev/null
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NNCC_PROTOBUF_HELPER_H
+#define NNCC_PROTOBUF_HELPER_H
+
+#include <google/protobuf/message.h>
+#include <string>
+
+namespace nnc {
+
+void readBinaryProto(const std::string& filename, google::protobuf::Message* message);
+
+} // namespace nnc
+
+#endif // NNCC_PROTOBUF_HELPER_H
#include "caffe2_importer.h"
#include "passes/common_frontend/shape_helper.h"
-#include "passes/common_frontend/proto_helper.h"
+#include "support/ProtobufHelper.h"
#include "caffe2_op_types.h"
#include "caffe2_op_creator.h"
delete _graph;
}
+static void loadModelFile(const std::string& filename, caffe2::NetDef* net) {
+ try {
+ readBinaryProto(filename, net);
+ } catch (...) {
+ std::throw_with_nested(PassException("Couldn't load the model file \"" + filename + "\""));
+ }
+}
+
void Caffe2Importer::import() {
GOOGLE_PROTOBUF_VERIFY_VERSION;
_net.reset(new NetDef());
- if (!readProtoFromBinaryFile<::caffe2::NetDef>(_predictNet.c_str(), _net.get()))
- throw PassException("Could not load model: " + _predictNet + "\n");
+ loadModelFile(_predictNet, _net.get());
std::unique_ptr<NetDef> net2;
net2.reset(new NetDef());
- if (!readProtoFromBinaryFile<::caffe2::NetDef>(_initNet.c_str(), net2.get()))
- throw PassException("Could not load model: " + _initNet + "\n");
+ loadModelFile(_initNet, net2.get());
_net->MergeFrom(*net2);
collectUnsupportedOps();
#include "pass/PassException.h"
#include "passes/common_frontend/shape_helper.h"
-#include "passes/common_frontend/proto_helper.h"
+#include "support/ProtobufHelper.h"
namespace nnc {
CaffeImporter::~CaffeImporter() {}
+static void loadModelFile(const std::string& filename, caffe::NetParameter* net) {
+ try {
+ readBinaryProto(filename, net);
+ } catch (...) {
+ std::throw_with_nested(PassException("Couldn't load the model file \"" + filename + "\""));
+ }
+}
+
void CaffeImporter::import() {
GOOGLE_PROTOBUF_VERIFY_VERSION;
_net.reset(new NetParameter());
- if (!readProtoFromBinaryFile<::caffe::NetParameter>(_modelFilename.c_str(), _net.get()))
- throw PassException("Could not load model: " + _modelFilename + "\n");
+ loadModelFile(_modelFilename, _net.get());
collectUnsupportedLayers();
}
set(SUPPORT_SOURCES
CommandLine.cpp
- CLOptionChecker.cpp)
+ CLOptionChecker.cpp
+ ProtobufHelper.cpp)
-add_library(nnc_support SHARED ${SUPPORT_SOURCES})
+add_library(nnc_support STATIC ${SUPPORT_SOURCES})
set_target_properties(nnc_support PROPERTIES LINKER_LANGUAGE CXX)
-target_link_libraries(nnc_support PRIVATE dl)
-install_nnc_library(nnc_support)
-
+set_target_properties(nnc_support PROPERTIES POSITION_INDEPENDENT_CODE ON)
--- /dev/null
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "support/ProtobufHelper.h"
+#include <google/protobuf/io/zero_copy_stream_impl.h>
+#include <google/protobuf/io/coded_stream.h>
+#include <cerrno>
+#include <cstring>
+#include <stdexcept>
+#include <fcntl.h>
+#include <unistd.h>
+
+namespace nnc {
+
+// Allow files up to 1GB, but warn about files larger than 512MB.
+static constexpr int protoTotalBytesLimit = 1024 << 20;
+static constexpr int protoTotalBytesWarningThreshold = 512 << 20;
+
+void readBinaryProto(const std::string& filename, google::protobuf::Message* message) {
+ int file_handle = open(filename.c_str(), O_RDONLY);
+
+ if (file_handle == -1)
+ throw std::runtime_error("Couldn't open file \"" + filename + "\": " +
+ std::strerror(errno) + ".");
+
+ google::protobuf::io::FileInputStream file_stream(file_handle);
+ file_stream.SetCloseOnDelete(true);
+
+ google::protobuf::io::CodedInputStream coded_stream(&file_stream);
+ coded_stream.SetTotalBytesLimit(protoTotalBytesLimit, protoTotalBytesWarningThreshold);
+
+ if (!message->ParseFromCodedStream(&coded_stream))
+ throw std::runtime_error("Couldn't parse file \"" + filename + "\".");
+
+ // If the file has not been consumed entirely, assume that the file is in the wrong format.
+ if (!coded_stream.ConsumedEntireMessage())
+ throw std::runtime_error("File \"" + filename + "\" has not been consumed entirely.");
+}
+
+} // namespace nnc