-// Copyright (c) 2018 Intel Corporation
+// Copyright (c) 2018-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
}
}
+size_t getSortSize(const arg_max_min_params& params) {
+ switch (params.argMaxMinAxis) {
+ case ArgMaxMinAxis::BATCH: return params.inputs[0].Batch().v;
+ case ArgMaxMinAxis::FEATURE: return params.inputs[0].Feature().v;
+ case ArgMaxMinAxis::Z: return params.inputs[0].Z().v;
+ case ArgMaxMinAxis::Y: return params.inputs[0].Y().v;
+ case ArgMaxMinAxis::X: return params.inputs[0].X().v;
+ default:
+ throw std::invalid_argument("Unsupported axis");
+ }
+}
+
ParamsKey ArgMaxMinKernelAxis::GetSupportedKey() const {
ParamsKey k;
k.EnableInputDataType(Datatype::F16);
if (!Validate(params, options)) {
return {};
}
-
const arg_max_min_params& orgParams = static_cast<const arg_max_min_params&>(params);
DispatchData runInfo;
runInfo.fp16UnitUsed = orgParams.inputs[0].GetDType() == Datatype::F16;
- runInfo.gws0 = Align(getOperationNumber(orgParams), 32);
- runInfo.gws1 = 1;
- runInfo.gws2 = 1;
+ size_t sort_size = orgParams.argMaxMinSortType == ArgMaxMinSortType::VALUE ? getSortSize(orgParams) : 1;
+
+ std::vector<size_t> local, global;
+ global = { Align(getOperationNumber(orgParams), 32), sort_size, 1 };
+ local = GetOptimalLocalWorkGroupSizes(global, params.engineInfo);
+
+ runInfo.gws0 = global[0];
+ runInfo.gws1 = global[1];
+ runInfo.gws2 = global[2];
- runInfo.lws0 = 32;
- runInfo.lws1 = 1;
- runInfo.lws2 = 1;
+ runInfo.lws0 = local[0];
+ runInfo.lws1 = local[1];
+ runInfo.lws2 = local[2];
KernelData kd = KernelData::Default<arg_max_min_params>(params);
#define COMPARE_SIGN <
#define COMPARE_PARTIAL_SIGN >=
#define COMPARE_MERGE_SIGN >
+ #define COMPARE_PARALLEL_SIGN_1 <=
+ #define COMPARE_PARALLEL_SIGN_2 <
#define INPUT0_FILL_VAL INPUT0_VAL_MIN
#else
#define COMPARE_SIGN >
#define COMPARE_PARTIAL_SIGN <=
#define COMPARE_MERGE_SIGN <
+ #define COMPARE_PARALLEL_SIGN_1 >=
+ #define COMPARE_PARALLEL_SIGN_2 >
#define INPUT0_FILL_VAL INPUT0_VAL_MAX
#endif
)
{
#include "include/arg_max_min_common.cl"
-#if (TOP_K == 1)
+#if SORT_BY_VALUE
+ const uint sort_idx = (uint)get_global_id(1);
+#elif TOP_K == 1
iav_type result[TOP_K];
#else
iav_type result[VALUES_NUM], temp_buf[VALUES_NUM];
const uint group_num = ((VALUES_NUM - 1) / group_size) + 1;
const uint last_group_size = (VALUES_NUM % group_size > 0) ? (VALUES_NUM % group_size) : group_size;
const uint last_group_offset = (group_num - 1) * group_size;
-#endif // (TOP_K == 1)
+#endif // SORT_BY_VALUE
- uint output_idx = (uint)get_global_id(0);
+ const uint output_idx = (uint)get_global_id(0);
if (output_idx >= OPERATION_NUM)
return;
#endif
#endif
-// Using simple sorting for (TOP_K == 1)
-#if (TOP_K == 1)
+// Using parallel sorting for sorting by values
+#if SORT_BY_VALUE
+ uint sort_position = 0;
+ indices[AXIS] = sort_idx;
+
+ iav_type result;
+ result.value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])];
+ result.index = sort_idx;
+
+ for (uint i = 0; i < sort_idx; i++) {
+ indices[AXIS] = i;
+ INPUT0_TYPE test_value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])];
+ if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
+ sort_position++;
+ if (sort_position >= TOP_K)
+ return;
+ }
+
+ for (uint i = sort_idx + 1; i < VALUES_NUM; i++) {
+ indices[AXIS] = i;
+ INPUT0_TYPE test_value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])];
+ if (result.value COMPARE_PARALLEL_SIGN_2 test_value)
+ sort_position++;
+ if (sort_position >= TOP_K)
+ return;
+ }
+// Using simple sorting for sorting by indices and when TOP_K == 1
+#elif TOP_K == 1
INPUT0_TYPE val = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])];
result[0].index = 0;
result[0].value = val;
val = INPUT0_FILL_VAL;
}
-// Using merge sorting when (TOP_K >= (VALUES_NUM / 2)) or (VALUES_NUM < MINIMUM_NUMBER_FOR_PARTIAL_SORTING)
+// Using merge sorting for sorting by indices and when (TOP_K >= (VALUES_NUM / 2)) or (VALUES_NUM < MINIMUM_NUMBER_FOR_PARTIAL_SORTING)
#elif ((TOP_K >= (VALUES_NUM / 2)) || (VALUES_NUM < MINIMUM_NUMBER_FOR_PARTIAL_SORTING))
-
for (uint i = 0; i < VALUES_NUM / 8; i++) {
uint index_offset = i * 8;
indices[AXIS] = result[index_offset].index = index_offset;
}
}
-// In other cases using mixed partial/merge sorting
-#else // (TOP_K == 1)
-
+// In other cases for sorting by indices using mixed partial/merge sorting
+#else // SORT_BY_VALUE
for (uint i = 0; i < VALUES_NUM / 8; i++) {
uint index_offset = i * 8;
indices[AXIS] = temp_buf[index_offset].index = index_offset;
result[i] = merge_buf;
}
+#endif // SORT_BY_VALUE
-#endif // (TOP_K == 1)
+#if SORT_BY_VALUE
+ indices[AXIS] = sort_position;
+#ifdef TOP_K_ORDER
+ output[FUNC_CALL(get_output_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])] = TO_OUTPUT_TYPE(result.value);
+#else
+ output[FUNC_CALL(get_output_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])] = TO_OUTPUT_TYPE(result.index);
+#endif
+#ifdef SECOND_OUTPUT_EXIST
+ #ifdef TOP_K_ORDER
+ second_output[FUNC_CALL(get_output_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])] = TO_INPUT1_TYPE(result.index);
+ #else
+ second_output[FUNC_CALL(get_output_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])] = TO_INPUT1_TYPE(result.value);
+ #endif
+#endif
+#else // SORT_BY_VALUE
for (uint top_k = 0; top_k < TOP_K; ++top_k) {
-#ifdef SORT_BY_VALUE
- indices[AXIS] = top_k;
-#endif
-#ifdef SORT_BY_INDEX
uint out_position = 0;
+
for (uint i = 0; i < TOP_K; ++i) {
if (i == top_k)
continue;
if (result[i].index < result[top_k].index)
out_position++;
}
+
indices[AXIS] = out_position;
-#endif
#ifdef TOP_K_ORDER
- output[FUNC_CALL(get_output_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])] = TO_OUTPUT_TYPE(result[top_k].value);
+ output[FUNC_CALL(get_output_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])] = TO_OUTPUT_TYPE(result[top_k].value);
#else
- output[FUNC_CALL(get_output_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])] = TO_OUTPUT_TYPE(result[top_k].index);
+ output[FUNC_CALL(get_output_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])] = TO_OUTPUT_TYPE(result[top_k].index);
#endif
#ifdef SECOND_OUTPUT_EXIST
-#ifdef TOP_K_ORDER
- second_output[FUNC_CALL(get_output_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])] = TO_INPUT1_TYPE(result[top_k].index);
-#else
- second_output[FUNC_CALL(get_output_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])] = TO_INPUT1_TYPE(result[top_k].value);
-#endif
+ #ifdef TOP_K_ORDER
+ second_output[FUNC_CALL(get_output_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])] = TO_INPUT1_TYPE(result[top_k].index);
+ #else
+ second_output[FUNC_CALL(get_output_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])] = TO_INPUT1_TYPE(result[top_k].value);
+ #endif
#endif
}
+#endif
}
#undef COMPARE_SIGN
#undef COMPARE_PARTIAL_SIGN
#undef COMPARE_MERGE_SIGN
+#undef COMPARE_PARALLEL_SIGN_1
+#undef COMPARE_PARALLEL_SIGN_2
#undef INPUT0_FILL_VAL
#undef AXIS
#undef VALUES_NUM