2 // Copyright (c) 2018 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 #include "lstm_gemm_kernel_base.h"
18 #include "kernel_selector_utils.h"
19 #include "common_tools.h"
21 namespace kernel_selector
23 JitConstants LSTMGemmKernelBase::GetJitConstants(const lstm_gemm_params& params) const
25 JitConstants jit = MakeBaseParamsJitConstants(params);
26 const auto& weights = params.weights;
27 const auto& recurrent = params.recurrent;
28 const auto& hidden = params.hidden;
29 const auto& bias = params.bias;
31 jit.AddConstants({ MakeJitConstant("BIAS", bias), MakeJitConstant("BIAS_TERM", true) });
33 if (params.hasHidden) {
34 jit.AddConstants({ MakeJitConstant("HIDDEN", hidden),
35 MakeJitConstant("HIDDEN_TERM", true),
36 MakeJitConstant("RECURRENT", recurrent),
37 MakeJitConstant("HIDDEN_DIRECTION", params.hidden_direction)
40 jit.AddConstants({ MakeJitConstant("WEIGHTS", weights)});
41 jit.AddConstants({ MakeJitConstant("DIRECTION", params.direction)});
42 jit.AddConstants({ MakeJitConstant("INPUT_DIRECTION", params.input_direction)});
47 KernelsData LSTMGemmKernelBase::GetCommonKernelsData(const Params& params, const optional_params& options) const
49 if (!Validate(params, options))
54 const lstm_gemm_params& orgParams = static_cast<const lstm_gemm_params&>(params);
56 KernelData kd = KernelData::Default<lstm_gemm_params>(params, orgParams.inputs.size());
58 float effiency = FORCE_PRIORITY_9;
59 const auto& input = orgParams.inputs[0];
61 auto newParams = orgParams;
62 newParams.inputs.resize(1);
63 newParams.inputs[0] = input;
64 auto out = newParams.output;
65 //TODO: reorder weights if needed
66 auto& kernel = kd.kernels[0];
67 auto cldnnJit = GetJitConstants(newParams);
68 auto entryPoint = GetEntryPoint(kernelName, newParams.layerID, options);
69 auto jit = CreateJit(kernelName, cldnnJit, entryPoint);
71 kernel.workGroups.global = { out.X().v, out.Batch().v, 1 };
72 kernel.kernelString = GetKernelString(kernelName, jit, entryPoint, params.engineInfo);
73 kernel.arguments.push_back({ ArgumentDescriptor::Types::INPUT, 0 });
74 kernel.arguments.push_back({ ArgumentDescriptor::Types::OUTPUT, 0 });
75 kernel.arguments.push_back({ ArgumentDescriptor::Types::WEIGHTS, 0 });
76 if (orgParams.hasHidden) {
77 kernel.arguments.push_back({ ArgumentDescriptor::Types::HIDDEN, 0 });
78 kernel.arguments.push_back({ ArgumentDescriptor::Types::RECURRENT, 0 });
80 if (orgParams.hasBias) {
81 kernel.arguments.push_back({ ArgumentDescriptor::Types::BIAS, 0 });
84 kd.estimatedTime = effiency;