* @brief NNTrainer *.tflite Interpreter
* @see https://github.com/nnstreamer/nntrainer
* @author Jihoon Lee <jhoon.it.lee@samsung.com>
+ * @author Donghak Park <donghak.park@samsung.com>
* @bug No known bugs except for NYI items
*/
#include <tflite_interpreter.h>
#include <fstream>
#include <map>
#include <memory>
+#include <regex>
#include <set>
#include <string>
#include <tuple>
uint8_t *buf = builder.GetBufferPointer();
size_t size = builder.GetSize();
flatbuffers::Verifier v(buf, size);
-
NNTR_THROW_IF(!tflite::VerifyModelBuffer(v), std::invalid_argument)
<< FUNC_TAG << "Verifying serialized model failed";
-
std::ofstream os(out, std::ios_base::binary);
const size_t error_buflen = 100;
char error_buf[error_buflen];
Exporter e(&fbb);
ln->exportTo(e, ml::train::ExportMethods::METHOD_TFLITE);
+ auto export_output = e.getResult<ml::train::ExportMethods::METHOD_TFLITE>();
+
+ if (export_output.get()->getWeights().size() == 0) {
+ export_output.get()->setTrainable(false);
+ }
- nodes.emplace_back(e.getResult<ml::train::ExportMethods::METHOD_TFLITE>());
+ nodes.emplace_back(move(export_output));
tf_to_layer.insert({nodes.back().get(), ln.get()});
layer_to_tf.insert({ln.get(), nodes.back().get()});
}
* [FC]:local_first
*/
+ // set reorder weight flag for FullyConnected layer
for (auto &n : nodes) {
auto tf_node = n.get();
tf_node->weightReorder(node_count);
}
+ if (tf_node->getOpType() ==
+ tflite::BuiltinOperator::BuiltinOperator_CONV_2D &&
+ nodes.at(node_count + 1).get()->getOpType() ==
+ tflite::BuiltinOperator::BuiltinOperator_MUL &&
+ nodes.at(node_count + 2).get()->getOpType() ==
+ tflite::BuiltinOperator::BuiltinOperator_RELU) {
+ // Fuse Conv2D + Mul + ReLU to Conv2D
+
+ auto props = tf_node->getProps();
+ auto tf_padding = tflite::Padding_SAME;
+
+ if (props[0] == 1) {
+ tf_padding = tflite::Padding_VALID;
+ }
+ auto new_options =
+ tflite::CreateConv2DOptions(fbb, tf_padding, props[1], props[2],
+ tflite::ActivationFunctionType_RELU)
+ .Union();
+ tf_node->setBuiltinOptions(tflite::BuiltinOptions_Conv2DOptions,
+ new_options);
+ // After Fusing Mark ReLU Node to be removed
+ nodes.at(node_count + 2).get()->setToBeRemoved(true);
+ }
+
+ if (node_count < 1) {
+ node_count++;
+ continue;
+ } else {
+ if (nodes.at(node_count - 1).get()->isTrainable() == true &&
+ tf_node->getOpType() == tflite::BuiltinOperator_MUL) {
+
+ // Fused weight(conv)
+ // = weight(conv) * (weight(bn) / sqrt(var(bn) + eps))
+
+ auto conv_weights = nodes.at(node_count - 1).get()->getWeights();
+ auto conv_weight = conv_weights.at(0)->clone();
+ auto conv_bias = conv_weights.at(1)->clone();
+
+ auto mul_weights = tf_node->getWeights();
+ auto mul_mean = mul_weights.at(0)->clone();
+ auto mul_var = mul_weights.at(1)->clone();
+ auto mul_weight = mul_weights.at(2)->clone();
+ auto mul_bias = mul_weights.at(3)->clone();
+ auto mul_epsilon = tf_node->getAdditionalProps().at(0);
+
+ // run sqrt(var(bn) + eps)
+ mul_var.add_i(mul_epsilon);
+ mul_var.pow_i(0.5f);
+ mul_weight.divide_i(mul_var);
+
+ mul_weight.reshape(TensorDim({mul_weight.getDim().channel(), 1, 1, 1}));
+ conv_weight.multiply_i(mul_weight);
+
+ mul_weight.reshape(TensorDim({1, 1, 1, mul_weight.getDim().batch()}));
+ conv_bias.subtract_i(mul_mean);
+ conv_bias.multiply_i(mul_weight);
+ conv_bias.add_i(mul_bias);
+
+ TfOpNode::Variables conv_new_weights;
+ conv_new_weights.push_back(&conv_weight);
+ conv_new_weights.push_back(&conv_bias);
+ nodes.at(node_count - 1).get()->setWeights(conv_new_weights);
+
+ // set mul node to be removed (mul mean batch normalization)
+ n->setToBeRemoved(true);
+ }
+ }
node_count++;
}
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<tflite::Tensor>>>
buildTensors(const TfOpIdxMap &map, flatbuffers::FlatBufferBuilder &fbb) {
- /// @todo: the actual (suqeezed) tensor dimension must be known before
+ /// @todo: the actual (squeezed) tensor dimension must be known before
/// coming here. For now, it is directly guessed for the fc layer
const auto &variables = map.getTensors();
const auto &buffer_map = map.getIndexMap<const float *, TfOpIdxMap::Buffer>();
} // namespace
+TfOpNodes buildRealizedOpNodes(TfOpNodes &nodes,
+ flatbuffers::FlatBufferBuilder &fbb) {
+ TfOpNodes realized_nodes;
+
+ bool set_input = false;
+ unsigned int node_count = 0;
+
+ for (auto &node : nodes) {
+
+ if (set_input) { // if front node is new added node set input output
+ node->setArg(0, realized_nodes.back().get());
+ realized_nodes.back()->setOutputs(node->getInputs());
+ set_input = false;
+ }
+
+ if (node->isToBeRemoved() == true) { // Remove node
+ realized_nodes.back().get()->setOutputs(
+ nodes.at(node_count + 1)->getInputs());
+ nodes.at(node_count + 1)->setArg(0, realized_nodes.back().get());
+ nodes.at(node_count + 1)
+ ->setInputs(realized_nodes.back().get()->getOutputs());
+ } else {
+ realized_nodes.push_back(std::move(node));
+
+ if (realized_nodes.back().get()->getOpType() ==
+ tflite::BuiltinOperator_MUL) { // Fused MUL ADD (Non Trainable)
+
+ // remove weights (In .tflite this mean INPUTS)
+ auto removed_weights = realized_nodes.back().get()->getWeights();
+ // y = x
+ // * (gamma / sqrt(variance + epsilon))
+ // + (beta - mean * gamma / sqrt(variance + epsilon) )
+ auto mul_mean = removed_weights.at(0)->clone();
+ auto mul_variance = removed_weights.at(1)->clone();
+ auto mul_gamma = removed_weights.at(2)->clone();
+ auto mul_beta = removed_weights.at(3)->clone();
+ auto mul_epsilon =
+ realized_nodes.back().get()->getAdditionalProps().at(0);
+
+ auto new_mul_weight = mul_gamma.clone();
+ new_mul_weight.allocate();
+
+ mul_variance.add_i(mul_epsilon);
+ mul_variance.pow_i(0.5f);
+ new_mul_weight.divide_i(mul_variance);
+
+ mul_mean.multiply_i(mul_gamma);
+ mul_beta.subtract_i(mul_mean);
+ mul_beta.divide_i(mul_variance);
+
+ auto ptr_add_weight = removed_weights.at(1);
+
+ removed_weights.clear();
+ removed_weights.push_back(&new_mul_weight);
+ realized_nodes.back().get()->setWeights(removed_weights);
+
+ auto removed_weights2 = realized_nodes.back().get()->getWeights();
+ removed_weights2.pop_back();
+ removed_weights2.pop_back();
+ removed_weights2.pop_back();
+ realized_nodes.back().get()->replaceWeights(removed_weights2);
+
+ TfOpNode tf_node;
+ tf_node.setInputs(realized_nodes.back()->getOutputs());
+ tf_node.setOpType(tflite::BuiltinOperator_ADD);
+ auto options =
+ tflite::CreateAddOptions(fbb, tflite::ActivationFunctionType_RELU)
+ .Union();
+
+ auto add_weights = realized_nodes.back().get()->getWeights();
+ add_weights.clear();
+ add_weights.push_back(ptr_add_weight);
+ tf_node.replaceWeights(add_weights);
+
+ auto new_weight_add = mul_beta.clone();
+ auto new_variable = tf_node.getWeights();
+ new_variable.clear();
+ new_variable.push_back(&new_weight_add);
+ tf_node.setWeights(new_variable);
+
+ tf_node.setBuiltinOptions(tflite::BuiltinOptions_AddOptions, options);
+ tf_node.finalize();
+
+ nodes.at(node_count + 1)
+ .get()
+ ->setToBeRemoved(true); // remove ReLU Layer and Fuse with Add
+
+ auto mul_node = realized_nodes.back().get();
+ tf_node.arity(1);
+ tf_node.setArg(0, mul_node);
+ //
+
+ std::unique_ptr<TfOpNode> ptr = std::make_unique<TfOpNode>(tf_node);
+ realized_nodes.push_back(std::move(ptr));
+ set_input = true;
+ }
+ }
+ node_count++;
+ }
+
+ return realized_nodes;
+}
+
void TfliteInterpreter::serialize(const GraphRepresentation &representation,
const std::string &out) {
- /// @todo check if graph is finalized & initialized and ready to serialize.
-
- /// 0. remove batch normalization layer in GraphRepresentation
- BnRealizer realizer({});
- GraphRepresentation graph = realizer.realize(representation);
/// 1. remove loss layer in GraphRepresentation
LossRealizer loss_realizer({});
- graph = loss_realizer.realize(graph);
+ GraphRepresentation graph = loss_realizer.realize(representation);
/// 2. The graph must have weights, input dims, output dims set
flatbuffers::FlatBufferBuilder fbb;
auto opNodes = buildOpNodes(graph, fbb);
- TfOpIdxMap map(opNodes); /// build TfOpIdxMap from opNodes
+ auto converted_opNodes = buildRealizedOpNodes(opNodes, fbb);
+ TfOpIdxMap map(converted_opNodes); /// build TfOpIdxMap from opNodes
auto opcodes = buildOperatorCodes(map, fbb);
- auto subgraphs = buildSubGraphs(opNodes, map, fbb);
+ auto subgraphs = buildSubGraphs(converted_opNodes, map, fbb);
auto buffers = buildBuffers(map, fbb);
auto desc = fbb.CreateString("This file is generated from NNTrainer");
-
tflite::ModelBuilder model_builder(fbb);
model_builder.add_operator_codes(opcodes);
void TfOpNode::setWeights(Variables weights_) {
unsigned int cnt = 0;
for (auto &w : weights_) {
- const unsigned int UNIT = w->batch();
- const unsigned int CHANNEL = w->channel();
- const unsigned int HEIGHT = w->height();
- const unsigned int WIDTH = w->width();
+ const unsigned int unit = w->batch();
+ const unsigned int channel = w->channel();
+ const unsigned int height = w->height();
+ const unsigned int width = w->width();
auto weight_data = weights.at(cnt)->getData();
auto *ptr = const_cast<float *>(weight_data);
memcpy(&ptr[0], &w->getData()[0],
- sizeof(float) * (UNIT * CHANNEL * HEIGHT * WIDTH));
+ sizeof(float) * (unit * channel * height * width));
cnt++;
}
}
auto previous_input_shape = input_nodes[0]->getInputs()[0];
- const unsigned int UNIT = outputs[0]->height();
- const unsigned int CHANNEL = previous_input_shape->channel();
- const unsigned int HEIGHT = previous_input_shape->height();
- const unsigned int WIDTH = previous_input_shape->width();
+ const unsigned int unit = outputs[0]->height();
+ const unsigned int channel = previous_input_shape->channel();
+ const unsigned int height = previous_input_shape->height();
+ const unsigned int width = previous_input_shape->width();
auto weight_data = weights[0]->getData();
auto *ptr = const_cast<float *>(weight_data);
- std::vector<float> old_value_list(UNIT * CHANNEL * HEIGHT * WIDTH);
+ std::vector<float> old_value_list(unit * channel * height * width);
memcpy(&old_value_list[0], &ptr[0],
- sizeof(float) * (UNIT * CHANNEL * HEIGHT * WIDTH));
+ sizeof(float) * (unit * channel * height * width));
- for (unsigned int h = 0; h < HEIGHT; h++) {
- for (unsigned int w = 0; w < WIDTH; w++) {
- for (unsigned int c = 0; c < CHANNEL; c++) {
+ for (unsigned int h = 0; h < height; h++) {
+ for (unsigned int w = 0; w < width; w++) {
+ for (unsigned int c = 0; c < channel; c++) {
- unsigned int now_position = h * (WIDTH * CHANNEL) + w * CHANNEL + c;
- unsigned int next_position = c * (HEIGHT * WIDTH) + h * WIDTH + w;
+ unsigned int now_position = h * (width * channel) + w * channel + c;
+ unsigned int next_position = c * (height * width) + h * width + w;
- memcpy(&ptr[now_position * UNIT],
- &old_value_list[next_position * UNIT], sizeof(float) * UNIT);
+ memcpy(&ptr[now_position * unit],
+ &old_value_list[next_position * unit], sizeof(float) * unit);
}
}
}