[IE CLDNN] TopK registry spill avoiding for sort-by-value mode (#2590)
authorIlya Znamenskiy <ilya.znamenskiy@intel.com>
Mon, 12 Oct 2020 05:36:57 +0000 (08:36 +0300)
committerGitHub <noreply@github.com>
Mon, 12 Oct 2020 05:36:57 +0000 (08:36 +0300)
inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/arg_max_min/arg_max_min_kernel_axis.cpp
inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/arg_max_min_axis.cl

index 9f557e0..15fc570 100644 (file)
@@ -1,4 +1,4 @@
-// Copyright (c) 2018 Intel Corporation
+// Copyright (c) 2018-2020 Intel Corporation
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
@@ -29,6 +29,18 @@ size_t getOperationNumber(const arg_max_min_params& params) {
     }
 }
 
+size_t getSortSize(const arg_max_min_params& params) {
+    switch (params.argMaxMinAxis) {
+        case ArgMaxMinAxis::BATCH: return params.inputs[0].Batch().v;
+        case ArgMaxMinAxis::FEATURE: return params.inputs[0].Feature().v;
+        case ArgMaxMinAxis::Z: return params.inputs[0].Z().v;
+        case ArgMaxMinAxis::Y: return params.inputs[0].Y().v;
+        case ArgMaxMinAxis::X: return params.inputs[0].X().v;
+        default:
+            throw std::invalid_argument("Unsupported axis");
+    }
+}
+
 ParamsKey ArgMaxMinKernelAxis::GetSupportedKey() const {
     ParamsKey k;
     k.EnableInputDataType(Datatype::F16);
@@ -72,19 +84,24 @@ KernelsData ArgMaxMinKernelAxis::GetKernelsData(const Params& params, const opti
     if (!Validate(params, options)) {
         return {};
     }
-
     const arg_max_min_params& orgParams = static_cast<const arg_max_min_params&>(params);
 
     DispatchData runInfo;
     runInfo.fp16UnitUsed = orgParams.inputs[0].GetDType() == Datatype::F16;
 
-    runInfo.gws0 = Align(getOperationNumber(orgParams), 32);
-    runInfo.gws1 = 1;
-    runInfo.gws2 = 1;
+    size_t sort_size = orgParams.argMaxMinSortType == ArgMaxMinSortType::VALUE ? getSortSize(orgParams) : 1;
+
+    std::vector<size_t> local, global;
+    global = { Align(getOperationNumber(orgParams), 32), sort_size, 1 };
+    local = GetOptimalLocalWorkGroupSizes(global, params.engineInfo);
+
+    runInfo.gws0 = global[0];
+    runInfo.gws1 = global[1];
+    runInfo.gws2 = global[2];
 
-    runInfo.lws0 = 32;
-    runInfo.lws1 = 1;
-    runInfo.lws2 = 1;
+    runInfo.lws0 = local[0];
+    runInfo.lws1 = local[1];
+    runInfo.lws2 = local[2];
 
     KernelData kd = KernelData::Default<arg_max_min_params>(params);
 
index fccbeda..ec35852 100644 (file)
     #define COMPARE_SIGN <
     #define COMPARE_PARTIAL_SIGN >=
     #define COMPARE_MERGE_SIGN >
+    #define COMPARE_PARALLEL_SIGN_1 <=
+    #define COMPARE_PARALLEL_SIGN_2 <
     #define INPUT0_FILL_VAL INPUT0_VAL_MIN
 #else
     #define COMPARE_SIGN >
     #define COMPARE_PARTIAL_SIGN <=
     #define COMPARE_MERGE_SIGN <
+    #define COMPARE_PARALLEL_SIGN_1 >=
+    #define COMPARE_PARALLEL_SIGN_2 >
     #define INPUT0_FILL_VAL INPUT0_VAL_MAX
 #endif
 
@@ -83,7 +87,9 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input
                             )
 {
 #include "include/arg_max_min_common.cl"
-#if (TOP_K == 1)
+#if SORT_BY_VALUE
+    const uint sort_idx = (uint)get_global_id(1);
+#elif TOP_K == 1
     iav_type result[TOP_K];
 #else
     iav_type result[VALUES_NUM], temp_buf[VALUES_NUM];
@@ -91,9 +97,9 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input
     const uint group_num = ((VALUES_NUM - 1) / group_size) + 1;
     const uint last_group_size = (VALUES_NUM % group_size > 0) ? (VALUES_NUM % group_size) : group_size;
     const uint last_group_offset = (group_num - 1) * group_size;
-#endif // (TOP_K == 1)
+#endif // SORT_BY_VALUE
 
-    uint output_idx = (uint)get_global_id(0);
+    const uint output_idx = (uint)get_global_id(0);
 
     if (output_idx >= OPERATION_NUM)
         return;
@@ -162,9 +168,35 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input
     #endif
 #endif
 
-// Using simple sorting for (TOP_K == 1)
-#if (TOP_K == 1)
+// Using parallel sorting for sorting by values
+#if SORT_BY_VALUE
+    uint sort_position = 0;
+    indices[AXIS] = sort_idx;
+
+    iav_type result;
+    result.value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])];
+    result.index = sort_idx;
+
+    for (uint i = 0; i < sort_idx; i++) {
+        indices[AXIS] = i;
+        INPUT0_TYPE test_value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])];
+        if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
+            sort_position++;
+        if (sort_position >= TOP_K)
+            return;
+    }
+
+    for (uint i = sort_idx + 1; i < VALUES_NUM; i++) {
+        indices[AXIS] = i;
+        INPUT0_TYPE test_value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])];
+        if (result.value COMPARE_PARALLEL_SIGN_2 test_value)
+            sort_position++;
+        if (sort_position >= TOP_K)
+            return;
+    }
 
