1 // SPDX-License-Identifier: Apache-2.0
3 * Copyright (C) 2021 Jihoon Lee <jhoon.it.lee@samsung.com>
5 * @file tflite_opnode.cpp
7 * @brief contains tflite opnode which has information to convert to tflite file
8 * @see https://github.com/nnstreamer/nntrainer
9 * @author Jihoon Lee <jhoon.it.lee@samsung.com>
10 * @author Donghak Park <donghak.park@samsung.com>
11 * @bug No known bugs except for NYI items
14 #include <tflite_opnode.h>
16 #include <layer_context.h>
17 #include <layer_node.h>
22 TfOpNode::TfOpNode() :
26 weight_transform(nullptr),
31 is_to_be_removed(false),
32 need_reorder_weight(false),
33 node_owned_variable(),
34 /// @todo distinguish between uninitialized and ADD operator.
35 op_type(tflite::BuiltinOperator_ADD),
37 builtin_option_type(tflite::BuiltinOptions_NONE){};
39 void TfOpNode::setLayerNode(const LayerNode &layer) {
40 is_input = layer.getNumInputConnections() == 0;
41 is_output = layer.getNumOutputConnections() == 0;
42 /// @todo support more loss layers
43 static const std::set<std::string> loss_type = {"mse", "cross"};
44 /** set to graph output node if output connection of the node includes loss
46 * @note this is workaround because it cannot be guaranteed that a loss layer
47 *always has a loss type in its name.
49 * There are two ways to pass `GraphRepresentation` parameters to `serialize`
52 * 1. with loss layer at the end of the graph
53 * 2. without loss layer but last node has loss layer output connection
55 * Loss layer of the first case is removed by `LossRealizer` and the previous
56 *layer of the loss layer is set as the output node. And, the below logic is
59 /// assume that loss layers have single output
60 if (layer.getNumOutputConnections() == 1) {
61 for (auto &loss : loss_type) {
62 if (layer.getOutputConnections()[0].find(loss) != std::string::npos) {
67 /// @todo support more virtual nodes
68 is_virtual = layer.getType() == "multiout";
70 auto &context = layer.getRunContext();
71 auto create_variables = [](auto tensor_getter, unsigned size) {
74 for (unsigned i = 0; i < size; ++i) {
75 v.push_back(tensor_getter(i));
81 * Q1) Why convert from NCHW to NHWC?
82 * A1) the tflite uses NHWC format; nntrainer uses NCHW
84 * Q2) Why are only output tensors reshaped?
85 * A2) the tflite needs only one tensor between nodes. Therefore,
86 * basically, outputs are used for tflite tensors
88 auto create_variables_with_NCHW_to_NHWC = [](auto tensor_getter,
92 for (unsigned i = 0; i < size; ++i) {
93 Tensor *tensor = const_cast<Tensor *>(tensor_getter(i));
94 tensor->reshape(TensorDim{tensor->batch(), tensor->height(),
95 tensor->width(), tensor->channel()});
101 inputs = create_variables(
102 [&context](unsigned idx) { return &context.getInput(idx); },
103 context.getNumInputs());
104 outputs = create_variables_with_NCHW_to_NHWC(
105 [&context](unsigned idx) { return &context.getOutput(idx); },
106 context.getNumOutputs());
107 weights = create_variables(
108 [&context](unsigned idx) {
109 auto &t = context.getWeight(idx);
110 NNTR_THROW_IF(t.empty() || !t.isAllocated(), std::invalid_argument)
111 << "every weight tensor must be allocated";
114 context.getNumWeights());
116 if (context.getNumWeights() == 0) {
117 is_trainable = false;
121 void TfOpNode::setWeightTransformFn(TransformFn fn) { weight_transform = fn; }
123 void TfOpNode::setInputTransformFn(TransformFn fn) { input_transform = fn; }
125 void TfOpNode::setWeights(Variables weights_) {
126 unsigned int cnt = 0;
127 for (auto &w : weights_) {
128 const unsigned int UNIT = w->batch();
129 const unsigned int CHANNEL = w->channel();
130 const unsigned int HEIGHT = w->height();
131 const unsigned int WIDTH = w->width();
133 auto weight_data = weights.at(cnt)->getData();
134 auto *ptr = const_cast<float *>(weight_data);
135 memcpy(&ptr[0], &w->getData()[0],
136 sizeof(float) * (UNIT * CHANNEL * HEIGHT * WIDTH));
141 void TfOpNode::weightReorder(unsigned int node_count) {
143 if (need_reorder_weight == true) {
145 auto previous_input_shape = input_nodes[0]->getInputs()[0];
147 const unsigned int unit = outputs[0]->height();
148 const unsigned int channel = previous_input_shape->channel();
149 const unsigned int height = previous_input_shape->height();
150 const unsigned int width = previous_input_shape->width();
152 auto weight_data = weights[0]->getData();
153 auto *ptr = const_cast<float *>(weight_data);
155 std::vector<float> old_value_list(unit * channel * height * width);
156 memcpy(&old_value_list[0], &ptr[0],
157 sizeof(float) * (unit * channel * height * width));
159 for (unsigned int h = 0; h < height; h++) {
160 for (unsigned int w = 0; w < width; w++) {
161 for (unsigned int c = 0; c < channel; c++) {
163 unsigned int now_position = h * (width * channel) + w * channel + c;
164 unsigned int next_position = c * (height * width) + h * width + w;
166 memcpy(&ptr[now_position * unit],
167 &old_value_list[next_position * unit], sizeof(float) * unit);
173 auto weight_transform_fn = [](std::vector<const Tensor *> &weights) {
174 std::vector<Tensor> new_weights;
175 new_weights.reserve(weights.size());
176 new_weights.push_back(weights[0]->transpose("0:2:1"));
177 new_weights.push_back(*weights[1]);
181 setWeightTransformFn(weight_transform_fn);
183 auto transform_if = [this](TransformFn &fn, Variables &v) {
186 v.resize(result.size());
187 node_owned_variable.insert(node_owned_variable.end(), result.begin(),
189 std::transform(node_owned_variable.end() - result.size(),
190 node_owned_variable.end(), v.begin(),
191 [](Tensor &t) { return &t; });
195 transform_if(weight_transform, weights);
198 void TfOpNode::finalize() {
199 auto transform_if = [this](TransformFn &fn, Variables &v) {
202 v.resize(result.size());
203 /// basically, result.size() == v.size() except InputLayer because a
204 /// Transpose operator is added for converting nchw to nhwc
205 /// @todo comment out below codes. TfOpNode needs to have LayerNode
207 // NNTR_THROW_IF(dynamic_cast<InputLayer>(layer_ptr->getLayer()) ==
208 // nullptr && result.size() != v.size(), std::invalid_argument)
209 // << "result size must match with given variable size";
210 node_owned_variable.insert(node_owned_variable.end(), result.begin(),
212 std::transform(node_owned_variable.end() - result.size(),
213 node_owned_variable.end(), v.begin(),
214 [](Tensor &t) { return &t; });
218 transform_if(weight_transform, weights);
219 transform_if(input_transform, inputs);
222 flatbuffers::Offset<void> TfOpNode::getBuiltinOps() const {
224 case tflite::BuiltinOperator_ADD:
225 case tflite::BuiltinOperator_AVERAGE_POOL_2D:
226 case tflite::BuiltinOperator_CONV_2D:
227 case tflite::BuiltinOperator_FULLY_CONNECTED:
228 case tflite::BuiltinOperator_RELU:
229 case tflite::BuiltinOperator_RESHAPE:
230 case tflite::BuiltinOperator_SOFTMAX:
231 case tflite::BuiltinOperator_TRANSPOSE:
232 case tflite::BuiltinOperator_MUL:
236 throw std::runtime_error{"Unsupported operator"};
240 void TfOpNode::setBuiltinOptions(
241 tflite::BuiltinOptions builtin_option_type_,
242 const flatbuffers::Offset<void> &builtin_ops_) {
243 builtin_ops = builtin_ops_;
244 builtin_option_type = builtin_option_type_;
247 } // namespace nntrainer