[ LORA ] Bugfix in LoRA support in FC Layer
authorEunju Yang <ej.yang@samsung.com>
Thu, 5 Sep 2024 04:55:57 +0000 (13:55 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Sun, 22 Sep 2024 23:07:43 +0000 (08:07 +0900)
- In the previous code, LoRA didn't work for the case batch_size > 1.
- Tensors used in LoRA-related computation were not updated when the
batch size is upsted.
- `setBatch()` function is implemented for `FullyConnectedLayer`.
- BugFix in Lifespan of loraTmp Tensor: FORWARD_DERIV_LIFESPANE ->
FORWARD_GRAD_LIFESPAN

Self evaluation:

Build test: [X]Passed [ ]Failed [ ]Skipped
Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Eunju Yang <ej.yang@samsung.com>
nntrainer/layers/fc_layer.cpp
nntrainer/layers/fc_layer.h

index 3ee52eed4a593d6ca7c17e0cf503f09100e48f4b..3cf25e19772e3b7c3eb15fe9913f44f77c48df44 100644 (file)
@@ -91,11 +91,14 @@ void FullyConnectedLayer::finalize(InitLayerContext &context) {
   /** set weight specifications */
   // @todo : This NCHW format setting is just temporal, it needs to be set by
   // global configuration
+
+  /** Bias Dimension : (1, 1, 1, unit) */
   TensorDim bias_dim(
     1, is_nchw ? 1 : unit, 1, is_nchw ? unit : 1,
     TensorDim::TensorType(context.getFormat(), context.getWeightDataType()),
     is_nchw ? 0b0001 : 0b0100);
 
+  /** Weight Dimension : (1, 1, in_dim.width(), unit)*/
   TensorDim weight_dim(
     1, is_nchw ? 1 : unit, is_nchw ? in_dim.width() : 1,
     is_nchw ? unit : in_dim.channel(),
@@ -115,25 +118,33 @@ void FullyConnectedLayer::finalize(InitLayerContext &context) {
   /** create weights for LoRA */
   if (lora_rank) {
 
-    /** loraA : (in_dim.width, lora_rank) */
+    /** loraA Dimension : (1, 1, in_dim.width, lora_rank) */
     TensorDim loraA_dim(
       1, is_nchw ? 1 : lora_rank, is_nchw ? in_dim.width() : 1,
       is_nchw ? lora_rank : in_dim.channel(),
       TensorDim::TensorType(context.getFormat(), context.getWeightDataType()),
       is_nchw ? 0b0011 : 0b0101);
 
-    /** loraB: (lora_rank, out_dim) */
+    /** loraB Dimension : (1, 1, lora_rank, unit) */
     TensorDim loraB_dim(
       1, is_nchw ? 1 : unit, is_nchw ? lora_rank : 1,
       is_nchw ? unit : lora_rank,
       TensorDim::TensorType(context.getFormat(), context.getWeightDataType()),
       is_nchw ? 0b0011 : 0b0101);
 
-    /** loraTmp: (1, lora_rank) */
+    /** loraTmp Dimension : (B, 1, in_dim.height(), lora_rank) */
     TensorDim loraTmp_dim(
-      1, is_nchw ? 1 : lora_rank, 1, is_nchw ? lora_rank : 1,
+      in_dim.batch(), is_nchw ? 1 : lora_rank, is_nchw ? in_dim.height() : 1,
+      is_nchw ? lora_rank : in_dim.width(),
+      TensorDim::TensorType(context.getFormat(), context.getWeightDataType()),
+      is_nchw ? 0b1011 : 0b1101);
+
+    /** loraTmp Dimension : (B, 1, in_dim.height(), unit) */
+    TensorDim loraOut_dim(
+      in_dim.batch(), is_nchw ? 1 : unit, is_nchw ? in_dim.height() : 1,
+      is_nchw ? unit : in_dim.width(),
       TensorDim::TensorType(context.getFormat(), context.getWeightDataType()),
-      is_nchw ? 0b0001 : 0b0100);
+      is_nchw ? 0b1011 : 0b1101);
 
     lora_idx[LORAParams::loraA] = context.requestWeight(
       loraA_dim, Initializer::ZEROS, weight_regularizer,
@@ -145,10 +156,10 @@ void FullyConnectedLayer::finalize(InitLayerContext &context) {
 
     lora_idx[LORAParams::loraTmp] =
       context.requestTensor(loraTmp_dim, "hidden_tmp_lora", Initializer::NONE,
-                            true, TensorLifespan::FORWARD_DERIV_LIFESPAN);
+                            true, TensorLifespan::FORWARD_GRAD_LIFESPAN);
 
     lora_idx[LORAParams::loraOut] =
-      context.requestTensor(bias_dim, "hidden_lora", Initializer::NONE, true,
+      context.requestTensor(loraOut_dim, "hidden_lora", Initializer::NONE, true,
                             TensorLifespan::FORWARD_FUNC_LIFESPAN);
   }
 }
@@ -164,6 +175,15 @@ void FullyConnectedLayer::setProperty(const std::vector<std::string> &values) {
   LayerImpl::setProperty(remain_props);
 }
 
+void FullyConnectedLayer::setBatch(nntrainer::RunLayerContext &context,
+                                   unsigned int batch) {
+  if (!std::get<props::LoraRank>(fc_props).empty()) {
+    // update Lora Tensor's batch info.
+    context.updateTensor(lora_idx[LORAParams::loraTmp], batch);
+    context.updateTensor(lora_idx[LORAParams::loraOut], batch);
+  }
+}
+
 void FullyConnectedLayer::forwarding(RunLayerContext &context, bool training) {
   Tensor &weight = context.getWeight(weight_idx[FCParams::weight]);
   Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX);
index 44ef99d912dbc1d4f55e23032dd36b5417ff33d9..f901697413b814280f1b828194e505d50808fcf0 100644 (file)
@@ -103,6 +103,12 @@ public:
    */
   void setProperty(const std::vector<std::string> &values) override;
 
+  /**
+   * @copydoc Layer::setBatch(RunLayerContext &context, unsigned int batch)
+   */
+  void setBatch(nntrainer::RunLayerContext &context,
+                unsigned int batch) override;
+
   inline static const std::string type = "fully_connected";
 
 private: