[GPU/OpenCL] Initial version of Addition Layer with OpenCL ops
authoryash.singh <yash.singh@samsung.com>
Thu, 23 May 2024 10:42:12 +0000 (16:12 +0530)
committerJijoong Moon <jijoong.moon@samsung.com>
Tue, 25 Jun 2024 07:56:59 +0000 (16:56 +0900)
Added naive version of OpenCL implementation for Addition Layer.
Incorporated kernel for ops used.
Added unit test for addition_layer_cl.

Signed-off-by: yash.singh <yash.singh@samsung.com>
api/ccapi/include/layer.h
nntrainer/cl_context.cpp
nntrainer/layers/cl_layers/addition_layer_cl.cpp [new file with mode: 0644]
nntrainer/layers/cl_layers/addition_layer_cl.h [new file with mode: 0644]
nntrainer/layers/cl_layers/meson.build
nntrainer/layers/layer_context.cpp
nntrainer/layers/layer_context.h
test/input_gen/gen_layer_tests.py
test/jni/Android.mk
test/unittest/layers/unittest_layers_addition_cl.cpp [new file with mode: 0644]

index ca0ae19f62ab32613bc88d3b98659f63b68cefe8..7e76134c5ba9021bb154ac648a84e83eb70e4174 100644 (file)
@@ -359,6 +359,17 @@ Addition(const std::vector<std::string> &properties = {}) {
   return createLayer(LayerType::LAYER_ADDITION, properties);
 }
 
+#ifdef ENABLE_OPENCL
+/**
+ * @brief Helper function to create Addition layer for GPU
+ */
+inline std::unique_ptr<Layer>
+AdditionCL(const std::vector<std::string> &properties = {},
+           const LayerComputeEngine &compute_engine = LayerComputeEngine::CPU) {
+  return createLayer(LayerType::LAYER_ADDITION, properties, compute_engine);
+}
+#endif
+
 /**
  * @brief Helper function to create concat layer
  */
index be7345eed0d0441b4c4d43757da4cdaae0511495..b92a14ca0d3967ad5a470ed86373dd4e31251bc3 100644 (file)
@@ -12,6 +12,7 @@
  * creates the OpenCL command queue and context.
  */
 
+#include <addition_layer_cl.h>
 #include <cl_context.h>
 #include <fc_layer_cl.h>
 
