2 * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "ONNXImporterImpl.h"
18 #include "ONNXHelpers.h"
19 #include "ONNXOpRegistration.h"
20 #include "onnx/onnx.pb.h"
22 #include "mir/Shape.h"
23 #include "mir/TensorUtil.h"
25 #include "mir/ops/ConstantOp.h"
29 #include <google/protobuf/io/zero_copy_stream_impl.h>
30 #include <google/protobuf/io/coded_stream.h>
31 #include <google/protobuf/text_format.h>
34 #include <stdex/Memory.h>
43 class ONNXImporterImpl final
48 /// @brief Load the model and convert it into a MIR Graph.
49 std::unique_ptr<mir::Graph> importModelFromBinaryFile(const std::string &filename);
50 std::unique_ptr<mir::Graph> importModelFromTextFile(const std::string &filename);
53 std::unique_ptr<mir::Graph> createIR();
54 void createGraphInputs();
55 void collectUnsupportedOps();
56 std::unique_ptr<onnx::ModelProto> _model;
57 std::unique_ptr<ConverterContext> _converterCtx;
58 std::unique_ptr<ModelContext> _modelCtx;
59 std::unique_ptr<mir::Graph> _graph;
62 ONNXImporterImpl::ONNXImporterImpl() { registerSupportedOps(); }
64 ONNXImporterImpl::~ONNXImporterImpl() = default;
66 void loadModelFromBinaryFile(const std::string &filename, onnx::ModelProto *model)
68 GOOGLE_PROTOBUF_VERIFY_VERSION;
70 int file_handle = open(filename.c_str(), O_RDONLY);
72 if (file_handle == -1)
73 throw std::runtime_error("Couldn't open file \"" + filename + "\": " + std::strerror(errno) +
76 google::protobuf::io::FileInputStream file_stream(file_handle);
77 file_stream.SetCloseOnDelete(true);
79 google::protobuf::io::CodedInputStream coded_stream(&file_stream);
80 coded_stream.SetTotalBytesLimit(INT_MAX, INT_MAX);
82 if (!model->ParseFromCodedStream(&coded_stream))
83 throw std::runtime_error("Couldn't parse file \"" + filename + "\".");
85 // If the file has not been consumed entirely, assume that the file is in the wrong format.
86 if (!coded_stream.ConsumedEntireMessage())
87 throw std::runtime_error("File \"" + filename + "\" has not been consumed entirely.");
90 void loadModelFromTextFile(const std::string &filename, onnx::ModelProto *model)
92 GOOGLE_PROTOBUF_VERIFY_VERSION;
94 int file_handle = open(filename.c_str(), O_RDONLY);
96 if (file_handle == -1)
97 throw std::runtime_error("Couldn't open file \"" + filename + "\": " + std::strerror(errno) +
100 google::protobuf::io::FileInputStream file_stream(file_handle);
101 file_stream.SetCloseOnDelete(true);
103 if (!google::protobuf::TextFormat::Parse(&file_stream, model))
104 throw std::runtime_error("Couldn't parse file \"" + filename + "\".");
107 std::unique_ptr<mir::Graph> ONNXImporterImpl::importModelFromBinaryFile(const std::string &filename)
109 _model = stdex::make_unique<onnx::ModelProto>();
110 loadModelFromBinaryFile(filename, _model.get());
111 _modelCtx = stdex::make_unique<ModelContext>(_model.get());
112 collectUnsupportedOps();
116 std::unique_ptr<mir::Graph> ONNXImporterImpl::importModelFromTextFile(const std::string &filename)
118 _model = stdex::make_unique<onnx::ModelProto>();
119 loadModelFromTextFile(filename, _model.get());
120 _modelCtx = stdex::make_unique<ModelContext>(_model.get());
121 collectUnsupportedOps();
125 void ONNXImporterImpl::collectUnsupportedOps()
127 std::set<std::pair<std::string, int64_t>> problems_op_set;
129 for (int i = 0; i < _model->graph().node_size(); i++)
131 const auto &onnx_node = _model->graph().node(i);
132 assert(onnx_node.has_op_type());
133 const auto &op_type = onnx_node.op_type();
134 auto opset = _modelCtx->getDomainOpsetVersion(onnx_node.domain());
136 NodeConverterRegistry::ConverterFunc converter =
137 NodeConverterRegistry::getInstance().lookup(op_type, opset);
139 if (converter == nullptr)
140 problems_op_set.emplace(op_type, opset);
142 if (!problems_op_set.empty())
144 std::cerr << "The following operators are not supported:\n";
145 for (const auto &op : problems_op_set)
146 std::cerr << op.first << " opset " << op.second << std::endl;
147 throw std::runtime_error("Unsupported operators found");
151 void ONNXImporterImpl::createGraphInputs()
153 const auto &graph = _model->graph();
154 const auto &initializer = graph.initializer();
155 const auto &value_info = graph.value_info();
157 // Create all initializer Tensors
158 for (const auto &tensor : initializer)
160 const auto mir_tensor = createTensor(&tensor);
161 auto *op = _graph->create<mir::ops::ConstantOp>(mir_tensor);
162 _converterCtx->setOutput(tensor.name(), op->getOutput(0));
165 for (const auto &input : graph.input())
167 assert(input.has_name());
169 if (_converterCtx->getOutput(input.name()) == nullptr)
171 const auto &onnx_input_shape = input.type().tensor_type().shape();
172 mir::Shape shape(onnx_input_shape.dim_size());
173 for (int i = 0; i < onnx_input_shape.dim_size(); i++)
175 assert(onnx_input_shape.dim(i).has_dim_value());
176 shape.dim(i) = static_cast<int32_t>(onnx_input_shape.dim(i).dim_value());
179 auto elem_type = onnxDataTypeToMirDataType(
180 (onnx::TensorProto_DataType)input.type().tensor_type().elem_type());
181 mir::TensorType type{elem_type, shape};
182 auto *op = _graph->create<mir::ops::InputOp>(type);
183 _converterCtx->setOutput(input.name(), op->getOutput(0));
188 std::unique_ptr<mir::Graph> ONNXImporterImpl::createIR()
190 _graph = stdex::make_unique<mir::Graph>();
191 _converterCtx = stdex::make_unique<ConverterContext>(_graph.get());
195 // Forming partially ordered computation graph
196 for (const auto &onnx_node : _model->graph().node())
198 assert(onnx_node.has_op_type());
199 auto &op_type = onnx_node.op_type();
200 auto opset = _modelCtx->getDomainOpsetVersion(onnx_node.domain());
202 NodeConverterRegistry::ConverterFunc converter =
203 NodeConverterRegistry::getInstance().lookup(op_type, opset);
204 assert(converter != nullptr);
205 converter(onnx_node, _converterCtx.get());
208 const auto &outputs = _model->graph().output();
209 for (const auto &output : outputs)
211 assert(output.has_name());
212 auto mir_output = _converterCtx->getOutput(output.name());
213 if (mir_output == nullptr)
214 throw std::runtime_error("Bad output name!");
216 _graph->create<mir::ops::OutputOp>(mir_output);
219 return std::move(_graph);
224 std::unique_ptr<mir::Graph> importModelFromBinaryFile(const std::string &filename)
226 ONNXImporterImpl importer;
227 return importer.importModelFromBinaryFile(filename);
230 std::unique_ptr<mir::Graph> importModelFromTextFile(const std::string &filename)
232 ONNXImporterImpl importer;
233 return importer.importModelFromTextFile(filename);
236 std::unique_ptr<mir::Graph> loadModel(const std::string &filename)
238 return importModelFromBinaryFile(filename);
241 } // namespace mir_onnx