Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / kernel_selector / core / actual_kernels / lstm / lstm_gemm_kernel_base.cpp
1 /*
2 // Copyright (c) 2018 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 #include "lstm_gemm_kernel_base.h"
18 #include "kernel_selector_utils.h"
19 #include "common_tools.h"
20
21 namespace kernel_selector
22 {
23     JitConstants LSTMGemmKernelBase::GetJitConstants(const lstm_gemm_params& params) const
24     {
25         JitConstants jit = MakeBaseParamsJitConstants(params);
26         const auto& weights = params.weights;
27         const auto& recurrent = params.recurrent;
28         const auto& hidden = params.hidden;
29         const auto& bias = params.bias;
30         if (params.hasBias) {
31             jit.AddConstants({ MakeJitConstant("BIAS", bias), MakeJitConstant("BIAS_TERM", true) });
32         }
33         if (params.hasHidden) {
34             jit.AddConstants({ MakeJitConstant("HIDDEN", hidden),
35                 MakeJitConstant("HIDDEN_TERM", true),
36                 MakeJitConstant("RECURRENT", recurrent),
37                 MakeJitConstant("HIDDEN_DIRECTION", params.hidden_direction)
38             });
39         }
40         jit.AddConstants({ MakeJitConstant("WEIGHTS", weights)});
41         jit.AddConstants({ MakeJitConstant("DIRECTION", params.direction)});
42         jit.AddConstants({ MakeJitConstant("INPUT_DIRECTION", params.input_direction)});
43
44         return jit;
45     }
46
47     KernelsData LSTMGemmKernelBase::GetCommonKernelsData(const Params& params, const optional_params& options) const
48     {
49         if (!Validate(params,  options))
50         {
51             return{};
52         }
53
54         const lstm_gemm_params& orgParams = static_cast<const lstm_gemm_params&>(params);
55
56         KernelData kd = KernelData::Default<lstm_gemm_params>(params, orgParams.inputs.size());
57
58         float effiency = FORCE_PRIORITY_9;
59         const auto& input = orgParams.inputs[0];
60
61         auto newParams = orgParams;
62         newParams.inputs.resize(1);
63         newParams.inputs[0] = input;
64         auto out = newParams.output;
65         //TODO: reorder weights if needed
66         auto& kernel = kd.kernels[0];
67         auto cldnnJit = GetJitConstants(newParams);
68         auto entryPoint = GetEntryPoint(kernelName, newParams.layerID, options);
69         auto jit = CreateJit(kernelName, cldnnJit, entryPoint);
70
71         kernel.workGroups.global = { out.X().v, out.Batch().v, 1 };
72         kernel.kernelString = GetKernelString(kernelName, jit, entryPoint, params.engineInfo);
73         kernel.arguments.push_back({ ArgumentDescriptor::Types::INPUT, 0 });
74         kernel.arguments.push_back({ ArgumentDescriptor::Types::OUTPUT, 0 });
75         kernel.arguments.push_back({ ArgumentDescriptor::Types::WEIGHTS, 0 });
76         if (orgParams.hasHidden) {
77             kernel.arguments.push_back({ ArgumentDescriptor::Types::HIDDEN, 0 });
78             kernel.arguments.push_back({ ArgumentDescriptor::Types::RECURRENT, 0 });
79         }
80         if (orgParams.hasBias) {
81             kernel.arguments.push_back({ ArgumentDescriptor::Types::BIAS, 0 });
82         }
83
84         kd.estimatedTime = effiency;
85
86         return{ kd };
87     }
88 }