[min_onnx] Support models stored as text (#9426)
authorSergei Barannikov/AI Tools Lab /SRR/Engineer/Samsung Electronics <s.barannikov@samsung.com>
Fri, 6 Dec 2019 10:29:59 +0000 (13:29 +0300)
committerAlexander Efimov/AI Tools Lab /SRR/Engineer/Samsung Electronics <a.efimov@samsung.com>
Fri, 6 Dec 2019 10:29:59 +0000 (13:29 +0300)
Add support for models stored as text.

Signed-off-by: Sergei Barannikov <s.barannikov@samsung.com>
compiler/mir-onnx-importer/ONNXImporterImpl.cpp
compiler/mir-onnx-importer/ONNXImporterImpl.h

index cbe03ed..800e982 100644 (file)
 #include "ONNXOpRegistration.h"
 #include "onnx/onnx.pb.h"
 
-#include "mir/Operation.h"
 #include "mir/Shape.h"
 #include "mir/TensorUtil.h"
-#include "mir/TensorVariant.h"
 
 #include "mir/ops/ConstantOp.h"
 
 
 #include <google/protobuf/io/zero_copy_stream_impl.h>
 #include <google/protobuf/io/coded_stream.h>
+#include <google/protobuf/text_format.h>
 #include <functional>
 #include <iostream>
 #include <stdex/Memory.h>
 #include <utility>
-#include <map>
 
-namespace
+namespace mir_onnx
 {
 
-using namespace mir_onnx;
+namespace
+{
 
 class ONNXImporterImpl final
 {
 public:
-  explicit ONNXImporterImpl(std::string filename);
+  ONNXImporterImpl();
   ~ONNXImporterImpl();
   /// @brief Load the model and convert it into a MIR Graph.
-  std::unique_ptr<mir::Graph> importModel();
+  std::unique_ptr<mir::Graph> importModelFromBinaryFile(const std::string &filename);
+  std::unique_ptr<mir::Graph> importModelFromTextFile(const std::string &filename);
 
 private:
-  void import();
   std::unique_ptr<mir::Graph> createIR();
   void createGraphInputs();
   void collectUnsupportedOps();
-  // Maps ONNX tensor names to corresponding MIR operation outputs.
-  std::string _modelFilename;
   std::unique_ptr<onnx::ModelProto> _model;
   std::unique_ptr<ConverterContext> _context;
   std::unique_ptr<mir::Graph> _graph;
 };
 
-ONNXImporterImpl::ONNXImporterImpl(std::string filename) : _modelFilename(std::move(filename))
-{
-  registerSupportedOps();
-}
+ONNXImporterImpl::ONNXImporterImpl() { registerSupportedOps(); }
 
 ONNXImporterImpl::~ONNXImporterImpl() = default;
 
-static void loadModelFile(const std::string &filename, onnx::ModelProto *model)
+void loadModelFromBinaryFile(const std::string &filename, onnx::ModelProto *model)
 {
   GOOGLE_PROTOBUF_VERIFY_VERSION;
 
@@ -92,12 +86,39 @@ static void loadModelFile(const std::string &filename, onnx::ModelProto *model)
     throw std::runtime_error("File \"" + filename + "\" has not been consumed entirely.");
 }
 
-void ONNXImporterImpl::import()
+void loadModelFromTextFile(const std::string &filename, onnx::ModelProto *model)
+{
+  GOOGLE_PROTOBUF_VERIFY_VERSION;
+
+  int file_handle = open(filename.c_str(), O_RDONLY);
+
+  if (file_handle == -1)
+    throw std::runtime_error("Couldn't open file \"" + filename + "\": " + std::strerror(errno) +
+                             ".");
+
+  google::protobuf::io::FileInputStream file_stream(file_handle);
+  file_stream.SetCloseOnDelete(true);
+
+  if (!google::protobuf::TextFormat::Parse(&file_stream, model))
+    throw std::runtime_error("Couldn't parse file \"" + filename + "\".");
+}
+
+std::unique_ptr<mir::Graph> ONNXImporterImpl::importModelFromBinaryFile(const std::string &filename)
 {
   _model = stdex::make_unique<onnx::ModelProto>();
-  loadModelFile(_modelFilename, _model.get());
+  loadModelFromBinaryFile(filename, _model.get());
 
   collectUnsupportedOps();
+  return createIR();
+}
+
+std::unique_ptr<mir::Graph> ONNXImporterImpl::importModelFromTextFile(const std::string &filename)
+{
+  _model = stdex::make_unique<onnx::ModelProto>();
+  loadModelFromTextFile(filename, _model.get());
+
+  collectUnsupportedOps();
+  return createIR();
 }
 
 void ONNXImporterImpl::collectUnsupportedOps()
@@ -201,20 +222,23 @@ std::unique_ptr<mir::Graph> ONNXImporterImpl::createIR()
   return std::move(_graph);
 }
 
-std::unique_ptr<mir::Graph> ONNXImporterImpl::importModel()
+} // namespace
+
+std::unique_ptr<mir::Graph> importModelFromBinaryFile(const std::string &filename)
 {
-  import();
-  return createIR();
-}
+  ONNXImporterImpl importer;
+  return importer.importModelFromBinaryFile(filename);
 }
 
-namespace mir_onnx
+std::unique_ptr<mir::Graph> importModelFromTextFile(const std::string &filename)
 {
+  ONNXImporterImpl importer;
+  return importer.importModelFromTextFile(filename);
+}
 
-std::unique_ptr<mir::Graph> loadModel(std::string filename)
+std::unique_ptr<mir::Graph> loadModel(const std::string &filename)
 {
-  ONNXImporterImpl importer(std::move(filename));
-  return importer.importModel();
+  return importModelFromBinaryFile(filename);
 }
 
 } // namespace mir_onnx
index 7b13370..02a49b3 100644 (file)
 namespace mir_onnx
 {
 
-std::unique_ptr<mir::Graph> loadModel(std::string filename);
+std::unique_ptr<mir::Graph> importModelFromBinaryFile(const std::string &filename);
+std::unique_ptr<mir::Graph> importModelFromTextFile(const std::string &filename);
+// TODO Remove after changing all uses.
+std::unique_ptr<mir::Graph> loadModel(const std::string &filename);
 
 } // namespace mir_onnx