Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / kernel_selector / core / cl_kernels / lstm_gemm_gpu_bfyx_ref.cl
1 // Copyright (c) 2016-2017 Intel Corporation
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15
16 #include "include/include_all.cl"
17
18 #ifndef DIRECTION
19 #define DIRECTION 0
20 #endif
21
22 // input     = [    batch,  sequence,               1,      input_size ]
23 // weights   = [        1, direction, 4 * hidden_size,      input_size ]
24 // recurrent = [        1, direction, 4 * hidden_size,     hidden_size ]
25 // biases    = [        1,         1,       direction, 4 * hidden_size ] optional
26 // hidden    = [    batch, direction,               1,     hidden_size ] optional
27 // tempGEMM  = [    batch, direction,               1, 4 * hidden_size ] output
28 KERNEL(lstm_gemm)(
29     const __global INPUT0_TYPE* input,
30     __global OUTPUT_TYPE* output,
31     const __global WEIGHTS_TYPE* weights
32 #if HIDDEN_TERM
33     , const __global OUTPUT_TYPE* hidden,
34     const __global RECURRENT_TYPE* recurrent
35 #endif
36 #if BIAS_TERM
37     , const __global BIAS_TYPE* biases
38 #endif
39     )
40 {
41     const uint y = get_global_id(0);
42     const uint b = get_global_id(1);
43
44     ACCUMULATOR_TYPE dotProd = 0;
45     for(uint x = 0; x < INPUT0_SIZE_X; ++x ) {
46       const uint input_idx     = GET_DATA_INDEX(INPUT0, b, 0, INPUT_DIRECTION, x);
47       const uint weights_idx   = GET_DATA_INDEX(WEIGHTS, 0, DIRECTION, y, x);
48       dotProd += (ACCUMULATOR_TYPE)(input[input_idx] * weights[weights_idx]);
49     }
50
51 #if HIDDEN_TERM
52     for(uint x = 0; x < HIDDEN_SIZE_X; ++x ) {
53       const uint hidden_idx    = GET_DATA_INDEX(HIDDEN, b, 0, HIDDEN_DIRECTION, x);
54       const uint recurrent_idx = GET_DATA_INDEX(RECURRENT, 0, DIRECTION, y, x);
55       dotProd += (ACCUMULATOR_TYPE)(hidden[hidden_idx] * recurrent[recurrent_idx]);
56     }
57 #endif
58
59 #if BIAS_TERM
60     const uint bias_idx = GET_DATA_INDEX(BIAS, 0, 0, DIRECTION, y);
61     dotProd += (ACCUMULATOR_TYPE)biases[bias_idx];
62 #endif
63     const uint output_idx = GET_DATA_INDEX(OUTPUT, b, 0, 0, y);
64     output[output_idx] = (OUTPUT_TYPE)dotProd;
65 }