// See the License for the specific language governing permissions and
// limitations under the License.
-
#include "include/common.cl"
#include "include/data_types.cl"
#define GLOBAL_SIZE 128
#define LOCAL_SIZE GLOBAL_SIZE
-typedef struct /* Index and Value type that holds index and value used in this kernel */
-{
- uint index;
- UNIT_TYPE value;
-} iav_type;
-
#ifdef BATCH_AXIS
#define GAP_SIZE (INPUT0_FEATURE_NUM * INPUT0_SIZE_X * INPUT0_SIZE_Y)
#define VALUES_NUM INPUT0_BATCH_NUM
__attribute__((reqd_work_group_size(LOCAL_SIZE, 1, 1)))
KERNEL(arg_max_gpu_axis)(const __global UNIT_TYPE* input, __global float* output)
{
+#include "include/arg_max_min_common.cl"
uint results[TOP_K];
__local iav_type scratch[LOCAL_SIZE];
const uint first_dim_id = (uint)get_global_id(1);