2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #ifndef __ONERT_BACKEND_CPU_OPS_LSTMLAYER_H__
18 #define __ONERT_BACKEND_CPU_OPS_LSTMLAYER_H__
20 #include <backend/IPortableTensor.h>
21 #include "OperationUtils.h"
22 #include <ir/InternalType.h>
23 #include <ir/operation/LSTM.h>
24 #include <exec/IFunction.h>
43 // TODO Support LSTM, BiDirectionalSequenceLSTM
44 class LSTMLayer : public ::onert::exec::IFunction
47 LSTMLayer() = default;
52 void configure(const IPortableTensor *input, const IPortableTensor *input_to_input_weights,
53 const IPortableTensor *input_to_forget_weights,
54 const IPortableTensor *input_to_cell_weights,
55 const IPortableTensor *input_to_output_weights,
56 const IPortableTensor *recurrent_to_input_weights,
57 const IPortableTensor *recurrent_to_forget_weights,
58 const IPortableTensor *recurrent_to_cell_weights,
59 const IPortableTensor *recurrent_to_output_weights,
60 const IPortableTensor *cell_to_input_weights,
61 const IPortableTensor *cell_to_forget_weights,
62 const IPortableTensor *cell_to_output_weights,
63 const IPortableTensor *input_layer_norm_weights,
64 const IPortableTensor *forget_layer_norm_weights,
65 const IPortableTensor *cell_layer_norm_weights,
66 const IPortableTensor *output_layer_norm_weights, const IPortableTensor *aux_input,
67 const IPortableTensor *aux_input_to_input_weights,
68 const IPortableTensor *aux_input_to_forget_weights,
69 const IPortableTensor *aux_input_to_cell_weights,
70 const IPortableTensor *aux_input_to_output_weights,
71 const IPortableTensor *input_gate_bias, const IPortableTensor *forget_gate_bias,
72 const IPortableTensor *cell_gate_bias, const IPortableTensor *output_gate_bias,
73 const IPortableTensor *projection_weights, const IPortableTensor *projection_bias,
74 const IPortableTensor *output_state_in, const IPortableTensor *cell_state_in,
75 const ir::operation::LSTM::Param ¶ms, bool forward_sequence, bool time_major,
76 int32_t output_offset, IPortableTensor *scratch_buffer,
77 IPortableTensor *output_state, IPortableTensor *cell_state,
78 IPortableTensor *output, bool has_output_state_data, bool has_cell_state_data);
83 const IPortableTensor *_input{nullptr};
84 const IPortableTensor *_input_to_input_weights{nullptr};
85 const IPortableTensor *_input_to_forget_weights{nullptr};
86 const IPortableTensor *_input_to_cell_weights{nullptr};
87 const IPortableTensor *_input_to_output_weights{nullptr};
88 const IPortableTensor *_recurrent_to_input_weights{nullptr};
89 const IPortableTensor *_recurrent_to_forget_weights{nullptr};
90 const IPortableTensor *_recurrent_to_cell_weights{nullptr};
91 const IPortableTensor *_recurrent_to_output_weights{nullptr};
92 const IPortableTensor *_cell_to_input_weights{nullptr};
93 const IPortableTensor *_cell_to_forget_weights{nullptr};
94 const IPortableTensor *_cell_to_output_weights{nullptr};
95 const IPortableTensor *_input_layer_norm_coefficients{nullptr};
96 const IPortableTensor *_forget_layer_norm_coefficients{nullptr};
97 const IPortableTensor *_cell_layer_norm_coefficients{nullptr};
98 const IPortableTensor *_output_layer_norm_coefficients{nullptr};
99 const IPortableTensor *_aux_input{nullptr};
100 const IPortableTensor *_aux_input_to_input_weights{nullptr};
101 const IPortableTensor *_aux_input_to_forget_weights{nullptr};
102 const IPortableTensor *_aux_input_to_cell_weights{nullptr};
103 const IPortableTensor *_aux_input_to_output_weights{nullptr};
104 const IPortableTensor *_input_gate_bias{nullptr};
105 const IPortableTensor *_forget_gate_bias{nullptr};
106 const IPortableTensor *_cell_gate_bias{nullptr};
107 const IPortableTensor *_output_gate_bias{nullptr};
108 const IPortableTensor *_projection_weights{nullptr};
109 const IPortableTensor *_projection_bias{nullptr};
110 const IPortableTensor *_output_state_in{nullptr};
111 const IPortableTensor *_cell_state_in{nullptr};
112 IPortableTensor *_scratch_buffer{nullptr};
113 IPortableTensor *_output_state{nullptr};
114 IPortableTensor *_cell_state{nullptr};
115 IPortableTensor *_output{nullptr};
116 std::vector<uint8_t> _scratch_vec{};
117 std::vector<uint8_t> _output_state_vec{};
118 std::vector<uint8_t> _cell_state_vec{};
119 ir::operation::LSTM::Param _params{};
120 bool _forward_sequence{true};
121 bool _time_major{true};
122 int32_t _output_offset{0};
123 bool _has_output_state_data{false};
124 bool _has_cell_state_data{false};
129 } // namespace backend
132 #endif // __ONERT_BACKEND_CPU_OPS_LSTMLAYER_H__