Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / runtime / onert / backend / cpu / ops / LSTMLayer.h
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef __ONERT_BACKEND_CPU_OPS_LSTMLAYER_H__
18 #define __ONERT_BACKEND_CPU_OPS_LSTMLAYER_H__
19
20 #include <backend/IPortableTensor.h>
21 #include "OperationUtils.h"
22 #include <ir/InternalType.h>
23 #include <ir/operation/LSTM.h>
24 #include <exec/IFunction.h>
25
26 namespace nnfw
27 {
28 namespace cker
29 {
30 class FCTempArena;
31 }
32 } // namespace nnfw
33
34 namespace onert
35 {
36 namespace backend
37 {
38 namespace cpu
39 {
40 namespace ops
41 {
42
43 // TODO Support LSTM, BiDirectionalSequenceLSTM
44 class LSTMLayer : public ::onert::exec::IFunction
45 {
46 public:
47   LSTMLayer() = default;
48
49 public:
50   void LSTMFloat();
51
52   void configure(
53     const IPortableTensor *input, const IPortableTensor *input_to_input_weights,
54     const IPortableTensor *input_to_forget_weights, const IPortableTensor *input_to_cell_weights,
55     const IPortableTensor *input_to_output_weights,
56     const IPortableTensor *recurrent_to_input_weights,
57     const IPortableTensor *recurrent_to_forget_weights,
58     const IPortableTensor *recurrent_to_cell_weights,
59     const IPortableTensor *recurrent_to_output_weights,
60     const IPortableTensor *cell_to_input_weights, const IPortableTensor *cell_to_forget_weights,
61     const IPortableTensor *cell_to_output_weights, const IPortableTensor *input_layer_norm_weights,
62     const IPortableTensor *forget_layer_norm_weights,
63     const IPortableTensor *cell_layer_norm_weights,
64     const IPortableTensor *output_layer_norm_weights, const IPortableTensor *aux_input,
65     const IPortableTensor *aux_input_to_input_weights,
66     const IPortableTensor *aux_input_to_forget_weights,
67     const IPortableTensor *aux_input_to_cell_weights,
68     const IPortableTensor *aux_input_to_output_weights, const IPortableTensor *input_gate_bias,
69     const IPortableTensor *forget_gate_bias, const IPortableTensor *cell_gate_bias,
70     const IPortableTensor *output_gate_bias, const IPortableTensor *projection_weights,
71     const IPortableTensor *projection_bias, const IPortableTensor *output_state_in,
72     const IPortableTensor *cell_state_in, const ir::operation::LSTM::Param &params,
73     bool forward_sequence, bool time_major, int32_t output_offset, IPortableTensor *scratch_buffer,
74     IPortableTensor *output_state, IPortableTensor *cell_state, IPortableTensor *output,
75     bool has_output_state_data, bool has_cell_state_data);
76
77   void run() override;
78
79 private:
80   const IPortableTensor *_input{nullptr};
81   const IPortableTensor *_input_to_input_weights{nullptr};
82   const IPortableTensor *_input_to_forget_weights{nullptr};
83   const IPortableTensor *_input_to_cell_weights{nullptr};
84   const IPortableTensor *_input_to_output_weights{nullptr};
85   const IPortableTensor *_recurrent_to_input_weights{nullptr};
86   const IPortableTensor *_recurrent_to_forget_weights{nullptr};
87   const IPortableTensor *_recurrent_to_cell_weights{nullptr};
88   const IPortableTensor *_recurrent_to_output_weights{nullptr};
89   const IPortableTensor *_cell_to_input_weights{nullptr};
90   const IPortableTensor *_cell_to_forget_weights{nullptr};
91   const IPortableTensor *_cell_to_output_weights{nullptr};
92   const IPortableTensor *_input_layer_norm_coefficients{nullptr};
93   const IPortableTensor *_forget_layer_norm_coefficients{nullptr};
94   const IPortableTensor *_cell_layer_norm_coefficients{nullptr};
95   const IPortableTensor *_output_layer_norm_coefficients{nullptr};
96   const IPortableTensor *_aux_input{nullptr};
97   const IPortableTensor *_aux_input_to_input_weights{nullptr};
98   const IPortableTensor *_aux_input_to_forget_weights{nullptr};
99   const IPortableTensor *_aux_input_to_cell_weights{nullptr};
100   const IPortableTensor *_aux_input_to_output_weights{nullptr};
101   const IPortableTensor *_input_gate_bias{nullptr};
102   const IPortableTensor *_forget_gate_bias{nullptr};
103   const IPortableTensor *_cell_gate_bias{nullptr};
104   const IPortableTensor *_output_gate_bias{nullptr};
105   const IPortableTensor *_projection_weights{nullptr};
106   const IPortableTensor *_projection_bias{nullptr};
107   const IPortableTensor *_output_state_in{nullptr};
108   const IPortableTensor *_cell_state_in{nullptr};
109   IPortableTensor *_scratch_buffer{nullptr};
110   IPortableTensor *_output_state{nullptr};
111   IPortableTensor *_cell_state{nullptr};
112   IPortableTensor *_output{nullptr};
113   std::vector<uint8_t> _scratch_vec{};
114   std::vector<uint8_t> _output_state_vec{};
115   std::vector<uint8_t> _cell_state_vec{};
116   ir::operation::LSTM::Param _params{};
117   bool _forward_sequence{true};
118   bool _time_major{true};
119   int32_t _output_offset{0};
120   bool _has_output_state_data{false};
121   bool _has_cell_state_data{false};
122 };
123
124 } // namespace ops
125 } // namespace cpu
126 } // namespace backend
127 } // namespace onert
128
129 #endif // __ONERT_BACKEND_CPU_OPS_LSTMLAYER_H__