Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / kernel_selector / core / actual_kernels / lstm / lstm_elt_kernel_base.h
1 /*
2 // Copyright (c) 2016 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 #pragma once
18
19 #include "common_kernel_base.h"
20 #include "kernel_selector_params.h"
21
22 namespace kernel_selector
23 {
24     ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
25     // lstm_elt_params
26     ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
27     struct lstm_elt_params : public base_params
28     {
29         enum order_type : int32_t {
30             offset_iofz, // ONNX default
31             offset_ifoz, // caffe
32             offset_izof, // pyTorch
33             offset_fizo  // IE default
34         };
35
36         lstm_elt_params()
37         : base_params(KernelType::LSTM_ELT)
38         {}
39
40         DataTensor cell;
41         bool has_cell = false;
42         order_type gate_order = offset_iofz;
43         float clip = 0;
44         bool input_forget = false;
45         uint32_t direction = 0;
46         uint32_t cell_direction = 0;
47
48         size_t GetOffsetIndex(order_type type, size_t idx) const {
49             static const std::map<order_type, std::vector<size_t>> offset_map {
50                 {offset_iofz, { 0, 1, 2, 3}},
51                 {offset_ifoz, { 0, 2, 1, 3}},
52                 {offset_izof, { 0, 3, 1, 2}},
53                 {offset_fizo, { 1, 3, 0, 2}}
54             };
55             return offset_map.at(type)[idx];
56         }
57
58         size_t GetOffsetIndexI() const { return GetOffsetIndex(gate_order, 0); }
59         size_t GetOffsetIndexO() const { return GetOffsetIndex(gate_order, 1); }
60         size_t GetOffsetIndexF() const { return GetOffsetIndex(gate_order, 2); }
61         size_t GetOffsetIndexZ() const { return GetOffsetIndex(gate_order, 3); }
62
63         void SetOffsetOrder(int32_t t) {
64             gate_order = static_cast<order_type>(t);
65         }
66
67         void SetCell(const DataTensor& v) {
68             cell = v;
69             has_cell = true;
70         }
71
72         virtual ParamsKey GetParamsKey() const override
73         {
74             ParamsKey k = base_params::GetParamsKey();
75             if (has_cell)
76             {
77                 k.EnableLSTMEltCell();
78             }
79             return k;
80         }
81     };
82
83     ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
84     // lstm_elt_optional_params
85     ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
86     struct lstm_elt_optional_params : optional_params
87     {
88         lstm_elt_optional_params() : optional_params(KernelType::LSTM_ELT) {}
89     };
90
91     ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
92     // LSTMEltKernelBase
93     ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
94     class LSTMEltKernelBase : public common_kernel_base
95     {
96     public:
97         using common_kernel_base::common_kernel_base;
98         virtual ~LSTMEltKernelBase() {}
99
100         struct DispatchData : public CommonDispatchData
101         {};
102
103     protected:
104         virtual JitConstants GetJitConstants(const lstm_elt_params& params) const;
105         KernelsData GetCommonKernelsData(const Params& params, const optional_params& optParams) const;
106
107         bool Validate(const Params& p, const optional_params&) const override
108         {
109             if (p.GetType() != KernelType::LSTM_ELT)
110             {
111                 return false;
112             }
113
114             return true;
115         }
116     };
117 }