[TFLite Export] Add variable, functions TfOpNodes for Fused OP export
authorDongHak Park <donghak.park@samsung.com>
Fri, 14 Apr 2023 08:27:46 +0000 (17:27 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Thu, 27 Jul 2023 09:19:40 +0000 (18:19 +0900)
for Export Tflite format with Fused Op add some Variable and Function

1. Add getter, setter, replace to weights
- for Fused Op we need to adjust weights after made Opnode

2. Add isToBeRemove variable
- After made Opnode, check condition and mark as to be remove

3. Add additional_props
- for BatchNormalization Fused Op we need additional props from nntrainer
- made vector<float> variable for save additional data

Signed-off-by: DongHak Park <donghak.park@samsung.com>
nntrainer/compiler/tflite_opnode.cpp
nntrainer/compiler/tflite_opnode.h

index d32f9c4..c542d3b 100644 (file)
@@ -28,6 +28,7 @@ TfOpNode::TfOpNode() :
   is_output(false),
   is_virtual(false),
   is_trainable(true),
+  is_to_be_removed(false),
   need_reorder_weight(false),
   node_owned_variable(),
   /// @todo distinguish between uninitialized and ADD operator.
@@ -111,12 +112,32 @@ void TfOpNode::setLayerNode(const LayerNode &layer) {
       return &t;
     },
     context.getNumWeights());
+
+  if (context.getNumWeights() == 0) {
+    is_trainable = false;
+  }
 }
 
 void TfOpNode::setWeightTransformFn(TransformFn fn) { weight_transform = fn; }
 
 void TfOpNode::setInputTransformFn(TransformFn fn) { input_transform = fn; }
 
+void TfOpNode::setWeights(Variables weights_) {
+  unsigned int cnt = 0;
+  for (auto &w : weights_) {
+    const unsigned int UNIT = w->batch();
+    const unsigned int CHANNEL = w->channel();
+    const unsigned int HEIGHT = w->height();
+    const unsigned int WIDTH = w->width();
+
+    auto weight_data = weights.at(cnt)->getData();
+    auto *ptr = const_cast<float *>(weight_data);
+    memcpy(&ptr[0], &w->getData()[0],
+           sizeof(float) * (UNIT * CHANNEL * HEIGHT * WIDTH));
+    cnt++;
+  }
+}
+
 void TfOpNode::weightReorder(unsigned int node_count) {
 
   if (need_reorder_weight == true) {
@@ -209,6 +230,7 @@ flatbuffers::Offset<void> TfOpNode::getBuiltinOps() const {
   case tflite::BuiltinOperator_SOFTMAX:
   case tflite::BuiltinOperator_TRANSPOSE:
   case tflite::BuiltinOperator_MUL:
+
     return builtin_ops;
   default:
     throw std::runtime_error{"Unsupported operator"};
index 5a752d0..9d7ba2f 100644 (file)
@@ -95,12 +95,60 @@ public:
   void setNeedReorderWeight() { need_reorder_weight = true; }
 
   /**
+   * @brief Set the To Be Removed object
+   *
+   */
+  void setToBeRemoved(bool to_be_removed) { is_to_be_removed = to_be_removed; }
+
+  /**
    * @brief Set the Trainable object
    *
    */
   void setTrainable(bool trainable) { is_trainable = trainable; }
 
   /**
+   * @brief Set the Inputs object
+   *
+   * @param inputs_
+   */
+  void setInputs(const Variables &inputs_) { inputs = inputs_; }
+
+  /**
+   * @brief Set the Outputs object
+   *
+   * @param outputs_
+   */
+  void setOutputs(const Variables &outputs_) { outputs = outputs_; }
+
+  /**
+   * @brief Set the Weights object
+   *
+   * @param weights_
+   */
+  void setWeights(Variables weights_);
+  /**
+   * @brief Replace the Weights object
+   *
+   * @param weights_
+   */
+  void replaceWeights(const Variables &weights_) { weights = weights_; }
+  /**
+   * @brief Set(Append) the Props object
+   *
+   * @param value
+   */
+  void AppendProps(const int &value) { props_vector.push_back(value); }
+
+  /**
+   * @brief Set(Append) the Additional Props object
+   *
+   * @param value
+   */
+  void AppendAdditionalProps(const float &value) {
+    additional_props.push_back(value);
+  }
+
+  /**
    * @brief Reorder Weight in case of NCHW --> NHWC
    *
    * @param node_count
@@ -189,6 +237,28 @@ public:
   bool isTrainable() const { return is_trainable; }
 
   /**
+   * @brief check if this layer is to be removed
+   *
+   * @return true
+   * @return false
+   */
+  bool isToBeRemoved() const { return is_to_be_removed; }
+
+  /**
+   * @brief Get the Props Vector
+   *
+   * @return const std::vector<int> props_vector
+   */
+  std::vector<int> getProps() const { return props_vector; }
+
+  /**
+   * @brief Get the Additional Props Vector
+   *
+   * @return const std::vector<float> additional_props
+   */
+  std::vector<float> getAdditionalProps() const { return additional_props; }
+
+  /**
    * @brief Get the Op Type object
    *
    * @return const tflite::BuiltinOperator
@@ -252,19 +322,22 @@ private:
   Variables outputs;                   /**< output variables */
   Variables weights;                   /**< weight variables */
   std::vector<TfOpNode *> input_nodes; /**< input nodes */
+  std::vector<int> props_vector;       /**< props vector */
+  std::vector<float> additional_props; /**< additional props vector */
 
-  TransformFn weight_transform; /**< weight transforms */
   /**
    * Q) Why do we need input transform?
    * A) To transform the nntrainer input data format(NCHW) to tflite
    *format(NHWC)
    **/
-  TransformFn input_transform; /**< input transforms */
+  TransformFn weight_transform; /**< weight transforms */
+  TransformFn input_transform;  /**< input transforms */
 
   bool is_input;            /**< true if given input is input; */
   bool is_output;           /**< true if given output is output; */
   bool is_virtual;          /**< true if given node is virtual; */
   bool is_trainable;        /**< true if given node has weight and trainable */
+  bool is_to_be_removed;    /**< true if given node is to be removed */
   bool need_reorder_weight; /**< true if given node need to reorder weight; */
 
   /// @todo change to shared_ptr or unique_ptr