Imported Upstream version 1.4.0
[platform/core/ml/nnfw.git] / compiler / mir-tflite-importer / tflite_importer.cpp
1 /*
2  * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "tflite_importer.h"
18 #include "tflite_op_creator.h"
19 #include "schema_generated.h"
20
21 #include "mir/TensorVariant.h"
22 #include "mir/ops/ConstantOp.h"
23 #include "mir/ops/OutputOp.h"
24
25 #include <fstream>
26 #include <stdex/Memory.h>
27 #include <utility>
28 #include <vector>
29 #include <set>
30
31 namespace mir_tflite
32 {
33
34 namespace
35 {
36
37 class TfliteImporter
38 {
39 public:
40   explicit TfliteImporter(std::string filename);
41
42   /// @brief Load the model and convert it into a MIR Graph.
43   std::unique_ptr<mir::Graph> importModel();
44
45   ~TfliteImporter();
46
47 private:
48   std::string _filename;
49   std::unique_ptr<tflite::ModelT> _model;
50
51   std::unique_ptr<mir::Graph> _graph;
52   std::unique_ptr<TFLiteOpCreator> _opCreator;
53
54   // Maps TFLite tensors indices to corresponding MIR operation outputs.
55   std::vector<mir::Operation::Output *> _tensorMap;
56
57   void import();
58
59   void walkModel(const tflite::ModelT *model);
60
61   void walkSubgraph(const tflite::SubGraphT *subgraph);
62
63   void walkOperator(const tflite::SubGraphT *subgraph, const tflite::OperatorT *op);
64
65   /**
66    * @brief Pass through tflite graph and collect operators unsupported by NNC
67    * @throw PassException with message, containing detected problems
68    */
69   void collectUnsupportedOps();
70
71   /**
72    * @brief Returns MIR operation outputs corresponding to the inputs of the given operator.
73    */
74   std::vector<mir::Operation::Output *> getMIRInputsForOperator(const tflite::SubGraphT *subgraph,
75                                                                 const tflite::OperatorT *op);
76 };
77
78 TfliteImporter::TfliteImporter(std::string filename) : _filename(std::move(filename))
79 {
80   _graph = stdex::make_unique<mir::Graph>();
81   _opCreator = stdex::make_unique<TFLiteOpCreator>(_graph.get());
82 }
83
84 TfliteImporter::~TfliteImporter() = default;
85
86 void TfliteImporter::import()
87 {
88   std::ifstream stream(_filename, std::ios::in | std::ios::binary);
89   if (stream.fail())
90     throw std::runtime_error("Couldn't open file \"" + _filename + "\".");
91
92   std::vector<char> model_buffer((std::istreambuf_iterator<char>(stream)),
93                                  std::istreambuf_iterator<char>());
94
95   if (stream.fail())
96     throw std::runtime_error("Couldn't read file \"" + _filename + "\".");
97
98   flatbuffers::Verifier verifier(reinterpret_cast<const std::uint8_t *>(model_buffer.data()),
99                                  model_buffer.size());
100
101   if (!tflite::VerifyModelBuffer(verifier))
102     throw std::runtime_error("Could not load model: " + _filename + "\n");
103
104   _model = tflite::UnPackModel(model_buffer.data());
105 }
106
107 static const std::set<tflite::BuiltinOperator> supportedOperators = {
108     tflite::BuiltinOperator_ADD,
109     tflite::BuiltinOperator_AVERAGE_POOL_2D,
110     tflite::BuiltinOperator_CONCATENATION,
111     tflite::BuiltinOperator_CONV_2D,
112     tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
113     tflite::BuiltinOperator_DIV,
114     tflite::BuiltinOperator_FULLY_CONNECTED,
115     tflite::BuiltinOperator_HARD_SWISH,
116     tflite::BuiltinOperator_LEAKY_RELU,
117     tflite::BuiltinOperator_LOGISTIC,
118     tflite::BuiltinOperator_MAX_POOL_2D,
119     tflite::BuiltinOperator_MAXIMUM,
120     tflite::BuiltinOperator_MEAN,
121     tflite::BuiltinOperator_MUL,
122     tflite::BuiltinOperator_PAD,
123     tflite::BuiltinOperator_RELU,
124     tflite::BuiltinOperator_RELU6,
125     tflite::BuiltinOperator_RESHAPE,
126     tflite::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR,
127     tflite::BuiltinOperator_RSQRT,
128     tflite::BuiltinOperator_SHAPE,
129     tflite::BuiltinOperator_SLICE,
130     tflite::BuiltinOperator_SOFTMAX,
131     tflite::BuiltinOperator_SQRT,
132     tflite::BuiltinOperator_SQUARED_DIFFERENCE,
133     tflite::BuiltinOperator_SQUEEZE,
134     tflite::BuiltinOperator_STRIDED_SLICE,
135     tflite::BuiltinOperator_SUB,
136     tflite::BuiltinOperator_TANH,
137     tflite::BuiltinOperator_TRANSPOSE,
138     tflite::BuiltinOperator_TRANSPOSE_CONV,
139 };
140
141 void TfliteImporter::collectUnsupportedOps()
142 {
143   std::set<std::string> errors;
144   for (const auto &subgraph : _model->subgraphs)
145     for (const auto &op : subgraph->operators)
146     {
147       tflite::BuiltinOperator opcode = _model->operator_codes[op->opcode_index]->builtin_code;
148       if (supportedOperators.find(opcode) == supportedOperators.end())
149       {
150         if (opcode <= tflite::BuiltinOperator_MAX)
151           errors.insert(std::string(EnumNameBuiltinOperator(opcode)) + ": unsupported operator");
152         else
153           errors.insert(std::to_string(opcode) + ": unsuppored in tflite custom opcode");
154       }
155     }
156
157   if (!errors.empty())
158   {
159     std::string msg("NNC can't load model. Detected problems:");
160     for (const auto &e : errors)
161       msg.append("\n  * " + e);
162     throw std::runtime_error(msg);
163   }
164 }
165
166 std::unique_ptr<mir::Graph> TfliteImporter::importModel()
167 {
168   import();
169   collectUnsupportedOps();
170   walkModel(_model.get());
171   return std::move(_graph);
172 }
173
174 void TfliteImporter::walkModel(const tflite::ModelT *model)
175 {
176   for (const auto &subgraph : model->subgraphs)
177     walkSubgraph(subgraph.get());
178 }
179
180 mir::DataType convertElementType(tflite::TensorType type)
181 {
182   switch (type)
183   {
184     case tflite::TensorType_INT32:
185       return mir::DataType::INT32;
186     case tflite::TensorType_FLOAT32:
187       return mir::DataType::FLOAT32;
188     case tflite::TensorType_INT64:
189       return mir::DataType::INT64;
190     case tflite::TensorType_UINT8:
191       return mir::DataType::UINT8;
192     default:
193       throw std::runtime_error(std::string("Unsupported tensor type: ") + EnumNameTensorType(type));
194   }
195 }
196
197 mir::TensorType getMirTensorType(const tflite::TensorT &tensor)
198 {
199   mir::DataType element_type = convertElementType(tensor.type);
200
201   mir::Shape shape(tensor.shape.size());
202   for (std::size_t i = 0; i < tensor.shape.size(); ++i)
203   {
204     shape.dim(i) = tensor.shape[i];
205   }
206
207   if (tensor.quantization != nullptr)
208   {
209     const tflite::QuantizationParametersT &params = *tensor.quantization;
210
211     if (params.details.type != tflite::QuantizationDetails_NONE)
212       throw std::runtime_error("Custom quantization is not supported.");
213
214     // Empty parameters mean no quantization at all.
215     if (params.scale.empty() && params.zero_point.empty())
216       return mir::TensorType{element_type, shape};
217
218     if (params.scale.size() != 1 || params.zero_point.size() != 1)
219       throw std::runtime_error("Non-scalar quantization is not supported.");
220
221     mir::AffineQuantization quantization{params.scale[0], static_cast<int>(params.zero_point[0])};
222
223     return mir::TensorType{element_type, shape, quantization};
224   }
225   else
226   {
227     return mir::TensorType{element_type, shape};
228   }
229 }
230
231 void TfliteImporter::walkSubgraph(const tflite::SubGraphT *subgraph)
232 {
233   _tensorMap.assign(subgraph->tensors.size(), nullptr);
234
235   for (const auto input_tensor_index : subgraph->inputs)
236   {
237     const tflite::TensorT &tensor = *subgraph->tensors[input_tensor_index];
238
239     mir::TensorType input_type = getMirTensorType(tensor);
240     auto input = _graph->create<mir::ops::InputOp>(input_type)->getOutput(0);
241     input->setName(tensor.name);
242
243     assert(_tensorMap[input_tensor_index] == nullptr);
244     _tensorMap[input_tensor_index] = input;
245   }
246
247   for (const auto &op : subgraph->operators)
248   {
249     walkOperator(subgraph, op.get());
250   }
251
252   for (const auto output_tensor_index : subgraph->outputs)
253   {
254     auto output = _tensorMap[output_tensor_index];
255     _graph->create<mir::ops::OutputOp>(output);
256   }
257 }
258
259 void TfliteImporter::walkOperator(const tflite::SubGraphT *subgraph, const tflite::OperatorT *op)
260 {
261   std::vector<mir::Operation::Output *> inputs = getMIRInputsForOperator(subgraph, op);
262   std::vector<mir::Operation::Output *> outputs;
263
264   tflite::BuiltinOperator opcode = _model->operator_codes[op->opcode_index]->builtin_code;
265   switch (opcode)
266   {
267     case tflite::BuiltinOperator_CONV_2D:
268       outputs = _opCreator->convertConv2D(op->builtin_options.AsConv2DOptions(), inputs);
269       break;
270     case tflite::BuiltinOperator_DEPTHWISE_CONV_2D:
271       outputs = _opCreator->convertDepthwiseConv2D(op->builtin_options.AsDepthwiseConv2DOptions(),
272                                                    inputs);
273       break;
274     case tflite::BuiltinOperator_MAX_POOL_2D:
275       outputs = _opCreator->convertMaxPool2D(op->builtin_options.AsPool2DOptions(), inputs);
276       break;
277     case tflite::BuiltinOperator_AVERAGE_POOL_2D:
278       outputs = _opCreator->convertAveragePool2D(op->builtin_options.AsPool2DOptions(), inputs);
279       break;
280     case tflite::BuiltinOperator_CONCATENATION:
281       outputs =
282           _opCreator->convertConcatenation(op->builtin_options.AsConcatenationOptions(), inputs);
283       break;
284     case tflite::BuiltinOperator_RESHAPE:
285       outputs = _opCreator->convertReshape(op->builtin_options.AsReshapeOptions(), inputs);
286       break;
287     case tflite::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR:
288       outputs = _opCreator->convertResizeNearestNeighbor(
289           op->builtin_options.AsResizeNearestNeighborOptions(), inputs);
290       break;
291     case tflite::BuiltinOperator_MEAN:
292       outputs = _opCreator->convertMean(op->builtin_options.AsReducerOptions(), inputs);
293       break;
294     case tflite::BuiltinOperator_FULLY_CONNECTED:
295       outputs =
296           _opCreator->convertFullyConnected(op->builtin_options.AsFullyConnectedOptions(), inputs);
297       break;
298     case tflite::BuiltinOperator_SOFTMAX:
299       outputs = _opCreator->convertSoftmax(op->builtin_options.AsSoftmaxOptions(), inputs);
300       break;
301     case tflite::BuiltinOperator_SLICE:
302       outputs = _opCreator->convertSlice(op->builtin_options.AsSliceOptions(), inputs);
303       break;
304     case tflite::BuiltinOperator_SQUEEZE:
305       outputs = _opCreator->convertSqueeze(op->builtin_options.AsSqueezeOptions(), inputs);
306       break;
307     case tflite::BuiltinOperator_LOGISTIC:
308       outputs = _opCreator->convertLogistic(inputs);
309       break;
310     case tflite::BuiltinOperator_RSQRT:
311       outputs = _opCreator->convertRsqrt(inputs);
312       break;
313     case tflite::BuiltinOperator_SQRT:
314       outputs = _opCreator->convertSqrt(inputs);
315       break;
316     case tflite::BuiltinOperator_ADD:
317       outputs = _opCreator->convertAdd(op->builtin_options.AsAddOptions(), inputs);
318       break;
319     case tflite::BuiltinOperator_SUB:
320       outputs = _opCreator->convertSub(op->builtin_options.AsSubOptions(), inputs);
321       break;
322     case tflite::BuiltinOperator_MUL:
323       outputs = _opCreator->convertMul(op->builtin_options.AsMulOptions(), inputs);
324       break;
325     case tflite::BuiltinOperator_DIV:
326       outputs = _opCreator->convertDiv(op->builtin_options.AsDivOptions(), inputs);
327       break;
328     case tflite::BuiltinOperator_MAXIMUM:
329       outputs = _opCreator->convertMax(inputs);
330       break;
331     case tflite::BuiltinOperator_SQUARED_DIFFERENCE:
332       outputs = _opCreator->convertSquaredDifference(inputs);
333       break;
334     case tflite::BuiltinOperator_TRANSPOSE_CONV:
335       outputs =
336           _opCreator->convertTransposeConv(op->builtin_options.AsTransposeConvOptions(), inputs);
337       break;
338     case tflite::BuiltinOperator_PAD:
339       outputs = _opCreator->convertPad(op->builtin_options.AsPadOptions(), inputs);
340       break;
341     case tflite::BuiltinOperator_TANH:
342       outputs = _opCreator->convertTanh(inputs);
343       break;
344     case tflite::BuiltinOperator_RELU:
345       outputs = _opCreator->convertReLU(inputs);
346       break;
347     case tflite::BuiltinOperator_RELU6:
348       outputs = _opCreator->convertReLU6(inputs);
349       break;
350     case tflite::BuiltinOperator_TRANSPOSE:
351       outputs = _opCreator->convertTranspose(op->builtin_options.AsTransposeOptions(), inputs);
352       break;
353     case tflite::BuiltinOperator_STRIDED_SLICE:
354       outputs =
355           _opCreator->convertStridedSlice(op->builtin_options.AsStridedSliceOptions(), inputs);
356       break;
357     case tflite::BuiltinOperator_LEAKY_RELU:
358       outputs = _opCreator->convertLeakyReLU(op->builtin_options.AsLeakyReluOptions(), inputs);
359       break;
360     case tflite::BuiltinOperator_SHAPE:
361       outputs = _opCreator->convertShape(op->builtin_options.AsShapeOptions(), inputs);
362       break;
363     case tflite::BuiltinOperator_HARD_SWISH:
364       outputs = _opCreator->convertHardSwish(op->builtin_options.AsHardSwishOptions(), inputs);
365       break;
366     default:
367       assert(false && "All unsupported types should have been found before this pass.");
368   }
369
370   assert(outputs.size() == op->outputs.size());
371   for (std::size_t i = 0; i < op->outputs.size(); ++i)
372   {
373     const auto tensor_index = op->outputs[i];
374     const tflite::TensorT &tensor = *subgraph->tensors[tensor_index];
375
376     mir::TensorType output_type = getMirTensorType(tensor);
377
378     // The type should have been inferred correctly, except for quantization information.
379     assert(outputs[i]->getType().getElementType() == output_type.getElementType() &&
380            outputs[i]->getType().getShape() == output_type.getShape());
381
382     outputs[i]->setName(tensor.name);
383     outputs[i]->setType(output_type);
384
385     assert(_tensorMap[tensor_index] == nullptr);
386     _tensorMap[tensor_index] = outputs[i];
387   }
388 }
389
390 std::vector<mir::Operation::Output *>
391 TfliteImporter::getMIRInputsForOperator(const tflite::SubGraphT *subgraph,
392                                         const tflite::OperatorT *op)
393 {
394   std::vector<mir::Operation::Output *> inputs;
395
396   for (const auto tensor_index : op->inputs)
397   {
398     const tflite::TensorT &tensor = *subgraph->tensors[tensor_index];
399     const tflite::BufferT &buffer = *_model->buffers[tensor.buffer];
400     if (!buffer.data.empty())
401     {
402       assert(_tensorMap[tensor_index] == nullptr);
403       mir::TensorType type = getMirTensorType(tensor);
404       mir::TensorVariant mir_tensor{type, buffer.data.data()};
405       inputs.emplace_back(_graph->create<mir::ops::ConstantOp>(mir_tensor)->getOutput(0));
406     }
407     else
408     {
409       assert(_tensorMap[tensor_index] != nullptr);
410       // By this point every input for the operation "op" should have corresponding
411       // Model IR operations that output its inputs. This assumption is provided by the fact
412       // that TFLite format specifies all operations in the execution order.
413       inputs.emplace_back(_tensorMap[tensor_index]);
414     }
415   }
416
417   return inputs;
418 }
419
420 } // namespace
421
422 std::unique_ptr<mir::Graph> loadModel(std::string filename)
423 {
424   TfliteImporter importer(std::move(filename));
425   return importer.importModel();
426 }
427
428 } // namespace mir_tflite