2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #ifndef __ONERT_BACKEND_ACL_COMMON_ACL_KERNEL_GEN_H_
18 #define __ONERT_BACKEND_ACL_COMMON_ACL_KERNEL_GEN_H_
20 #include <exec/IFunction.h>
21 #include <ir/Operands.h>
23 #include <ir/operation/LSTM.h>
24 #include <arm_compute/runtime/CL/CLFunctions.h>
33 template <typename T_FunctionWrapper, typename T_Tensor, typename T_ACLLayer,
34 typename T_TensorBuilder>
35 std::unique_ptr<exec::IFunction>
36 kernelGenLSTM(const ir::operation::LSTM &node, const ir::Operands &operands,
37 const std::shared_ptr<T_TensorBuilder> &tensor_builder)
39 // TODO Support dynamic rnn
40 // TODO Fix subtle error in the case of non-CIFG, non-peephole and No Projection.
41 const auto scratch_buffer_index{
42 node.getOutputs().at(ir::operation::LSTM::Output::SCRATCH_BUFFER)};
43 const auto output_state_out_index{
44 node.getOutputs().at(ir::operation::LSTM::Output::OUTPUT_STATE_OUT)};
45 const auto cell_state_out_index{
46 node.getOutputs().at(ir::operation::LSTM::Output::CELL_STATE_OUT)};
47 const auto output_index{node.getOutputs().at(ir::operation::LSTM::Output::OUTPUT)};
49 const auto input_index{node.getInputs().at(ir::operation::LSTM::Input::INPUT)};
50 const auto input_to_input_weights_index{
51 node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_INPUT_WEIGHTS)}; // optional
52 const auto input_to_forget_weights_index{
53 node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_FORGET_WEIGHTS)};
54 const auto input_to_cell_weights_index{
55 node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_CELL_WEIGHTS)};
56 const auto input_to_output_weights_index{
57 node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_OUTPUT_WEIGHTS)};
58 const auto recurrent_to_input_weights_index{
59 node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_INPUT_WEIGHTS)}; // optional
60 const auto recurrent_to_forget_weights_index{
61 node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_FORGET_WEIGHTS)};
62 const auto recurrent_to_cell_weights_index{
63 node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_CELL_WEIGHTS)};
64 const auto recurrent_to_output_weights_index{
65 node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_OUTPUT_WEIGHTS)};
66 const auto cell_to_input_weights_index{
67 node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_INPUT_WEIGHTS)}; // optional
68 const auto cell_to_forget_weights_index{
69 node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_FORGET_WEIGHTS)}; // optional
70 const auto cell_to_output_weights_index{
71 node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_OUTPUT_WEIGHTS)}; // optional
72 const auto input_gate_bias_index{
73 node.getInputs().at(ir::operation::LSTM::Input::INPUT_GATE_BIAS)};
74 const auto forget_gate_bias_index{
75 node.getInputs().at(ir::operation::LSTM::Input::FORGET_GATE_BIAS)};
76 const auto cell_bias_index{node.getInputs().at(ir::operation::LSTM::Input::CELL_BIAS)};
77 const auto output_gate_bias_index{
78 node.getInputs().at(ir::operation::LSTM::Input::OUTPUT_GATE_BIAS)};
79 const auto projection_weights_index{
80 node.getInputs().at(ir::operation::LSTM::Input::PROJECTION_WEIGHTS)}; // optional
81 const auto projection_bias_index{
82 node.getInputs().at(ir::operation::LSTM::Input::PROJECTION_BIAS)}; // optional
83 const auto output_state_in_index{
84 node.getInputs().at(ir::operation::LSTM::Input::OUTPUT_STATE_IN)};
85 const auto cell_state_in_index{node.getInputs().at(ir::operation::LSTM::Input::CELL_STATE_IN)};
86 const auto cell_threshold = node.param().cell_threshold;
87 const auto projection_threshold = node.param().projection_threshold;
89 bool has_input_to_input_weights = operands.at(input_to_input_weights_index).shape().dim(0) != 0 &&
90 operands.at(input_to_input_weights_index).shape().dim(1) != 0;
91 bool has_recurrent_to_input_weights =
92 operands.at(recurrent_to_input_weights_index).shape().dim(0) != 0 &&
93 operands.at(recurrent_to_input_weights_index).shape().dim(1) != 0;
94 bool has_cell_to_forget_weights = operands.at(cell_to_forget_weights_index).shape().dim(0) != 0;
95 bool has_cell_to_output_weights = operands.at(cell_to_output_weights_index).shape().dim(0) != 0;
96 bool has_projection_weights = operands.at(projection_weights_index).shape().dim(0) != 0 &&
97 operands.at(projection_weights_index).shape().dim(1) != 0;
98 bool has_projection_bias = operands.at(projection_bias_index).shape().dim(0);
100 // NOTE The input_to_input_weights and the recurrent_to_input_weights do not exist in CIFG.
103 // NOTE The cell_to_input_weights does not exist in non-peephole although regular LSTM(non-CIFG).
104 bool has_cifg_param = has_input_to_input_weights && has_recurrent_to_input_weights;
106 // NOTE The cell_to_forget_weights and the cell_to_output_weights exist in peephole.
107 // But the cell_to_input_weights does not exist in regular CIFG although peephole.
109 // false: no peephole
110 bool has_peephole_param = has_cell_to_forget_weights && has_cell_to_output_weights;
112 // NOTE Although the projection weights has data the projection bias may not have data.
113 bool has_projection_param = has_projection_weights;
115 const auto activation = node.param().activation;
116 const auto cell_clip = cell_threshold;
117 const auto projection_clip = projection_threshold;
118 assert(cell_clip >= 0.f && projection_clip >= 0.f);
120 auto scratch_buffer_tensor = tensor_builder->at(scratch_buffer_index).get();
121 auto output_state_out_tensor = tensor_builder->at(output_state_out_index).get();
122 auto cell_state_out_tensor = tensor_builder->at(cell_state_out_index).get();
123 auto output_tensor = tensor_builder->at(output_index).get();
125 auto input_tensor = tensor_builder->at(input_index).get();
127 auto input_to_forget_weights_tensor = tensor_builder->at(input_to_forget_weights_index).get();
128 auto input_to_cell_weights_tensor = tensor_builder->at(input_to_cell_weights_index).get();
129 auto input_to_output_weights_tensor = tensor_builder->at(input_to_output_weights_index).get();
130 auto recurrent_to_forget_weights_tensor =
131 tensor_builder->at(recurrent_to_forget_weights_index).get();
132 auto recurrent_to_cell_weights_tensor = tensor_builder->at(recurrent_to_cell_weights_index).get();
133 auto recurrent_to_output_weights_tensor =
134 tensor_builder->at(recurrent_to_output_weights_index).get();
136 auto forget_gate_bias_tensor = tensor_builder->at(forget_gate_bias_index).get();
137 auto cell_bias_tensor = tensor_builder->at(cell_bias_index).get();
138 auto output_gate_bias_tensor = tensor_builder->at(output_gate_bias_index).get();
139 auto output_state_in_tensor = tensor_builder->at(output_state_in_index).get();
140 auto cell_state_in_tensor = tensor_builder->at(cell_state_in_index).get();
142 auto act_info = ::onert::backend::acl_common::asActivationLayerInfo(activation);
144 auto fn = std::make_unique<T_ACLLayer>();
146 ::arm_compute::LSTMParams<T_Tensor> lstm_params{};
149 auto input_to_input_weights_tensor =
150 tensor_builder->at(input_to_input_weights_index).get(); // optional
151 auto recurrent_to_input_weights_tensor =
152 tensor_builder->at(recurrent_to_input_weights_index).get(); // optional
153 auto cell_to_input_weights_handle =
154 has_peephole_param ? tensor_builder->at(cell_to_input_weights_index).get()->handle()
155 : nullptr; // optional (non-cifg && peephole)
156 auto input_gate_bias_tensor = tensor_builder->at(input_gate_bias_index).get(); // optional
157 lstm_params.set_cifg_params(input_to_input_weights_tensor->handle(),
158 recurrent_to_input_weights_tensor->handle(),
159 cell_to_input_weights_handle, input_gate_bias_tensor->handle());
161 if (has_peephole_param)
163 auto cell_to_forget_weights_tensor =
164 tensor_builder->at(cell_to_forget_weights_index).get(); // optional
165 auto cell_to_output_weights_tensor =
166 tensor_builder->at(cell_to_output_weights_index).get(); // optional
167 lstm_params.set_peephole_params(cell_to_forget_weights_tensor->handle(),
168 cell_to_output_weights_tensor->handle());
170 if (has_projection_param)
172 auto projection_weights_tensor = tensor_builder->at(projection_weights_index).get(); // optional
173 auto projection_bias_handle = has_projection_bias
174 ? tensor_builder->at(projection_bias_index).get()->handle()
175 : nullptr; // optional
176 lstm_params.set_projection_params(projection_weights_tensor->handle(), projection_bias_handle);
179 fn->configure(input_tensor->handle(), input_to_forget_weights_tensor->handle(),
180 input_to_cell_weights_tensor->handle(), input_to_output_weights_tensor->handle(),
181 recurrent_to_forget_weights_tensor->handle(),
182 recurrent_to_cell_weights_tensor->handle(),
183 recurrent_to_output_weights_tensor->handle(), forget_gate_bias_tensor->handle(),
184 cell_bias_tensor->handle(), output_gate_bias_tensor->handle(),
185 output_state_in_tensor->handle(), cell_state_in_tensor->handle(),
186 scratch_buffer_tensor->handle(), output_state_out_tensor->handle(),
187 cell_state_out_tensor->handle(), output_tensor->handle(), lstm_params, act_info,
188 cell_clip, projection_clip);
190 return std::make_unique<T_FunctionWrapper>(std::move(fn));
193 template <typename T_FunctionWrapper, typename T_Tensor, typename T_ACLLayer,
194 typename T_TensorBuilder>
195 std::unique_ptr<exec::IFunction>
196 kernelGenFullyConnected(const ir::operation::FullyConnected &node, const ir::Operands &operands,
197 const std::shared_ptr<T_TensorBuilder> &tensor_builder, ir::Layout layout)
199 using ir::operation::FullyConnected;
201 const auto output_index{node.getOutputs().at(0)};
202 const auto input_index{node.getInputs().at(FullyConnected::Input::INPUT)};
203 const auto weight_index{node.getInputs().at(FullyConnected::Input::WEIGHT)};
204 const auto bias_index{node.getInputs().at(FullyConnected::Input::BIAS)};
206 const auto input_rank = operands.at(input_index).shape().rank();
208 const auto output_size =
209 operands.at(output_index).shape().dim(operands.at(output_index).shape().rank() - 1);
210 UNUSED_RELEASE(output_size);
211 assert(operands.at(bias_index).shape().dim(0) == output_size);
212 assert(operands.at(weight_index).shape().dim(0) == output_size);
213 const auto batch_size =
214 operands.at(output_index).shape().dim(operands.at(output_index).shape().rank() - 2);
215 const auto input_size =
216 operands.at(weight_index).shape().dim(operands.at(weight_index).shape().rank() - 1);
218 // Check for reshaping input's shape into rank-2
219 bool needs_reshape = false;
220 ir::Shape reshape(2);
221 if (input_rank == 3 || input_rank == 4)
223 const auto &ifm_shape = operands.at(input_index).shape();
224 auto feature_size = 1;
225 for (int i = 0; i < ifm_shape.rank(); ++i)
227 feature_size *= ifm_shape.dim(i);
230 UNUSED_RELEASE(feature_size);
231 assert(feature_size == batch_size * input_size);
234 needs_reshape = true;
235 reshape.dim(0) = batch_size; /* H */
236 reshape.dim(1) = input_size; /* W */
239 auto output_tensor = tensor_builder->at(output_index).get();
240 const auto input_tensor = tensor_builder->at(input_index).get();
241 const auto weight_tensor = tensor_builder->at(weight_index).get();
242 const auto bias_tensor = tensor_builder->at(bias_index).get();
243 const auto frontend_layout = layout;
244 const auto acl_layout = output_tensor->handle()->info()->data_layout();
247 std::make_unique<T_ACLLayer>(tensor_builder->acl_tensor_manager()->internal_buffer_manager());
249 typename T_ACLLayer::KernelType kernel_type = T_ACLLayer::KernelType::GENERAL;
250 if (operands.at(weight_index).isConstant())
252 kernel_type = T_ACLLayer::KernelType::PREPROCESSED_WEIGHTS;
253 assert(operands.at(weight_index).data());
257 input_tensor->handle(), weight_tensor->handle(), bias_tensor->handle(),
258 output_tensor->handle(), needs_reshape,
259 ::onert::backend::acl_common::asTensorShape(
260 reshape, frontend_layout, ::onert::backend::acl_common::asRuntimeLayout(acl_layout)),
263 return std::make_unique<T_FunctionWrapper>(std::move(fn));
266 template <typename T_ACLLayer, typename T_PoolOp, typename T_TensorBuilder>
267 std::unique_ptr<::arm_compute::IFunction>
268 kernelGenPool2D(const T_PoolOp &node, const ir::Operands &operands,
269 const std::shared_ptr<T_TensorBuilder> &tensor_builder, ir::Layout layout,
270 ::arm_compute::PoolingType pooling_type)
272 const auto ofm_index{node.getOutputs().at(0)};
273 const auto ifm_index{node.getInputs().at(0)};
275 const auto ofm_shape = operands.at(ofm_index).shape().asFeature(layout);
276 const auto ifm_shape = operands.at(ifm_index).shape().asFeature(layout);
278 const auto kh = node.param().kh;
279 const auto kw = node.param().kw;
280 const auto stride = node.param().stride;
282 ir::calculatePadding(node.param().padding, ifm_shape, ofm_shape, stride, kw, kh);
284 VERBOSE(Pool2DParam) << "IFM_H: " << ifm_shape.H << std::endl;
285 VERBOSE(Pool2DParam) << "IFM_W: " << ifm_shape.W << std::endl;
286 VERBOSE(Pool2DParam) << "OFM_H: " << ofm_shape.H << std::endl;
287 VERBOSE(Pool2DParam) << "OFM_W: " << ofm_shape.W << std::endl;
288 VERBOSE(Pool2DParam) << "KER_H: " << kh << std::endl;
289 VERBOSE(Pool2DParam) << "KER_W: " << kw << std::endl;
290 VERBOSE(Pool2DParam) << "STRIDE_H: " << stride.vertical << std::endl;
291 VERBOSE(Pool2DParam) << "STRIDE_W: " << stride.horizontal << std::endl;
292 VERBOSE(Pool2DParam) << "PAD(T): " << padding.top << std::endl;
293 VERBOSE(Pool2DParam) << "PAD(B): " << padding.bottom << std::endl;
294 VERBOSE(Pool2DParam) << "PAD(L): " << padding.left << std::endl;
295 VERBOSE(Pool2DParam) << "PAD(R): " << padding.right << std::endl;
297 auto ofm_tensor = tensor_builder->at(ofm_index).get();
298 auto ifm_tensor = tensor_builder->at(ifm_index).get();
300 ::arm_compute::PoolingLayerInfo info{
301 pooling_type, ::arm_compute::Size2D{kw, kh}, ifm_tensor->info()->data_layout(),
302 acl_common::asPadStrideInfo(padding, stride), true /* exclude_padding */};
304 auto fn = std::make_unique<T_ACLLayer>();
306 fn->configure(ifm_tensor->handle(), ofm_tensor->handle(), info);
311 } // namespace acl_common
312 } // namespace backend
315 #endif // __ONERT_BACKEND_ACL_COMMON_ACL_KERNEL_GEN_H_