+// Using simple sorting for sorting by indices and when TOP_K == 1
+#elif TOP_K == 1
     INPUT0_TYPE val = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])];
     result[0].index = 0;
     result[0].value = val;
@@ -194,9 +226,8 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input
         val = INPUT0_FILL_VAL;
     }
 
-// Using merge sorting when (TOP_K >= (VALUES_NUM / 2)) or (VALUES_NUM < MINIMUM_NUMBER_FOR_PARTIAL_SORTING)
+// Using merge sorting for sorting by indices and when (TOP_K >= (VALUES_NUM / 2)) or (VALUES_NUM < MINIMUM_NUMBER_FOR_PARTIAL_SORTING)
 #elif ((TOP_K >= (VALUES_NUM / 2)) || (VALUES_NUM < MINIMUM_NUMBER_FOR_PARTIAL_SORTING))
-
     for (uint i = 0; i < VALUES_NUM / 8; i++) {
         uint index_offset = i * 8;
         indices[AXIS] = result[index_offset].index = index_offset;
@@ -245,9 +276,8 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input
         }
     }
 
-// In other cases using mixed partial/merge sorting
-#else // (TOP_K == 1)
-
+// In other cases for sorting by indices using mixed partial/merge sorting
+#else // SORT_BY_VALUE
     for (uint i = 0; i < VALUES_NUM / 8; i++) {
         uint index_offset = i * 8;
         indices[AXIS] = temp_buf[index_offset].index = index_offset;
@@ -365,41 +395,56 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input
 
         result[i] = merge_buf;
     }
+#endif // SORT_BY_VALUE
 
-#endif // (TOP_K == 1)
+#if SORT_BY_VALUE
+    indices[AXIS] = sort_position;
+#ifdef TOP_K_ORDER
+    output[FUNC_CALL(get_output_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])] = TO_OUTPUT_TYPE(result.value);
+#else
+    output[FUNC_CALL(get_output_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])] = TO_OUTPUT_TYPE(result.index);
+#endif
+#ifdef SECOND_OUTPUT_EXIST
+    #ifdef TOP_K_ORDER
+    second_output[FUNC_CALL(get_output_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])] = TO_INPUT1_TYPE(result.index);
+    #else
+    second_output[FUNC_CALL(get_output_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])] = TO_INPUT1_TYPE(result.value);
+    #endif
+#endif
 
+#else // SORT_BY_VALUE
     for (uint top_k = 0; top_k < TOP_K; ++top_k) {
-#ifdef SORT_BY_VALUE
-        indices[AXIS] = top_k;
-#endif
-#ifdef SORT_BY_INDEX
         uint out_position = 0;
+
         for (uint i = 0; i < TOP_K; ++i) {
             if (i == top_k)
                 continue;
             if (result[i].index < result[top_k].index)
                 out_position++;
         }
+
         indices[AXIS] = out_position;
-#endif
 #ifdef TOP_K_ORDER
-    output[FUNC_CALL(get_output_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])] = TO_OUTPUT_TYPE(result[top_k].value);
+        output[FUNC_CALL(get_output_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])] = TO_OUTPUT_TYPE(result[top_k].value);
 #else
-    output[FUNC_CALL(get_output_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])] = TO_OUTPUT_TYPE(result[top_k].index);
+        output[FUNC_CALL(get_output_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])] = TO_OUTPUT_TYPE(result[top_k].index);
 #endif
 #ifdef SECOND_OUTPUT_EXIST
-#ifdef TOP_K_ORDER
-    second_output[FUNC_CALL(get_output_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])] = TO_INPUT1_TYPE(result[top_k].index);
-#else
-    second_output[FUNC_CALL(get_output_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])] = TO_INPUT1_TYPE(result[top_k].value);
-#endif
+    #ifdef TOP_K_ORDER
+        second_output[FUNC_CALL(get_output_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])] = TO_INPUT1_TYPE(result[top_k].index);
+    #else
+        second_output[FUNC_CALL(get_output_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])] = TO_INPUT1_TYPE(result[top_k].value);
+    #endif
 #endif
     }
+#endif
 }
 
 #undef COMPARE_SIGN
 #undef COMPARE_PARTIAL_SIGN
 #undef COMPARE_MERGE_SIGN
+#undef COMPARE_PARALLEL_SIGN_1
+#undef COMPARE_PARALLEL_SIGN_2
 #undef INPUT0_FILL_VAL
 #undef AXIS
 #undef VALUES_NUM