[GPU/OpenCL] Initial version of Concat Layer with OpenCL ops
authorNiket Agarwal <niket.a@samsung.com>
Wed, 3 Jul 2024 10:42:38 +0000 (16:12 +0530)
committerJijoong Moon <jijoong.moon@samsung.com>
Wed, 31 Jul 2024 08:32:19 +0000 (17:32 +0900)
Added naive version of OpenCL implementation for Concat Layer.
Incorporated kernel for ops used.
Added unit test for Concat_cl.

Signed-off-by: Niket Agarwal <niket.a@samsung.com>
api/ccapi/include/layer.h
nntrainer/cl_context.cpp
nntrainer/layers/cl_layers/concat_cl.cpp [new file with mode: 0644]
nntrainer/layers/cl_layers/concat_cl.h [new file with mode: 0644]
nntrainer/layers/cl_layers/meson.build
nntrainer/layers/layer_context.cpp
nntrainer/layers/layer_context.h
test/jni/Android.mk
test/unittest/layers/unittest_layers_concat_cl.cpp [new file with mode: 0644]

index 5b09216a5fa32410510716b0109783ff2886153c..e384231e6fa3deab758d47d6e6bb3296156a9232 100644 (file)
@@ -387,8 +387,9 @@ Addition(const std::vector<std::string> &properties = {},
  * @brief Helper function to create concat layer
  */
 inline std::unique_ptr<Layer>
-Concat(const std::vector<std::string> &properties = {}) {
-  return createLayer(LayerType::LAYER_CONCAT, properties);
+Concat(const std::vector<std::string> &properties = {},
+       const LayerComputeEngine &compute_engine = LayerComputeEngine::CPU) {
+  return createLayer(LayerType::LAYER_CONCAT, properties, compute_engine);
 }
 
 /**
index 2ba0a390d39bd98c89ad9ca40bdef48f144606a6..a1288cbdc0946296067aed2d0754885d7aeb09cc 100644 (file)
 
 #include <addition_layer_cl.h>
 #include <cl_context.h>
+#include <concat_cl.h>
 #include <fc_layer_cl.h>
 #include <reshape_cl.h>
-#include <swiglu_cl.h>
 #include <rmsnorm_layer_cl.h>
+#include <swiglu_cl.h>
 
 namespace nntrainer {
 
@@ -45,6 +46,9 @@ static void add_default_object(ClContext &cc) {
 
   cc.registerFactory(nntrainer::createLayer<RMSNormLayerCl>,
                      RMSNormLayerCl::type, ml::train::LayerType::LAYER_RMSNORM);
+
+  cc.registerFactory(nntrainer::createLayer<ConcatLayerCl>, ConcatLayerCl::type,
+                     ml::train::LayerType::LAYER_CONCAT);
 }
 
 static void registerer(ClContext &cc) noexcept {
diff --git a/nntrainer/layers/cl_layers/concat_cl.cpp b/nntrainer/layers/cl_layers/concat_cl.cpp
new file mode 100644 (file)
index 0000000..12c2099
--- /dev/null
@@ -0,0 +1,1073 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024  Niket Agarwal <niket.a@samsung.com>
+ *
+ * @file   concat_cl.cpp
+ * @date   2 July 2024
+ * @brief  Implementation of Concat Layer
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Niket Agarwal <niket.a@samsung.com>
+ * @bug    No known bugs except for NYI items
+ *
+ */
+
+#include <cstring>
+#include <vector>
+
+#include <concat_cl.h>
+#include <iostream>
+#include <layer_context.h>
+#include <nntr_threads.h>
+#include <nntrainer_error.h>
+#include <nntrainer_log.h>
+#include <node_exporter.h>
+#include <tensor_dim.h>
+#include <util_func.h>
+
+std::string concat_cl_axis3_kernel_fp16_ =
+  R"(
+    #pragma OPENCL EXTENSION cl_khr_fp16 : enable
+    __kernel void concat_cl_axis3_fp16(__global const half* in1, 
+                                           __global const half* in2, 
+                                           __global half* out,
+                                           const int batch_size, 
+                                           const int channels, 
+                                           const int height, 
+                                           const int width1, 
+                                           const int width2) {
+    int global_id = get_global_id(0);
+    
+    int total_width = width1 + width2;
+    
+    int width = total_width;
+
+    // 4D space coordinates
+    int w = global_id % total_width;
+    int h = (global_id / total_width) % height;
+    int c = (global_id / (total_width * height)) % channels;
+    int b = global_id / (total_width * height * channels);
+
+    int output_index = ((b * channels + c) * height + h) * total_width + w;
+    
+    // Determining if the index is in in1 or in2
+    if (w < width1) {
+        // in1 index calculation
+        int input1_index = ((b * channels + c) * height + h) * width1 + w;
+        out[output_index] = in1[input1_index];
+  
+    } else {
+        // in2 index calculation
+        int input2_index = ((b * channels + c) * height + h) * width2 + (w - width1);
+        out[output_index] = in2[input2_index];
+    }
+})";
+
+std::string concat_cl_axis3_kernel_ =
+  R"(__kernel void concat_cl_axis3(__global const float* in1, 
+                                           __global const float* in2, 
+                                           __global float* out,
+                                           const int batch_size, 
+                                           const int channels, 
+                                           const int height, 
+                                           const int width1, 
+                                           const int width2) {
+    int global_id = get_global_id(0);
+    
+    int total_width = width1 + width2;
+    
+    int width = total_width;
+
+    // 4D space coordinates
+    int w = global_id % total_width;
+    int h = (global_id / total_width) % height;
+    int c = (global_id / (total_width * height)) % channels;
+    int b = global_id / (total_width * height * channels);
+
+    int output_index = ((b * channels + c) * height + h) * total_width + w;
+    
+    // Determining if the index is in in1 or in2
+    if (w < width1) {
+        // in1 index calculation
+        int input1_index = ((b * channels + c) * height + h) * width1 + w;
+        out[output_index] = in1[input1_index];
+  
+    } else {
+        // in2 index calculation
+        int input2_index = ((b * channels + c) * height + h) * width2 + (w - width1);
+        out[output_index] = in2[input2_index];
+    }
+})";
+
+std::string concat_cl_axis2_kernel_fp16_ =
+  R"(__kernel void concat_cl_axis2_fp16(__global const half* in1,
+                             __global const half* in2,
+                             __global half* out,
+                             const int batch_size,
+                             const int channels,
+                             const int height1,
+                             const int height2,
+                             const int width) {
+    
+    int total_height = height1 + height2;
+    int global_id = get_global_id(0);
+    
+    // Calculate the coordinates in the 4D space
+    int w = global_id % width;
+    int h = (global_id / width) % total_height;
+    int c = (global_id / (width * total_height)) % channels;
+    int b = global_id / (width * total_height * channels);
+
+    // Calculate the offset for the current batch, channel, and width in the output tensor
+    int output_index = ((b * channels + c) * total_height + h) * width + w;
+
+    if (h < height1) {
+        // Index within input1
+        int input1_index = ((b * channels + c) * height1 + h) * width + w;
+        out[output_index] = in1[input1_index];
+    } else {
+        // Index within input2
+        int input2_index = ((b * channels + c) * height2 + (h - height1)) * width + w;
+        out[output_index] = in2[input2_index];
+    }
+
+})";
+
+std::string concat_cl_axis2_kernel_ =
+  R"(__kernel void concat_cl_axis2(__global const float* in1,
+                             __global const float* in2,
+                             __global float* out,
+                             const int batch_size,
+                             const int channels,
+                             const int height1,
+                             const int height2,
+                             const int width) {
+    
+    int total_height = height1 + height2;
+    int global_id = get_global_id(0);
+    
+    // Calculate the coordinates in the 4D space
+    int w = global_id % width;
+    int h = (global_id / width) % total_height;
+    int c = (global_id / (width * total_height)) % channels;
+    int b = global_id / (width * total_height * channels);
+
+    // Calculate the offset for the current batch, channel, and width in the output tensor
+    int output_index = ((b * channels + c) * total_height + h) * width + w;
+
+    if (h < height1) {
+        // Index within input1
+        int input1_index = ((b * channels + c) * height1 + h) * width + w;
+        out[output_index] = in1[input1_index];
+    } else {
+        // Index within input2
+        int input2_index = ((b * channels + c) * height2 + (h - height1)) * width + w;
+        out[output_index] = in2[input2_index];
+    }
+
+})";
+
+std::string concat_cl_axis1_kernel_fp16_ =
+  R"(__kernel void concat_cl_axis1_fp16(__global const half* in1, 
+                                           __global const half* in2, 
+                                           __global half* out,
+                                           const int batch_size, 
+                                           const int channels1, 
+                                           const int channels2, 
+                                           const int height, 
+                                           const int width) {
+    int global_id = get_global_id(0);
+    
+    int total_channels = channels1 + channels2;
+
+    // Calculate the coordinates in the 4D space
+    int w = global_id % width;
+    int h = (global_id / width) % height;
+    int c = (global_id / (width * height)) % total_channels;
+    int b = global_id / (width * height * total_channels);
+
+    // Calculate the offset for the current batch, height, and width in the output tensor
+    int output_index = ((b * total_channels + c) * height + h) * width + w;
+
+    if (c < channels1) {
+        // Index within input1
+        int input1_index = ((b * channels1 + c) * height + h) * width + w;
+        out[output_index] = in1[input1_index];
+    } else {
+        // Index within input2
+        int input2_index = ((b * channels2 + (c - channels1)) * height + h) * width + w;
+        out[output_index] = in2[input2_index];
+    }
+})";
+
+std::string concat_cl_axis1_kernel_ =
+  R"(__kernel void concat_cl_axis1(__global const float* in1, 
+                                           __global const float* in2, 
+                                           __global float* out,
+                                           const int batch_size, 
+                                           const int channels1, 
+                                           const int channels2, 
+                                           const int height, 
+                                           const int width) {
+    int global_id = get_global_id(0);
+    
+    int total_channels = channels1 + channels2;
+
+    // Calculate the coordinates in the 4D space
+    int w = global_id % width;
+    int h = (global_id / width) % height;
+    int c = (global_id / (width * height)) % total_channels;
+    int b = global_id / (width * height * total_channels);
+
+    // Calculate the offset for the current batch, height, and width in the output tensor
+    int output_index = ((b * total_channels + c) * height + h) * width + w;
+
+    if (c < channels1) {
+        // Index within input1
+        int input1_index = ((b * channels1 + c) * height + h) * width + w;
+        out[output_index] = in1[input1_index];
+    } else {
+        // Index within input2
+        int input2_index = ((b * channels2 + (c - channels1)) * height + h) * width + w;
+        out[output_index] = in2[input2_index];
+    }
+})";
+
+namespace nntrainer {
+ConcatLayerCl::ConcatLayerCl() : Layer() {}
+
+static constexpr size_t SINGLE_INOUT_IDX = 0;
+static constexpr size_t INPUT_IDX_1 = 0;
+static constexpr size_t INPUT_IDX_2 = 1;
+
+void ConcatLayerCl::finalize(InitLayerContext &context) {
+  auto &concat_dimension_prop = std::get<props::ConcatDimension>(concat_props);
+  /** for backward compatibility, default concat dimension will be channel */
+  /// @todo this is hacky way to force concat dimension to width if channel
+  /// dimension is taken, this is because recurrent realizer, return sequence
+  /// exploits concat layer but have no control over where to stack/axis
+  unsigned int concat_dimension =
+    context.getInputDimensions().front().channel() > 1 ? 3 : 1;
+  if (!concat_dimension_prop.empty())
+    concat_dimension = concat_dimension_prop.get();
+
+  /**
+   * The concat is only done along the axis dimension.
+   * For example, consider 2 inputs a, b with dimensions [b,c,h,w] each
+   * 1. concat_dimension = 1, output_dim = [b,c_a+c_b,h,w]
+   * 2. concat_dimension = 2, output_dim = [b,c,h_a+h_b,w]
+   * 3. concat_dimension = 3, output_dim = [b,c,h,w_a+w_b]
+   */
+  auto const &input_dims = context.getInputDimensions();
+  const TensorDim &input_dim_0 = input_dims[SINGLE_INOUT_IDX];
+  unsigned int concat_dim_val = input_dim_0.getTensorDim(concat_dimension);
+
+  for (unsigned int idx = 1; idx < input_dims.size(); ++idx) {
+    const TensorDim &dim = input_dims[idx];
+
+    for (unsigned int i = 0; i < ml::train::TensorDim::getNumDim(); ++i) {
+      if (i == concat_dimension)
+        continue;
+      NNTR_THROW_IF(input_dim_0[i] != dim[i], std::runtime_error)
+        << "Error: concat layer requires same shape from all input layers "
+           "along non-concat dimension";
+    }
+    concat_dim_val += dim[concat_dimension];
+  }
+
+  TensorDim output_dim = input_dim_0;
+  output_dim.setTensorDim(concat_dimension, concat_dim_val);
+
+  context.setOutputDimensions({output_dim});
+}
+
+void ConcatLayerCl::forwarding(RunLayerContext &context, bool training) {
+  Tensor &out = context.getOutput(SINGLE_INOUT_IDX);
+  const Tensor &in1 = context.getInput(INPUT_IDX_1);
+  const Tensor &in2 = context.getInput(INPUT_IDX_2);
+  ConcatProcess(in1, in2, out, context);
+}
+
+void ConcatLayerCl::incremental_forwarding(RunLayerContext &context,
+                                           unsigned int from, unsigned int to,
+                                           bool training) {
+  Tensor &out = context.getOutput(SINGLE_INOUT_IDX);
+  const Tensor &in1 = context.getInput(INPUT_IDX_1);
+  const Tensor &in2 = context.getInput(INPUT_IDX_2);
+  if (from) {
+    NNTR_THROW_IF(to - from != 1, std::invalid_argument)
+      << "incremental step size is not 1";
+    from = 0;
+    to = 1;
+  }
+  ConcatProcess(in1, in2, out, context);
+}
+
+opencl::Kernel ConcatLayerCl::kernel_concat_axis3;
+opencl::Kernel ConcatLayerCl::kernel_concat_axis3_fp16;
+opencl::Kernel ConcatLayerCl::kernel_concat_axis2;
+opencl::Kernel ConcatLayerCl::kernel_concat_axis2_fp16;
+opencl::Kernel ConcatLayerCl::kernel_concat_axis1;
+opencl::Kernel ConcatLayerCl::kernel_concat_axis1_fp16;
+
+void ConcatLayerCl::ConcatProcess(Tensor const &in1, Tensor const &in2,
+                                  Tensor &result, RunLayerContext &context) {
+
+  unsigned int input1_batch_size, input1_height, input1_width, input1_channels,
+    input2_batch_size, input2_height, input2_width, input2_channels;
+
+  auto dim1 = in1.getDim();
+  auto dim2 = in2.getDim();
+  input1_batch_size = dim1.batch();
+  input1_height = dim1.height();
+  input1_channels = dim1.channel();
+  input1_width = dim1.width();
+  input2_batch_size = dim2.batch();
+  input2_height = dim2.height();
+  input2_channels = dim2.channel();
+  input2_width = dim2.width();
+
+  if (in1.getDataType() == ml::train::TensorDim::DataType::FP32) {
+    const float *data1 = in1.getData();
+    const float *data2 = in2.getData();
+    float *rdata = result.getData();
+    if (input1_width != input2_width) {
+      concat_cl_axis3(data1, data2, rdata, input1_batch_size, input1_channels,
+                      input1_height, input1_width, input2_width, context);
+    } else if (input1_height != input2_height) {
+      concat_cl_axis2(data1, data2, rdata, input1_batch_size, input1_channels,
+                      input1_width, input1_height, input2_height, context);
+    } else if (input1_channels != input2_channels) {
+      concat_cl_axis1(data1, data2, rdata, input1_batch_size, input1_height,
+                      input1_width, input1_channels, input2_channels, context);
+    }
+  } else if (in1.getDataType() == ml::train::TensorDim::DataType::FP16) {
+#ifdef ENABLE_FP16
+    const _FP16 *data1 = in1.getData<_FP16>();
+    const _FP16 *data2 = in2.getData<_FP16>();
+    _FP16 *rdata = result.getData<_FP16>();
+    if (input1_width != input2_width) {
+      concat_cl_axis3_fp16(data1, data2, rdata, input1_batch_size,
+                           input1_channels, input1_height, input1_width,
+                           input2_width, context);
+    } else if (input1_height != input2_height) {
+      concat_cl_axis2_fp16(data1, data2, rdata, input1_batch_size,
+                           input1_channels, input1_width, input1_height,
+                           input2_height, context);
+    } else if (input1_channels != input2_channels) {
+      concat_cl_axis1_fp16(data1, data2, rdata, input1_batch_size,
+                           input1_height, input1_width, input1_channels,
+                           input2_channels, context);
+    }
+#else
+    throw std::invalid_argument("Error: enable-fp16 is not enabled");
+#endif
+  }
+}
+
+void ConcatLayerCl::concat_cl_axis3(
+  const float *matAdata, const float *vecXdata, float *vecYdata,
+  unsigned int input1_batch_size, unsigned int input1_channels,
+  unsigned int input1_height, unsigned int input1_width,
+  unsigned int input2_width, RunLayerContext &context) {
+
+  bool result = false;
+
+  do {
+    result = context.clCreateKernel(concat_cl_axis3_kernel_,
+                                    context.LayerKernel::CONCAT_AXIS3,
+                                    ConcatLayerCl::kernel_concat_axis3);
+    if (!result) {
+      break;
+    }
+
+    int dim = int(input1_batch_size * input1_channels * input1_height *
+                  (input1_width + input2_width));
+
+    opencl::Buffer inputA(context.context_inst_,
+                          sizeof(float) * input1_batch_size * input1_channels *
+                            input1_height * input1_width,
+                          true, nullptr);
+
+    opencl::Buffer inputX(context.context_inst_,
+                          sizeof(float) * input1_batch_size * input1_channels *
+                            input1_height * input2_width,
+                          true, nullptr);
+
+    opencl::Buffer inOutY(context.context_inst_,
+                          sizeof(float) * input1_batch_size * input1_channels *
+                            input1_height * (input1_width + input2_width),
+                          true, nullptr);
+
+    result = inputA.WriteData(context.command_queue_inst_, matAdata);
+    if (!result) {
+      break;
+    }
+
+    result = inputX.WriteData(context.command_queue_inst_, vecXdata);
+    if (!result) {
+      break;
+    }
+
+    result = inOutY.WriteData(context.command_queue_inst_, vecYdata);
+    if (!result) {
+      break;
+    }
+
+    result = ConcatLayerCl::kernel_concat_axis3.SetKernelArguments(
+      0, &inputA, sizeof(cl_mem));
+    if (!result) {
+      break;
+    }
+
+    result = ConcatLayerCl::kernel_concat_axis3.SetKernelArguments(
+      1, &inputX, sizeof(cl_mem));
+    if (!result) {
+      break;
+    }
+
+    result = ConcatLayerCl::kernel_concat_axis3.SetKernelArguments(
+      2, &inOutY, sizeof(cl_mem));
+    if (!result) {
+      break;
+    }
+
+    result = ConcatLayerCl::kernel_concat_axis3.SetKernelArguments(
+      3, &input1_batch_size, sizeof(int));
+    if (!result) {
+      break;
+    }
+
+    result = ConcatLayerCl::kernel_concat_axis3.SetKernelArguments(
+      4, &input1_channels, sizeof(int));
+    if (!result) {
+      break;
+    }
+
+    result = ConcatLayerCl::kernel_concat_axis3.SetKernelArguments(
+      5, &input1_height, sizeof(int));
+    if (!result) {
+      break;
+    }
+
+    result = ConcatLayerCl::kernel_concat_axis3.SetKernelArguments(
+      6, &input1_width, sizeof(int));
+    if (!result) {
+      break;
+    }
+
+    result = ConcatLayerCl::kernel_concat_axis3.SetKernelArguments(
+      7, &input2_width, sizeof(int));
+    if (!result) {
+      break;
+    }
+
+    const int work_groups_count[3] = {dim, 1, 1};
+    const int work_group_size[3] = {32, 32, 1}; // test-value
+
+    result = context.command_queue_inst_.DispatchCommand(
+      ConcatLayerCl::kernel_concat_axis3, work_groups_count, work_group_size);
+    if (!result) {
+      break;
+    }
+
+    result = inOutY.ReadData(context.command_queue_inst_, vecYdata);
+    if (!result) {
+      break;
+    }
+
+  } while (false);
+}
+
+void ConcatLayerCl::concat_cl_axis3_fp16(
+  const __fp16 *matAdata, const __fp16 *vecXdata, __fp16 *vecYdata,
+  unsigned int input1_batch_size, unsigned int input1_channels,
+  unsigned int input1_height, unsigned int input1_width,
+  unsigned int input2_width, RunLayerContext &context) {
+
+  bool result = false;
+
+  do {
+    result = context.clCreateKernel(concat_cl_axis3_kernel_fp16_,
+                                    context.LayerKernel::CONCAT_AXIS3_FP16,
+                                    ConcatLayerCl::kernel_concat_axis3_fp16);
+    if (!result) {
+      break;
+    }
+
+    int dim = int(input1_batch_size * input1_channels * input1_height *
+                  (input1_width + input2_width));
+
+    opencl::Buffer inputA(context.context_inst_,
+                          sizeof(__fp16) * input1_batch_size * input1_channels *
+                            input1_height * input1_width,
+                          true, nullptr);
+
+    opencl::Buffer inputX(context.context_inst_,
+                          sizeof(__fp16) * input1_batch_size * input1_channels *
+                            input1_height * input2_width,
+                          true, nullptr);
+
+    opencl::Buffer inOutY(context.context_inst_,
+                          sizeof(__fp16) * input1_batch_size * input1_channels *
+                            input1_height * (input1_width + input2_width),
+                          true, nullptr);
+
+    result = inputA.WriteData(context.command_queue_inst_, matAdata);
+    if (!result) {
+      break;
+    }
+
+    result = inputX.WriteData(context.command_queue_inst_, vecXdata);
+    if (!result) {
+      break;
+    }
+
+    result = inOutY.WriteData(context.command_queue_inst_, vecYdata);
+    if (!result) {
+      break;
+    }
+
+    result = ConcatLayerCl::kernel_concat_axis3_fp16.SetKernelArguments(
+      0, &inputA, sizeof(cl_mem));
+    if (!result) {
+      break;
+    }
+
+    result = ConcatLayerCl::kernel_concat_axis3_fp16.SetKernelArguments(
+      1, &inputX, sizeof(cl_mem));
+    if (!result) {
+      break;
+    }
+
+    result = ConcatLayerCl::kernel_concat_axis3_fp16.SetKernelArguments(
+      2, &inOutY, sizeof(cl_mem));
+    if (!result) {
+      break;
+    }
+
+    result = ConcatLayerCl::kernel_concat_axis3_fp16.SetKernelArguments(
+      3, &input1_batch_size, sizeof(int));
+    if (!result) {
+      break;
+    }
+
+    result = ConcatLayerCl::kernel_concat_axis3_fp16.SetKernelArguments(
+      4, &input1_channels, sizeof(int));
+    if (!result) {
+      break;
+    }
+
+    result = ConcatLayerCl::kernel_concat_axis3_fp16.SetKernelArguments(
+      5, &input1_height, sizeof(int));
+    if (!result) {
+      break;
+    }
+
+    result = ConcatLayerCl::kernel_concat_axis3_fp16.SetKernelArguments(
+      6, &input1_width, sizeof(int));
+    if (!result) {
+      break;
+    }
+
+    result = ConcatLayerCl::kernel_concat_axis3_fp16.SetKernelArguments(
+      7, &input2_width, sizeof(int));
+    if (!result) {
+      break;
+    }
+
+    const int work_groups_count[3] = {dim, 1, 1};
+    const int work_group_size[3] = {32, 32, 1}; // test-value
+
+    result = context.command_queue_inst_.DispatchCommand(
+      ConcatLayerCl::kernel_concat_axis3_fp16, work_groups_count,
+      work_group_size);
+    if (!result) {
+      break;
+    }
+
+    result = inOutY.ReadData(context.command_queue_inst_, vecYdata);
+    if (!result) {
+      break;
+    }
+
+  } while (false);
+}
+
+void ConcatLayerCl::concat_cl_axis2(
+  const float *matAdata, const float *vecXdata, float *vecYdata,
+  unsigned int input1_batch_size, unsigned int input1_channels,
+  unsigned int input1_width, unsigned int input1_height,
+  unsigned int input2_height, RunLayerContext &context) {
+
+  bool result = false;
+
+  do {
+    result = context.clCreateKernel(concat_cl_axis2_kernel_,
+                                    context.LayerKernel::CONCAT_AXIS2,
+                                    ConcatLayerCl::kernel_concat_axis2);
+    if (!result) {
+      break;
+    }
+
+    int dim = int(input1_batch_size * input1_channels * input1_width *
+                  (input1_height + input2_height));
+
+    opencl::Buffer inputA(context.context_inst_,
+                          sizeof(float) * input1_batch_size * input1_channels *
+                            input1_height * input1_width,
+                          true, nullptr);
+
+    opencl::Buffer inputX(context.context_inst_,
+                          sizeof(float) * input1_batch_size * input1_channels *
+                            input2_height * input1_width,
+                          true, nullptr);
+
+    opencl::Buffer inOutY(context.context_inst_,
+                          sizeof(float) * input1_batch_size * input1_channels *
+                            (input1_height + input2_height) * input1_width,
+                          true, nullptr);
+
+    result = inputA.WriteData(context.command_queue_inst_, matAdata);
+    if (!result) {
+      break;
+    }
+
+    result = inputX.WriteData(context.command_queue_inst_, vecXdata);
+    if (!result) {
+      break;
+    }
+
+    result = inOutY.WriteData(context.command_queue_inst_, vecYdata);
+    if (!result) {
+      break;
+    }
+
+    result = ConcatLayerCl::kernel_concat_axis2.SetKernelArguments(
+      0, &inputA, sizeof(cl_mem));
+    if (!result) {
+      break;
+    }
+
+    result = ConcatLayerCl::kernel_concat_axis2.SetKernelArguments(
+      1, &inputX, sizeof(cl_mem));
+    if (!result) {
+      break;
+    }
+
+    result = ConcatLayerCl::kernel_concat_axis2.SetKernelArguments(
+      2, &inOutY, sizeof(cl_mem));
+    if (!result) {
+      break;
+    }
+
+    result = ConcatLayerCl::kernel_concat_axis2.SetKernelArguments(
+      3, &input1_batch_size, sizeof(int));
+    if (!result) {
+      break;
+    }
+
+    result = ConcatLayerCl::kernel_concat_axis2.SetKernelArguments(
+      4, &input1_channels, sizeof(int));
+    if (!result) {
+      break;
+    }
+
+    result = ConcatLayerCl::kernel_concat_axis2.SetKernelArguments(
+      5, &input1_height, sizeof(int));
+    if (!result) {
+      break;
+    }
+
+    result = ConcatLayerCl::kernel_concat_axis2.SetKernelArguments(
+      6, &input2_height, sizeof(int));
+    if (!result) {
+      break;
+    }
+
+    result = ConcatLayerCl::kernel_concat_axis2.SetKernelArguments(
+      7, &input1_width, sizeof(int));
+    if (!result) {
+      break;
+    }
+
+    const int work_groups_count[3] = {dim, 1, 1};
+    const int work_group_size[3] = {32, 32, 1}; // test-value
+
+    result = context.command_queue_inst_.DispatchCommand(
+      ConcatLayerCl::kernel_concat_axis2, work_groups_count, work_group_size);
+    if (!result) {
+      break;
+    }
+
+    result = inOutY.ReadData(context.command_queue_inst_, vecYdata);
+    if (!result) {
+      break;
+    }
+
+  } while (false);
+}
+
+void ConcatLayerCl::concat_cl_axis2_fp16(
+  const __fp16 *matAdata, const __fp16 *vecXdata, __fp16 *vecYdata,
+  unsigned int input1_batch_size, unsigned int input1_channels,
+  unsigned int input1_width, unsigned int input1_height,
+  unsigned int input2_height, RunLayerContext &context) {
+
+  bool result = false;
+
+  do {
+    result = context.clCreateKernel(concat_cl_axis2_kernel_fp16_,
+                                    context.LayerKernel::CONCAT_AXIS2_FP16,
+                                    ConcatLayerCl::kernel_concat_axis2_fp16);
+    if (!result) {
+      break;
+    }
+
+    int dim = int(input1_batch_size * input1_channels * input1_width *
+                  (input1_height + input2_height));
+
+    opencl::Buffer inputA(context.context_inst_,
+                          sizeof(__fp16) * input1_batch_size * input1_channels *
+                            input1_height * input1_width,
+                          true, nullptr);
+
+    opencl::Buffer inputX(context.context_inst_,
+                          sizeof(__fp16) * input1_batch_size * input1_channels *
+                            input2_height * input1_width,
+                          true, nullptr);
+
+    opencl::Buffer inOutY(context.context_inst_,
+                          sizeof(__fp16) * input1_batch_size * input1_channels *
+                            (input1_height + input2_height) * input1_width,
+                          true, nullptr);
+
+    result = inputA.WriteData(context.command_queue_inst_, matAdata);
+    if (!result) {
+      break;
+    }
+
+    result = inputX.WriteData(context.command_queue_inst_, vecXdata);
+    if (!result) {
+      break;
+    }
+
+    result = inOutY.WriteData(context.command_queue_inst_, vecYdata);
+    if (!result) {
+      break;
+    }
+
+    result = ConcatLayerCl::kernel_concat_axis2_fp16.SetKernelArguments(
+      0, &inputA, sizeof(cl_mem));
+    if (!result) {
+      break;
+    }
+
+    result = ConcatLayerCl::kernel_concat_axis2_fp16.SetKernelArguments(
+      1, &inputX, sizeof(cl_mem));
+    if (!result) {
+      break;
+    }
+
+    result = ConcatLayerCl::kernel_concat_axis2_fp16.SetKernelArguments(
+      2, &inOutY, sizeof(cl_mem));
+    if (!result) {
+      break;
+    }
+
+    result = ConcatLayerCl::kernel_concat_axis2_fp16.SetKernelArguments(
+      3, &input1_batch_size, sizeof(int));
+    if (!result) {
+      break;
+    }
+
+    result = ConcatLayerCl::kernel_concat_axis2_fp16.SetKernelArguments(
+      4, &input1_channels, sizeof(int));
+    if (!result) {
+      break;
+    }
+
+    result = ConcatLayerCl::kernel_concat_axis2_fp16.SetKernelArguments(
+      5, &input1_height, sizeof(int));
+    if (!result) {
+      break;
+    }
+
+    result = ConcatLayerCl::kernel_concat_axis2_fp16.SetKernelArguments(
+      6, &input2_height, sizeof(int));
+    if (!result) {
+      break;
+    }
+
+    result = ConcatLayerCl::kernel_concat_axis2_fp16.SetKernelArguments(
+      7, &input1_width, sizeof(int));
+    if (!result) {
+      break;
+    }
+
+    const int work_groups_count[3] = {dim, 1, 1};
+    const int work_group_size[3] = {32, 32, 1}; // test-value
+
+    result = context.command_queue_inst_.DispatchCommand(
+      ConcatLayerCl::kernel_concat_axis2_fp16, work_groups_count,
+      work_group_size);
+    if (!result) {
+      break;
+    }
+
+    result = inOutY.ReadData(context.command_queue_inst_, vecYdata);
+    if (!result) {
+      break;
+    }
+
+  } while (false);
+}
+
+void ConcatLayerCl::concat_cl_axis1(
+  const float *matAdata, const float *vecXdata, float *vecYdata,
+  unsigned int input1_batch_size, unsigned int input1_height,
+  unsigned int input1_width, unsigned int input1_channels,
+  unsigned int input2_channels, RunLayerContext &context) {
+
+  bool result = false;
+
+  do {
+    result = context.clCreateKernel(concat_cl_axis1_kernel_,
+                                    context.LayerKernel::CONCAT_AXIS1,
+                                    ConcatLayerCl::kernel_concat_axis1);
+    if (!result) {
+      break;
+    }
+
+    int dim = int(input1_batch_size * input1_width * input1_height *
+                  (input1_channels + input2_channels));
+
+    opencl::Buffer inputA(context.context_inst_,
+                          sizeof(float) * input1_batch_size * input1_channels *
+                            input1_height * input1_width,
+                          true, nullptr);
+
+    opencl::Buffer inputX(context.context_inst_,
+                          sizeof(float) * input1_batch_size * input2_channels *
+                            input1_height * input1_width,
+                          true, nullptr);
+
+    opencl::Buffer inOutY(context.context_inst_,
+                          sizeof(float) * input1_batch_size * input1_width *
+                            input1_height * (input1_channels + input2_channels),
+                          true, nullptr);
+
+    result = inputA.WriteData(context.command_queue_inst_, matAdata);
+    if (!result) {
+      break;
+    }
+
+    result = inputX.WriteData(context.command_queue_inst_, vecXdata);
+    if (!result) {
+      break;
+    }
+
+    result = inOutY.WriteData(context.command_queue_inst_, vecYdata);
+    if (!result) {
+      break;
+    }
+
+    result = ConcatLayerCl::kernel_concat_axis1.SetKernelArguments(
+      0, &inputA, sizeof(cl_mem));
+    if (!result) {
+      break;
+    }
+
+    result = ConcatLayerCl::kernel_concat_axis1.SetKernelArguments(
+      1, &inputX, sizeof(cl_mem));
+    if (!result) {
+      break;
+    }
+
+    result = ConcatLayerCl::kernel_concat_axis1.SetKernelArguments(
+      2, &inOutY, sizeof(cl_mem));
+    if (!result) {
+      break;
+    }
+
+    result = ConcatLayerCl::kernel_concat_axis1.SetKernelArguments(
+      3, &input1_batch_size, sizeof(int));
+    if (!result) {
+      break;
+    }
+
+    result = ConcatLayerCl::kernel_concat_axis1.SetKernelArguments(
+      4, &input1_channels, sizeof(int));
+    if (!result) {
+      break;
+    }
+
+    result = ConcatLayerCl::kernel_concat_axis1.SetKernelArguments(
+      5, &input2_channels, sizeof(int));
+    if (!result) {
+      break;
+    }
+
+    result = ConcatLayerCl::kernel_concat_axis1.SetKernelArguments(
+      6, &input1_height, sizeof(int));
+    if (!result) {
+      break;
+    }
+
+    result = ConcatLayerCl::kernel_concat_axis1.SetKernelArguments(
+      7, &input1_width, sizeof(int));
+    if (!result) {
+      break;
+    }
+
+    const int work_groups_count[3] = {dim, 1, 1};
+    const int work_group_size[3] = {32, 32, 1}; // test-value
+
+    result = context.command_queue_inst_.DispatchCommand(
+      ConcatLayerCl::kernel_concat_axis1, work_groups_count, work_group_size);
+    if (!result) {
+      break;
+    }
+
+    result = inOutY.ReadData(context.command_queue_inst_, vecYdata);
+    if (!result) {
+      break;
+    }
+
+  } while (false);
+}
+
+void ConcatLayerCl::concat_cl_axis1_fp16(
+  const __fp16 *matAdata, const __fp16 *vecXdata, __fp16 *vecYdata,
+  unsigned int input1_batch_size, unsigned int input1_height,
+  unsigned int input1_width, unsigned int input1_channels,
+  unsigned int input2_channels, RunLayerContext &context) {
+
+  bool result = false;
+
+  do {
+    result = context.clCreateKernel(concat_cl_axis1_kernel_fp16_,
+                                    context.LayerKernel::CONCAT_AXIS1_FP16,
+                                    ConcatLayerCl::kernel_concat_axis1_fp16);
+    if (!result) {
+      break;
+    }
+
+    int dim = int(input1_batch_size * input1_width * input1_height *
+                  (input1_channels + input2_channels));
+
+    opencl::Buffer inputA(context.context_inst_,
+                          sizeof(__fp16) * input1_batch_size * input1_channels *
+                            input1_height * input1_width,
+                          true, nullptr);
+
+    opencl::Buffer inputX(context.context_inst_,
+                          sizeof(__fp16) * input1_batch_size * input2_channels *
+                            input1_height * input1_width,
+                          true, nullptr);
+
+    opencl::Buffer inOutY(context.context_inst_,
+                          sizeof(__fp16) * input1_batch_size * input1_width *
+                            input1_height * (input1_channels + input2_channels),
+                          true, nullptr);
+
+    result = inputA.WriteData(context.command_queue_inst_, matAdata);
+    if (!result) {
+      break;
+    }
+
+    result = inputX.WriteData(context.command_queue_inst_, vecXdata);
+    if (!result) {
+      break;
+    }
+
+    result = inOutY.WriteData(context.command_queue_inst_, vecYdata);
+    if (!result) {
+      break;
+    }
+
+    result = ConcatLayerCl::kernel_concat_axis1_fp16.SetKernelArguments(
+      0, &inputA, sizeof(cl_mem));
+    if (!result) {
+      break;
+    }
+
+    result = ConcatLayerCl::kernel_concat_axis1_fp16.SetKernelArguments(
+      1, &inputX, sizeof(cl_mem));
+    if (!result) {
+      break;
+    }
+
+    result = ConcatLayerCl::kernel_concat_axis1_fp16.SetKernelArguments(
+      2, &inOutY, sizeof(cl_mem));
+    if (!result) {
+      break;
+    }
+
+    result = ConcatLayerCl::kernel_concat_axis1_fp16.SetKernelArguments(
+      3, &input1_batch_size, sizeof(int));
+    if (!result) {
+      break;
+    }
+
+    result = ConcatLayerCl::kernel_concat_axis1_fp16.SetKernelArguments(
+      4, &input1_channels, sizeof(int));
+    if (!result) {
+      break;
+    }
+
+    result = ConcatLayerCl::kernel_concat_axis1_fp16.SetKernelArguments(
+      5, &input2_channels, sizeof(int));
+    if (!result) {
+      break;
+    }
+
+    result = ConcatLayerCl::kernel_concat_axis1_fp16.SetKernelArguments(
+      6, &input1_height, sizeof(int));
+    if (!result) {
+      break;
+    }
+
+    result = ConcatLayerCl::kernel_concat_axis1_fp16.SetKernelArguments(
+      7, &input1_width, sizeof(int));
+    if (!result) {
+      break;
+    }
+
+    const int work_groups_count[3] = {dim, 1, 1};
+    const int work_group_size[3] = {32, 32, 1}; // test-value
+
+    result = context.command_queue_inst_.DispatchCommand(
+      ConcatLayerCl::kernel_concat_axis1_fp16, work_groups_count,
+      work_group_size);
+    if (!result) {
+      break;
+    }
+
+    result = inOutY.ReadData(context.command_queue_inst_, vecYdata);
+    if (!result) {
+      break;
+    }
+
+  } while (false);
+}
+
+void ConcatLayerCl::calcDerivative(RunLayerContext &context) {
+  //   /**
+  //    * @todo skipping calcDerivative, support yet to be added
+  //    */
+}
+
+void ConcatLayerCl::setProperty(const std::vector<std::string> &values) {
+  auto remain_props = loadProperties(values, concat_props);
+  NNTR_THROW_IF(!remain_props.empty(), std::invalid_argument)
+    << "[ConcatLayer] Unknown Layer Properties count " +
+         std::to_string(values.size());
+}
+
+void ConcatLayerCl::exportTo(Exporter &exporter,
+                             const ml::train::ExportMethods &method) const {
+  Layer::exportTo(exporter, method);
+  exporter.saveResult(concat_props, method, this);
+}
+
+} /* namespace nntrainer */
diff --git a/nntrainer/layers/cl_layers/concat_cl.h b/nntrainer/layers/cl_layers/concat_cl.h
new file mode 100644 (file)
index 0000000..f0933d3
--- /dev/null
@@ -0,0 +1,248 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024  Niket Agarwal <niket.a@samsung.com>
+ *
+ * @file   concat_cl.h
+ * @date   2 July 2024
+ * @brief  Implementation of Concat Layer
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Niket Agarwal <niket.a@samsung.com>
+ * @bug    No known bugs except for NYI items
+ *
+ */
+
+#ifndef __CONCAT_LAYER_CL_H__
+#define __CONCAT_LAYER_CL_H__
+#ifdef __cplusplus
+
+#include <common_properties.h>
+#include <layer_context.h>
+#include <layer_devel.h>
+#include <layer_impl.h>
+#include <opencl_buffer.h>
+#include <opencl_kernel.h>
+#include <tensor_dim.h>
+#include <utility>
+
+namespace nntrainer {
+
+/**
+ * @class   Concat Layer
+ * @brief   Concat Layer
+ */
+class ConcatLayerCl : public Layer {
+public:
+  /**
+   * @brief     Constructor of Concat Layer
+   */
+  ConcatLayerCl();
+
+  /**
+   * @brief     Destructor of Concat Layer
+   */
+  ~ConcatLayerCl() = default;
+
+  /**
+   *  @brief  Move constructor of ConcatLayer.
+   *  @param[in] ConcatLayer &&
+   */
+  ConcatLayerCl(ConcatLayerCl &&rhs) noexcept = default;
+
+  /**
+   * @brief  Move assignment operator.
+   * @parma[in] rhs ConcatLayer to be moved.
+   */
+  ConcatLayerCl &operator=(ConcatLayerCl &&rhs) = default;
+
+  /**
+   * @copydoc Layer::finalize(InitLayerContext &context)
+   */
+  void finalize(InitLayerContext &context) override;
+
+  /**
+   * @copydoc Layer::forwarding(RunLayerContext &context, bool training)
+   */
+  void forwarding(RunLayerContext &context, bool training) override;
+
+  /**
+   * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned
+   * int from, unsigned int to, bool training)
+   */
+  void incremental_forwarding(RunLayerContext &context, unsigned int from,
+                              unsigned int to, bool training) override;
+
+  /**
+   * @copydoc Layer::calcDerivative(RunLayerContext &context)
+   */
+  void calcDerivative(RunLayerContext &context) override;
+
+  /**
+   * @copydoc Layer::getType()
+   */
+  const std::string getType() const override { return ConcatLayerCl::type; };
+
+  /**
+   * @copydoc Layer::exportTo(Exporter &exporter, ml::train::ExportMethods
+   * method)
+   */
+  void exportTo(Exporter &exporter,
+                const ml::train::ExportMethods &method) const override;
+
+  /**
+   * @copydoc Layer::supportBackwarding()
+   */
+  bool supportBackwarding() const override { return false; }
+
+  /**
+   * @copydoc Layer::setProperty(const PropertyType type, const std::string
+   * &value)
+   */
+  void setProperty(const std::vector<std::string> &values) override;
+
+  inline static const std::string type = "concat";
+
+  static opencl::Kernel kernel_concat_axis3;
+  static opencl::Kernel kernel_concat_axis3_fp16;
+  static opencl::Kernel kernel_concat_axis2;
+  static opencl::Kernel kernel_concat_axis2_fp16;
+  static opencl::Kernel kernel_concat_axis1;
+  static opencl::Kernel kernel_concat_axis1_fp16;
+
+  /**
+   * @brief Process data and dimensions for concat
+   * @param[in] input1 Tensor
+   * @param[in] input2 Tensor
+   * @param[in] result Tensor
+   * @param[in] RunLayerContext reference
+   */
+  void ConcatProcess(Tensor const &in1, Tensor const &in2, Tensor &result,
+                     RunLayerContext &context);
+
+  /**
+   * @brief     concat computation for axis 3
+   * @param[in] matAdata float * for Input Tensor A
+   * @param[in] vecXdata float * for Input Tensor X
+   * @param[in] vecYdata float * for Output Tensor Y
+   * @param[in] input1_batch_size  represents the number of samples in the input
+   * tensor
+   * @param[in] input1_channels   represents the channels of the input tensor
+   * @param[in] input1_height   represents the height of the input tensor
+   * @param[in] input1_width   represents the width of the input tensor A
+   * @param[in] input2_width   represents the width of the input tensor X
+   * @param[in] context RunLayerContext reference
+   */
+  void concat_cl_axis3(const float *matAdata, const float *vecXdata,
+                       float *vecYdata, unsigned int input1_batch_size,
+                       unsigned int input1_channels, unsigned int input1_height,
+                       unsigned int input1_width, unsigned int input2_width,
+                       RunLayerContext &context);
+
+  /**
+   * @brief     concat computation for axis 3 fp16
+   * @param[in] matAdata fp16 * for Input Tensor A
+   * @param[in] vecXdata fp16 * for Input Tensor X
+   * @param[in] vecYdata fp16 * for Output Tensor Y
+   * @param[in] input1_batch_size  represents the number of samples in the input
+   * tensor
+   * @param[in] input1_channels   represents the channels of the input tensor
+   * @param[in] input1_height   represents the height of the input tensor
+   * @param[in] input1_width   represents the width of the input tensor A
+   * @param[in] input2_width   represents the width of the input tensor X
+   * @param[in] context RunLayerContext reference
+   */
+  void concat_cl_axis3_fp16(const __fp16 *matAdata, const __fp16 *vecXdata,
+                            __fp16 *vecYdata, unsigned int input1_batch_size,
+                            unsigned int input1_channels,
+                            unsigned int input1_height,
+                            unsigned int input1_width,
+                            unsigned int input2_width,
+                            RunLayerContext &context);
+
+  /**
+   * @brief     concat computation for axis 2
+   * @param[in] matAdata float * for Input Tensor A
+   * @param[in] vecXdata float * for Input Tensor X
+   * @param[in] vecYdata float * for Output Tensor Y
+   * @param[in] input1_batch_size  represents the number of samples in the input
+   * tensor
+   * @param[in] input1_channels   represents the channels of the input tensor
+   * @param[in] input1_width   represents the width of the input tensor
+   * @param[in] input1_height   represents the height of the input tensor A
+   * @param[in] input2_height   represents the height of the input tensor X
+   * @param[in] context RunLayerContext reference
+   */
+  void concat_cl_axis2(const float *matAdata, const float *vecXdata,
+                       float *vecYdata, unsigned int input1_batch_size,
+                       unsigned int input1_channels, unsigned int input1_width,
+                       unsigned int input1_height, unsigned int input2_height,
+                       RunLayerContext &context);
+
+  /**
+   * @brief     concat computation for axis 2 fp16
+   * @param[in] matAdata fp16 * for Input Tensor A
+   * @param[in] vecXdata fp16 * for Input Tensor X
+   * @param[in] vecYdata fp16 * for Output Tensor Y
+   * @param[in] input1_batch_size  represents the number of samples in the input
+   * tensor
+   * @param[in] input1_channels   represents the channels of the input tensor
+   * @param[in] input1_width   represents the width of the input tensor
+   * @param[in] input1_height   represents the height of the input tensor A
+   * @param[in] input2_height   represents the height of the input tensor X
+   * @param[in] context RunLayerContext reference
+   */
+  void concat_cl_axis2_fp16(const __fp16 *matAdata, const __fp16 *vecXdata,
+                            __fp16 *vecYdata, unsigned int input1_batch_size,
+                            unsigned int input1_channels,
+                            unsigned int input1_width,
+                            unsigned int input1_height,
+                            unsigned int input2_height,
+                            RunLayerContext &context);
+
+  /**
+   * @brief     concat computation for axis 1
+   * @param[in] matAdata float * for Input Tensor A
+   * @param[in] vecXdata float * for Input Tensor X
+   * @param[in] vecYdata float * for Output Tensor Y
+   * @param[in] input1_batch_size  represents the number of samples in the input
+   * tensor
+   * @param[in] input1_height   represents the height of the input tensor
+   * @param[in] input1_width   represents the width of the input tensor
+   * @param[in] input1_channels   represents the channels of the input tensor A
+   * @param[in] input2_channels   represents the channels of the input tensor X
+   * @param[in] context RunLayerContext reference
+   */
+  void concat_cl_axis1(const float *matAdata, const float *vecXdata,
+                       float *vecYdata, unsigned int input1_batch_size,
+                       unsigned int input1_height, unsigned int input1_width,
+                       unsigned int input1_channels,
+                       unsigned int input2_channels, RunLayerContext &context);
+
+  /**
+   * @brief     concat computation for axis 1 fp16
+   * @param[in] matAdata fp16 * for Input Tensor A
+   * @param[in] vecXdata fp16 * for Input Tensor X
+   * @param[in] vecYdata fp16 * for Output Tensor Y
+   * @param[in] input1_batch_size  represents the number of samples in the input
+   * tensor
+   * @param[in] input1_height   represents the height of the input tensor
+   * @param[in] input1_width   represents the width of the input tensor
+   * @param[in] input1_channels   represents the channels of the input tensor A
+   * @param[in] input2_channels   represents the channels of the input tensor X
+   * @param[in] context RunLayerContext reference
+   */
+  void concat_cl_axis1_fp16(const __fp16 *matAdata, const __fp16 *vecXdata,
+                            __fp16 *vecYdata, unsigned int input1_batch_size,
+                            unsigned int input1_height,
+                            unsigned int input1_width,
+                            unsigned int input1_channels,
+                            unsigned int input2_channels,
+                            RunLayerContext &context);
+
+private:
+  std::tuple<props::ConcatDimension> concat_props;
+};
+
+} // namespace nntrainer
+
+#endif /* __cplusplus */
+#endif /* __CONCAT_LAYER_CL_H__ */
index 634a2d8fb56b1d04c3444ce5b19fac1770d362d1..e4ca30a880a07b984bbcb6f506063ba241af73c6 100644 (file)
@@ -3,7 +3,8 @@ cl_layer_sources = [
   'addition_layer_cl.cpp',
   'swiglu_cl.cpp',
   'reshape_cl.cpp',
-  'rmsnorm_layer_cl.cpp'
+  'rmsnorm_layer_cl.cpp',
+  'concat_cl.cpp',
 ]
 
 foreach s : cl_layer_sources
