Imported Upstream version 1.4.0
[platform/core/ml/nnfw.git] / compiler / mir-onnx-importer / ONNXImporterImpl.cpp
1 /*
2  * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "ONNXImporterImpl.h"
18 #include "ONNXHelpers.h"
19 #include "ONNXOpRegistration.h"
20 #include "onnx/onnx.pb.h"
21
22 #include "mir/Shape.h"
23 #include "mir/TensorUtil.h"
24
25 #include "mir/ops/ConstantOp.h"
26
27 #include <fcntl.h>
28
29 #include <google/protobuf/io/zero_copy_stream_impl.h>
30 #include <google/protobuf/io/coded_stream.h>
31 #include <google/protobuf/text_format.h>
32 #include <functional>
33 #include <iostream>
34 #include <stdex/Memory.h>
35 #include <utility>
36
37 namespace mir_onnx
38 {
39
40 namespace
41 {
42
43 class ONNXImporterImpl final
44 {
45 public:
46   ONNXImporterImpl();
47   ~ONNXImporterImpl();
48   /// @brief Load the model and convert it into a MIR Graph.
49   std::unique_ptr<mir::Graph> importModelFromBinaryFile(const std::string &filename);
50   std::unique_ptr<mir::Graph> importModelFromTextFile(const std::string &filename);
51
52 private:
53   std::unique_ptr<mir::Graph> createIR();
54   void createGraphInputs();
55   void collectUnsupportedOps();
56   std::unique_ptr<onnx::ModelProto> _model;
57   std::unique_ptr<ConverterContext> _converterCtx;
58   std::unique_ptr<ModelContext> _modelCtx;
59   std::unique_ptr<mir::Graph> _graph;
60 };
61
62 ONNXImporterImpl::ONNXImporterImpl() { registerSupportedOps(); }
63
64 ONNXImporterImpl::~ONNXImporterImpl() = default;
65
66 void loadModelFromBinaryFile(const std::string &filename, onnx::ModelProto *model)
67 {
68   GOOGLE_PROTOBUF_VERIFY_VERSION;
69
70   int file_handle = open(filename.c_str(), O_RDONLY);
71
72   if (file_handle == -1)
73     throw std::runtime_error("Couldn't open file \"" + filename + "\": " + std::strerror(errno) +
74                              ".");
75
76   google::protobuf::io::FileInputStream file_stream(file_handle);
77   file_stream.SetCloseOnDelete(true);
78
79   google::protobuf::io::CodedInputStream coded_stream(&file_stream);
80   coded_stream.SetTotalBytesLimit(INT_MAX, INT_MAX);
81
82   if (!model->ParseFromCodedStream(&coded_stream))
83     throw std::runtime_error("Couldn't parse file \"" + filename + "\".");
84
85   // If the file has not been consumed entirely, assume that the file is in the wrong format.
86   if (!coded_stream.ConsumedEntireMessage())
87     throw std::runtime_error("File \"" + filename + "\" has not been consumed entirely.");
88 }
89
90 void loadModelFromTextFile(const std::string &filename, onnx::ModelProto *model)
91 {
92   GOOGLE_PROTOBUF_VERIFY_VERSION;
93
94   int file_handle = open(filename.c_str(), O_RDONLY);
95
96   if (file_handle == -1)
97     throw std::runtime_error("Couldn't open file \"" + filename + "\": " + std::strerror(errno) +
98                              ".");
99
100   google::protobuf::io::FileInputStream file_stream(file_handle);
101   file_stream.SetCloseOnDelete(true);
102
103   if (!google::protobuf::TextFormat::Parse(&file_stream, model))
104     throw std::runtime_error("Couldn't parse file \"" + filename + "\".");
105 }
106
107 std::unique_ptr<mir::Graph> ONNXImporterImpl::importModelFromBinaryFile(const std::string &filename)
108 {
109   _model = stdex::make_unique<onnx::ModelProto>();
110   loadModelFromBinaryFile(filename, _model.get());
111   _modelCtx = stdex::make_unique<ModelContext>(_model.get());
112   collectUnsupportedOps();
113   return createIR();
114 }
115
116 std::unique_ptr<mir::Graph> ONNXImporterImpl::importModelFromTextFile(const std::string &filename)
117 {
118   _model = stdex::make_unique<onnx::ModelProto>();
119   loadModelFromTextFile(filename, _model.get());
120   _modelCtx = stdex::make_unique<ModelContext>(_model.get());
121   collectUnsupportedOps();
122   return createIR();
123 }
124
125 void ONNXImporterImpl::collectUnsupportedOps()
126 {
127   std::set<std::pair<std::string, int64_t>> problems_op_set;
128
129   for (int i = 0; i < _model->graph().node_size(); i++)
130   {
131     const auto &onnx_node = _model->graph().node(i);
132     assert(onnx_node.has_op_type());
133     const auto &op_type = onnx_node.op_type();
134     auto opset = _modelCtx->getDomainOpsetVersion(onnx_node.domain());
135
136     NodeConverterRegistry::ConverterFunc converter =
137         NodeConverterRegistry::getInstance().lookup(op_type, opset);
138
139     if (converter == nullptr)
140       problems_op_set.emplace(op_type, opset);
141   }
142   if (!problems_op_set.empty())
143   {
144     std::cerr << "The following operators are not supported:\n";
145     for (const auto &op : problems_op_set)
146       std::cerr << op.first << " opset " << op.second << std::endl;
147     throw std::runtime_error("Unsupported operators found");
148   }
149 }
150
151 void ONNXImporterImpl::createGraphInputs()
152 {
153   const auto &graph = _model->graph();
154   const auto &initializer = graph.initializer();
155   const auto &value_info = graph.value_info();
156
157   // Create all initializer Tensors
158   for (const auto &tensor : initializer)
159   {
160     const auto mir_tensor = createTensor(&tensor);
161     auto *op = _graph->create<mir::ops::ConstantOp>(mir_tensor);
162     _converterCtx->setOutput(tensor.name(), op->getOutput(0));
163   }
164
165   for (const auto &input : graph.input())
166   {
167     assert(input.has_name());
168
169     if (_converterCtx->getOutput(input.name()) == nullptr)
170     {
171       const auto &onnx_input_shape = input.type().tensor_type().shape();
172       mir::Shape shape(onnx_input_shape.dim_size());
173       for (int i = 0; i < onnx_input_shape.dim_size(); i++)
174       {
175         assert(onnx_input_shape.dim(i).has_dim_value());
176         shape.dim(i) = static_cast<int32_t>(onnx_input_shape.dim(i).dim_value());
177       }
178
179       auto elem_type = onnxDataTypeToMirDataType(
180           (onnx::TensorProto_DataType)input.type().tensor_type().elem_type());
181       mir::TensorType type{elem_type, shape};
182       auto *op = _graph->create<mir::ops::InputOp>(type);
183       _converterCtx->setOutput(input.name(), op->getOutput(0));
184     }
185   }
186 }
187
188 std::unique_ptr<mir::Graph> ONNXImporterImpl::createIR()
189 {
190   _graph = stdex::make_unique<mir::Graph>();
191   _converterCtx = stdex::make_unique<ConverterContext>(_graph.get());
192
193   createGraphInputs();
194
195   // Forming partially ordered computation graph
196   for (const auto &onnx_node : _model->graph().node())
197   {
198     assert(onnx_node.has_op_type());
199     auto &op_type = onnx_node.op_type();
200     auto opset = _modelCtx->getDomainOpsetVersion(onnx_node.domain());
201     // Get converter
202     NodeConverterRegistry::ConverterFunc converter =
203         NodeConverterRegistry::getInstance().lookup(op_type, opset);
204     assert(converter != nullptr);
205     converter(onnx_node, _converterCtx.get());
206   }
207   // Set graph outputs
208   const auto &outputs = _model->graph().output();
209   for (const auto &output : outputs)
210   {
211     assert(output.has_name());
212     auto mir_output = _converterCtx->getOutput(output.name());
213     if (mir_output == nullptr)
214       throw std::runtime_error("Bad output name!");
215
216     _graph->create<mir::ops::OutputOp>(mir_output);
217   }
218
219   return std::move(_graph);
220 }
221
222 } // namespace
223
224 std::unique_ptr<mir::Graph> importModelFromBinaryFile(const std::string &filename)
225 {
226   ONNXImporterImpl importer;
227   return importer.importModelFromBinaryFile(filename);
228 }
229
230 std::unique_ptr<mir::Graph> importModelFromTextFile(const std::string &filename)
231 {
232   ONNXImporterImpl importer;
233   return importer.importModelFromTextFile(filename);
234 }
235
236 std::unique_ptr<mir::Graph> loadModel(const std::string &filename)
237 {
238   return importModelFromBinaryFile(filename);
239 }
240
241 } // namespace mir_onnx