2 // Copyright (c) 2016 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
19 #include "common_kernel_base.h"
20 #include "kernel_selector_params.h"
22 namespace kernel_selector
24 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
26 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
27 struct lstm_elt_params : public base_params
29 enum order_type : int32_t {
30 offset_iofz, // ONNX default
32 offset_izof, // pyTorch
33 offset_fizo // IE default
37 : base_params(KernelType::LSTM_ELT)
41 bool has_cell = false;
42 order_type gate_order = offset_iofz;
44 bool input_forget = false;
45 uint32_t direction = 0;
46 uint32_t cell_direction = 0;
48 size_t GetOffsetIndex(order_type type, size_t idx) const {
49 static const std::map<order_type, std::vector<size_t>> offset_map {
50 {offset_iofz, { 0, 1, 2, 3}},
51 {offset_ifoz, { 0, 2, 1, 3}},
52 {offset_izof, { 0, 3, 1, 2}},
53 {offset_fizo, { 1, 3, 0, 2}}
55 return offset_map.at(type)[idx];
58 size_t GetOffsetIndexI() const { return GetOffsetIndex(gate_order, 0); }
59 size_t GetOffsetIndexO() const { return GetOffsetIndex(gate_order, 1); }
60 size_t GetOffsetIndexF() const { return GetOffsetIndex(gate_order, 2); }
61 size_t GetOffsetIndexZ() const { return GetOffsetIndex(gate_order, 3); }
63 void SetOffsetOrder(int32_t t) {
64 gate_order = static_cast<order_type>(t);
67 void SetCell(const DataTensor& v) {
72 virtual ParamsKey GetParamsKey() const override
74 ParamsKey k = base_params::GetParamsKey();
77 k.EnableLSTMEltCell();
83 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
84 // lstm_elt_optional_params
85 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
86 struct lstm_elt_optional_params : optional_params
88 lstm_elt_optional_params() : optional_params(KernelType::LSTM_ELT) {}
91 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
93 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
94 class LSTMEltKernelBase : public common_kernel_base
97 using common_kernel_base::common_kernel_base;
98 virtual ~LSTMEltKernelBase() {}
100 struct DispatchData : public CommonDispatchData
104 virtual JitConstants GetJitConstants(const lstm_elt_params& params) const;
105 KernelsData GetCommonKernelsData(const Params& params, const optional_params& optParams) const;
107 bool Validate(const Params& p, const optional_params&) const override
109 if (p.GetType() != KernelType::LSTM_ELT)