index 53951d4f69e352395905ac83053591a68fa5c24c..e114a4bca21acadfa9341856c518cfa42eae70a9 100644 (file)
@@ -723,6 +723,18 @@ std::string RunLayerContext::getKernelName(LayerKernel layerKernel) {
     return "rmsnorm_cl";
   case LayerKernel::RMSNORM_FP16:
     return "rmsnorm_cl_fp16";
+  case LayerKernel::CONCAT_AXIS3:
+    return "concat_cl_axis3";
+  case LayerKernel::CONCAT_AXIS3_FP16:
+    return "concat_cl_axis3_fp16";
+  case LayerKernel::CONCAT_AXIS2:
+    return "concat_cl_axis2";
+  case LayerKernel::CONCAT_AXIS2_FP16:
+    return "concat_cl_axis2_fp16";
+  case LayerKernel::CONCAT_AXIS1:
+    return "concat_cl_axis1";
+  case LayerKernel::CONCAT_AXIS1_FP16:
+    return "concat_cl_axis1_fp16";
   default:
     return "";
   }
index 993e98fd0132f58e5ddd95edef58e3ca78ee42f7..1d781210295aa55ee498bb717a790f952a799585 100644 (file)
@@ -849,8 +849,14 @@ public:
     SSCAL_FP16 = 1 << 17,         /**< placeholder for kernel name */
     COPY = 1 << 18,               /**< placeholder for kernel name */
     COPY_FP16 = 1 << 19,          /**< placeholder for kernel name */
