1 // SPDX-License-Identifier: Apache-2.0
3 * Copyright (C) 2021 Jihoon Lee <jhoon.it.lee@samsung.com>
5 * @file tflite_interpreter.cpp
7 * @brief NNTrainer *.tflite Interpreter
8 * @see https://github.com/nnstreamer/nntrainer
9 * @author Jihoon Lee <jhoon.it.lee@samsung.com>
10 * @bug No known bugs except for NYI items
12 #include <tflite_interpreter.h>
21 #include <type_traits>
24 #include <bn_realizer.h>
26 #include <layer_node.h>
27 #include <loss_realizer.h>
28 #include <nntrainer_error.h>
29 #include <node_exporter.h>
31 #include <tf_schema_generated.h>
32 #include <tflite_opnode.h>
34 static constexpr const char *FUNC_TAG = "[TFLITE INTERPRETER] ";
40 * @brief after finishing building, call this to safe to a file
42 * @param builder flatbuffer builder
45 void builder2file(const flatbuffers::FlatBufferBuilder &builder,
46 const std::string &out) {
47 uint8_t *buf = builder.GetBufferPointer();
48 size_t size = builder.GetSize();
49 flatbuffers::Verifier v(buf, size);
51 NNTR_THROW_IF(!tflite::VerifyModelBuffer(v), std::invalid_argument)
52 << FUNC_TAG << "Verifying serialized model failed";
54 std::ofstream os(out, std::ios_base::binary);
55 const size_t error_buflen = 100;
56 char error_buf[error_buflen];
57 NNTR_THROW_IF(!os.good(), std::invalid_argument)
59 << "failed to open, reason: " << strerror_r(errno, error_buf, error_buflen);
61 std::streamsize sz = static_cast<std::streamsize>(builder.GetSize());
62 NNTR_THROW_IF(sz < 0, std::invalid_argument)
63 << FUNC_TAG << "builder size: " << builder.GetSize()
64 << " is too big. It cannot be represented by std::streamsize";
66 os.write((char *)builder.GetBufferPointer(), sz);
71 * @brief get predecessor nodes
73 * @param node the node from which to get predecessor nodes
74 * @note virtual nodes are ignored
76 std::vector<const TfOpNode *> getPredNodes(const TfOpNode &node) {
77 std::vector<const TfOpNode *> predNodes;
79 for (auto input : node.getInputNodes()) {
80 const TfOpNode *pred = input;
81 while (pred->isVirtualNode()) {
82 /// Assume that virtual nodes have single input
83 assert(pred->arity() == 1);
86 predNodes.push_back(pred);
91 using TfOpNodes = std::vector<std::unique_ptr<TfOpNode>>;
94 * @brief Bidirectional Index map
96 * @tparam Key type of a underlying hashable value, please note that T will be
97 * copied, so please use this for pointers and primitive values that is okay to
99 * @tparam Data data type to be stored inside the vector, if not given, same as
102 template <typename KeyType, typename DataType = KeyType>
103 class BidirectionalIndexMap {
106 * @brief addDatapoint to the map
108 * @param key key to be added to search for the data
109 * @param data data to be added if there is no occurrence, data will be
112 void addDataWhenNotFound(KeyType key, DataType data) {
113 auto search = key2index.find(key);
115 if (search == key2index.end()) {
116 key2index[key] = index2data.size();
117 index2data.push_back(data);
122 * @brief addDatapoint to the map when key and datatype is same
124 * @param key key/data to add
126 void addDataWhenNotFound(KeyType key) {
127 static_assert(std::is_same<KeyType, DataType>::value == true,
128 "key type and data type are different!");
129 addDataWhenNotFound(key, key);
133 * @brief Get the Index of the data
135 * @param key data that will be the key
136 * @return unsigned int index
138 unsigned int getIndex(const KeyType &key) const {
139 auto search = key2index.find(key);
141 NNTR_THROW_IF(search == key2index.end(), std::invalid_argument)
142 << FUNC_TAG << "Cannot find index for key: " << key;
144 return search->second;
148 * @brief Get the Data object
150 * @param idx index to be searched
151 * @return T datapoint T
153 DataType getData(unsigned int index) const {
154 NNTR_THROW_IF(index >= index2data.size(), std::invalid_argument)
155 << FUNC_TAG << "Cannot find data for index: " << index;
157 return index2data[index];
161 * @brief Get the Data object
163 * @return const std::vector<T>& underlying data
165 const std::vector<DataType> &getData() const { return index2data; }
168 std::unordered_map<KeyType, unsigned int> key2index; /**< key -> index map */
169 std::vector<DataType> index2data; /**< index -> data map */
173 * @brief tensorflow operation index map, this class manages operation index
179 using Buffer = std::pair<size_t, const float *>;
181 TfOpIdxMap(const TfOpNodes &nodes) {
182 auto &opcode_map = getIndexMap<tflite::BuiltinOperator>();
183 auto update_opcode = [&opcode_map](tflite::BuiltinOperator opcode) {
184 opcode_map.addDataWhenNotFound(opcode);
187 auto &buffer_map = getIndexMap<const float *, Buffer>();
188 buffer_map.addDataWhenNotFound(
189 nullptr, {0, empty_buffer}); // this represents undefined buffer
190 buffer_map.addDataWhenNotFound(
191 empty_buffer, {0, empty_buffer}); /// this represents empty buffer
193 auto update_buffer_map = [&buffer_map](const TfOpNode::Variables &variables,
195 for (auto &variable : variables) {
196 const float *buf = variable->getData();
197 assert(buf != nullptr);
198 auto byte_size = dynamic ? 0 : variable->bytes();
199 buffer_map.addDataWhenNotFound(buf, {byte_size, buf});
203 auto register_tensors =
204 [&tensors = this->tensors](const TfOpNode::Variables &variables) {
205 for (auto &variable : variables) {
206 auto tensor_it = std::find(tensors.begin(), tensors.end(), variable);
207 if (tensor_it == tensors.end()) {
208 tensors.push_back(variable);
213 for (auto &op_node : nodes) {
214 if (op_node->isVirtualNode())
216 update_opcode(op_node->getOpType());
218 if (op_node->isInputNode()) {
220 * Q) Why only register graph input tensor?
222 * A) the tflite needs only one tensor between nodes. Therefore,
223 *basically, no inputs are considered except graph input that doesn't
226 register_tensors(op_node->getInputs());
228 * Q) Why only update second input of the input node?
230 * A) 1. graph input nodes should be Transpose operator to change data
231 *format from NCHW to NHWC.
232 * 2. Transpose operator has two inputs - input to be
233 *transposed(input[0]), 1d permute vector(input[1])
234 * 3. input[0] has nullptr data pointer, which can't be added to
235 *buffer_map. But, input[0] should have its own buffer and it will be
236 *considered when the tflite buffers are built.
238 assert(op_node->getInputs()[0]->getData() == nullptr);
239 update_buffer_map({op_node->getInputs()[1]}, false);
241 register_tensors(op_node->getWeights());
242 update_buffer_map(op_node->getWeights(), false);
244 register_tensors(op_node->getOutputs());
245 update_buffer_map(op_node->getOutputs(), true);
248 auto update_model_io_to = [this](const TfOpNode::Variables &variables,
249 std::vector<int> &v) {
250 for (auto &variable : variables) {
251 if (variable->getName().find("nntrainer_internal_perm") !=
254 v.push_back(this->getTensorIndex(variable));
258 for (auto &op_node : nodes) {
259 if (op_node->isVirtualNode())
261 if (op_node->isInputNode()) {
262 update_model_io_to(op_node->getInputs(), inputs);
264 if (op_node->isOutputNode()) {
265 update_model_io_to(op_node->getOutputs(), outputs);
270 template <typename KeyType, typename DataType = KeyType>
271 BidirectionalIndexMap<KeyType, DataType> &getIndexMap() {
272 return std::get<BidirectionalIndexMap<KeyType, DataType>>(maps);
275 template <typename KeyType, typename DataType = KeyType>
276 const BidirectionalIndexMap<KeyType, DataType> &getIndexMap() const {
277 return std::get<BidirectionalIndexMap<KeyType, DataType>>(maps);
280 const float *get_empty_buffer() const { return empty_buffer; }
282 const std::vector<int> &getInputs() const { return inputs; }
284 const std::vector<int> &getOutputs() const { return outputs; }
286 const std::vector<const Tensor *> &getTensors() const { return tensors; }
288 std::ptrdiff_t getTensorIndex(const Tensor *tensor) const {
289 auto tensor_it = std::find(tensors.begin(), tensors.end(), tensor);
290 NNTR_THROW_IF(tensor_it == tensors.cend(), std::invalid_argument)
291 << FUNC_TAG << "Cannot find index for tensor: " << tensor->getName();
292 return std::distance(tensors.begin(), tensor_it);
296 float empty_buffer[0]; /**< reserved uninitialized tensor points to this
299 std::tuple<BidirectionalIndexMap<const float *, Buffer>, /**< buffer map
301 BidirectionalIndexMap<tflite::BuiltinOperator>> /**< opcode map
305 std::vector<int> inputs;
306 std::vector<int> outputs;
307 /// since it is used as a tensor index, the order is important
308 std::vector<const Tensor *> tensors;
311 TfOpNodes buildOpNodes(const GraphRepresentation &representation,
312 flatbuffers::FlatBufferBuilder &fbb) {
314 /// @todo TfOpNode needs to have LayerNode pointer
315 std::map<TfOpNode *, const LayerNode *> tf_to_layer;
316 std::map<const LayerNode *, TfOpNode *> layer_to_tf;
318 /// @todo, look ahead of layers to get nodes that can be fused
319 /// we will need to have a dedicated builder
320 for (auto iter = representation.cbegin(); iter != representation.cend();
322 const auto &ln = *iter;
325 ln->exportTo(e, ml::train::ExportMethods::METHOD_TFLITE);
327 nodes.emplace_back(e.getResult<ml::train::ExportMethods::METHOD_TFLITE>());
328 tf_to_layer.insert({nodes.back().get(), ln.get()});
329 layer_to_tf.insert({ln.get(), nodes.back().get()});
333 bool is_local_first = true;
334 /** is_local_first : first FC Layer after Channel related layer
336 * : Input -> Conv -> Conv -> Flatten -> [FC]:local_first
337 * : Input -> Conv -> Flatten -> [FC]:local_first -> Conv -> Flatten ->
341 for (auto &n : nodes) {
342 auto tf_node = n.get();
344 if (tf_node->getOptionType() ==
345 tflite::BuiltinOptions::BuiltinOptions_FullyConnectedOptions &&
346 node_count != 0 && is_local_first) {
347 tf_node->setNeedReorderWeight();
348 is_local_first = false;
351 if (is_local_first == false &&
352 tf_node->getOptionType() !=
353 tflite::BuiltinOptions::BuiltinOptions_FullyConnectedOptions) {
354 is_local_first = true;
360 /// set arity of TfOpNodes
361 for (auto &n : nodes) {
362 auto tf_node = n.get();
363 auto layer_node = tf_to_layer.find(tf_node)->second;
364 auto layer_node_inputs = layer_node->getInputConnections();
366 /// assume that the TfOpNode and the LayerNode have a one-to-one
368 tf_node->arity(layer_node_inputs.size());
369 for (size_t index = 0; index < layer_node_inputs.size(); index++) {
370 auto input_layer_name = layer_node_inputs[index];
371 auto input_later_node_iterator = std::find_if(
372 representation.begin(), representation.end(),
373 [&input_layer_name](std::shared_ptr<nntrainer::LayerNode> node) {
374 return istrequal(node.get()->getName(), input_layer_name);
377 if (input_later_node_iterator != representation.end()) {
378 auto input_layer_node = input_later_node_iterator->get();
379 if (layer_to_tf.find(input_layer_node) != layer_to_tf.end()) {
380 tf_node->setArg(index, layer_to_tf.find(input_layer_node)->second);
387 for (auto &n : nodes) {
388 auto tf_node = n.get();
389 if (tf_node->getOptionType() ==
390 tflite::BuiltinOptions::BuiltinOptions_FullyConnectedOptions) {
391 tf_node->weightReorder(node_count);
400 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<tflite::Buffer>>>
401 buildBuffers(const TfOpIdxMap &map, flatbuffers::FlatBufferBuilder &fbb) {
402 const auto &buffers =
403 map.getIndexMap<const float *, TfOpIdxMap::Buffer>().getData();
405 std::vector<flatbuffers::Offset<tflite::Buffer>> fb_buffers;
406 fb_buffers.reserve(buffers.size());
408 auto create_buffer_offset = [&fbb](const TfOpIdxMap::Buffer &buffer) {
409 if (buffer.first == 0) {
410 return tflite::CreateBuffer(fbb);
413 auto data = fbb.CreateVector(
414 reinterpret_cast<const uint8_t *>(buffer.second), buffer.first);
416 return tflite::CreateBuffer(fbb, data);
419 std::transform(buffers.begin(), buffers.end(), std::back_inserter(fb_buffers),
420 create_buffer_offset);
423 for (unsigned index = 0; index < map.getInputs().size(); index++) {
424 fb_buffers.push_back(create_buffer_offset({0, nullptr}));
426 return fbb.CreateVector(fb_buffers);
430 flatbuffers::Vector<flatbuffers::Offset<tflite::OperatorCode>>>
431 buildOperatorCodes(const TfOpIdxMap &map, flatbuffers::FlatBufferBuilder &fbb) {
432 const auto &op_codes = map.getIndexMap<tflite::BuiltinOperator>().getData();
434 std::vector<flatbuffers::Offset<tflite::OperatorCode>> fb_op_codes;
435 fb_op_codes.reserve(op_codes.size());
437 auto create_op_offset = [&fbb](const tflite::BuiltinOperator &op,
438 int32_t version = 1) {
439 tflite::OperatorCodeBuilder builder(fbb);
440 builder.add_deprecated_builtin_code(static_cast<int8_t>(op));
441 /// @todo find reason why version field is not shown
442 /// on json when version is 1 (other versions are fine)
443 builder.add_version(version);
444 builder.add_builtin_code(op);
445 return builder.Finish();
448 std::transform(op_codes.begin(), op_codes.end(),
449 std::back_inserter(fb_op_codes), create_op_offset);
451 return fbb.CreateVector(fb_op_codes);
454 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<tflite::Tensor>>>
455 buildTensors(const TfOpIdxMap &map, flatbuffers::FlatBufferBuilder &fbb) {
456 /// @todo: the actual (suqeezed) tensor dimension must be known before
457 /// coming here. For now, it is directly guessed for the fc layer
458 const auto &variables = map.getTensors();
459 const auto &buffer_map = map.getIndexMap<const float *, TfOpIdxMap::Buffer>();
460 auto graph_input_offset = map.getInputs().size() - 1;
462 std::vector<flatbuffers::Offset<tflite::Tensor>> fb_tensors;
463 fb_tensors.reserve(variables.size());
465 auto create_tensor = [&fbb, &buffer_map,
466 &graph_input_offset](const Tensor *var) {
467 auto dim = var->getDim();
468 bool need_shape_signature = dim.is_dynamic();
469 std::vector<int32_t> eff_dim = dim.getEffectiveDimension();
470 auto shape = fbb.CreateVector(eff_dim);
472 decltype(shape) shape_sig;
473 if (need_shape_signature) {
474 std::vector<int32_t> dyn_dim = dim.getEffectiveDimension(true);
475 shape_sig = fbb.CreateVector(dyn_dim);
478 /// change this var->getName when tensor have it's own name
479 auto name = fbb.CreateString("nntrainer_converted" + var->getName());
481 /// only graph inputs have nullptr data pointer.
482 unsigned int buffer_idx =
483 var->getData() == nullptr
484 ? buffer_map.getData().size() - graph_input_offset--
485 : buffer_map.getIndex(var->getData());
487 tflite::TensorBuilder builder(fbb);
488 builder.add_name(name);
489 builder.add_buffer(buffer_idx);
490 /// @todo support more data types
491 /// @note this is workaround because nntrainer tensor allows only float
493 if (var->getName().find("nntrainer_internal_perm") != std::string::npos) {
494 builder.add_type(tflite::TensorType_INT32);
496 builder.add_type(tflite::TensorType_FLOAT32);
497 builder.add_shape(shape);
498 if (need_shape_signature) {
499 builder.add_shape_signature(shape_sig);
501 return builder.Finish();
504 std::transform(variables.begin(), variables.end(),
505 std::back_inserter(fb_tensors), create_tensor);
507 return fbb.CreateVector(fb_tensors);
510 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<tflite::Operator>>>
511 buildOperators(const TfOpNodes &nodes, const TfOpIdxMap &map,
512 flatbuffers::FlatBufferBuilder &fbb) {
514 /// this lambda maps variables to list of indexes in the map
515 auto variables_to_idx_vector = [&map](const TfOpNode::Variables &v) {
516 std::vector<int> idx_vector;
517 idx_vector.reserve(v.size());
520 v.begin(), v.end(), std::back_inserter(idx_vector),
521 [&map](const Tensor *variable) { return map.getTensorIndex(variable); });
525 auto create_operator = [&fbb, &map,
526 &variables_to_idx_vector](const TfOpNode &node) {
527 auto &index_map = map.getIndexMap<tflite::BuiltinOperator>();
529 auto op_code = index_map.getIndex(node.getOpType());
530 std::vector<int> inputs;
531 if (node.isInputNode()) {
532 inputs = variables_to_idx_vector(node.getInputs());
535 * Q) Why find a tensor that shares a buffer with input tensor?
537 * A) the tflite needs only one tensor between nodes. Therefore,
538 *basically, output tensors are used for tflite tensor that shares its
541 TfOpNode::Variables input_tensors;
542 for (auto parent_node : getPredNodes(node)) {
543 for (auto parent_out : parent_node->getOutputs()) {
544 for (auto in : node.getInputs()) {
545 /// second condition is a workaround
546 /// Transpose op output tensor originally had nullptr data pointer
547 /// but it has been allocated (parent_out->getData()). But, the
548 /// buffer that shared its buffer hasn't so it has still nullptr
550 /// @todo remove this workaround
551 if (parent_out->getData() == in->getData() ||
552 (in->getData() == nullptr && parent_out->getData())) {
553 if (std::find(input_tensors.begin(), input_tensors.end(),
554 parent_out) != input_tensors.end())
556 input_tensors.push_back(parent_out);
561 inputs = variables_to_idx_vector(input_tensors);
563 auto weights = variables_to_idx_vector(node.getWeights());
565 /// weights are part of input in tflite
566 inputs.insert(inputs.end(), weights.begin(), weights.end());
568 auto outputs = variables_to_idx_vector(node.getOutputs());
570 auto fb_inputs = fbb.CreateVector(inputs);
571 auto fb_outputs = fbb.CreateVector(outputs);
572 auto fb_options = node.getBuiltinOps();
574 tflite::OperatorBuilder builder(fbb);
575 builder.add_opcode_index(op_code);
576 builder.add_builtin_options_type(node.getOptionType());
577 builder.add_builtin_options(fb_options);
578 builder.add_inputs(fb_inputs);
579 builder.add_outputs(fb_outputs);
580 return builder.Finish();
583 std::vector<flatbuffers::Offset<tflite::Operator>> v;
584 v.reserve(nodes.size());
586 for (auto &node : nodes) {
587 if (node->isVirtualNode())
589 auto op = create_operator(*node);
593 return fbb.CreateVector(v);
596 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<tflite::SubGraph>>>
597 buildSubGraphs(const TfOpNodes &nodes, const TfOpIdxMap &map,
598 flatbuffers::FlatBufferBuilder &fbb) {
600 auto tensors = buildTensors(map, fbb);
601 auto ops = buildOperators(nodes, map, fbb);
603 /// @todo extract this to buildSubgraph if there is one or more subgraph
604 auto name = fbb.CreateString("main");
605 auto inputs = fbb.CreateVector(map.getInputs());
606 auto outputs = fbb.CreateVector(map.getOutputs());
608 auto builder = tflite::SubGraphBuilder(fbb);
609 builder.add_tensors(tensors);
610 builder.add_inputs(inputs);
611 builder.add_outputs(outputs);
612 builder.add_name(name);
613 builder.add_operators(ops);
614 auto subgraph = builder.Finish();
616 std::vector<flatbuffers::Offset<tflite::SubGraph>> subgraphs;
617 subgraphs.reserve(1);
618 subgraphs.push_back(subgraph);
620 return fbb.CreateVector(subgraphs);
625 void TfliteInterpreter::serialize(const GraphRepresentation &representation,
626 const std::string &out) {
627 /// @todo check if graph is finalized & initialized and ready to serialize.
629 /// 0. remove batch normalization layer in GraphRepresentation
630 BnRealizer realizer({});
631 GraphRepresentation graph = realizer.realize(representation);
633 /// 1. remove loss layer in GraphRepresentation
634 LossRealizer loss_realizer({});
635 graph = loss_realizer.realize(graph);
637 /// 2. The graph must have weights, input dims, output dims set
638 flatbuffers::FlatBufferBuilder fbb;
640 auto opNodes = buildOpNodes(graph, fbb);
641 TfOpIdxMap map(opNodes); /// build TfOpIdxMap from opNodes
643 auto opcodes = buildOperatorCodes(map, fbb);
644 auto subgraphs = buildSubGraphs(opNodes, map, fbb);
645 auto buffers = buildBuffers(map, fbb);
646 auto desc = fbb.CreateString("This file is generated from NNTrainer");
648 tflite::ModelBuilder model_builder(fbb);
650 model_builder.add_operator_codes(opcodes);
651 model_builder.add_subgraphs(subgraphs);
652 model_builder.add_buffers(buffers);
653 model_builder.add_version(3);
654 model_builder.add_description(desc);
655 auto model = model_builder.Finish();
657 fbb.Finish(model, tflite::ModelIdentifier());
658 builder2file(fbb, out);
661 GraphRepresentation TfliteInterpreter::deserialize(const std::string &in) {
662 /// ======== list of things to consider ========
663 /// we need to reconstruct some properties from the shape
664 /// eg) units are not saved as a property
670 } // namespace nntrainer