[TF Export] Update tflite_opnode
authorDongHak Park <donghak.park@samsung.com>
Fri, 31 Mar 2023 07:58:12 +0000 (16:58 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Mon, 10 Apr 2023 22:41:05 +0000 (07:41 +0900)
Update tflite_opnode
- Add is_trainable for make fused op, by checking is trainable we can make fused of for inference
- Add MUL Op for BatchNormalization Fused Op

Signed-off-by: DongHak Park <donghak.park@samsung.com>
nntrainer/compiler/tflite_opnode.cpp
nntrainer/compiler/tflite_opnode.h

index f1b88b5..79af20d 100644 (file)
@@ -7,6 +7,7 @@
  * @brief contains tflite opnode which has information to convert to tflite file
  * @see        https://github.com/nnstreamer/nntrainer
  * @author Jihoon Lee <jhoon.it.lee@samsung.com>
+ * @author Donghak Park <donghak.park@samsung.com>
  * @bug No known bugs except for NYI items
  */
 
@@ -26,6 +27,7 @@ TfOpNode::TfOpNode() :
   is_input(false),
   is_output(false),
   is_virtual(false),
+  is_trainable(true),
   need_reorder_weight(false),
   node_owned_variable(),
   /// @todo distinguish between uninitialized and ADD operator.
@@ -206,6 +208,7 @@ flatbuffers::Offset<void> TfOpNode::getBuiltinOps() const {
   case tflite::BuiltinOperator_RESHAPE:
   case tflite::BuiltinOperator_SOFTMAX:
   case tflite::BuiltinOperator_TRANSPOSE:
+  case tflite::BuiltinOperator_MUL:
     return builtin_ops;
   default:
     throw std::runtime_error{"Unsupported operator"};
index dca1e46..5a752d0 100644 (file)
@@ -7,6 +7,7 @@
  * @brief contains tflite opnode which has information to convert to tflite file
  * @see        https://github.com/nnstreamer/nntrainer
  * @author Jihoon Lee <jhoon.it.lee@samsung.com>
+ * @author Donghak Park <donghak.park@samsung.com>
  * @bug No known bugs except for NYI items
  */
 
@@ -94,6 +95,12 @@ public:
   void setNeedReorderWeight() { need_reorder_weight = true; }
 
   /**
+   * @brief Set the Trainable object
+   *
+   */
+  void setTrainable(bool trainable) { is_trainable = trainable; }
+
+  /**
    * @brief Reorder Weight in case of NCHW --> NHWC
    *
    * @param node_count
@@ -168,12 +175,20 @@ public:
   /**
    * @brief check if this layer need to reorder
    *
-   * @return true
-   * @return false
+   * @return true if weight need to reorder
+   * @return false if reordering is not required
    */
   bool isNeedReorder() const { return need_reorder_weight; }
 
   /**
+   * @brief check if this layer is trainable
+   *
+   * @return true if layer(OpNode) trainable
+   * @return false if layer(OpNode) non-trainable
+   */
+  bool isTrainable() const { return is_trainable; }
+
+  /**
    * @brief Get the Op Type object
    *
    * @return const tflite::BuiltinOperator
@@ -249,6 +264,7 @@ private:
   bool is_input;            /**< true if given input is input; */
   bool is_output;           /**< true if given output is output; */
   bool is_virtual;          /**< true if given node is virtual; */
+  bool is_trainable;        /**< true if given node has weight and trainable */
   bool need_reorder_weight; /**< true if given node need to reorder weight; */
 
   /// @todo change to shared_ptr or unique_ptr