[Layer] Improve forwarding logic of ConcatLayer
authorDonghyeon Jeong <dhyeon.jeong@samsung.com>
Thu, 8 Aug 2024 08:27:20 +0000 (17:27 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Mon, 12 Aug 2024 01:23:22 +0000 (10:23 +0900)
This PR updates current ConcatLayer forwarding for faster computation.

**Changes proposed in this PR:**
- Utilize the Tensor::concat() operation to perform forwarding and replace manual mapping and copying.

**Self-evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test:   [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Donghyeon Jeong <dhyeon.jeong@samsung.com>
nntrainer/layers/concat_layer.cpp

index 8a28fb3e8036ef696e05945927579cf472aad9d9..5536c4a82d9c89ae230ae48df6fdf10c106fb551 100644 (file)
@@ -112,63 +112,43 @@ void ConcatLayer::forwarding(RunLayerContext &context, bool training) {
    * @todo avoid copy by creating input here as a shared_tensor of the output
    * here and then this layer can be in_place as well
    */
-  Tensor &output = context.getOutput(SINGLE_INOUT_IDX);
 
-  const TensorDim out_dim = output.getDim();
-  output.reshape(output_reshape_helper);
-  unsigned int output_height_offset = 0;
-  unsigned int data_copy_size = output_reshape_helper.width();
-  TensorDim::TensorType tensor_type = out_dim.getTensorType();
+  // Store original input tensor dimensions, then reshape input tensors.
+  std::vector<Tensor> input_tensors;
+  std::vector<TensorDim> original_input_dims;
 
   for (unsigned int idx = 0; idx < context.getNumInputs(); idx++) {
     Tensor &input = context.getInput(idx);
-    const TensorDim in_dim = input.getDim();
-    auto const &irh = input_reshape_helper[idx];
-    input.reshape(irh);
+    original_input_dims.push_back(input.getDim());
+    input.reshape(input_reshape_helper[idx]);
+    input_tensors.push_back(input);
+  }
 
-    if (in_dim.getDataType() == TensorDim::DataType::FP32) {
-      /** loop over the dimensions before the concat dimension */
-      for (unsigned int batch = 0; batch < output.batch(); batch++) {
-        /** loop over the concat dimension itself */
-        for (unsigned int count = 0; count < irh.height(); count++) {
-          Tensor dest_tensor = Tensor::Map<float>(
-            output.getAddress<float>(batch, 0, output_height_offset + count, 0),
-            data_copy_size * sizeof(float),
-            {1, 1, 1, data_copy_size, tensor_type});
-          const Tensor source_tensor =
-            Tensor::Map<float>(input.getAddress<float>(batch, 0, count, 0),
-                               data_copy_size * sizeof(float),
-                               {1, 1, 1, data_copy_size, tensor_type});
-          dest_tensor.copy(source_tensor);
-        }
-      }
-    } else if (in_dim.getDataType() == TensorDim::DataType::FP16) {
-#ifdef ENABLE_FP16
-      /** loop over the dimensions before the concat dimension */
-      for (unsigned int batch = 0; batch < output.batch(); batch++) {
-        /** loop over the concat dimension itself */
-        for (unsigned int count = 0; count < irh.height(); count++) {
-          Tensor dest_tensor = Tensor::Map<_FP16>(
-            output.getAddress<_FP16>(batch, 0, output_height_offset + count, 0),
-            data_copy_size * sizeof(_FP16),
-            {1, 1, 1, data_copy_size, tensor_type});
-          const Tensor source_tensor =
-            Tensor::Map<_FP16>(input.getAddress<_FP16>(batch, 0, count, 0),
-                               data_copy_size * sizeof(_FP16),
-                               {1, 1, 1, data_copy_size, tensor_type});
-          dest_tensor.copy(source_tensor);
-        }
-      }
-#else
-      throw std::invalid_argument("Error: enable-fp16 is not enabled");
-#endif
+  // Store the original output tensor dimension, then reshape the output tensor.
+  Tensor &output = context.getOutput(SINGLE_INOUT_IDX);
+  const TensorDim original_output_dim = output.getDim();
+  output.reshape(output_reshape_helper);
+
+  // Search for an axis and concatenate tensors.
+  const TensorDim out_dim = output.getDim();
+  const TensorDim in_dim = context.getInput(0).getDim();
+
+  for (int axis = 0; axis < 4; ++axis) {
+    if (out_dim[axis] != in_dim[axis]) {
+      /// @todo Currently a new output tensor is created. This can be optimized.
+      Tensor result = Tensor::cat(input_tensors, axis);
+      output.copy(result);
+      break;
     }
+  }
 
-    input.reshape(in_dim);
-    output_height_offset += irh.height();
+  // Revert the tensors' dimensions back to their original shape.
+  for (unsigned int idx = 0; idx < context.getNumInputs(); idx++) {
+    Tensor &in = context.getInput(idx);
+    in.reshape(original_input_dims[idx]);
   }
 
-  output.reshape(out_dim);
+  output.reshape(original_output_dim);
 }
 
 void ConcatLayer::incremental_forwarding(RunLayerContext &context,
@@ -229,7 +209,7 @@ void ConcatLayer::calcDerivative(RunLayerContext &context) {
   unsigned int data_copy_size = output_reshape_helper.width();
   TensorDim::TensorType tensor_type = output.getTensorType();
 
- for (unsigned int idx = 0; idx < context.getNumInputs(); idx++) {
 for (unsigned int idx = 0; idx < context.getNumInputs(); idx++) {
     Tensor &input = context.getOutgoingDerivative(idx);
     const TensorDim in_dim = input.getDim();
     auto const &irh = input_reshape_helper[idx];