1 // Copyright (c) 2016-2017 Intel Corporation
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
7 // http://www.apache.org/licenses/LICENSE-2.0
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
16 #include "include/include_all.cl"
22 // input = [ batch, sequence, 1, input_size ]
23 // weights = [ 1, direction, 4 * hidden_size, input_size ]
24 // recurrent = [ 1, direction, 4 * hidden_size, hidden_size ]
25 // biases = [ 1, 1, direction, 4 * hidden_size ] optional
26 // hidden = [ batch, direction, 1, hidden_size ] optional
27 // tempGEMM = [ batch, direction, 1, 4 * hidden_size ] output
29 const __global INPUT0_TYPE* input,
30 __global OUTPUT_TYPE* output,
31 const __global WEIGHTS_TYPE* weights
33 , const __global OUTPUT_TYPE* hidden,
34 const __global RECURRENT_TYPE* recurrent
37 , const __global BIAS_TYPE* biases
41 const uint y = get_global_id(0);
42 const uint b = get_global_id(1);
44 ACCUMULATOR_TYPE dotProd = 0;
45 for(uint x = 0; x < INPUT0_SIZE_X; ++x ) {
46 const uint input_idx = GET_DATA_INDEX(INPUT0, b, 0, INPUT_DIRECTION, x);
47 const uint weights_idx = GET_DATA_INDEX(WEIGHTS, 0, DIRECTION, y, x);
48 dotProd += (ACCUMULATOR_TYPE)(input[input_idx] * weights[weights_idx]);
52 for(uint x = 0; x < HIDDEN_SIZE_X; ++x ) {
53 const uint hidden_idx = GET_DATA_INDEX(HIDDEN, b, 0, HIDDEN_DIRECTION, x);
54 const uint recurrent_idx = GET_DATA_INDEX(RECURRENT, 0, DIRECTION, y, x);
55 dotProd += (ACCUMULATOR_TYPE)(hidden[hidden_idx] * recurrent[recurrent_idx]);
60 const uint bias_idx = GET_DATA_INDEX(BIAS, 0, 0, DIRECTION, y);
61 dotProd += (ACCUMULATOR_TYPE)biases[bias_idx];
63 const uint output_idx = GET_DATA_INDEX(OUTPUT, b, 0, 0, y);
64 output[output_idx] = (OUTPUT_TYPE)dotProd;