#include "ONNXOpRegistration.h"
#include "onnx/onnx.pb.h"
-#include "mir/Operation.h"
#include "mir/Shape.h"
#include "mir/TensorUtil.h"
-#include "mir/TensorVariant.h"
#include "mir/ops/ConstantOp.h"
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <google/protobuf/io/coded_stream.h>
+#include <google/protobuf/text_format.h>
#include <functional>
#include <iostream>
#include <stdex/Memory.h>
#include <utility>
-#include <map>
-namespace
+namespace mir_onnx
{
-using namespace mir_onnx;
+namespace
+{
class ONNXImporterImpl final
{
public:
- explicit ONNXImporterImpl(std::string filename);
+ ONNXImporterImpl();
~ONNXImporterImpl();
/// @brief Load the model and convert it into a MIR Graph.
- std::unique_ptr<mir::Graph> importModel();
+ std::unique_ptr<mir::Graph> importModelFromBinaryFile(const std::string &filename);
+ std::unique_ptr<mir::Graph> importModelFromTextFile(const std::string &filename);
private:
- void import();
std::unique_ptr<mir::Graph> createIR();
void createGraphInputs();
void collectUnsupportedOps();
- // Maps ONNX tensor names to corresponding MIR operation outputs.
- std::string _modelFilename;
std::unique_ptr<onnx::ModelProto> _model;
std::unique_ptr<ConverterContext> _context;
std::unique_ptr<mir::Graph> _graph;
};
-ONNXImporterImpl::ONNXImporterImpl(std::string filename) : _modelFilename(std::move(filename))
-{
- registerSupportedOps();
-}
+ONNXImporterImpl::ONNXImporterImpl() { registerSupportedOps(); }
ONNXImporterImpl::~ONNXImporterImpl() = default;
-static void loadModelFile(const std::string &filename, onnx::ModelProto *model)
+void loadModelFromBinaryFile(const std::string &filename, onnx::ModelProto *model)
{
GOOGLE_PROTOBUF_VERIFY_VERSION;
throw std::runtime_error("File \"" + filename + "\" has not been consumed entirely.");
}
-void ONNXImporterImpl::import()
+void loadModelFromTextFile(const std::string &filename, onnx::ModelProto *model)
+{
+ GOOGLE_PROTOBUF_VERIFY_VERSION;
+
+ int file_handle = open(filename.c_str(), O_RDONLY);
+
+ if (file_handle == -1)
+ throw std::runtime_error("Couldn't open file \"" + filename + "\": " + std::strerror(errno) +
+ ".");
+
+ google::protobuf::io::FileInputStream file_stream(file_handle);
+ file_stream.SetCloseOnDelete(true);
+
+ if (!google::protobuf::TextFormat::Parse(&file_stream, model))
+ throw std::runtime_error("Couldn't parse file \"" + filename + "\".");
+}
+
+std::unique_ptr<mir::Graph> ONNXImporterImpl::importModelFromBinaryFile(const std::string &filename)
{
_model = stdex::make_unique<onnx::ModelProto>();
- loadModelFile(_modelFilename, _model.get());
+ loadModelFromBinaryFile(filename, _model.get());
collectUnsupportedOps();
+ return createIR();
+}
+
+std::unique_ptr<mir::Graph> ONNXImporterImpl::importModelFromTextFile(const std::string &filename)
+{
+ _model = stdex::make_unique<onnx::ModelProto>();
+ loadModelFromTextFile(filename, _model.get());
+
+ collectUnsupportedOps();
+ return createIR();
}
void ONNXImporterImpl::collectUnsupportedOps()
return std::move(_graph);
}
-std::unique_ptr<mir::Graph> ONNXImporterImpl::importModel()
+} // namespace
+
+std::unique_ptr<mir::Graph> importModelFromBinaryFile(const std::string &filename)
{
- import();
- return createIR();
-}
+ ONNXImporterImpl importer;
+ return importer.importModelFromBinaryFile(filename);
}
-namespace mir_onnx
+std::unique_ptr<mir::Graph> importModelFromTextFile(const std::string &filename)
{
+ ONNXImporterImpl importer;
+ return importer.importModelFromTextFile(filename);
+}
-std::unique_ptr<mir::Graph> loadModel(std::string filename)
+std::unique_ptr<mir::Graph> loadModel(const std::string &filename)
{
- ONNXImporterImpl importer(std::move(filename));
- return importer.importModel();
+ return importModelFromBinaryFile(filename);
}
} // namespace mir_onnx