2 * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "tflite_importer.h"
18 #include "tflite_op_creator.h"
19 #include "schema_generated.h"
21 #include "mir/TensorVariant.h"
22 #include "mir/ops/ConstantOp.h"
23 #include "mir/ops/OutputOp.h"
26 #include <stdex/Memory.h>
40 explicit TfliteImporter(std::string filename);
42 /// @brief Load the model and convert it into a MIR Graph.
43 std::unique_ptr<mir::Graph> importModel();
48 std::string _filename;
49 std::unique_ptr<tflite::ModelT> _model;
51 std::unique_ptr<mir::Graph> _graph;
52 std::unique_ptr<TFLiteOpCreator> _opCreator;
54 // Maps TFLite tensors indices to corresponding MIR operation outputs.
55 std::vector<mir::Operation::Output *> _tensorMap;
59 void walkModel(const tflite::ModelT *model);
61 void walkSubgraph(const tflite::SubGraphT *subgraph);
63 void walkOperator(const tflite::SubGraphT *subgraph, const tflite::OperatorT *op);
66 * @brief Pass through tflite graph and collect operators unsupported by NNC
67 * @throw PassException with message, containing detected problems
69 void collectUnsupportedOps();
72 * @brief Returns MIR operation outputs corresponding to the inputs of the given operator.
74 std::vector<mir::Operation::Output *> getMIRInputsForOperator(const tflite::SubGraphT *subgraph,
75 const tflite::OperatorT *op);
78 TfliteImporter::TfliteImporter(std::string filename) : _filename(std::move(filename))
80 _graph = stdex::make_unique<mir::Graph>();
81 _opCreator = stdex::make_unique<TFLiteOpCreator>(_graph.get());
84 TfliteImporter::~TfliteImporter() = default;
86 void TfliteImporter::import()
88 std::ifstream stream(_filename, std::ios::in | std::ios::binary);
90 throw std::runtime_error("Couldn't open file \"" + _filename + "\".");
92 std::vector<char> model_buffer((std::istreambuf_iterator<char>(stream)),
93 std::istreambuf_iterator<char>());
96 throw std::runtime_error("Couldn't read file \"" + _filename + "\".");
98 flatbuffers::Verifier verifier(reinterpret_cast<const std::uint8_t *>(model_buffer.data()),
101 if (!tflite::VerifyModelBuffer(verifier))
102 throw std::runtime_error("Could not load model: " + _filename + "\n");
104 _model = tflite::UnPackModel(model_buffer.data());
107 static const std::set<tflite::BuiltinOperator> supportedOperators = {
108 tflite::BuiltinOperator_ADD,
109 tflite::BuiltinOperator_AVERAGE_POOL_2D,
110 tflite::BuiltinOperator_CONCATENATION,
111 tflite::BuiltinOperator_CONV_2D,
112 tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
113 tflite::BuiltinOperator_DIV,
114 tflite::BuiltinOperator_FULLY_CONNECTED,
115 tflite::BuiltinOperator_HARD_SWISH,
116 tflite::BuiltinOperator_LEAKY_RELU,
117 tflite::BuiltinOperator_LOGISTIC,
118 tflite::BuiltinOperator_MAX_POOL_2D,
119 tflite::BuiltinOperator_MAXIMUM,
120 tflite::BuiltinOperator_MEAN,
121 tflite::BuiltinOperator_MUL,
122 tflite::BuiltinOperator_PAD,
123 tflite::BuiltinOperator_RELU,
124 tflite::BuiltinOperator_RELU6,
125 tflite::BuiltinOperator_RESHAPE,
126 tflite::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR,
127 tflite::BuiltinOperator_RSQRT,
128 tflite::BuiltinOperator_SHAPE,
129 tflite::BuiltinOperator_SLICE,
130 tflite::BuiltinOperator_SOFTMAX,
131 tflite::BuiltinOperator_SQRT,
132 tflite::BuiltinOperator_SQUARED_DIFFERENCE,
133 tflite::BuiltinOperator_SQUEEZE,
134 tflite::BuiltinOperator_STRIDED_SLICE,
135 tflite::BuiltinOperator_SUB,
136 tflite::BuiltinOperator_TANH,
137 tflite::BuiltinOperator_TRANSPOSE,
138 tflite::BuiltinOperator_TRANSPOSE_CONV,
141 void TfliteImporter::collectUnsupportedOps()
143 std::set<std::string> errors;
144 for (const auto &subgraph : _model->subgraphs)
145 for (const auto &op : subgraph->operators)
147 tflite::BuiltinOperator opcode = _model->operator_codes[op->opcode_index]->builtin_code;
148 if (supportedOperators.find(opcode) == supportedOperators.end())
150 if (opcode <= tflite::BuiltinOperator_MAX)
151 errors.insert(std::string(EnumNameBuiltinOperator(opcode)) + ": unsupported operator");
153 errors.insert(std::to_string(opcode) + ": unsuppored in tflite custom opcode");
159 std::string msg("NNC can't load model. Detected problems:");
160 for (const auto &e : errors)
161 msg.append("\n * " + e);
162 throw std::runtime_error(msg);
166 std::unique_ptr<mir::Graph> TfliteImporter::importModel()
169 collectUnsupportedOps();
170 walkModel(_model.get());
171 return std::move(_graph);
174 void TfliteImporter::walkModel(const tflite::ModelT *model)
176 for (const auto &subgraph : model->subgraphs)
177 walkSubgraph(subgraph.get());
180 mir::DataType convertElementType(tflite::TensorType type)
184 case tflite::TensorType_INT32:
185 return mir::DataType::INT32;
186 case tflite::TensorType_FLOAT32:
187 return mir::DataType::FLOAT32;
188 case tflite::TensorType_INT64:
189 return mir::DataType::INT64;
190 case tflite::TensorType_UINT8:
191 return mir::DataType::UINT8;
193 throw std::runtime_error(std::string("Unsupported tensor type: ") + EnumNameTensorType(type));
197 mir::TensorType getMirTensorType(const tflite::TensorT &tensor)
199 mir::DataType element_type = convertElementType(tensor.type);
201 mir::Shape shape(tensor.shape.size());
202 for (std::size_t i = 0; i < tensor.shape.size(); ++i)
204 shape.dim(i) = tensor.shape[i];
207 if (tensor.quantization != nullptr)
209 const tflite::QuantizationParametersT ¶ms = *tensor.quantization;
211 if (params.details.type != tflite::QuantizationDetails_NONE)
212 throw std::runtime_error("Custom quantization is not supported.");
214 // Empty parameters mean no quantization at all.
215 if (params.scale.empty() && params.zero_point.empty())
216 return mir::TensorType{element_type, shape};
218 if (params.scale.size() != 1 || params.zero_point.size() != 1)
219 throw std::runtime_error("Non-scalar quantization is not supported.");
221 mir::AffineQuantization quantization{params.scale[0], static_cast<int>(params.zero_point[0])};
223 return mir::TensorType{element_type, shape, quantization};
227 return mir::TensorType{element_type, shape};
231 void TfliteImporter::walkSubgraph(const tflite::SubGraphT *subgraph)
233 _tensorMap.assign(subgraph->tensors.size(), nullptr);
235 for (const auto input_tensor_index : subgraph->inputs)
237 const tflite::TensorT &tensor = *subgraph->tensors[input_tensor_index];
239 mir::TensorType input_type = getMirTensorType(tensor);
240 auto input = _graph->create<mir::ops::InputOp>(input_type)->getOutput(0);
241 input->setName(tensor.name);
243 assert(_tensorMap[input_tensor_index] == nullptr);
244 _tensorMap[input_tensor_index] = input;
247 for (const auto &op : subgraph->operators)
249 walkOperator(subgraph, op.get());
252 for (const auto output_tensor_index : subgraph->outputs)
254 auto output = _tensorMap[output_tensor_index];
255 _graph->create<mir::ops::OutputOp>(output);
259 void TfliteImporter::walkOperator(const tflite::SubGraphT *subgraph, const tflite::OperatorT *op)
261 std::vector<mir::Operation::Output *> inputs = getMIRInputsForOperator(subgraph, op);
262 std::vector<mir::Operation::Output *> outputs;
264 tflite::BuiltinOperator opcode = _model->operator_codes[op->opcode_index]->builtin_code;
267 case tflite::BuiltinOperator_CONV_2D:
268 outputs = _opCreator->convertConv2D(op->builtin_options.AsConv2DOptions(), inputs);
270 case tflite::BuiltinOperator_DEPTHWISE_CONV_2D:
271 outputs = _opCreator->convertDepthwiseConv2D(op->builtin_options.AsDepthwiseConv2DOptions(),
274 case tflite::BuiltinOperator_MAX_POOL_2D:
275 outputs = _opCreator->convertMaxPool2D(op->builtin_options.AsPool2DOptions(), inputs);
277 case tflite::BuiltinOperator_AVERAGE_POOL_2D:
278 outputs = _opCreator->convertAveragePool2D(op->builtin_options.AsPool2DOptions(), inputs);
280 case tflite::BuiltinOperator_CONCATENATION:
282 _opCreator->convertConcatenation(op->builtin_options.AsConcatenationOptions(), inputs);
284 case tflite::BuiltinOperator_RESHAPE:
285 outputs = _opCreator->convertReshape(op->builtin_options.AsReshapeOptions(), inputs);
287 case tflite::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR:
288 outputs = _opCreator->convertResizeNearestNeighbor(
289 op->builtin_options.AsResizeNearestNeighborOptions(), inputs);
291 case tflite::BuiltinOperator_MEAN:
292 outputs = _opCreator->convertMean(op->builtin_options.AsReducerOptions(), inputs);
294 case tflite::BuiltinOperator_FULLY_CONNECTED:
296 _opCreator->convertFullyConnected(op->builtin_options.AsFullyConnectedOptions(), inputs);
298 case tflite::BuiltinOperator_SOFTMAX:
299 outputs = _opCreator->convertSoftmax(op->builtin_options.AsSoftmaxOptions(), inputs);
301 case tflite::BuiltinOperator_SLICE:
302 outputs = _opCreator->convertSlice(op->builtin_options.AsSliceOptions(), inputs);
304 case tflite::BuiltinOperator_SQUEEZE:
305 outputs = _opCreator->convertSqueeze(op->builtin_options.AsSqueezeOptions(), inputs);
307 case tflite::BuiltinOperator_LOGISTIC:
308 outputs = _opCreator->convertLogistic(inputs);
310 case tflite::BuiltinOperator_RSQRT:
311 outputs = _opCreator->convertRsqrt(inputs);
313 case tflite::BuiltinOperator_SQRT:
314 outputs = _opCreator->convertSqrt(inputs);
316 case tflite::BuiltinOperator_ADD:
317 outputs = _opCreator->convertAdd(op->builtin_options.AsAddOptions(), inputs);
319 case tflite::BuiltinOperator_SUB:
320 outputs = _opCreator->convertSub(op->builtin_options.AsSubOptions(), inputs);
322 case tflite::BuiltinOperator_MUL:
323 outputs = _opCreator->convertMul(op->builtin_options.AsMulOptions(), inputs);
325 case tflite::BuiltinOperator_DIV:
326 outputs = _opCreator->convertDiv(op->builtin_options.AsDivOptions(), inputs);
328 case tflite::BuiltinOperator_MAXIMUM:
329 outputs = _opCreator->convertMax(inputs);
331 case tflite::BuiltinOperator_SQUARED_DIFFERENCE:
332 outputs = _opCreator->convertSquaredDifference(inputs);
334 case tflite::BuiltinOperator_TRANSPOSE_CONV:
336 _opCreator->convertTransposeConv(op->builtin_options.AsTransposeConvOptions(), inputs);
338 case tflite::BuiltinOperator_PAD:
339 outputs = _opCreator->convertPad(op->builtin_options.AsPadOptions(), inputs);
341 case tflite::BuiltinOperator_TANH:
342 outputs = _opCreator->convertTanh(inputs);
344 case tflite::BuiltinOperator_RELU:
345 outputs = _opCreator->convertReLU(inputs);
347 case tflite::BuiltinOperator_RELU6:
348 outputs = _opCreator->convertReLU6(inputs);
350 case tflite::BuiltinOperator_TRANSPOSE:
351 outputs = _opCreator->convertTranspose(op->builtin_options.AsTransposeOptions(), inputs);
353 case tflite::BuiltinOperator_STRIDED_SLICE:
355 _opCreator->convertStridedSlice(op->builtin_options.AsStridedSliceOptions(), inputs);
357 case tflite::BuiltinOperator_LEAKY_RELU:
358 outputs = _opCreator->convertLeakyReLU(op->builtin_options.AsLeakyReluOptions(), inputs);
360 case tflite::BuiltinOperator_SHAPE:
361 outputs = _opCreator->convertShape(op->builtin_options.AsShapeOptions(), inputs);
363 case tflite::BuiltinOperator_HARD_SWISH:
364 outputs = _opCreator->convertHardSwish(op->builtin_options.AsHardSwishOptions(), inputs);
367 assert(false && "All unsupported types should have been found before this pass.");
370 assert(outputs.size() == op->outputs.size());
371 for (std::size_t i = 0; i < op->outputs.size(); ++i)
373 const auto tensor_index = op->outputs[i];
374 const tflite::TensorT &tensor = *subgraph->tensors[tensor_index];
376 mir::TensorType output_type = getMirTensorType(tensor);
378 // The type should have been inferred correctly, except for quantization information.
379 assert(outputs[i]->getType().getElementType() == output_type.getElementType() &&
380 outputs[i]->getType().getShape() == output_type.getShape());
382 outputs[i]->setName(tensor.name);
383 outputs[i]->setType(output_type);
385 assert(_tensorMap[tensor_index] == nullptr);
386 _tensorMap[tensor_index] = outputs[i];
390 std::vector<mir::Operation::Output *>
391 TfliteImporter::getMIRInputsForOperator(const tflite::SubGraphT *subgraph,
392 const tflite::OperatorT *op)
394 std::vector<mir::Operation::Output *> inputs;
396 for (const auto tensor_index : op->inputs)
398 const tflite::TensorT &tensor = *subgraph->tensors[tensor_index];
399 const tflite::BufferT &buffer = *_model->buffers[tensor.buffer];
400 if (!buffer.data.empty())
402 assert(_tensorMap[tensor_index] == nullptr);
403 mir::TensorType type = getMirTensorType(tensor);
404 mir::TensorVariant mir_tensor{type, buffer.data.data()};
405 inputs.emplace_back(_graph->create<mir::ops::ConstantOp>(mir_tensor)->getOutput(0));
409 assert(_tensorMap[tensor_index] != nullptr);
410 // By this point every input for the operation "op" should have corresponding
411 // Model IR operations that output its inputs. This assumption is provided by the fact
412 // that TFLite format specifies all operations in the execution order.
413 inputs.emplace_back(_tensorMap[tensor_index]);
422 std::unique_ptr<mir::Graph> loadModel(std::string filename)
424 TfliteImporter importer(std::move(filename));
425 return importer.importModel();
428 } // namespace mir_tflite