2 // Copyright (c) 2016 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 #include "lstm_elt_kernel_base.h"
18 #include "kernel_selector_utils.h"
19 #include "common_tools.h"
21 namespace kernel_selector
23 JitConstants LSTMEltKernelBase::GetJitConstants(const lstm_elt_params& params) const
25 JitConstants jit = MakeBaseParamsJitConstants(params);
27 if (params.has_cell) {
28 const auto& cell = params.cell;
30 MakeJitConstant("CELL_TERM", true),
31 MakeJitConstant("CELL", cell),
32 MakeJitConstant("CELL_DIRECTION", params.cell_direction)
35 if (params.clip > 0) {
36 std::string psclip = toCodeString(params.clip);
37 std::string nsclip = toCodeString(-params.clip);
38 jit.AddConstants({ MakeJitConstant("CLIP(x)", "((x > " + psclip + ") ? " +
39 psclip + ": (x < " + nsclip + ") ? " + nsclip + " : (x))") });
42 jit.AddConstants({ MakeJitConstant("CLIP(x)", "(x)") });
44 if (params.input_forget) {
45 jit.AddConstants({ MakeJitConstant("INPUT_FORGET", true) });
47 jit.AddConstants({ MakeJitConstant("DIRECTION", params.direction) });
49 const auto& GEMMInput = params.inputs[0];
50 size_t size = GEMMInput.X().v / 4;
52 MakeJitConstant("GEMM_OFFSET_I", params.GetOffsetIndexI() * size),
53 MakeJitConstant("GEMM_OFFSET_O", params.GetOffsetIndexO() * size),
54 MakeJitConstant("GEMM_OFFSET_F", params.GetOffsetIndexF() * size),
55 MakeJitConstant("GEMM_OFFSET_Z", params.GetOffsetIndexZ() * size),
60 KernelsData LSTMEltKernelBase::GetCommonKernelsData(const Params& params, const optional_params& options) const
62 if (!Validate(params, options))
67 const lstm_elt_params& orgParams = static_cast<const lstm_elt_params&>(params);
69 KernelData kd = KernelData::Default<lstm_elt_params>(params, orgParams.inputs.size());
71 float effiency = FORCE_PRIORITY_1;
72 const auto& input = orgParams.inputs[0];
74 auto newParams = orgParams;
75 newParams.inputs.resize(1);
76 newParams.inputs[0] = input;
77 auto out = newParams.output;
79 auto& kernel = kd.kernels[0];
80 auto cldnnJit = GetJitConstants(newParams);
81 auto entryPoint = GetEntryPoint(kernelName, newParams.layerID, options);
82 auto jit = CreateJit(kernelName, cldnnJit, entryPoint);
84 kernel.workGroups.global = { out.X().v, out.Batch().v, 1 };
85 kernel.kernelString = GetKernelString(kernelName, jit, entryPoint, params.engineInfo);
86 kernel.arguments.push_back({ ArgumentDescriptor::Types::INPUT, 0 });
87 kernel.arguments.push_back({ ArgumentDescriptor::Types::OUTPUT, 0 });
88 if (orgParams.has_cell) {
89 kernel.arguments.push_back({ ArgumentDescriptor::Types::CELL, 0 });
92 kd.estimatedTime = effiency;