[ FC ] update incremental_forwarding to support LoRA and multi-batch
authorEunju Yang <ej.yang@samsung.com>
Mon, 2 Sep 2024 08:47:02 +0000 (17:47 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Sun, 22 Sep 2024 23:07:43 +0000 (08:07 +0900)
- This commit add some codes to support LoRA in incremental_forwarding.
- This commit updates the incremental_forwarding to support multiple
batch input. However, it is not the desirable way in that it cannot be
parallelized across the batch axis. I left this issue on the comment.

Self evaluation:

Build test: [X]Passed [ ]Failed [ ]Skipped
Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Eunju Yang <ej.yang@samsung.com>
nntrainer/layers/fc_layer.cpp

index 67a41f50ed49ae7d64cbaf28559c70ecf9e8ebbe..3ee52eed4a593d6ca7c17e0cf503f09100e48f4b 100644 (file)
@@ -230,20 +230,39 @@ void FullyConnectedLayer::incremental_forwarding(RunLayerContext &context,
     to = 1;
   }
 
+  input_step_dim.batch(1);
   input_step_dim.height(to - from);
+  hidden_step_dim.batch(1);
   hidden_step_dim.height(to - from);
 
-  // @todo: set reset stride as false. This implementation only works when batch
-  // size is 1
-  Tensor input_step = input_.getSharedDataTensor(input_step_dim, 0, true);
-  Tensor hidden_step = hidden_.getSharedDataTensor(hidden_step_dim, 0, true);
-
-  input_step.dot(weight, hidden_step, false, false);
+  // @todo make it parallelized with batch axis
+  for (unsigned int b = 0; b < hidden_.batch(); ++b) {
+    Tensor input_step = input_.getSharedDataTensor(
+      input_step_dim, b * hidden_dim.getFeatureLen(), true);
+    Tensor hidden_step = hidden_.getSharedDataTensor(
+      hidden_step_dim, b * hidden_dim.getFeatureLen(), true);
+
+    input_step.dot(weight, hidden_step, false, false);
+
+    if (!std::get<props::LoraRank>(fc_props).empty()) {
+      Tensor &loraA = context.getWeight(lora_idx[LORAParams::loraA]);
+      Tensor &loraB = context.getWeight(lora_idx[LORAParams::loraB]);
+      Tensor &hidden_tmp_lora =
+        context.getTensor(lora_idx[LORAParams::loraTmp]);
+      Tensor &hidden_out_lora =
+        context.getTensor(lora_idx[LORAParams::loraOut]);
+
+      input_step.dot(loraA, hidden_tmp_lora, false, false);
+      hidden_tmp_lora.dot(loraB, hidden_out_lora, false, false);
+      hidden_out_lora.multiply_i(lora_scaling);
+      hidden_step.add_i(hidden_out_lora);
+    }
 
-  if (auto &disable_bias = std::get<props::DisableBias>(*layer_impl_props);
-      disable_bias.empty() || disable_bias.get() == false) {
-    Tensor &bias = context.getWeight(weight_idx[FCParams::bias]);
-    hidden_step.add_i(bias);
+    if (auto &disable_bias = std::get<props::DisableBias>(*layer_impl_props);
+        disable_bias.empty() || disable_bias.get() == false) {
+      Tensor &bias = context.getWeight(weight_idx[FCParams::bias]);
+      hidden_step.add_i(bias);
+    }
   }
 }