-    RMSNORM = 1 << 20,
-    RMSNORM_FP16 = 1 << 21
+    RMSNORM = 1 << 20,            /**< placeholder for kernel name */
+    RMSNORM_FP16 = 1 << 21,       /**< placeholder for kernel name */
+    CONCAT_AXIS3 = 1 << 22,       /**< placeholder for kernel name */
+    CONCAT_AXIS3_FP16 = 1 << 23,  /**< placeholder for kernel name */
+    CONCAT_AXIS2 = 1 << 24,       /**< placeholder for kernel name */
+    CONCAT_AXIS2_FP16 = 1 << 25,  /**< placeholder for kernel name */
+    CONCAT_AXIS1 = 1 << 26,       /**< placeholder for kernel name */
+    CONCAT_AXIS1_FP16 = 1 << 27,  /**< placeholder for kernel name */
   };
 
   /**
index e0853d90b8932a6c59d46ae7452225fa7224247e..cd5faad571f9460e6a0a50794ee50157feee6591 100644 (file)
@@ -442,6 +442,7 @@ LOCAL_SRC_FILES := \
         ../unittest/layers/unittest_layer_node.cpp \
         ../unittest/layers/unittest_layers.cpp \
         ../unittest/layers/unittest_layers_impl.cpp \
+        ../unittest/layers/unittest_layers_concat_cl.cpp \
         ../unittest/layers/unittest_layers_swiglu_cl.cpp \
         ../unittest/layers/unittest_layers_fully_connected_cl.cpp \
         ../unittest/layers/unittest_layers_input.cpp \
diff --git a/test/unittest/layers/unittest_layers_concat_cl.cpp b/test/unittest/layers/unittest_layers_concat_cl.cpp
new file mode 100644 (file)
index 0000000..9c456a6
--- /dev/null
@@ -0,0 +1,64 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Niket Agarwal <niket.a@samsung.com>
+ *
+ * @file unittest_layers_concat_cl.cpp
+ * @date 2 July 2024
+ * @brief Concat Layer Test
+ * @see        https://github.com/nnstreamer/nntrainer
+ * @author Niket Agarwal <niket.a@samsung.com>
+ * @bug No known bugs except for NYI items
+ */
+#include <tuple>
+
+#include <gtest/gtest.h>
+
+#include <concat_cl.h>
+#include <layers_common_tests.h>
+
+auto semantic_concat_gpu = LayerSemanticsParamType(
+  nntrainer::createLayer<nntrainer::ConcatLayerCl>,
+  nntrainer::ConcatLayerCl::type, {},
+  LayerCreateSetPropertyOptions::AVAILABLE_FROM_APP_CONTEXT, false, 1);
+
+GTEST_PARAMETER_TEST(ConcatGPU, LayerSemanticsGpu,
+                     ::testing::Values(semantic_concat_gpu));
+
+auto concat_dim3 = LayerGoldenTestParamType(
+  nntrainer::createLayer<nntrainer::ConcatLayerCl>, {"axis=3"},
+  "2:3:3:2,2:3:3:3", "concat_dim3.nnlayergolden",
+  LayerGoldenTestParamOptions::SKIP_CALC_DERIV, "nchw", "fp32", "fp32");
+
+auto concat_dim2 = LayerGoldenTestParamType(
+  nntrainer::createLayer<nntrainer::ConcatLayerCl>, {"axis=2"},
+  "2:3:2:3,2:3:3:3", "concat_dim2.nnlayergolden",
+  LayerGoldenTestParamOptions::SKIP_CALC_DERIV, "nchw", "fp32", "fp32");
+
+auto concat_dim1 = LayerGoldenTestParamType(
+  nntrainer::createLayer<nntrainer::ConcatLayerCl>, {"axis=1"},
+  "2:2:3:3,2:3:3:3", "concat_dim1.nnlayergolden",
+  LayerGoldenTestParamOptions::SKIP_CALC_DERIV, "nchw", "fp32", "fp32");
+
+GTEST_PARAMETER_TEST(ConcatGPU, LayerGoldenTest,
+                     ::testing::Values(concat_dim3, concat_dim2, concat_dim1));
+
+#ifdef ENABLE_FP16
+auto concat_dim3_w16a16 = LayerGoldenTestParamType(
+  nntrainer::createLayer<nntrainer::ConcatLayerCl>, {"axis=3"},
+  "2:3:3:2,2:3:3:3", "concat_dim3_w16a16.nnlayergolden",
+  LayerGoldenTestParamOptions::SKIP_CALC_DERIV, "nchw", "fp16", "fp16");
+
+auto concat_dim2_w16a16 = LayerGoldenTestParamType(
+  nntrainer::createLayer<nntrainer::ConcatLayerCl>, {"axis=2"},
+  "2:3:2:3,2:3:3:3", "concat_dim2_w16a16.nnlayergolden",
+  LayerGoldenTestParamOptions::SKIP_CALC_DERIV, "nchw", "fp16", "fp16");
+
+auto concat_dim1_w16a16 = LayerGoldenTestParamType(
+  nntrainer::createLayer<nntrainer::ConcatLayerCl>, {"axis=1"},
+  "2:2:3:3,2:3:3:3", "concat_dim1_w16a16.nnlayergolden",
+  LayerGoldenTestParamOptions::SKIP_CALC_DERIV, "nchw", "fp16", "fp16");
+
+GTEST_PARAMETER_TEST(ConcatGPU16, LayerGoldenTest,
+                     ::testing::Values(concat_dim3_w16a16, concat_dim2_w16a16,
+                                       concat_dim1_w16a16));
+#endif