[IE CLDNN] Add pooling b_fs_yx_fsv16 int8 (#565)
authorJedrzej Hajduczenia <jedrzej.hajduczenia@intel.com>
Thu, 18 Jun 2020 13:40:52 +0000 (15:40 +0200)
committerGitHub <noreply@github.com>
Thu, 18 Jun 2020 13:40:52 +0000 (16:40 +0300)
inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/pooling/pooling_kernel_gpu_b_fs_yx_fsv16_imad.cpp [new file with mode: 0644]
inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/pooling/pooling_kernel_gpu_b_fs_yx_fsv16_imad.h [new file with mode: 0644]
inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/pooling/pooling_kernel_selector.cpp
inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/pooling_gpu_b_fs_yx_fsv16_imad.cl [new file with mode: 0644]
inference-engine/thirdparty/clDNN/tests/test_cases/pooling_gpu_test.cpp

diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/pooling/pooling_kernel_gpu_b_fs_yx_fsv16_imad.cpp b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/pooling/pooling_kernel_gpu_b_fs_yx_fsv16_imad.cpp
new file mode 100644 (file)
index 0000000..9574e41
--- /dev/null
@@ -0,0 +1,102 @@
+// Copyright (c) 2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "pooling_kernel_gpu_b_fs_yx_fsv16_imad.h"
+#include "kernel_selector_utils.h"
+
+#define FEATURE_SLICE_SIZE 16
+
+namespace kernel_selector {
+ParamsKey PoolingKernelGPU_b_fs_yx_fsv16_imad::GetSupportedKey() const {
+    ParamsKey k;
+    k.EnableInputDataType(Datatype::INT8);
+    k.EnableInputDataType(Datatype::UINT8);
+    k.EnableOutputDataType(Datatype::INT8);
+    k.EnableOutputDataType(Datatype::UINT8);
+    k.EnableOutputDataType(Datatype::F32);
+    k.EnableInputLayout(DataLayout::b_fs_yx_fsv16);
+    k.EnableOutputLayout(DataLayout::b_fs_yx_fsv16);
+    k.EnableTensorOffset();
+    k.EnableTensorPitches();
+    k.EnableBatching();
+    k.EnablePoolType(PoolType::MAX);
+    k.EnablePoolType(PoolType::AVG);
+    k.EnablePoolRemainder(PoolRemainder::FLOOR);
+    k.EnablePoolRemainder(PoolRemainder::CEIL);
+    k.EnablePoolKernelDividerMode(KernelDividerMode::FIXED);
+    k.EnablePoolKernelDividerMode(KernelDividerMode::DYNAMIC);
+    k.EnablePoolKernelDividerMode(KernelDividerMode::DYNAMIC_WITH_PADDING);
+    k.EnableDifferentTypes();
+    return k;
+}
+
+PoolingKernelBase::DispatchData PoolingKernelGPU_b_fs_yx_fsv16_imad::SetDefault(const pooling_params& params) const {
+    DispatchData runInfo = PoolingKernelBase::SetDefault(params);
+
+    const auto& out = params.output;
+    auto x = out.X().v;
+    auto y = out.Y().v;
+    auto f = out.Feature().v;
+    auto b = out.Batch().v;
+
+    runInfo.gws0 = x;
+    runInfo.gws1 = y;
+    // we got b_fs_yx_fsv16 format, we process 16 features per workitem
+    runInfo.gws2 = CeilDiv(f, FEATURE_SLICE_SIZE) * b;
+
+    auto local = GetOptimalLocalWorkGroupSizes({ runInfo.gws0, runInfo.gws1, runInfo.gws2 }, params.engineInfo);
+
+    runInfo.lws0 = local[0];
+    runInfo.lws1 = local[1];
+    runInfo.lws2 = local[2];
+
+    return runInfo;
+}
+
+JitConstants PoolingKernelGPU_b_fs_yx_fsv16_imad::GetJitConstants(const pooling_params& params, DispatchData kd) const {
+    auto jit = PoolingKernelBase::GetJitConstants(params, kd);
+
+    const size_t in_x_pitch = FEATURE_SLICE_SIZE;
+    const size_t in_y_pitch = FEATURE_SLICE_SIZE * params.inputs[0].X().LogicalDimPadded();
+    jit.AddConstant(MakeJitConstant("IN_X_PITCH", in_x_pitch));
+    jit.AddConstant(MakeJitConstant("IN_Y_PITCH", in_y_pitch));
+    jit.Merge(MakeTypeJitConstants(GetActivationType(params), "ACTIVATION"));
+    jit.Merge(MakeTypeJitConstants(GetAccumulatorType(params), "ACCUMULATOR"));
+
+    if (!params.fused_ops.empty()) {
+        auto input_dt = EnableRound(params) ? Datatype::INT32 : GetActivationType(params);
+        FusedOpsConfiguration conf = {"", {"b", "f", "y", "x"}, "pool_result[i]", input_dt, 1};
+        conf.SetLoopAxes({ Tensor::DataChannelName::FEATURE }, true);
+        jit.Merge(MakeFusedOpsJitConstants(params, { conf }));
+    }
+
+    return jit;
+}
+
+KernelsData PoolingKernelGPU_b_fs_yx_fsv16_imad::GetKernelsData(const Params& params, const optional_params& options) const {
+    return GetCommonKernelsData(params, options, FORCE_PRIORITY_1);
+}
+
+bool PoolingKernelGPU_b_fs_yx_fsv16_imad::Validate(const Params& params, const optional_params& options) const {
+    if (!PoolingKernelBase::Validate(params, options)) {
+        return false;
+    }
+    auto p = dynamic_cast<const pooling_params&>(params);
+
+    if (p.inputs[0].Feature().v % FEATURE_SLICE_SIZE != 0)
+        return false;
+
+    return true;
+}
+}  // namespace kernel_selector
diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/pooling/pooling_kernel_gpu_b_fs_yx_fsv16_imad.h b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/pooling/pooling_kernel_gpu_b_fs_yx_fsv16_imad.h
new file mode 100644 (file)
index 0000000..329cb8b
--- /dev/null
@@ -0,0 +1,41 @@
+// Copyright (c) 2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+
+#pragma once
+
+#include "pooling_kernel_base.h"
+#include <vector>
+
+namespace kernel_selector {
+class PoolingKernelGPU_b_fs_yx_fsv16_imad: public PoolingKernelBase{
+public:
+    PoolingKernelGPU_b_fs_yx_fsv16_imad() : PoolingKernelBase("pooling_gpu_b_fs_yx_fsv16_imad") {}
+    virtual ~PoolingKernelGPU_b_fs_yx_fsv16_imad() {}
+
+    KernelsData GetKernelsData(const Params& params, const optional_params& options) const override;
+    ParamsKey GetSupportedKey() const override;
+    DispatchData SetDefault(const pooling_params& params) const override;
+    bool Validate(const Params&, const optional_params&) const override;
+    std::vector<FusedOpType> GetSupportedFusedOps() const override {
+        return { FusedOpType::ELTWISE,
+                 FusedOpType::QUANTIZE,
+                 FusedOpType::SCALE,
+                 FusedOpType::ACTIVATION };
+    }
+
+protected:
+    JitConstants GetJitConstants(const pooling_params& params, DispatchData kd) const override;
+};
+}  // namespace kernel_selector
index 3177325..9785c2f 100644 (file)
@@ -27,6 +27,7 @@
 #include "pooling_kernel_gpu_fs_b_yx_fsv32.h"
 #include "pooling_kernel_gpu_b_fs_yx_fsv16.h"
 #include "pooling_kernel_gpu_bsv16_fsv16.h"
+#include "pooling_kernel_gpu_b_fs_yx_fsv16_imad.h"
 
 namespace kernel_selector {
 
@@ -44,6 +45,7 @@ pooling_kernel_selector::pooling_kernel_selector() {
     Attach<PoolingKerneGPU_fs_b_yx_fsv32>();
     Attach<PoolingKernel_b_fs_yx_fsv16>();
     Attach<PoolingKernel_bsv16_fsv16>();
+    Attach<PoolingKernelGPU_b_fs_yx_fsv16_imad>();
 }
 
 KernelsData pooling_kernel_selector::GetBestKernels(const Params& params, const optional_params& options) const {
diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/pooling_gpu_b_fs_yx_fsv16_imad.cl b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/pooling_gpu_b_fs_yx_fsv16_imad.cl
new file mode 100644 (file)
index 0000000..91e272e
--- /dev/null
@@ -0,0 +1,207 @@
+// Copyright (c) 2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+
+#include "include/include_all.cl"
+#include "include/data_types.cl"
+
+#define ALIGN_TO(val, multiple) (((val) + (multiple) - 1) / (multiple) * (multiple))
+
+#define AS_TYPE(type, val) CAT(as_, type)(val)
+#define IN_VEC16 MAKE_VECTOR_TYPE(INPUT0_TYPE, 16)
+#define OUT_VEC16 MAKE_VECTOR_TYPE(OUTPUT_TYPE, 16)
+
+#define ACTIVATION_VEC16 MAKE_VECTOR_TYPE(ACTIVATION_TYPE, 16)
+#define TO_ACTIVATION_VEC16 CAT(convert_, ACTIVATION_VEC16)
+
+#define FEATURE_SLICE_SIZE 16
+
+#if MAX_POOLING
+    #define INIT_VAL ACCUMULATOR_VAL_MIN
+#elif AVG_POOLING
+    #define INIT_VAL ACCUMULATOR_VAL_ZERO
+#else
+#error
+#endif
+
+
+inline ACCUMULATOR_TYPE FUNC(apply_pooling)(ACCUMULATOR_TYPE tmp, ACCUMULATOR_TYPE in)
+{
+#if MAX_POOLING
+    return ACCUMULATOR_MAX_FUNC(tmp, in);
+#elif AVG_POOLING
+    return tmp + in;
+#endif
+}
+
+__attribute__((intel_reqd_sub_group_size(FEATURE_SLICE_SIZE)))
+KERNEL(pooling_gpu_b_fs_yx_fsv16)(
+    const __global INPUT0_TYPE* input,
+    __global OUTPUT_TYPE* output
+#if HAS_FUSED_OPS_DECLS
+    , FUSED_OPS_DECLS
+#endif
+)
+{
+    const uint x    = (uint)get_global_id(0);
+    const uint y    = (uint)get_global_id(1);
+    const uint bf   = (uint)get_global_id(2);
+    const uint f    = (bf * FEATURE_SLICE_SIZE) % ALIGN_TO(INPUT0_FEATURE_NUM, FEATURE_SLICE_SIZE);
+    const uint b    = (bf * FEATURE_SLICE_SIZE) / ALIGN_TO(INPUT0_FEATURE_NUM, FEATURE_SLICE_SIZE);
+
+    const int offset_x = (int)x*STRIDE_SIZE_X - PADDING_SIZE_X;
+    const int offset_y = (int)y*STRIDE_SIZE_Y - PADDING_SIZE_Y;
+
+    ACCUMULATOR_TYPE result[FEATURE_SLICE_SIZE] = { INIT_VAL, INIT_VAL, INIT_VAL, INIT_VAL, INIT_VAL, INIT_VAL, INIT_VAL, INIT_VAL,
+                                                    INIT_VAL, INIT_VAL, INIT_VAL, INIT_VAL, INIT_VAL, INIT_VAL, INIT_VAL, INIT_VAL };
+
+#ifdef CHECK_BOUNDRY
+    if (offset_x + POOL_SIZE_X < 0 || offset_x >= INPUT0_SIZE_X ||
+        offset_y + POOL_SIZE_Y < 0 || offset_y >= INPUT0_SIZE_Y)
+    {
+        return;
+    }
+
+#ifdef DYNAMIC_KERNEL_DIVIDER
+    uint num_elements = 0;
+#endif
+
+    const uint batch_and_feature_offset = INPUT0_GET_INDEX(b, f, 0, 0);
+    __attribute__((opencl_unroll_hint(POOL_SIZE_Y)))
+    for(uint j = 0; j < POOL_SIZE_Y; j++)
+    {
+        int input_offset_y = offset_y + j;
+        bool zero_y = input_offset_y >= INPUT0_SIZE_Y || input_offset_y < 0;
+        if(!zero_y)
+        {
+            __attribute__((opencl_unroll_hint(POOL_SIZE_X)))
+            for(uint i = 0; i < POOL_SIZE_X; i++)
+            {
+                int input_offset_x = offset_x + i;
+                bool zero = input_offset_x >= INPUT0_SIZE_X || input_offset_x < 0;
+                if(!zero)
+                {
+                    const uint input_idx = batch_and_feature_offset + input_offset_y*IN_Y_PITCH + input_offset_x*IN_X_PITCH;
+
+                    int4 int_data = vload4(0, (__global int*)(input + input_idx));
+                    IN_VEC16 ch16_data = AS_TYPE(IN_VEC16, int_data);
+                    __attribute__((opencl_unroll_hint(FEATURE_SLICE_SIZE)))
+                    for(uint k = 0; k < FEATURE_SLICE_SIZE; k++)
+                    {
+                        result[k] = FUNC_CALL(apply_pooling)(result[k], ch16_data[k]);
+                    }
+
+#ifdef DYNAMIC_KERNEL_DIVIDER
+                    num_elements++;
+#endif
+                }
+            }
+        }
+    }
+#ifdef DYNAMIC_WITH_PADDING_KERNEL_DIVIDER
+    const int hend = min(offset_y + POOL_SIZE_Y, INPUT0_SIZE_Y + PADDING_SIZE_Y);
+    const int wend = min(offset_x + POOL_SIZE_X, INPUT0_SIZE_X + PADDING_SIZE_X);
+    const uint num_elements = (hend - offset_y) * (wend - offset_x);
+#endif
+#else // !CHECK_BOUNDRY
+    uint input_idx = INPUT0_GET_INDEX(b, f, offset_y, offset_x);
+    __attribute__((opencl_unroll_hint(POOL_SIZE_Y)))
+    for(uint j = 0; j < POOL_SIZE_Y; j++)
+    {
+        __attribute__((opencl_unroll_hint(POOL_SIZE_X)))
+        for(uint i = 0; i < POOL_SIZE_X; i++)
+        {
+            int4 int_data = vload4(0, (__global int*)(input + input_idx));
+            IN_VEC16 ch16_data = AS_TYPE(IN_VEC16, int_data);
+            __attribute__((opencl_unroll_hint(FEATURE_SLICE_SIZE)))
+            for(uint k = 0; k < FEATURE_SLICE_SIZE; k++)
+            {
+                result[k] = FUNC_CALL(apply_pooling)(result[k], ch16_data[k]);
+            }
+
+            input_idx += IN_X_PITCH;
+        }
+        input_idx += (IN_Y_PITCH - POOL_SIZE_X*IN_X_PITCH);
+    }
+
+#if defined(DYNAMIC_KERNEL_DIVIDER) || defined(DYNAMIC_WITH_PADDING_KERNEL_DIVIDER)
+    const uint num_elements = POOL_SIZE_X*POOL_SIZE_Y;
+#endif
+#endif
+
+    
+    ACTIVATION_VEC16 pool_result;
+#if defined AVG_POOLING
+#if ENABLE_ROUND
+    __attribute__((opencl_unroll_hint(FEATURE_SLICE_SIZE)))
+    for(uint i = 0; i < FEATURE_SLICE_SIZE; i++) {
+    #if defined(DYNAMIC_KERNEL_DIVIDER) || defined(DYNAMIC_WITH_PADDING_KERNEL_DIVIDER)
+        pool_result[i] = convert_int(round(((float)result[i] / max(num_elements, (uint)1))));
+    #else
+        pool_result[i] = convert_int(round((float)result[i] / (int)(POOL_SIZE_Y * POOL_SIZE_X)));
+    #endif
+    }
+#else
+    __attribute__((opencl_unroll_hint(FEATURE_SLICE_SIZE)))
+    for(uint i = 0; i < FEATURE_SLICE_SIZE; i++) {
+    #if defined(DYNAMIC_KERNEL_DIVIDER) || defined(DYNAMIC_WITH_PADDING_KERNEL_DIVIDER)
+        pool_result[i] = (float)result[i] / max(num_elements, (uint)1);
+    #else
+        pool_result[i] = (float)result[i] / (int)(POOL_SIZE_Y * POOL_SIZE_X);
+    #endif
+    }
+#endif  // ENABLE_ROUND
+#else  // AVG_POOLING
+    __attribute__((opencl_unroll_hint(FEATURE_SLICE_SIZE)))
+    for (uint i = 0; i < FEATURE_SLICE_SIZE; ++i) {
+        pool_result[i] = result[i];
+    }
+#endif  // AVG_POOLING
+
+OUT_VEC16 final_result = (OUTPUT_TYPE)(0);
+#if HAS_FUSED_OPS && FUSED_OPS_CAN_USE_PRELOAD
+    FUSED_OPS_PRELOAD
+#endif
+
+    __attribute__((opencl_unroll_hint(FEATURE_SLICE_SIZE)))
+    for (uint i = 0; i < FEATURE_SLICE_SIZE; ++i) {
+#if HAS_FUSED_OPS
+#if FUSED_OPS_CAN_USE_PRELOAD
+        FUSED_OPS_CALC
+#else
+        FUSED_OPS
+#endif
+        final_result[i] = FUSED_OPS_RESULT;
+#else
+        final_result[i] = TO_OUTPUT_TYPE(ACTIVATION(pool_result[i], ACTIVATION_PARAMS));
+#endif
+    }
+    
+    const uint output_pos = OUTPUT_GET_INDEX(b, f, y, x);
+
+#if OUTPUT_TYPE_SIZE == 1
+    vstore4(as_uint4(final_result), 0, ((__global uint*)(output + output_pos)));
+#else
+    *((__global OUT_VEC16*)(output + output_pos)) = final_result;
+#endif
+}
+
+#undef ALIGN_TO
+#undef AS_TYPE
+#undef IN_VEC16
+#undef OUT_VEC16
+#undef ACTIVATION_VEC16
+#undef TO_ACTIVATION_VEC16
+#undef INIT_VAL
+#undef FEATURE_SLICE_SIZE
index 78b1fa8..5cfd1bf 100644 (file)
@@ -2586,7 +2586,7 @@ INSTANTIATE_TEST_CASE_P(
     smoke_low_precision,
     pooling_random_test,
     testing::Combine(testing::Values(1, 2),
-                     testing::Values(3, 8),
+                     testing::Values(3, 8, 64),
                      testing::Values(std::tuple<size_t, size_t>(12, 12), std::tuple<size_t, size_t>(24, 24)),
                      testing::Values(std::tuple<size_t, size_t>(4, 4), std::tuple<size_t, size_t>(2, 2)),
                      testing::Values(std::tuple<int, int>(2, 2)),