2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #ifndef __ONERT_BACKEND_ACL_COMMON_ACL_KERNEL_GEN_H_
18 #define __ONERT_BACKEND_ACL_COMMON_ACL_KERNEL_GEN_H_
20 #include <exec/IFunction.h>
21 #include <ir/Operands.h>
23 #include <ir/operation/LSTM.h>
24 #include <arm_compute/runtime/CL/CLFunctions.h>
33 template <typename Layer, typename... Args>
34 std::unique_ptr<arm_compute::IFunction> generateLayer(Args &&... args)
36 auto l = std::make_unique<Layer>();
38 l->configure(std::forward<Args>(args)...);
43 template <typename Layer, typename... Args>
44 std::unique_ptr<arm_compute::IFunction>
45 generateLayer(std::shared_ptr<arm_compute::IMemoryManager> memory_manager, Args &&... args)
47 auto l = std::make_unique<Layer>(memory_manager);
49 l->configure(std::forward<Args>(args)...);
54 template <typename T_FunctionWrapper, typename T_Tensor, typename T_ACLLayer,
55 typename T_TensorRegistry>
56 std::unique_ptr<exec::IFunction> kernelGenLSTM(const ir::operation::LSTM &node,
57 const ir::Operands &operands,
58 const std::shared_ptr<T_TensorRegistry> &tensor_reg)
60 // TODO Support dynamic rnn
61 // TODO Fix subtle error in the case of non-CIFG, non-peephole and No Projection.
62 const auto scratch_buffer_index{
63 node.getOutputs().at(ir::operation::LSTM::Output::SCRATCH_BUFFER)};
64 const auto output_state_out_index{
65 node.getOutputs().at(ir::operation::LSTM::Output::OUTPUT_STATE_OUT)};
66 const auto cell_state_out_index{
67 node.getOutputs().at(ir::operation::LSTM::Output::CELL_STATE_OUT)};
68 const auto output_index{node.getOutputs().at(ir::operation::LSTM::Output::OUTPUT)};
70 const auto input_index{node.getInputs().at(ir::operation::LSTM::Input::INPUT)};
71 const auto input_to_input_weights_index{
72 node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_INPUT_WEIGHTS)}; // optional
73 const auto input_to_forget_weights_index{
74 node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_FORGET_WEIGHTS)};
75 const auto input_to_cell_weights_index{
76 node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_CELL_WEIGHTS)};
77 const auto input_to_output_weights_index{
78 node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_OUTPUT_WEIGHTS)};
79 const auto recurrent_to_input_weights_index{
80 node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_INPUT_WEIGHTS)}; // optional
81 const auto recurrent_to_forget_weights_index{
82 node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_FORGET_WEIGHTS)};
83 const auto recurrent_to_cell_weights_index{
84 node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_CELL_WEIGHTS)};
85 const auto recurrent_to_output_weights_index{
86 node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_OUTPUT_WEIGHTS)};
87 const auto cell_to_input_weights_index{
88 node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_INPUT_WEIGHTS)}; // optional
89 const auto cell_to_forget_weights_index{
90 node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_FORGET_WEIGHTS)}; // optional
91 const auto cell_to_output_weights_index{
92 node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_OUTPUT_WEIGHTS)}; // optional
93 const auto input_gate_bias_index{
94 node.getInputs().at(ir::operation::LSTM::Input::INPUT_GATE_BIAS)};
95 const auto forget_gate_bias_index{
96 node.getInputs().at(ir::operation::LSTM::Input::FORGET_GATE_BIAS)};
97 const auto cell_bias_index{node.getInputs().at(ir::operation::LSTM::Input::CELL_BIAS)};
98 const auto output_gate_bias_index{
99 node.getInputs().at(ir::operation::LSTM::Input::OUTPUT_GATE_BIAS)};
100 const auto projection_weights_index{
101 node.getInputs().at(ir::operation::LSTM::Input::PROJECTION_WEIGHTS)}; // optional
102 const auto projection_bias_index{
103 node.getInputs().at(ir::operation::LSTM::Input::PROJECTION_BIAS)}; // optional
104 const auto output_state_in_index{
105 node.getInputs().at(ir::operation::LSTM::Input::OUTPUT_STATE_IN)};
106 const auto cell_state_in_index{node.getInputs().at(ir::operation::LSTM::Input::CELL_STATE_IN)};
107 const auto cell_threshold = node.param().cell_threshold;
108 const auto projection_threshold = node.param().projection_threshold;
110 bool has_input_to_input_weights = operands.at(input_to_input_weights_index).shape().dim(0) != 0 &&
111 operands.at(input_to_input_weights_index).shape().dim(1) != 0;
112 bool has_recurrent_to_input_weights =
113 operands.at(recurrent_to_input_weights_index).shape().dim(0) != 0 &&
114 operands.at(recurrent_to_input_weights_index).shape().dim(1) != 0;
115 bool has_cell_to_forget_weights = operands.at(cell_to_forget_weights_index).shape().dim(0) != 0;
116 bool has_cell_to_output_weights = operands.at(cell_to_output_weights_index).shape().dim(0) != 0;
117 bool has_projection_weights = operands.at(projection_weights_index).shape().dim(0) != 0 &&
118 operands.at(projection_weights_index).shape().dim(1) != 0;
119 bool has_projection_bias = operands.at(projection_bias_index).shape().dim(0);
121 // NOTE The input_to_input_weights and the recurrent_to_input_weights do not exist in CIFG.
124 // NOTE The cell_to_input_weights does not exist in non-peephole although regular LSTM(non-CIFG).
125 bool has_cifg_param = has_input_to_input_weights && has_recurrent_to_input_weights;
127 // NOTE The cell_to_forget_weights and the cell_to_output_weights exist in peephole.
128 // But the cell_to_input_weights does not exist in regular CIFG although peephole.
130 // false: no peephole
131 bool has_peephole_param = has_cell_to_forget_weights && has_cell_to_output_weights;
133 // NOTE Although the projection weights has data the projection bias may not have data.
134 bool has_projection_param = has_projection_weights;
136 const auto activation = node.param().activation;
137 const auto cell_clip = cell_threshold;
138 const auto projection_clip = projection_threshold;
139 assert(cell_clip >= 0.f && projection_clip >= 0.f);
141 auto scratch_buffer_tensor = tensor_reg->getAclTensor(scratch_buffer_index).get();
142 auto output_state_out_tensor = tensor_reg->getAclTensor(output_state_out_index).get();
143 auto cell_state_out_tensor = tensor_reg->getAclTensor(cell_state_out_index).get();
144 auto output_tensor = tensor_reg->getAclTensor(output_index).get();
146 auto input_tensor = tensor_reg->getAclTensor(input_index).get();
148 auto input_to_forget_weights_tensor =
149 tensor_reg->getAclTensor(input_to_forget_weights_index).get();
150 auto input_to_cell_weights_tensor = tensor_reg->getAclTensor(input_to_cell_weights_index).get();
151 auto input_to_output_weights_tensor =
152 tensor_reg->getAclTensor(input_to_output_weights_index).get();
153 auto recurrent_to_forget_weights_tensor =
154 tensor_reg->getAclTensor(recurrent_to_forget_weights_index).get();
155 auto recurrent_to_cell_weights_tensor =
156 tensor_reg->getAclTensor(recurrent_to_cell_weights_index).get();
157 auto recurrent_to_output_weights_tensor =
158 tensor_reg->getAclTensor(recurrent_to_output_weights_index).get();
160 auto forget_gate_bias_tensor = tensor_reg->getAclTensor(forget_gate_bias_index).get();
161 auto cell_bias_tensor = tensor_reg->getAclTensor(cell_bias_index).get();
162 auto output_gate_bias_tensor = tensor_reg->getAclTensor(output_gate_bias_index).get();
163 auto output_state_in_tensor = tensor_reg->getAclTensor(output_state_in_index).get();
164 auto cell_state_in_tensor = tensor_reg->getAclTensor(cell_state_in_index).get();
166 auto act_info = asActivationLayerInfo(activation);
168 ::arm_compute::LSTMParams<T_Tensor> lstm_params{};
171 auto input_to_input_weights_tensor =
172 tensor_reg->getAclTensor(input_to_input_weights_index).get(); // optional
173 auto recurrent_to_input_weights_tensor =
174 tensor_reg->getAclTensor(recurrent_to_input_weights_index).get(); // optional
175 auto cell_to_input_weights_handle =
176 has_peephole_param ? tensor_reg->getAclTensor(cell_to_input_weights_index).get()->handle()
177 : nullptr; // optional (non-cifg && peephole)
178 auto input_gate_bias_tensor = tensor_reg->getAclTensor(input_gate_bias_index).get(); // optional
179 lstm_params.set_cifg_params(input_to_input_weights_tensor->handle(),
180 recurrent_to_input_weights_tensor->handle(),
181 cell_to_input_weights_handle, input_gate_bias_tensor->handle());
183 if (has_peephole_param)
185 auto cell_to_forget_weights_tensor =
186 tensor_reg->getAclTensor(cell_to_forget_weights_index).get(); // optional
187 auto cell_to_output_weights_tensor =
188 tensor_reg->getAclTensor(cell_to_output_weights_index).get(); // optional
189 lstm_params.set_peephole_params(cell_to_forget_weights_tensor->handle(),
190 cell_to_output_weights_tensor->handle());
192 if (has_projection_param)
194 auto projection_weights_tensor =
195 tensor_reg->getAclTensor(projection_weights_index).get(); // optional
196 auto projection_bias_handle =
197 has_projection_bias ? tensor_reg->getAclTensor(projection_bias_index).get()->handle()
198 : nullptr; // optional
199 lstm_params.set_projection_params(projection_weights_tensor->handle(), projection_bias_handle);
202 auto fn = generateLayer<T_ACLLayer>(
203 input_tensor->handle(), input_to_forget_weights_tensor->handle(),
204 input_to_cell_weights_tensor->handle(), input_to_output_weights_tensor->handle(),
205 recurrent_to_forget_weights_tensor->handle(), recurrent_to_cell_weights_tensor->handle(),
206 recurrent_to_output_weights_tensor->handle(), forget_gate_bias_tensor->handle(),
207 cell_bias_tensor->handle(), output_gate_bias_tensor->handle(),
208 output_state_in_tensor->handle(), cell_state_in_tensor->handle(),
209 scratch_buffer_tensor->handle(), output_state_out_tensor->handle(),
210 cell_state_out_tensor->handle(), output_tensor->handle(), lstm_params, act_info, cell_clip,
213 return std::make_unique<T_FunctionWrapper>(std::move(fn));
216 template <typename T_FunctionWrapper, typename T_Tensor, typename T_ACLLayer,
217 typename T_TensorBuilder, typename T_TensorRegistry>
218 std::unique_ptr<exec::IFunction>
219 kernelGenFullyConnected(const ir::operation::FullyConnected &node, const ir::Operands &operands,
220 const std::shared_ptr<T_TensorBuilder> &tensor_builder,
221 const std::shared_ptr<T_TensorRegistry> &tensor_reg, ir::Layout layout)
223 using ir::operation::FullyConnected;
225 const auto output_index{node.getOutputs().at(0)};
226 const auto input_index{node.getInputs().at(FullyConnected::Input::INPUT)};
227 const auto weight_index{node.getInputs().at(FullyConnected::Input::WEIGHT)};
228 const auto bias_index{node.getInputs().at(FullyConnected::Input::BIAS)};
230 const auto input_rank = operands.at(input_index).shape().rank();
232 const auto output_size =
233 operands.at(output_index).shape().dim(operands.at(output_index).shape().rank() - 1);
234 UNUSED_RELEASE(output_size);
235 assert(operands.at(bias_index).shape().dim(0) == output_size);
236 assert(operands.at(weight_index).shape().dim(0) == output_size);
237 const auto batch_size =
238 operands.at(output_index).shape().dim(operands.at(output_index).shape().rank() - 2);
239 const auto input_size =
240 operands.at(weight_index).shape().dim(operands.at(weight_index).shape().rank() - 1);
242 // Check for reshaping input's shape into rank-2
243 bool needs_reshape = false;
244 ir::Shape reshape(2);
245 if (input_rank == 3 || input_rank == 4)
247 const auto &ifm_shape = operands.at(input_index).shape();
248 auto feature_size = 1;
249 for (int i = 0; i < ifm_shape.rank(); ++i)
251 feature_size *= ifm_shape.dim(i);
254 UNUSED_RELEASE(feature_size);
255 assert(feature_size == batch_size * input_size);
258 needs_reshape = true;
259 reshape.dim(0) = batch_size; /* H */
260 reshape.dim(1) = input_size; /* W */
263 auto output_tensor = tensor_reg->getAclTensor(output_index).get();
264 const auto input_tensor = tensor_reg->getAclTensor(input_index).get();
265 const auto weight_tensor = tensor_reg->getAclTensor(weight_index).get();
266 const auto bias_tensor = tensor_reg->getAclTensor(bias_index).get();
267 const auto frontend_layout = layout;
268 const auto acl_layout = output_tensor->handle()->info()->data_layout();
270 typename T_ACLLayer::KernelType kernel_type = T_ACLLayer::KernelType::GENERAL;
271 if (operands.at(weight_index).isConstant())
273 kernel_type = T_ACLLayer::KernelType::PREPROCESSED_WEIGHTS;
274 assert(operands.at(weight_index).data());
277 auto fn = generateLayer<T_ACLLayer>(
278 tensor_builder->acl_tensor_manager()->internal_buffer_manager(), input_tensor->handle(),
279 weight_tensor->handle(), bias_tensor->handle(), output_tensor->handle(), needs_reshape,
280 asTensorShape(reshape, frontend_layout, asRuntimeLayout(acl_layout)), kernel_type);
282 return std::make_unique<T_FunctionWrapper>(std::move(fn));
285 template <typename T_ACLLayer, typename T_PoolOp, typename T_AclTensorRegistry>
286 std::unique_ptr<::arm_compute::IFunction>
287 kernelGenPool2D(const T_PoolOp &node, const ir::Operands &operands,
288 const std::shared_ptr<T_AclTensorRegistry> &tensor_reg, ir::Layout layout,
289 ::arm_compute::PoolingType pooling_type)
291 const auto ofm_index{node.getOutputs().at(0)};
292 const auto ifm_index{node.getInputs().at(0)};
294 const auto ofm_shape = operands.at(ofm_index).shape().asFeature(layout);
295 const auto ifm_shape = operands.at(ifm_index).shape().asFeature(layout);
297 const auto kh = node.param().kh;
298 const auto kw = node.param().kw;
299 const auto stride = node.param().stride;
301 ir::calculatePadding(node.param().padding, ifm_shape, ofm_shape, stride, kw, kh);
303 VERBOSE(Pool2DParam) << "IFM_H: " << ifm_shape.H << std::endl;
304 VERBOSE(Pool2DParam) << "IFM_W: " << ifm_shape.W << std::endl;
305 VERBOSE(Pool2DParam) << "OFM_H: " << ofm_shape.H << std::endl;
306 VERBOSE(Pool2DParam) << "OFM_W: " << ofm_shape.W << std::endl;
307 VERBOSE(Pool2DParam) << "KER_H: " << kh << std::endl;
308 VERBOSE(Pool2DParam) << "KER_W: " << kw << std::endl;
309 VERBOSE(Pool2DParam) << "STRIDE_H: " << stride.vertical << std::endl;
310 VERBOSE(Pool2DParam) << "STRIDE_W: " << stride.horizontal << std::endl;
311 VERBOSE(Pool2DParam) << "PAD(T): " << padding.top << std::endl;
312 VERBOSE(Pool2DParam) << "PAD(B): " << padding.bottom << std::endl;
313 VERBOSE(Pool2DParam) << "PAD(L): " << padding.left << std::endl;
314 VERBOSE(Pool2DParam) << "PAD(R): " << padding.right << std::endl;
316 auto ofm_tensor = tensor_reg->getAclTensor(ofm_index).get();
317 auto ifm_tensor = tensor_reg->getAclTensor(ifm_index).get();
319 ::arm_compute::PoolingLayerInfo info{
320 pooling_type, ::arm_compute::Size2D{kw, kh}, ifm_tensor->info()->data_layout(),
321 asPadStrideInfo(padding, stride), true /* exclude_padding */};
323 auto fn = generateLayer<T_ACLLayer>(ifm_tensor->handle(), ofm_tensor->handle(), info);
328 } // namespace acl_common
329 } // namespace backend
332 #endif // __ONERT_BACKEND_ACL_COMMON_ACL_KERNEL_GEN_H_