1 // Copyright (c) 2016-2017 Intel Corporation
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
7 // http://www.apache.org/licenses/LICENSE-2.0
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
16 #include "include/include_all.cl"
26 // Sums value of result across all subgroups.
27 #define SUM_ACROSS_SUB_GROUP(val) \
30 val += intel_sub_group_shuffle(val, x+1); \
31 val += intel_sub_group_shuffle(val, x+2); \
32 val += intel_sub_group_shuffle(val, x+4); \
33 val += (SIMD > 8) ? intel_sub_group_shuffle(val, x+8) : 0; \
34 val += (SIMD > 16) ? intel_sub_group_shuffle(val, x+16) : 0; \
37 // input = [ batch, sequence, 1, input_size ]
38 // weights = [ 1, direction, 4 * hidden_size, input_size ]
39 // recurrent = [ 1, direction, 4 * hidden_size, hidden_size ]
40 // biases = [ 1, 1, direction, 4 * hidden_size ] optional
41 // hidden = [ batch, direction, 1, hidden_size ] optional
42 // tempGEMM = [ batch, direction, 1, 4 * hidden_size ] output
44 __attribute__((reqd_work_group_size(SIMD, 1, 1)))
46 const __global INPUT0_TYPE* input,
47 __global OUTPUT_TYPE* output,
48 const __global WEIGHTS_TYPE* weights
50 , const __global OUTPUT_TYPE* hidden,
51 const __global RECURRENT_TYPE* recurrent
54 , const __global BIAS_TYPE* biases
58 const uint x = get_local_id(0);
59 const uint y = get_global_id(1);
60 const int local_sz = get_local_size(0);
61 const int weight_num_rows = get_global_size(1);
71 K = INPUT0_SIZE_X; // Width of weight matrix
72 start_offset = GET_DATA_INDEX(WEIGHTS, 0, DIRECTION, y, 0); // set as the starting offset of the weight matrix
73 end_offset = start_offset + K;
74 matrix_offset = start_offset + (x * 4); // Weight offset for the work item to work on
75 vector_offset = GET_DATA_INDEX(INPUT0, 0, 0, INPUT_DIRECTION, (x*4)); // Input offset for the work item to work on
78 for(; matrix_offset < end_offset; matrix_offset += (local_sz * 4), vector_offset += (local_sz * 4))
80 float4 mask = (float4) (1 , (matrix_offset + 1) < end_offset , (matrix_offset + 2) < end_offset , (matrix_offset + 3) < end_offset);
81 float4 m = (float4) (weights[matrix_offset], weights[matrix_offset + 1], weights[matrix_offset + 2], weights[matrix_offset + 3]);
84 const float4 v = (float4) (input[vector_offset], input[vector_offset + 1], input[vector_offset + 2], input[vector_offset + 3]);
89 result = sum.x + sum.y + sum.z + sum.w;
92 K = HIDDEN_SIZE_X; // width of recurrent matrix
93 start_offset = GET_DATA_INDEX(RECURRENT, 0, DIRECTION, y, 0); // set as the starting offset of the recurrent matrix
94 end_offset = start_offset + K;
95 matrix_offset = start_offset + (x * 4); // recurrent offset for the work item to work on
96 vector_offset = GET_DATA_INDEX(HIDDEN, 0, 0, HIDDEN_DIRECTION, (x*4)); // hidden vector offset for the work item to work on
98 for(; matrix_offset < end_offset; matrix_offset += (local_sz * 4), vector_offset += (local_sz * 4))
100 float4 mask = (float4) (1 , (matrix_offset + 1) < end_offset , (matrix_offset + 2) < end_offset , (matrix_offset + 3) < end_offset);
101 float4 m = (float4) (recurrent[matrix_offset], recurrent[matrix_offset + 1], recurrent[matrix_offset + 2], recurrent[matrix_offset + 3]);
104 const float4 v = (float4) (hidden[vector_offset], hidden[vector_offset + 1], hidden[vector_offset + 2], hidden[vector_offset + 3]);
106 sum = mad(m, v, sum);
109 result += sum.x + sum.y + sum.z + sum.w;
112 // Add together partial sums contained in each work item's "result" variable
113 SUM_ACROSS_SUB_GROUP(result);
117 output[y] = (OUTPUT_TYPE)result;
120 const uint bias_idx = GET_DATA_INDEX(BIAS, 0, 0, DIRECTION, y);
121 float bias = (ACCUMULATOR_TYPE)biases[bias_idx];
122 output[y] += (OUTPUT_TYPE)bias;
127 #undef SUM_ACROSS_SUB_GROUP