[TFLite Export] Add Realized Path for Fused Op
authorDongHak Park <donghak.park@samsung.com>
Fri, 14 Apr 2023 08:35:07 +0000 (17:35 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Thu, 27 Jul 2023 10:38:00 +0000 (19:38 +0900)
For Fused OP Made Realized Path

1. Check Trainable
 - check node is trainable or not for fusing
2. Conv + ReLU Fusing
3. Batch Normalization Fusing

Signed-off-by: DongHak Park <donghak.park@samsung.com>
nntrainer/compiler/tflite_interpreter.cpp
nntrainer/compiler/tflite_opnode.cpp

index 43ad4fa..22a930e 100644 (file)
@@ -7,6 +7,7 @@
  * @brief NNTrainer *.tflite Interpreter
  * @see        https://github.com/nnstreamer/nntrainer
  * @author Jihoon Lee <jhoon.it.lee@samsung.com>
+ * @author Donghak Park <donghak.park@samsung.com>
  * @bug No known bugs except for NYI items
  */
 #include <tflite_interpreter.h>
@@ -15,6 +16,7 @@
 #include <fstream>
 #include <map>
 #include <memory>
+#include <regex>
 #include <set>
 #include <string>
 #include <tuple>
@@ -47,10 +49,8 @@ void builder2file(const flatbuffers::FlatBufferBuilder &builder,
   uint8_t *buf = builder.GetBufferPointer();
   size_t size = builder.GetSize();
   flatbuffers::Verifier v(buf, size);
-
   NNTR_THROW_IF(!tflite::VerifyModelBuffer(v), std::invalid_argument)
     << FUNC_TAG << "Verifying serialized model failed";
-
   std::ofstream os(out, std::ios_base::binary);
   const size_t error_buflen = 100;
   char error_buf[error_buflen];
@@ -323,8 +323,13 @@ TfOpNodes buildOpNodes(const GraphRepresentation &representation,
 
     Exporter e(&fbb);
     ln->exportTo(e, ml::train::ExportMethods::METHOD_TFLITE);
+    auto export_output = e.getResult<ml::train::ExportMethods::METHOD_TFLITE>();
+
+    if (export_output.get()->getWeights().size() == 0) {
+      export_output.get()->setTrainable(false);
+    }
 
-    nodes.emplace_back(e.getResult<ml::train::ExportMethods::METHOD_TFLITE>());
+    nodes.emplace_back(move(export_output));
     tf_to_layer.insert({nodes.back().get(), ln.get()});
     layer_to_tf.insert({ln.get(), nodes.back().get()});
   }
@@ -338,6 +343,7 @@ TfOpNodes buildOpNodes(const GraphRepresentation &representation,
    * [FC]:local_first
    */
 
+  // set reorder weight flag for FullyConnected layer
   for (auto &n : nodes) {
     auto tf_node = n.get();
 
@@ -394,6 +400,73 @@ TfOpNodes buildOpNodes(const GraphRepresentation &representation,
       tf_node->weightReorder(node_count);
     }
 
+    if (tf_node->getOpType() ==
+          tflite::BuiltinOperator::BuiltinOperator_CONV_2D &&
+        nodes.at(node_count + 1).get()->getOpType() ==
+          tflite::BuiltinOperator::BuiltinOperator_MUL &&
+        nodes.at(node_count + 2).get()->getOpType() ==
+          tflite::BuiltinOperator::BuiltinOperator_RELU) {
+      // Fuse Conv2D + Mul + ReLU to Conv2D
+
+      auto props = tf_node->getProps();
+      auto tf_padding = tflite::Padding_SAME;
+
+      if (props[0] == 1) {
+        tf_padding = tflite::Padding_VALID;
+      }
+      auto new_options =
+        tflite::CreateConv2DOptions(fbb, tf_padding, props[1], props[2],
+                                    tflite::ActivationFunctionType_RELU)
+          .Union();
+      tf_node->setBuiltinOptions(tflite::BuiltinOptions_Conv2DOptions,
+                                 new_options);
+      // After Fusing Mark ReLU Node to be removed
+      nodes.at(node_count + 2).get()->setToBeRemoved(true);
+    }
+
+    if (node_count < 1) {
+      node_count++;
+      continue;
+    } else {
+      if (nodes.at(node_count - 1).get()->isTrainable() == true &&
+          tf_node->getOpType() == tflite::BuiltinOperator_MUL) {
+
+        // Fused weight(conv)
+        // = weight(conv) * (weight(bn) / sqrt(var(bn) + eps))
+
+        auto conv_weights = nodes.at(node_count - 1).get()->getWeights();
+        auto conv_weight = conv_weights.at(0)->clone();
+        auto conv_bias = conv_weights.at(1)->clone();
+
+        auto mul_weights = tf_node->getWeights();
+        auto mul_mean = mul_weights.at(0)->clone();
+        auto mul_var = mul_weights.at(1)->clone();
+        auto mul_weight = mul_weights.at(2)->clone();
+        auto mul_bias = mul_weights.at(3)->clone();
+        auto mul_epsilon = tf_node->getAdditionalProps().at(0);
+
+        // run sqrt(var(bn) + eps)
+        mul_var.add_i(mul_epsilon);
+        mul_var.pow_i(0.5f);
+        mul_weight.divide_i(mul_var);
+
+        mul_weight.reshape(TensorDim({mul_weight.getDim().channel(), 1, 1, 1}));
+        conv_weight.multiply_i(mul_weight);
+
+        mul_weight.reshape(TensorDim({1, 1, 1, mul_weight.getDim().batch()}));
+        conv_bias.subtract_i(mul_mean);
+        conv_bias.multiply_i(mul_weight);
+        conv_bias.add_i(mul_bias);
+
+        TfOpNode::Variables conv_new_weights;
+        conv_new_weights.push_back(&conv_weight);
+        conv_new_weights.push_back(&conv_bias);
+        nodes.at(node_count - 1).get()->setWeights(conv_new_weights);
+
+        // set mul node to be removed (mul mean batch normalization)
+        n->setToBeRemoved(true);
+      }
+    }
     node_count++;
   }
 
@@ -456,7 +529,7 @@ buildOperatorCodes(const TfOpIdxMap &map, flatbuffers::FlatBufferBuilder &fbb) {
 
 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<tflite::Tensor>>>
 buildTensors(const TfOpIdxMap &map, flatbuffers::FlatBufferBuilder &fbb) {
-  /// @todo: the actual (suqeezed) tensor dimension must be known before
+  /// @todo: the actual (squeezed) tensor dimension must be known before
   /// coming here. For now, it is directly guessed for the fc layer
   const auto &variables = map.getTensors();
   const auto &buffer_map = map.getIndexMap<const float *, TfOpIdxMap::Buffer>();
@@ -625,29 +698,127 @@ buildSubGraphs(const TfOpNodes &nodes, const TfOpIdxMap &map,
 
 } // namespace
 
+TfOpNodes buildRealizedOpNodes(TfOpNodes &nodes,
+                               flatbuffers::FlatBufferBuilder &fbb) {
+  TfOpNodes realized_nodes;
+
+  bool set_input = false;
+  unsigned int node_count = 0;
+
+  for (auto &node : nodes) {
+
+    if (set_input) { // if front node is new added node set input output
+      node->setArg(0, realized_nodes.back().get());
+      realized_nodes.back()->setOutputs(node->getInputs());
+      set_input = false;
+    }
+
+    if (node->isToBeRemoved() == true) { // Remove node
+      realized_nodes.back().get()->setOutputs(
+        nodes.at(node_count + 1)->getInputs());
+      nodes.at(node_count + 1)->setArg(0, realized_nodes.back().get());
+      nodes.at(node_count + 1)
+        ->setInputs(realized_nodes.back().get()->getOutputs());
+    } else {
+      realized_nodes.push_back(std::move(node));
+
+      if (realized_nodes.back().get()->getOpType() ==
+          tflite::BuiltinOperator_MUL) { // Fused MUL ADD (Non Trainable)
+
+        // remove weights (In .tflite this mean INPUTS)
+        auto removed_weights = realized_nodes.back().get()->getWeights();
+        // y = x
+        // * (gamma / sqrt(variance + epsilon))
+        // + (beta - mean * gamma / sqrt(variance + epsilon) )
+        auto mul_mean = removed_weights.at(0)->clone();
+        auto mul_variance = removed_weights.at(1)->clone();
+        auto mul_gamma = removed_weights.at(2)->clone();
+        auto mul_beta = removed_weights.at(3)->clone();
+        auto mul_epsilon =
+          realized_nodes.back().get()->getAdditionalProps().at(0);
+
+        auto new_mul_weight = mul_gamma.clone();
+        new_mul_weight.allocate();
+
+        mul_variance.add_i(mul_epsilon);
+        mul_variance.pow_i(0.5f);
+        new_mul_weight.divide_i(mul_variance);
+
+        mul_mean.multiply_i(mul_gamma);
+        mul_beta.subtract_i(mul_mean);
+        mul_beta.divide_i(mul_variance);
+
+        auto ptr_add_weight = removed_weights.at(1);
+
+        removed_weights.clear();
+        removed_weights.push_back(&new_mul_weight);
+        realized_nodes.back().get()->setWeights(removed_weights);
+
+        auto removed_weights2 = realized_nodes.back().get()->getWeights();
+        removed_weights2.pop_back();
+        removed_weights2.pop_back();
+        removed_weights2.pop_back();
+        realized_nodes.back().get()->replaceWeights(removed_weights2);
+
+        TfOpNode tf_node;
+        tf_node.setInputs(realized_nodes.back()->getOutputs());
+        tf_node.setOpType(tflite::BuiltinOperator_ADD);
+        auto options =
+          tflite::CreateAddOptions(fbb, tflite::ActivationFunctionType_RELU)
+            .Union();
+
+        auto add_weights = realized_nodes.back().get()->getWeights();
+        add_weights.clear();
+        add_weights.push_back(ptr_add_weight);
+        tf_node.replaceWeights(add_weights);
+
+        auto new_weight_add = mul_beta.clone();
+        auto new_variable = tf_node.getWeights();
+        new_variable.clear();
+        new_variable.push_back(&new_weight_add);
+        tf_node.setWeights(new_variable);
+
+        tf_node.setBuiltinOptions(tflite::BuiltinOptions_AddOptions, options);
+        tf_node.finalize();
+
+        nodes.at(node_count + 1)
+          .get()
+          ->setToBeRemoved(true); // remove ReLU Layer and Fuse with Add
+
+        auto mul_node = realized_nodes.back().get();
+        tf_node.arity(1);
+        tf_node.setArg(0, mul_node);
+        //
+
+        std::unique_ptr<TfOpNode> ptr = std::make_unique<TfOpNode>(tf_node);
+        realized_nodes.push_back(std::move(ptr));
+        set_input = true;
+      }
+    }
+    node_count++;
+  }
+
+  return realized_nodes;
+}
+
 void TfliteInterpreter::serialize(const GraphRepresentation &representation,
                                   const std::string &out) {
-  /// @todo check if graph is finalized & initialized and ready to serialize.
-
-  /// 0. remove batch normalization layer in GraphRepresentation
-  BnRealizer realizer({});
-  GraphRepresentation graph = realizer.realize(representation);
 
   /// 1. remove loss layer in GraphRepresentation
   LossRealizer loss_realizer({});
-  graph = loss_realizer.realize(graph);
+  GraphRepresentation graph = loss_realizer.realize(representation);
 
   /// 2. The graph must have weights, input dims, output dims set
   flatbuffers::FlatBufferBuilder fbb;
 
   auto opNodes = buildOpNodes(graph, fbb);
-  TfOpIdxMap map(opNodes); /// build TfOpIdxMap from opNodes
+  auto converted_opNodes = buildRealizedOpNodes(opNodes, fbb);
 
+  TfOpIdxMap map(converted_opNodes); /// build TfOpIdxMap from opNodes
   auto opcodes = buildOperatorCodes(map, fbb);
-  auto subgraphs = buildSubGraphs(opNodes, map, fbb);
+  auto subgraphs = buildSubGraphs(converted_opNodes, map, fbb);
   auto buffers = buildBuffers(map, fbb);
   auto desc = fbb.CreateString("This file is generated from NNTrainer");
-
   tflite::ModelBuilder model_builder(fbb);
 
   model_builder.add_operator_codes(opcodes);
index c542d3b..1ad3a95 100644 (file)
@@ -125,15 +125,15 @@ void TfOpNode::setInputTransformFn(TransformFn fn) { input_transform = fn; }
 void TfOpNode::setWeights(Variables weights_) {
   unsigned int cnt = 0;
   for (auto &w : weights_) {
-    const unsigned int UNIT = w->batch();
-    const unsigned int CHANNEL = w->channel();
-    const unsigned int HEIGHT = w->height();
-    const unsigned int WIDTH = w->width();
+    const unsigned int unit = w->batch();
+    const unsigned int channel = w->channel();
+    const unsigned int height = w->height();
+    const unsigned int width = w->width();
 
     auto weight_data = weights.at(cnt)->getData();
     auto *ptr = const_cast<float *>(weight_data);
     memcpy(&ptr[0], &w->getData()[0],
-           sizeof(float) * (UNIT * CHANNEL * HEIGHT * WIDTH));
+           sizeof(float) * (unit * channel * height * width));
     cnt++;
   }
 }
@@ -144,27 +144,27 @@ void TfOpNode::weightReorder(unsigned int node_count) {
 
     auto previous_input_shape = input_nodes[0]->getInputs()[0];
 
-    const unsigned int UNIT = outputs[0]->height();
-    const unsigned int CHANNEL = previous_input_shape->channel();
-    const unsigned int HEIGHT = previous_input_shape->height();
-    const unsigned int WIDTH = previous_input_shape->width();
+    const unsigned int unit = outputs[0]->height();
+    const unsigned int channel = previous_input_shape->channel();
+    const unsigned int height = previous_input_shape->height();
+    const unsigned int width = previous_input_shape->width();
 
     auto weight_data = weights[0]->getData();
     auto *ptr = const_cast<float *>(weight_data);
 
-    std::vector<float> old_value_list(UNIT * CHANNEL * HEIGHT * WIDTH);
+    std::vector<float> old_value_list(unit * channel * height * width);
     memcpy(&old_value_list[0], &ptr[0],
-           sizeof(float) * (UNIT * CHANNEL * HEIGHT * WIDTH));
+           sizeof(float) * (unit * channel * height * width));
 
-    for (unsigned int h = 0; h < HEIGHT; h++) {
-      for (unsigned int w = 0; w < WIDTH; w++) {
-        for (unsigned int c = 0; c < CHANNEL; c++) {
+    for (unsigned int h = 0; h < height; h++) {
+      for (unsigned int w = 0; w < width; w++) {
+        for (unsigned int c = 0; c < channel; c++) {
 
-          unsigned int now_position = h * (WIDTH * CHANNEL) + w * CHANNEL + c;
-          unsigned int next_position = c * (HEIGHT * WIDTH) + h * WIDTH + w;
+          unsigned int now_position = h * (width * channel) + w * channel + c;
+          unsigned int next_position = c * (height * width) + h * width + w;
 
-          memcpy(&ptr[now_position * UNIT],
-                 &old_value_list[next_position * UNIT], sizeof(float) * UNIT);
+          memcpy(&ptr[now_position * unit],
+                 &old_value_list[next_position * unit], sizeof(float) * unit);
         }
       }
     }