1 // Copyright (c) 2018 Intel Corporation
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
7 // http://www.apache.org/licenses/LICENSE-2.0
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
15 #include "include/common.cl"
16 #include "include/data_types.cl"
18 #define GLOBAL_SIZE 128
19 #define LOCAL_SIZE GLOBAL_SIZE
22 #define GAP_SIZE (INPUT0_FEATURE_NUM * INPUT0_SIZE_X * INPUT0_SIZE_Y)
23 #define VALUES_NUM INPUT0_BATCH_NUM
24 #define FIRST_DIM_SIZE INPUT0_SIZE_X
25 #define SECOND_DIM_SIZE INPUT0_SIZE_Y
26 #define FIRST_DIM_MUL 1
27 #define SECOND_DIM_MUL INPUT0_SIZE_X
28 #define THIRD_DIM_MUL (INPUT0_SIZE_X * INPUT0_SIZE_Y)
31 #define GAP_SIZE (INPUT0_SIZE_X * INPUT0_SIZE_Y)
32 #define VALUES_NUM INPUT0_FEATURE_NUM
33 #define FIRST_DIM_SIZE INPUT0_SIZE_X
34 #define SECOND_DIM_SIZE INPUT0_SIZE_Y
35 #define FIRST_DIM_MUL 1
36 #define SECOND_DIM_MUL INPUT0_SIZE_X
37 #define THIRD_DIM_MUL (INPUT0_SIZE_X * INPUT0_SIZE_Y * INPUT0_FEATURE_NUM)
40 #define GAP_SIZE INPUT0_SIZE_X
41 #define VALUES_NUM INPUT0_SIZE_Y
42 #define FIRST_DIM_SIZE INPUT0_SIZE_X
43 #define SECOND_DIM_SIZE INPUT0_FEATURE_NUM
44 #define FIRST_DIM_MUL 1
45 #define SECOND_DIM_MUL (INPUT0_SIZE_Y * INPUT0_SIZE_X)
46 #define THIRD_DIM_MUL (INPUT0_SIZE_X * INPUT0_SIZE_Y * INPUT0_FEATURE_NUM)
50 #define VALUES_NUM INPUT0_SIZE_X
51 #define FIRST_DIM_SIZE INPUT0_SIZE_Y
52 #define SECOND_DIM_SIZE INPUT0_FEATURE_NUM
53 #define FIRST_DIM_MUL INPUT0_SIZE_X
54 #define SECOND_DIM_MUL (INPUT0_SIZE_Y * INPUT0_SIZE_X)
55 #define THIRD_DIM_MUL (INPUT0_SIZE_X * INPUT0_SIZE_Y * INPUT0_FEATURE_NUM)
59 #define COMPARE_SIGN <
60 #define UNIT_FILL_VAL UNIT_VAL_MIN
62 #define COMPARE_SIGN >
63 #define UNIT_FILL_VAL UNIT_VAL_MAX
66 __attribute__((reqd_work_group_size(LOCAL_SIZE, 1, 1)))
67 KERNEL(arg_max_gpu_axis)(const __global UNIT_TYPE* input, __global float* output)
69 #include "include/arg_max_min_common.cl"
71 __local iav_type scratch[LOCAL_SIZE];
72 const uint first_dim_id = (uint)get_global_id(1);
73 const uint second_third_dim_id = (uint)get_global_id(2);
74 const uint second_dim_id = second_third_dim_id % SECOND_DIM_SIZE;
75 const uint third_dim_id = second_third_dim_id / SECOND_DIM_SIZE;
76 const uint output_index = (first_dim_id + second_dim_id * FIRST_DIM_SIZE + third_dim_id * FIRST_DIM_SIZE * SECOND_DIM_SIZE) * TOP_K;
77 const uint offset = first_dim_id * FIRST_DIM_MUL + second_dim_id * SECOND_DIM_MUL + third_dim_id * THIRD_DIM_MUL;
78 uint local_index = get_local_id(0);
79 uint global_index = offset + local_index * GAP_SIZE;
83 uint temp_index = global_index;
84 uint start_index = (global_index - offset) / GAP_SIZE;
85 __attribute__((opencl_unroll_hint))
86 for (uint i = 0; i < TOP_K; i++)
88 accumulator.index = start_index;
89 accumulator.value = input[global_index];
90 for (int j = 0; j < i; j++)
92 if (accumulator.index == results[j])
93 accumulator.value = UNIT_FILL_VAL;
95 global_index += GLOBAL_SIZE * GAP_SIZE;
96 uint element_index = start_index + GLOBAL_SIZE;
97 while (global_index < offset + VALUES_NUM * GAP_SIZE)
100 element.value = input[global_index];
101 element.index = element_index;
102 for (int j = 0; j < i; j++){
103 if (element.index == results[j])
104 element.value = UNIT_FILL_VAL;
106 if(accumulator.value COMPARE_SIGN element.value)
108 accumulator.value = element.value;
109 accumulator.index = element.index;
111 element_index += GLOBAL_SIZE;
112 global_index += GLOBAL_SIZE * GAP_SIZE;
114 if (local_index < VALUES_NUM)
115 scratch[local_index] = accumulator;
117 scratch[local_index].value = UNIT_FILL_VAL;
119 barrier(CLK_LOCAL_MEM_FENCE);
121 __attribute__((opencl_unroll_hint))
122 for(uint scratch_offset = LOCAL_SIZE / 2; scratch_offset > 0; scratch_offset /= 2)
124 if (local_index < scratch_offset)
126 iav_type other = scratch[local_index + scratch_offset];
127 iav_type mine = scratch[local_index];
129 if(mine.value COMPARE_SIGN other.value)
131 scratch[local_index] = other;
134 barrier(CLK_LOCAL_MEM_FENCE);
137 if (local_index == 0)
139 output[output_index + i] = scratch[0].index;
141 global_index = temp_index;
142 results[i] = scratch[0].index;
150 #undef FIRST_DIM_SIZE
151 #undef SECOND_DIM_SIZE
153 #undef SECOND_DIM_MUL