Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / kernel_selector / core / actual_kernels / lstm / lstm_elt_kernel_base.cpp
1 /*
2 // Copyright (c) 2016 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 #include "lstm_elt_kernel_base.h"
18 #include "kernel_selector_utils.h"
19 #include "common_tools.h"
20
21 namespace kernel_selector
22 {
23     JitConstants LSTMEltKernelBase::GetJitConstants(const lstm_elt_params& params) const
24     {
25         JitConstants jit = MakeBaseParamsJitConstants(params);
26
27         if (params.has_cell) {
28             const auto& cell = params.cell;
29             jit.AddConstants({
30                 MakeJitConstant("CELL_TERM", true),
31                 MakeJitConstant("CELL", cell),
32                 MakeJitConstant("CELL_DIRECTION", params.cell_direction)
33             });
34         }
35         if (params.clip > 0) {
36             std::string psclip = toCodeString(params.clip);
37             std::string nsclip = toCodeString(-params.clip);
38             jit.AddConstants({ MakeJitConstant("CLIP(x)", "((x > " + psclip + ") ? " +
39                 psclip + ": (x < " + nsclip + ") ? " + nsclip + " : (x))") });
40         }
41         else {
42             jit.AddConstants({ MakeJitConstant("CLIP(x)", "(x)") });
43         }
44         if (params.input_forget) {
45             jit.AddConstants({ MakeJitConstant("INPUT_FORGET", true) });
46         }
47         jit.AddConstants({ MakeJitConstant("DIRECTION", params.direction) });
48
49         const auto& GEMMInput = params.inputs[0];
50         size_t size = GEMMInput.X().v / 4;
51         jit.AddConstants({
52             MakeJitConstant("GEMM_OFFSET_I", params.GetOffsetIndexI() * size),
53             MakeJitConstant("GEMM_OFFSET_O", params.GetOffsetIndexO() * size),
54             MakeJitConstant("GEMM_OFFSET_F", params.GetOffsetIndexF() * size),
55             MakeJitConstant("GEMM_OFFSET_Z", params.GetOffsetIndexZ() * size),
56         });
57         return jit;
58     }
59
60     KernelsData LSTMEltKernelBase::GetCommonKernelsData(const Params& params, const optional_params& options) const
61     {
62         if (!Validate(params, options))
63         {
64             return{};
65         }
66
67         const lstm_elt_params& orgParams = static_cast<const lstm_elt_params&>(params);
68
69         KernelData kd = KernelData::Default<lstm_elt_params>(params, orgParams.inputs.size());
70
71         float effiency = FORCE_PRIORITY_1;
72         const auto& input = orgParams.inputs[0];
73
74         auto newParams = orgParams;
75         newParams.inputs.resize(1);
76         newParams.inputs[0] = input;
77         auto out = newParams.output;
78
79         auto& kernel = kd.kernels[0];
80         auto cldnnJit = GetJitConstants(newParams);
81         auto entryPoint = GetEntryPoint(kernelName, newParams.layerID, options);
82         auto jit = CreateJit(kernelName, cldnnJit, entryPoint);
83
84         kernel.workGroups.global = { out.X().v, out.Batch().v, 1 };
85         kernel.kernelString = GetKernelString(kernelName, jit, entryPoint, params.engineInfo);
86         kernel.arguments.push_back({ ArgumentDescriptor::Types::INPUT, 0 });
87         kernel.arguments.push_back({ ArgumentDescriptor::Types::OUTPUT, 0 });
88         if (orgParams.has_cell) {
89             kernel.arguments.push_back({ ArgumentDescriptor::Types::CELL, 0 });
90         }
91
92         kd.estimatedTime = effiency;
93
94         return{ kd };
95     }
96 }