@@ -26,6 +27,10 @@ static void add_default_object(ClContext &cc) {
   cc.registerFactory(nntrainer::createLayer<FullyConnectedLayerCl>,
                      FullyConnectedLayerCl::type,
                      ml::train::LayerType::LAYER_FC);
+
+  cc.registerFactory(nntrainer::createLayer<AdditionLayerCL>,
+                     AdditionLayerCL::type,
+                     ml::train::LayerType::LAYER_ADDITION);
 }
 
 static void registerer(ClContext &cc) noexcept {
diff --git a/nntrainer/layers/cl_layers/addition_layer_cl.cpp b/nntrainer/layers/cl_layers/addition_layer_cl.cpp
new file mode 100644 (file)
index 0000000..48ea84d
--- /dev/null
@@ -0,0 +1,210 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Yash Singh <yash.singh@samsung.com>
+ *
+ * @file   addition_layer_cl.cpp
+ * @date   17 May 2024
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Yash Singh yash.singh@samsung.com>
+ * @bug    No known bugs except for NYI items
+ * @brief       This is Addition Layer Class Class for Neural Network with OpenCl
+ * implementation
+ */
+
+#include <addition_layer_cl.h>
+#include <nntrainer_error.h>
+#include <nntrainer_log.h>
+#include <node_exporter.h>
+#include <util_func.h>
+
+#include <layer_context.h>
+
+std::string addition_cl_kernel_ =
+  R"(__kernel void addition_cl(__global const float* input, __global float* output, const unsigned int size) {
+    #pragma printf_support
+    size_t idx = get_global_id(0);
+    if (idx < size) {
+        output[idx] = output[idx] + input[idx];
+    }
+})";
+
+namespace nntrainer {
+
+static constexpr size_t SINGLE_INOUT_IDX = 0;
+
+void AdditionLayerCL::finalize(InitLayerContext &context) {
+  context.setOutputDimensions({context.getInputDimensions()[0]});
+}
+
+void AdditionLayerCL::forwarding(RunLayerContext &context, bool training) {
+  Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX);
+
+  /** @todo check possibility for in-place of addition layer */
+  for (unsigned int idx = 0; idx < context.getNumInputs(); ++idx) {
+    const Tensor &input_ = context.getInput(idx);
+    if (!idx) {
+      hidden_.copy(input_);
+    } else {
+      // hidden_.add_i(input_);
+      AddProcess(input_, hidden_, context);
+    }
+  }
+}
+
+/**
+ * @brief declaring static kerinputnel objects
+ *
+ */
+opencl::Kernel AdditionLayerCL::kernel_addition;
+
+void AdditionLayerCL::AddProcess(Tensor const &input, Tensor &result,
+                                 RunLayerContext &context) {
+
+  CREATE_IF_EMPTY_DIMS(result, result.getDim());
+
+  NNTR_THROW_IF(result.getData() == nullptr, std::invalid_argument)
+    << result.getName() << " is not allocated";
+  NNTR_THROW_IF(input.getData() == nullptr, std::invalid_argument)
+    << input.getName() << " is not allocated";
+
+  if (input.getDim() != result.getDim()) {
+    throw std::invalid_argument(
+      "Error: Dimensions does not match for addition");
+  }
+
+  if (input.getDataType() == ml::train::TensorDim::DataType::FP32) {
+    unsigned int size = input.size();
+    const float *data = input.getData();
+    float *rdata = result.getData();
+
+    addition_cl(data, rdata, size, context);
+
+  } else
+    throw std::invalid_argument("Error: OpenCL fp16 is not supported yet.");
+}
+
+void AdditionLayerCL::addition_cl(const float *input, float *res,
+                                  unsigned int size, RunLayerContext &context) {
+
+  bool result = false;
+  do {
+    result = result =
+      context.clCreateKernel(addition_cl_kernel_, context.LayerKernel::ADD,
+                             AdditionLayerCL::kernel_addition);
+    if (!result) {
+      break;
+    }
+
+    size_t dim1_size = sizeof(float) * size;
+    opencl::Buffer inputA(context.context_inst_, dim1_size, true, nullptr);
+
+    opencl::Buffer inOutRes(context.context_inst_, dim1_size, true, nullptr);
+
+    result = inputA.WriteData(context.command_queue_inst_, input);
+    if (!result) {
+      break;
+    }
+
+    result = inOutRes.WriteData(context.command_queue_inst_, res);
+    if (!result) {
+      break;
+    }
+
+    result = AdditionLayerCL::kernel_addition.SetKernelArguments(
+      0, &inputA, sizeof(cl_mem));
+    if (!result) {
+      break;
+    }
+
+    result = AdditionLayerCL::kernel_addition.SetKernelArguments(
+      1, &inOutRes, sizeof(cl_mem));
+    if (!result) {
+      break;
+    }
+
+    result = AdditionLayerCL::kernel_addition.SetKernelArguments(2, &size,
+                                                                 sizeof(int));
+    if (!result) {
+      break;
+    }
+
+    const int work_groups_count[3] = {(int)size, 1, 1};
+    const int work_group_size[3] = {32, 32, 1}; // test-value
+    result = context.command_queue_inst_.DispatchCommand(
+      AdditionLayerCL::kernel_addition, work_groups_count, work_group_size);
+    if (!result) {
+      break;
+    }
+
+    result = inOutRes.ReadData(context.command_queue_inst_, res);
+    if (!result) {
+      break;
+    }
+
+  } while (false);
+}
+
+void AdditionLayerCL::incremental_forwarding(RunLayerContext &context,
+                                             unsigned int from, unsigned int to,
+                                             bool training) {
+  Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX);
+  TensorDim hidden_dim = hidden_.getDim();
+  TensorDim hidden_step_dim = hidden_dim;
+
+  if (from) {
+    NNTR_THROW_IF(to - from != 1, std::invalid_argument)
+      << "incremental step size is not 1";
+    from = 0;
+    to = 1;
+  }
+
+  hidden_step_dim.batch(1);
+  hidden_step_dim.height(to - from);
+
+  for (unsigned int b = 0; b < hidden_.batch(); ++b) {
+    Tensor hidden_step = hidden_.getSharedDataTensor(
+      hidden_step_dim, b * hidden_dim.getFeatureLen(), true);
+
+    /** @todo check possibility for in-place of addition layer */
+    for (unsigned int idx = 0; idx < context.getNumInputs(); ++idx) {
+      const Tensor &input_ = context.getInput(idx);
+      TensorDim input_dim = input_.getDim();
+
+      TensorDim input_step_dim = input_dim;
+      input_step_dim.batch(1);
+      input_step_dim.height(to - from);
+
+      Tensor input_step = input_.getSharedDataTensor(
+        input_step_dim, b * input_dim.getFeatureLen(), true);
+      if (!idx) {
+        hidden_step.copy(input_step);
+      } else {
+        // hidden_step.add_i(input_step);
+        AddProcess(input_step, hidden_step, context);
+      }
+    }
+  }
+}
+
+void AdditionLayerCL::calcDerivative(RunLayerContext &context) {
+
+  for (unsigned int idx = 0; idx < context.getNumInputs(); ++idx) {
+    /**
+     * TODO: replace this with tensor assignment during optimization.
+     * Tensor assignment needs to make sure that the previous connected layers
+     * are not inplace
+     */
+    context.getOutgoingDerivative(idx).copy(
+      context.getIncomingDerivative(SINGLE_INOUT_IDX));
+  }
+}
+
+void AdditionLayerCL::setProperty(const std::vector<std::string> &values) {
+  auto remain_props = loadProperties(values, add_props);
+  if (!remain_props.empty()) {
+    std::string msg = "[AdditionLayer] Unknown Layer Properties count " +
+                      std::to_string(values.size());
+    throw exception::not_supported(msg);
+  }
+}
+} /* namespace nntrainer */
diff --git a/nntrainer/layers/cl_layers/addition_layer_cl.h b/nntrainer/layers/cl_layers/addition_layer_cl.h
new file mode 100644 (file)
index 0000000..78b9293
--- /dev/null
@@ -0,0 +1,136 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Yash Singh <yash.singh@samsung.com>
+ *
+ * @file   addition_layer_cl.h
+ * @date   17 May 2024
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Yash Singh yash.singh@samsung.com>
+ * @bug    No known bugs except for NYI items
+ * @brief       This is Addition Layer Class Class for Neural Network with OpenCl
+ * implementation
+ */
+
+#ifndef __ADDITION_LAYER_CL_H__
+#define __ADDITION_LAYER_CL_H__
+#ifdef __cplusplus
+
+#include <common_properties.h>
+#include <layer_devel.h>
+#include <opencl_buffer.h>
+#include <opencl_kernel.h>
+
+#define CREATE_IF_EMPTY_DIMS(tensor, ...) \
+  do {                                    \
+    if (tensor.empty())                   \
+      tensor = Tensor(__VA_ARGS__);       \
+  } while (0);
+
+namespace nntrainer {
+
+/**
+ * @class   AdditionLayerCL
+ * @brief   Addition Layer
+ */
+class AdditionLayerCL : public Layer {
+public:
+  /**
+   * @brief     Constructor of Addition Layer
+   */
+  AdditionLayerCL() : Layer(), add_props(props::Print()) {}
+
+  /**
+   * @brief     Destructor of Addition Layer
+   */
+  ~AdditionLayerCL(){};
+
+  /**
+   *  @brief  Move constructor of AdditionLayer.
+   *  @param[in] AdditionLayer &&
+   */
+  AdditionLayerCL(AdditionLayerCL &&rhs) noexcept = default;
+
+  /**
+   * @brief  Move assignment operator.
+   * @parma[in] rhs AdditionLayer to be moved.
+   */
+  AdditionLayerCL &operator=(AdditionLayerCL &&rhs) = default;
+
+  /**
+   * @copydoc Layer::finalize(InitLayerContext &context)
+   */
+  void finalize(InitLayerContext &context) override;
+
+  /**
+   * @copydoc Layer::forwarding(RunLayerContext &context, bool training)
+   */
+  void forwarding(RunLayerContext &context, bool training) override;
+
+  /**
+   * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned
+   * int from, unsigned int to, bool training)
+   */
+  void incremental_forwarding(RunLayerContext &context, unsigned int from,
+                              unsigned int to, bool training) override;
+
+  /**
+   * @copydoc Layer::calcDerivative(RunLayerContext &context)
+   */
+  void calcDerivative(RunLayerContext &context) override;
+
+  /**
+   * @brief declaring static kernel objects
+   */
+  static opencl::Kernel kernel_addition;
+
+  /**
+   * @brief Process data and dimensions for add operation used in addition layer
+   * @param[in] input Tensor
+   * @param[in] result Tensor
+   * @param[in] RunLayerContext reference
+   */
+  void AddProcess(Tensor const &input, Tensor &result,
+                  RunLayerContext &context);
+
+  /**
+   * @brief     addition : sum of all input vectors
+   * @param[in] input float * for input
+   * @param[in] res float * for result/output
+   * @param[in] size number of elements in input vector
+   * @param[in] context RunLayerContext reference
+   */
+  void addition_cl(const float *input, float *res, unsigned int size,
+                   RunLayerContext &context);
+
+  /**
+   * @copydoc bool supportBackwarding() const
+   */
+  bool supportBackwarding() const override { return true; };
+
+  /**
+   * @copydoc Layer::exportTo(Exporter &exporter, ml::train::ExportMethods
+   * method)
+   */
+  void exportTo(Exporter &exporter,
+                const ml::train::ExportMethods &method) const override {}
+
+  /**
+   * @copydoc Layer::setProperty(const std::vector<std::string> &values)
+   */
+  void setProperty(const std::vector<std::string> &values) override;
+
+  /**
+   * @copydoc Layer::getType()
+   */
+  const std::string getType() const override { return AdditionLayerCL::type; };
+
+  std::tuple<props::Print>
+    add_props; /**< fc layer properties : unit - number of output neurons */
+
+  inline static const std::string type = "addition";
+};
+
+} // namespace nntrainer
+
+#endif /* __cplusplus */
+#endif /* __ADDITION_LAYER_H__ */
index 5c6ad1358f8e5f414c22bb7fe2dfb2da748c0947..349e1f443de99ac37d9670d847c03c57370bb506 100644 (file)
@@ -1,5 +1,7 @@
 cl_layer_sources = [
   'fc_layer_cl.cpp',
+  'blas_kernels.cpp',
+  'addition_layer_cl.cpp'
 ]
 
 foreach s : cl_layer_sources
