arm_compute v18.02
[platform/upstream/armcl.git] / src / core / CL / cl_kernels / direct_convolution_1x1_3x3_5x5_quantized.cl
index d0cf032..b58dc7a 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -247,3 +247,62 @@ __kernel void direct_convolution_1x1_3x3_5x5_quantized(
     vstore8(convert_uchar8_sat(pixels0), 0, (__global uchar *)dst.ptr);
 }
 #endif // defined(DATA_TYPE) && defined(STRIDE_X) && defined(WEIGHTS_DEPTH)
+
+/** This function computes the output stage of a depthwise convolution.
+ *
+ * @param[in] src_ptr                            Pointer to the source image. Supported data types: QASYMM8
+ * @param[in] src_stride_x                       Stride of the source image in X dimension (in bytes)
+ * @param[in] src_step_x                         src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y                       Stride of the source image in Y dimension (in bytes)
+ * @param[in] src_step_y                         src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes  The offset of the first element in the source image
+ * @param[in] src_stride_z                       Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z                         src_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_ptr                            Pointer to the destination tensor. Supported data types: QASYMM8
+ * @param[in] dst_stride_x                       Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x                         dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y                       Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y                         dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_step_z                         dst_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes  The offset of the first element in the destination tensor
+ * @param[in] bias_ptr                           (Optional) Pointer to the biases vector. Supported data types: S32
+ * @param[in] bias_stride_x                      (Optional) Stride of the biases vector in X dimension (in bytes)
+ * @param[in] bias_step_x                        (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the biases vector
+ * @param[in] output_offset                      Quantized offset of zero point of the output tensor data range
+ * @param[in] output_multiplier                  Output scale multiplier
+ * @param[in] output_shift                       Output scale divisor exponent
+ */
+
+__kernel void output_stage_quantized(
+    TENSOR3D_DECLARATION(src),
+    TENSOR3D_DECLARATION(dst),
+#if defined(HAS_BIAS)
+    VECTOR_DECLARATION(bias),
+#endif //defined(HAS_BIAS)
+    int output_offset,
+    int output_multiplier,
+    int output_shift)
+{
+    Image src = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(src);
+    Image dst = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst);
+#if defined(HAS_BIAS)
+    Vector bias = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bias);
+#endif //defined(HAS_BIAS)
+
+    // Load input
+    int16 vals = vload16(0, (__global int *)(src.ptr));
+
+#if defined(HAS_BIAS)
+    // Load and add bias
+    int bias_value = *((__global int *)(vector_offset(&bias, get_global_id(2))));
+    vals += (int16)(bias_value);
+#endif //defined(HAS_BIAS)
+
+    vals = ASYMM_MULT_BY_QUANT_MULTIPLIER_LESS_THAN_ONE(vals, output_multiplier, output_shift, 16);
+    vals = vals + output_offset;
+
+    // Store result in dst
+    vstore16(convert_uchar16_sat(vals), 0, (__global uchar *)dst.ptr);
+}