1 // SPDX-License-Identifier: Apache-2.0
3 * Copyright (C) 2021 Jihoon Lee <jhoon.it.lee@samsung.com>
5 * @file tflite_interpreter.cpp
7 * @brief NNTrainer *.tflite Interpreter
8 * @see https://github.com/nnstreamer/nntrainer
9 * @author Jihoon Lee <jhoon.it.lee@samsung.com>
10 * @bug No known bugs except for NYI items
12 #include <tflite_interpreter.h>
21 #include <type_traits>
24 #include <bn_realizer.h>
26 #include <layer_node.h>
27 #include <loss_realizer.h>
28 #include <nntrainer_error.h>
29 #include <node_exporter.h>
31 #include <tf_schema_generated.h>
32 #include <tflite_opnode.h>
34 static constexpr const char *FUNC_TAG = "[TFLITE INTERPRETER] ";
40 * @brief after finishing building, call this to safe to a file
42 * @param builder flatbuffer builder
45 void builder2file(const flatbuffers::FlatBufferBuilder &builder,
46 const std::string &out) {
47 uint8_t *buf = builder.GetBufferPointer();
48 size_t size = builder.GetSize();
49 flatbuffers::Verifier v(buf, size);
51 NNTR_THROW_IF(!tflite::VerifyModelBuffer(v), std::invalid_argument)
52 << FUNC_TAG << "Verifying serialized model failed";
54 std::ofstream os(out, std::ios_base::binary);
55 const size_t error_buflen = 100;
56 char error_buf[error_buflen];
57 NNTR_THROW_IF(!os.good(), std::invalid_argument)
59 << "failed to open, reason: " << strerror_r(errno, error_buf, error_buflen);
61 std::streamsize sz = static_cast<std::streamsize>(builder.GetSize());
62 NNTR_THROW_IF(sz < 0, std::invalid_argument)
63 << FUNC_TAG << "builder size: " << builder.GetSize()
64 << " is too big. It cannot be represented by std::streamsize";
66 os.write((char *)builder.GetBufferPointer(), sz);
71 * @brief get predecessor nodes
73 * @param node the node from which to get predecessor nodes
74 * @note virtual nodes are ignored
76 std::vector<const TfOpNode *> getPredNodes(const TfOpNode &node) {
77 std::vector<const TfOpNode *> predNodes;
79 for (auto input : node.getInputNodes()) {
80 const TfOpNode *pred = input;
81 while (pred->isVirtualNode()) {
82 /// Assume that virtual nodes have single input
83 assert(pred->arity() == 1);
86 predNodes.push_back(pred);
91 using TfOpNodes = std::vector<std::unique_ptr<TfOpNode>>;
94 * @brief Bidirectional Index map
96 * @tparam Key type of a underlying hashable value, please note that T will be
97 * copied, so please use this for pointers and primitive values that is okay to
99 * @tparam Data data type to be stored inside the vector, if not given, same as
102 template <typename KeyType, typename DataType = KeyType>
103 class BidirectionalIndexMap {
106 * @brief addDatapoint to the map
108 * @param key key to be added to search for the data
109 * @param data data to be added if there is no occurrence, data will be
112 void addDataWhenNotFound(KeyType key, DataType data) {
113 auto search = key2index.find(key);
115 if (search == key2index.end()) {
116 key2index[key] = index2data.size();
117 index2data.push_back(data);
122 * @brief addDatapoint to the map when key and datatype is same
124 * @param key key/data to add
126 void addDataWhenNotFound(KeyType key) {
127 static_assert(std::is_same<KeyType, DataType>::value == true,
128 "key type and data type are different!");
129 addDataWhenNotFound(key, key);
133 * @brief Get the Index of the data
135 * @param key data that will be the key
136 * @return unsigned int index
138 unsigned int getIndex(const KeyType &key) const {
139 auto search = key2index.find(key);
141 NNTR_THROW_IF(search == key2index.end(), std::invalid_argument)
142 << FUNC_TAG << "Cannot find index for key: " << key;
144 return search->second;
148 * @brief Get the Data object
150 * @param idx index to be searched
151 * @return T datapoint T
153 DataType getData(unsigned int index) const {
154 NNTR_THROW_IF(index >= index2data.size(), std::invalid_argument)
155 << FUNC_TAG << "Cannot find data for index: " << index;
157 return index2data[index];
161 * @brief Get the Data object
163 * @return const std::vector<T>& underlying data
165 const std::vector<DataType> &getData() const { return index2data; }
168 std::unordered_map<KeyType, unsigned int> key2index; /**< key -> index map */
169 std::vector<DataType> index2data; /**< index -> data map */
173 * @brief tensorflow operation index map, this class manages operation index
179 using Buffer = std::pair<size_t, const float *>;
181 TfOpIdxMap(const TfOpNodes &nodes) {
182 auto &opcode_map = getIndexMap<tflite::BuiltinOperator>();
183 auto update_opcode = [&opcode_map](tflite::BuiltinOperator opcode) {
184 opcode_map.addDataWhenNotFound(opcode);
187 auto &buffer_map = getIndexMap<const float *, Buffer>();
188 buffer_map.addDataWhenNotFound(
189 nullptr, {0, empty_buffer}); // this represents undefined buffer
190 buffer_map.addDataWhenNotFound(
191 empty_buffer, {0, empty_buffer}); /// this represents empty buffer
193 auto update_buffer_map = [&buffer_map](const TfOpNode::Variables &variables,
195 for (auto &variable : variables) {
196 const float *buf = variable->getData();
197 assert(buf != nullptr);
198 auto byte_size = dynamic ? 0 : variable->bytes();
199 buffer_map.addDataWhenNotFound(buf, {byte_size, buf});
203 auto register_tensors =
204 [&tensors = this->tensors](const TfOpNode::Variables &variables) {
205 for (auto &variable : variables) {
206 auto tensor_it = std::find(tensors.begin(), tensors.end(), variable);
207 if (tensor_it == tensors.end()) {
208 tensors.push_back(variable);
213 for (auto &op_node : nodes) {
214 if (op_node->isVirtualNode())
216 update_opcode(op_node->getOpType());
218 if (op_node->isInputNode()) {
220 * Q) Why only register graph input tensor?
222 * A) the tflite needs only one tensor between nodes. Therefore,
223 *basically, no inputs are considered except graph input that doesn't
226 register_tensors(op_node->getInputs());
228 * Q) Why only update second input of the input node?
230 * A) 1. graph input nodes should be Transpose operator to change data
231 *format from NCHW to NHWC.
232 * 2. Transpose operator has two inputs - input to be
233 *transposed(input[0]), 1d permute vector(input[1])
234 * 3. input[0] has nullptr data pointer, which can't be added to
235 *buffer_map. But, input[0] should have its own buffer and it will be
236 *considered when the tflite buffers are built.
238 assert(op_node->getInputs()[0]->getData() == nullptr);
239 update_buffer_map({op_node->getInputs()[1]}, false);
241 register_tensors(op_node->getWeights());
242 update_buffer_map(op_node->getWeights(), false);
244 register_tensors(op_node->getOutputs());
245 update_buffer_map(op_node->getOutputs(), true);
248 auto update_model_io_to = [this](const TfOpNode::Variables &variables,
249 std::vector<int> &v) {
250 for (auto &variable : variables) {
251 if (variable->getName().find("nntrainer_internal_perm") !=
254 v.push_back(this->getTensorIndex(variable));
258 for (auto &op_node : nodes) {
259 if (op_node->isVirtualNode())
261 if (op_node->isInputNode()) {
262 update_model_io_to(op_node->getInputs(), inputs);
264 if (op_node->isOutputNode()) {
265 update_model_io_to(op_node->getOutputs(), outputs);
270 template <typename KeyType, typename DataType = KeyType>
271 BidirectionalIndexMap<KeyType, DataType> &getIndexMap() {
272 return std::get<BidirectionalIndexMap<KeyType, DataType>>(maps);
275 template <typename KeyType, typename DataType = KeyType>
276 const BidirectionalIndexMap<KeyType, DataType> &getIndexMap() const {
277 return std::get<BidirectionalIndexMap<KeyType, DataType>>(maps);
280 const float *get_empty_buffer() const { return empty_buffer; }
282 const std::vector<int> &getInputs() const { return inputs; }
284 const std::vector<int> &getOutputs() const { return outputs; }
286 const std::vector<const Tensor *> &getTensors() const { return tensors; }
288 std::ptrdiff_t getTensorIndex(const Tensor *tensor) const {
289 auto tensor_it = std::find(tensors.begin(), tensors.end(), tensor);
290 NNTR_THROW_IF(tensor_it == tensors.cend(), std::invalid_argument)
291 << FUNC_TAG << "Cannot find index for tensor: " << tensor->getName();
292 return std::distance(tensors.begin(), tensor_it);
296 float empty_buffer[0]; /**< reserved uninitialized tensor points to this
299 std::tuple<BidirectionalIndexMap<const float *, Buffer>, /**< buffer map
301 BidirectionalIndexMap<tflite::BuiltinOperator>> /**< opcode map
305 std::vector<int> inputs;
306 std::vector<int> outputs;
307 /// since it is used as a tensor index, the order is important
308 std::vector<const Tensor *> tensors;
311 TfOpNodes buildOpNodes(const GraphRepresentation &representation,
312 flatbuffers::FlatBufferBuilder &fbb) {
314 /// @todo TfOpNode needs to have LayerNode pointer
315 std::map<TfOpNode *, const LayerNode *> tf_to_layer;
316 std::map<const LayerNode *, TfOpNode *> layer_to_tf;
318 /// @todo, look ahead of layers to get nodes that can be fused
319 /// we will need to have a dedicated builder
320 for (auto iter = representation.cbegin(); iter != representation.cend();
322 const auto &ln = *iter;
325 ln->exportTo(e, ml::train::ExportMethods::METHOD_TFLITE);
327 nodes.emplace_back(e.getResult<ml::train::ExportMethods::METHOD_TFLITE>());
328 tf_to_layer.insert({nodes.back().get(), ln.get()});
329 layer_to_tf.insert({ln.get(), nodes.back().get()});
333 bool is_local_first = true;
334 /** is_local_first : first FC Layer after Channel related layer
336 * : Input -> Conv -> Conv -> Flatten -> [FC]:local_first
337 * : Input -> Conv -> Flatten -> [FC]:local_first -> Conv -> Flatten ->
341 for (auto &n : nodes) {
342 auto tf_node = n.get();
344 if (tf_node->getOptionType() ==
345 tflite::BuiltinOptions::BuiltinOptions_FullyConnectedOptions &&
346 node_count != 0 && is_local_first) {
347 tf_node->setNeedReorderWeight();
348 is_local_first = false;
351 if (is_local_first == false &&
352 tf_node->getOptionType() !=
353 tflite::BuiltinOptions::BuiltinOptions_FullyConnectedOptions) {
354 is_local_first = true;
360 /// set arity of TfOpNodes
361 for (auto &n : nodes) {
362 auto tf_node = n.get();
363 auto layer_node = tf_to_layer.find(tf_node)->second;
364 auto layer_node_inputs = layer_node->getInputConnections();
366 /// assume that the TfOpNode and the LayerNode have a one-to-one
368 tf_node->arity(layer_node_inputs.size());
369 for (size_t index = 0; index < layer_node_inputs.size(); index++) {
370 auto input_layer_name = layer_node_inputs[index];
371 auto input_later_node_iterator = std::find_if(
372 representation.begin(), representation.end(),
373 [&input_layer_name](std::shared_ptr<nntrainer::LayerNode> node) {
374 return istrequal(node.get()->getName(), input_layer_name);
376 if (input_later_node_iterator != representation.end()) {
377 auto input_layer_node = input_later_node_iterator->get();
378 tf_node->setArg(index, layer_to_tf.find(input_layer_node)->second);
384 for (auto &n : nodes) {
385 auto tf_node = n.get();
386 if (tf_node->getOptionType() ==
387 tflite::BuiltinOptions::BuiltinOptions_FullyConnectedOptions) {
388 tf_node->weightReorder(node_count);
397 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<tflite::Buffer>>>
398 buildBuffers(const TfOpIdxMap &map, flatbuffers::FlatBufferBuilder &fbb) {
399 const auto &buffers =
400 map.getIndexMap<const float *, TfOpIdxMap::Buffer>().getData();
402 std::vector<flatbuffers::Offset<tflite::Buffer>> fb_buffers;
403 fb_buffers.reserve(buffers.size());
405 auto create_buffer_offset = [&fbb](const TfOpIdxMap::Buffer &buffer) {
406 if (buffer.first == 0) {
407 return tflite::CreateBuffer(fbb);
410 auto data = fbb.CreateVector(
411 reinterpret_cast<const uint8_t *>(buffer.second), buffer.first);
413 return tflite::CreateBuffer(fbb, data);
416 std::transform(buffers.begin(), buffers.end(), std::back_inserter(fb_buffers),
417 create_buffer_offset);
420 for (unsigned index = 0; index < map.getInputs().size(); index++) {
421 fb_buffers.push_back(create_buffer_offset({0, nullptr}));
423 return fbb.CreateVector(fb_buffers);
427 flatbuffers::Vector<flatbuffers::Offset<tflite::OperatorCode>>>
428 buildOperatorCodes(const TfOpIdxMap &map, flatbuffers::FlatBufferBuilder &fbb) {
429 const auto &op_codes = map.getIndexMap<tflite::BuiltinOperator>().getData();
431 std::vector<flatbuffers::Offset<tflite::OperatorCode>> fb_op_codes;
432 fb_op_codes.reserve(op_codes.size());
434 auto create_op_offset = [&fbb](const tflite::BuiltinOperator &op,
435 int32_t version = 1) {
436 tflite::OperatorCodeBuilder builder(fbb);
437 builder.add_deprecated_builtin_code(static_cast<int8_t>(op));
438 /// @todo find reason why version field is not shown
439 /// on json when version is 1 (other versions are fine)
440 builder.add_version(version);
441 builder.add_builtin_code(op);
442 return builder.Finish();
445 std::transform(op_codes.begin(), op_codes.end(),
446 std::back_inserter(fb_op_codes), create_op_offset);
448 return fbb.CreateVector(fb_op_codes);
451 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<tflite::Tensor>>>
452 buildTensors(const TfOpIdxMap &map, flatbuffers::FlatBufferBuilder &fbb) {
453 /// @todo: the actual (suqeezed) tensor dimension must be known before
454 /// coming here. For now, it is directly guessed for the fc layer
455 const auto &variables = map.getTensors();
456 const auto &buffer_map = map.getIndexMap<const float *, TfOpIdxMap::Buffer>();
457 auto graph_input_offset = map.getInputs().size() - 1;
459 std::vector<flatbuffers::Offset<tflite::Tensor>> fb_tensors;
460 fb_tensors.reserve(variables.size());
462 auto create_tensor = [&fbb, &buffer_map,
463 &graph_input_offset](const Tensor *var) {
464 auto dim = var->getDim();
465 bool need_shape_signature = dim.is_dynamic();
466 std::vector<int32_t> eff_dim = dim.getEffectiveDimension();
467 auto shape = fbb.CreateVector(eff_dim);
469 decltype(shape) shape_sig;
470 if (need_shape_signature) {
471 std::vector<int32_t> dyn_dim = dim.getEffectiveDimension(true);
472 shape_sig = fbb.CreateVector(dyn_dim);
475 /// change this var->getName when tensor have it's own name
476 auto name = fbb.CreateString("nntrainer_converted" + var->getName());
478 /// only graph inputs have nullptr data pointer.
479 unsigned int buffer_idx =
480 var->getData() == nullptr
481 ? buffer_map.getData().size() - graph_input_offset--
482 : buffer_map.getIndex(var->getData());
484 tflite::TensorBuilder builder(fbb);
485 builder.add_name(name);
486 builder.add_buffer(buffer_idx);
487 /// @todo support more data types
488 /// @note this is workaround because nntrainer tensor allows only float
490 if (var->getName().find("nntrainer_internal_perm") != std::string::npos) {
491 builder.add_type(tflite::TensorType_INT32);
493 builder.add_type(tflite::TensorType_FLOAT32);
494 builder.add_shape(shape);
495 if (need_shape_signature) {
496 builder.add_shape_signature(shape_sig);
498 return builder.Finish();
501 std::transform(variables.begin(), variables.end(),
502 std::back_inserter(fb_tensors), create_tensor);
504 return fbb.CreateVector(fb_tensors);
507 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<tflite::Operator>>>
508 buildOperators(const TfOpNodes &nodes, const TfOpIdxMap &map,
509 flatbuffers::FlatBufferBuilder &fbb) {
511 /// this lambda maps variables to list of indexes in the map
512 auto variables_to_idx_vector = [&map](const TfOpNode::Variables &v) {
513 std::vector<int> idx_vector;
514 idx_vector.reserve(v.size());
517 v.begin(), v.end(), std::back_inserter(idx_vector),
518 [&map](const Tensor *variable) { return map.getTensorIndex(variable); });
522 auto create_operator = [&fbb, &map,
523 &variables_to_idx_vector](const TfOpNode &node) {
524 auto &index_map = map.getIndexMap<tflite::BuiltinOperator>();
526 auto op_code = index_map.getIndex(node.getOpType());
527 std::vector<int> inputs;
528 if (node.isInputNode()) {
529 inputs = variables_to_idx_vector(node.getInputs());
532 * Q) Why find a tensor that shares a buffer with input tensor?
534 * A) the tflite needs only one tensor between nodes. Therefore,
535 *basically, output tensors are used for tflite tensor that shares its
538 TfOpNode::Variables input_tensors;
539 for (auto parent_node : getPredNodes(node)) {
540 for (auto parent_out : parent_node->getOutputs()) {
541 for (auto in : node.getInputs()) {
542 /// second condition is a workaround
543 /// Transpose op output tensor originally had nullptr data pointer
544 /// but it has been allocated (parent_out->getData()). But, the
545 /// buffer that shared its buffer hasn't so it has still nullptr
547 /// @todo remove this workaround
548 if (parent_out->getData() == in->getData() ||
549 (in->getData() == nullptr && parent_out->getData())) {
550 if (std::find(input_tensors.begin(), input_tensors.end(),
551 parent_out) != input_tensors.end())
553 input_tensors.push_back(parent_out);
558 inputs = variables_to_idx_vector(input_tensors);
560 auto weights = variables_to_idx_vector(node.getWeights());
562 /// weights are part of input in tflite
563 inputs.insert(inputs.end(), weights.begin(), weights.end());
565 auto outputs = variables_to_idx_vector(node.getOutputs());
567 auto fb_inputs = fbb.CreateVector(inputs);
568 auto fb_outputs = fbb.CreateVector(outputs);
569 auto fb_options = node.getBuiltinOps();
571 tflite::OperatorBuilder builder(fbb);
572 builder.add_opcode_index(op_code);
573 builder.add_builtin_options_type(node.getOptionType());
574 builder.add_builtin_options(fb_options);
575 builder.add_inputs(fb_inputs);
576 builder.add_outputs(fb_outputs);
577 return builder.Finish();
580 std::vector<flatbuffers::Offset<tflite::Operator>> v;
581 v.reserve(nodes.size());
583 for (auto &node : nodes) {
584 if (node->isVirtualNode())
586 auto op = create_operator(*node);
590 return fbb.CreateVector(v);
593 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<tflite::SubGraph>>>
594 buildSubGraphs(const TfOpNodes &nodes, const TfOpIdxMap &map,
595 flatbuffers::FlatBufferBuilder &fbb) {
597 auto tensors = buildTensors(map, fbb);
598 auto ops = buildOperators(nodes, map, fbb);
600 /// @todo extract this to buildSubgraph if there is one or more subgraph
601 auto name = fbb.CreateString("main");
602 auto inputs = fbb.CreateVector(map.getInputs());
603 auto outputs = fbb.CreateVector(map.getOutputs());
605 auto builder = tflite::SubGraphBuilder(fbb);
606 builder.add_tensors(tensors);
607 builder.add_inputs(inputs);
608 builder.add_outputs(outputs);
609 builder.add_name(name);
610 builder.add_operators(ops);
611 auto subgraph = builder.Finish();
613 std::vector<flatbuffers::Offset<tflite::SubGraph>> subgraphs;
614 subgraphs.reserve(1);
615 subgraphs.push_back(subgraph);
617 return fbb.CreateVector(subgraphs);
622 void TfliteInterpreter::serialize(const GraphRepresentation &representation,
623 const std::string &out) {
624 /// @todo check if graph is finalized & initialized and ready to serialize.
626 /// 0. remove batch normalization layer in GraphRepresentation
627 BnRealizer realizer({});
628 GraphRepresentation graph = realizer.realize(representation);
630 /// 1. remove loss layer in GraphRepresentation
631 LossRealizer loss_realizer({});
632 graph = loss_realizer.realize(graph);
634 /// 2. The graph must have weights, input dims, output dims set
635 flatbuffers::FlatBufferBuilder fbb;
637 auto opNodes = buildOpNodes(graph, fbb);
638 TfOpIdxMap map(opNodes); /// build TfOpIdxMap from opNodes
640 auto opcodes = buildOperatorCodes(map, fbb);
641 auto subgraphs = buildSubGraphs(opNodes, map, fbb);
642 auto buffers = buildBuffers(map, fbb);
643 auto desc = fbb.CreateString("This file is generated from NNTrainer");
645 tflite::ModelBuilder model_builder(fbb);
647 model_builder.add_operator_codes(opcodes);
648 model_builder.add_subgraphs(subgraphs);
649 model_builder.add_buffers(buffers);
650 model_builder.add_version(3);
651 model_builder.add_description(desc);
652 auto model = model_builder.Finish();
654 fbb.Finish(model, tflite::ModelIdentifier());
655 builder2file(fbb, out);
658 GraphRepresentation TfliteInterpreter::deserialize(const std::string &in) {
659 /// ======== list of things to consider ========
660 /// we need to reconstruct some properties from the shape
661 /// eg) units are not saved as a property
667 } // namespace nntrainer