index 015879cdf3c85d47ec991f0e7a080578bc64b757..a6615b92aa8e2282ff2ec6015e9d74d119945857 100644 (file)
@@ -690,6 +690,8 @@ std::string RunLayerContext::getKernelName(LayerKernel layerKernel) {
     return "dot_cl_fp16";
   case LayerKernel::SGEMM_FP16:
     return "sgemm_cl_fp16";
+  case LayerKernel::ADD:
+    return "addition_cl";
   default:
     return "";
   }
index e0ed137c3c63ce0373eabcae1ae20ecf03cfbdde..105725a57b0fad12a15c328fb8221f72d3bce6ce 100644 (file)
@@ -835,6 +835,7 @@ public:
     SGEMV_FP16 = 1 << 3, /**< placeholder for kernel name */
     DOT_FP16 = 1 << 4,   /**< placeholder for kernel name */
     SGEMM_FP16 = 1 << 5, /**< placeholder for kernel name */
+    ADD = 1 << 6         /**< placeholder for kernel name */
   };
 
   /**
index 7a1ed18ec672539e397cf03446ded73135c787ea..5c20c7b10d1fad0996bb55458bc32aaf664a6ccd 100644 (file)
@@ -865,9 +865,6 @@ if __name__ == "__main__":
         positional_encoding, [(3, 1, 10, 6)], "positional_encoding_w16a16"
     )
 
-    added = K.layers.Add()
-    record_single_fp16(added, [(2, 3, 3, 3), (2, 3, 3, 3)], "added_w16a16")
-
     def swiglu(inputs):
         [x, y] = inputs
         # swish(x) = x * sigmoid(x)
@@ -883,3 +880,15 @@ if __name__ == "__main__":
         "swiglu",
         input_type="float",
     )
+    
+    added = K.layers.Add()
+    record_single_fp16(added, [(2, 3, 3, 3), (2, 3, 3, 3)], "added_w16a16")
+    
+    added = K.layers.Add()
+    record_single(added, [(2, 3, 3, 3), (2, 3, 3, 3)], "added_w32a32")
+    
+    added = K.layers.Add()
+    record_single(added, [(3, 4, 3, 4), (3, 4, 3, 4)], "added_w32a32_2")
+    
+    added = K.layers.Add()
+    record_single(added, [(20, 55, 50, 55), (20, 55, 50, 55)], "added_w32a32_3")
index 978e98bd67a633e43fb4b5d79c2995bdf737d97e..963beb3b012d7ba7ffbf35eff2cb05aa68ab3851 100644 (file)
@@ -453,6 +453,7 @@ LOCAL_SRC_FILES := \
         ../unittest/layers/unittest_layers_flatten.cpp \
         ../unittest/layers/unittest_layers_activation.cpp \
         ../unittest/layers/unittest_layers_addition.cpp \
+        ../unittest/layers/unittest_layers_addition_cl.cpp \
         ../unittest/layers/unittest_layers_multiout.cpp \
         ../unittest/layers/unittest_layers_rnn.cpp \
         ../unittest/layers/unittest_layers_rnncell.cpp \
diff --git a/test/unittest/layers/unittest_layers_addition_cl.cpp b/test/unittest/layers/unittest_layers_addition_cl.cpp
new file mode 100644 (file)
index 0000000..a5d6907
--- /dev/null
@@ -0,0 +1,50 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Yash Singh <yash.singh@samsung.com>
+ *
+ * @file unittest_layers_addition_cl.cpp
+ * @date 17 May 2024
+ * @brief Addition Layer Test
+ * @see        https://github.com/nnstreamer/nntrainer
+ * @author Yash Singh <yash.singh@samsung.com>
+ * @bug No known bugs except for NYI items
+ */
+#include <tuple>
+
+#include <gtest/gtest.h>
+
+#include <addition_layer_cl.h>
+#include <layers_common_tests.h>
+
+auto semantic_addition_gpu = LayerSemanticsParamType(
+  nntrainer::createLayer<nntrainer::AdditionLayerCL>,
+  nntrainer::AdditionLayerCL::type, {},
+  LayerCreateSetPropertyOptions::AVAILABLE_FROM_APP_CONTEXT, false, 1);
+
+auto semantic_addition_multi_gpu = LayerSemanticsParamType(
+  nntrainer::createLayer<nntrainer::AdditionLayerCL>,
+  nntrainer::AdditionLayerCL::type, {},
+  LayerCreateSetPropertyOptions::AVAILABLE_FROM_APP_CONTEXT, false, 2);
+
+GTEST_PARAMETER_TEST(AdditionGPU, LayerSemantics,
+                     ::testing::Values(semantic_addition_gpu,
+                                       semantic_addition_multi_gpu));
+
+auto addition_w32a32 = LayerGoldenTestParamType(
+  nntrainer::createLayer<nntrainer::AdditionLayerCL>, {}, "2:3:3:3,2:3:3:3",
+  "added_w32a32.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT, "nchw",
+  "fp32", "fp32");
+
+auto addition_w32a32_2 = LayerGoldenTestParamType(
+  nntrainer::createLayer<nntrainer::AdditionLayerCL>, {}, "3:4:3:4,3:4:3:4",
+  "added_w32a32_2.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT, "nchw",
+  "fp32", "fp32");
+
+auto addition_w32a32_3 = LayerGoldenTestParamType(
+  nntrainer::createLayer<nntrainer::AdditionLayerCL>, {},
+  "20:55:50:55,20:55:50:55", "added_w32a32_3.nnlayergolden",
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
+
+GTEST_PARAMETER_TEST(AdditionGPU, LayerGoldenTest,
+                     ::testing::Values(addition_w32a32, addition_w32a32_2,
+                                       addition_w32a32_3));