[nnc] Check that the model has the correct format in Caffe importer (#2753)
authorСергей Баранников/AI Tools Lab /SRR/Engineer/삼성전자 <s.barannikov@samsung.com>
Thu, 10 Jan 2019 13:38:53 +0000 (16:38 +0300)
committerРоман Михайлович Русяев/AI Tools Lab /SRR/Staff Engineer/삼성전자 <r.rusyaev@samsung.com>
Thu, 10 Jan 2019 13:38:53 +0000 (16:38 +0300)
Ensure that the model file consumed entirely, otherwise complain that the model cannot be loaded. Fixes #2050.

Signed-off-by: Sergei Barannikov <s.barannikov@samsung.com>
contrib/nnc/driver/main.cpp
contrib/nnc/include/passes/common_frontend/proto_helper.h [deleted file]
contrib/nnc/include/support/ProtobufHelper.h [new file with mode: 0644]
contrib/nnc/passes/caffe2_frontend/caffe2_importer.cpp
contrib/nnc/passes/caffe_frontend/caffe_importer.cpp
contrib/nnc/support/CMakeLists.txt
contrib/nnc/support/ProtobufHelper.cpp [new file with mode: 0644]

index 1e19bd7..2b5c239 100644 (file)
 
 using namespace nnc;
 
-int main(int argc, const char *argv[])
-{
+/*
+ * Prints the explanatory string of an exception. If the exception is nested, recurses to print
+ * the explanatory string of the exception it holds.
+ */
+static void printException(const std::exception& e, int indent = 0) {
+  std::cerr << std::string(indent, ' ') << e.what() << std::endl;
+  try {
+    std::rethrow_if_nested(e);
+  } catch (const std::exception& e) {
+    printException(e, indent + 2);
+  }
+}
+
+int main(int argc, const char* argv[]) {
   int exit_code = EXIT_FAILURE;
 
-  try
-  {
+  try {
     // Parse command line
     cli::CommandLine::getParser()->parseCommandLine(argc, argv);
 
@@ -43,15 +54,11 @@ int main(int argc, const char *argv[])
 
     // errors didn't happen
     exit_code = EXIT_SUCCESS;
-  }
-  catch ( const DriverException &e )
-  {
-    std::cerr << e.what() << std::endl;
+  } catch (const DriverException& e) {
+    printException(e);
     std::cerr << "use --help for more information" << std::endl;
-  }
-  catch ( const PassException &e )
-  {
-    std::cerr << e.what() << std::endl;
+  } catch (const PassException& e) {
+    printException(e);
   }
 
   return exit_code;
diff --git a/contrib/nnc/include/passes/common_frontend/proto_helper.h b/contrib/nnc/include/passes/common_frontend/proto_helper.h
deleted file mode 100644 (file)
index 3d402e2..0000000
+++ /dev/null
@@ -1,64 +0,0 @@
-/*
- * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef NNCC_PROTO_HELPER_H
-#define NNCC_PROTO_HELPER_H
-
-#include <iostream>
-#include <fcntl.h>
-#include <unistd.h>
-#include <memory>
-
-#include "google/protobuf/io/coded_stream.h"
-#include "google/protobuf/io/zero_copy_stream_impl.h"
-#include "google/protobuf/text_format.h"
-
-#include "passes/common_frontend/model_allocation.h"
-
-namespace nnc {
-
-const int protoBytesLimit = INT_MAX;
-const int protoBytesWarningLimit = 1024 * 1024 * 512;
-
-template <typename protoType>
-bool readProtoFromTextFile(const char* filename, protoType* proto) {
-  std::unique_ptr<ModelAllocation> protoMap(new ModelAllocation{filename});
-
-  google::protobuf::io::CodedInputStream coded_input(
-          (const google::protobuf::uint8*)protoMap->getDataPnt(), protoMap->getNumBytes());
-  coded_input.SetTotalBytesLimit(protoBytesLimit, protoBytesWarningLimit);
-
-  bool success = google::protobuf::TextFormat::Parse(&coded_input, proto);
-
-  return success;
-}
-
-template <typename protoType>
-bool readProtoFromBinaryFile(const char* filename, protoType* proto) {
-  std::unique_ptr<ModelAllocation> protoMap(new ModelAllocation{filename});
-
-  google::protobuf::io::CodedInputStream coded_input(
-          (const google::protobuf::uint8*)protoMap->getDataPnt(), protoMap->getNumBytes());
-  coded_input.SetTotalBytesLimit(protoBytesLimit, protoBytesWarningLimit);
-
-  bool success = proto->ParseFromCodedStream(&coded_input);
-
-  return success;
-}
-
-} // namespace nnc
-
-#endif // NNCC_PROTO_HELPER_H
diff --git a/contrib/nnc/include/support/ProtobufHelper.h b/contrib/nnc/include/support/ProtobufHelper.h
new file mode 100644 (file)
index 0000000..8fc82ee
--- /dev/null
@@ -0,0 +1,29 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NNCC_PROTOBUF_HELPER_H
+#define NNCC_PROTOBUF_HELPER_H
+
+#include <google/protobuf/message.h>
+#include <string>
+
+namespace nnc {
+
+void readBinaryProto(const std::string& filename, google::protobuf::Message* message);
+
+} // namespace nnc
+
+#endif // NNCC_PROTOBUF_HELPER_H
index 5a0e613..a6c543b 100644 (file)
@@ -21,7 +21,7 @@
 
 #include "caffe2_importer.h"
 #include "passes/common_frontend/shape_helper.h"
-#include "passes/common_frontend/proto_helper.h"
+#include "support/ProtobufHelper.h"
 
 #include "caffe2_op_types.h"
 #include "caffe2_op_creator.h"
@@ -54,17 +54,23 @@ void Caffe2Importer::cleanup() {
   delete _graph;
 }
 
+static void loadModelFile(const std::string& filename, caffe2::NetDef* net) {
+  try {
+    readBinaryProto(filename, net);
+  } catch (...) {
+    std::throw_with_nested(PassException("Couldn't load the model file \"" + filename + "\""));
+  }
+}
+
 void Caffe2Importer::import() {
   GOOGLE_PROTOBUF_VERIFY_VERSION;
 
   _net.reset(new NetDef());
-  if (!readProtoFromBinaryFile<::caffe2::NetDef>(_predictNet.c_str(), _net.get()))
-    throw PassException("Could not load model: " + _predictNet + "\n");
+  loadModelFile(_predictNet, _net.get());
 
   std::unique_ptr<NetDef> net2;
   net2.reset(new NetDef());
-  if (!readProtoFromBinaryFile<::caffe2::NetDef>(_initNet.c_str(), net2.get()))
-    throw PassException("Could not load model: " + _initNet + "\n");
+  loadModelFile(_initNet, net2.get());
   _net->MergeFrom(*net2);
 
   collectUnsupportedOps();
index 62f64ce..5258c0c 100644 (file)
@@ -28,7 +28,7 @@
 #include "pass/PassException.h"
 
 #include "passes/common_frontend/shape_helper.h"
-#include "passes/common_frontend/proto_helper.h"
+#include "support/ProtobufHelper.h"
 
 namespace nnc {
 
@@ -43,12 +43,19 @@ CaffeImporter::CaffeImporter(std::string filename) : _modelFilename(std::move(fi
 
 CaffeImporter::~CaffeImporter() {}
 
+static void loadModelFile(const std::string& filename, caffe::NetParameter* net) {
+  try {
+    readBinaryProto(filename, net);
+  } catch (...) {
+    std::throw_with_nested(PassException("Couldn't load the model file \"" + filename + "\""));
+  }
+}
+
 void CaffeImporter::import() {
   GOOGLE_PROTOBUF_VERIFY_VERSION;
 
   _net.reset(new NetParameter());
-  if (!readProtoFromBinaryFile<::caffe::NetParameter>(_modelFilename.c_str(), _net.get()))
-    throw PassException("Could not load model: " + _modelFilename + "\n");
+  loadModelFile(_modelFilename, _net.get());
 
   collectUnsupportedLayers();
 }
index 03d980e..f4b4188 100644 (file)
@@ -1,9 +1,8 @@
 set(SUPPORT_SOURCES
         CommandLine.cpp
-        CLOptionChecker.cpp)
+        CLOptionChecker.cpp
+        ProtobufHelper.cpp)
 
-add_library(nnc_support SHARED ${SUPPORT_SOURCES})
+add_library(nnc_support STATIC ${SUPPORT_SOURCES})
 set_target_properties(nnc_support PROPERTIES LINKER_LANGUAGE CXX)
-target_link_libraries(nnc_support PRIVATE dl)
-install_nnc_library(nnc_support)
-
+set_target_properties(nnc_support PROPERTIES POSITION_INDEPENDENT_CODE ON)
diff --git a/contrib/nnc/support/ProtobufHelper.cpp b/contrib/nnc/support/ProtobufHelper.cpp
new file mode 100644 (file)
index 0000000..34256a8
--- /dev/null
@@ -0,0 +1,53 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "support/ProtobufHelper.h"
+#include <google/protobuf/io/zero_copy_stream_impl.h>
+#include <google/protobuf/io/coded_stream.h>
+#include <cerrno>
+#include <cstring>
+#include <stdexcept>
+#include <fcntl.h>
+#include <unistd.h>
+
+namespace nnc {
+
+// Allow files up to 1GB, but warn about files larger than 512MB.
+static constexpr int protoTotalBytesLimit = 1024 << 20;
+static constexpr int protoTotalBytesWarningThreshold = 512 << 20;
+
+void readBinaryProto(const std::string& filename, google::protobuf::Message* message) {
+  int file_handle = open(filename.c_str(), O_RDONLY);
+
+  if (file_handle == -1)
+    throw std::runtime_error("Couldn't open file \"" + filename + "\": " +
+                             std::strerror(errno) + ".");
+
+  google::protobuf::io::FileInputStream file_stream(file_handle);
+  file_stream.SetCloseOnDelete(true);
+
+  google::protobuf::io::CodedInputStream coded_stream(&file_stream);
+  coded_stream.SetTotalBytesLimit(protoTotalBytesLimit, protoTotalBytesWarningThreshold);
+
+  if (!message->ParseFromCodedStream(&coded_stream))
+    throw std::runtime_error("Couldn't parse file \"" + filename + "\".");
+
+  // If the file has not been consumed entirely, assume that the file is in the wrong format.
+  if (!coded_stream.ConsumedEntireMessage())
+    throw std::runtime_error("File \"" + filename + "\" has not been consumed entirely.");
+}
+
+} // namespace nnc