Imported Upstream version 1.8.0
[platform/core/ml/nnfw.git] / runtime / onert / backend / acl_common / AclKernelGen.h
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef __ONERT_BACKEND_ACL_COMMON_ACL_KERNEL_GEN_H_
18 #define __ONERT_BACKEND_ACL_COMMON_ACL_KERNEL_GEN_H_
19
20 #include <exec/IFunction.h>
21 #include <ir/Operands.h>
22
23 #include <ir/operation/LSTM.h>
24 #include <arm_compute/runtime/CL/CLFunctions.h>
25
26 namespace onert
27 {
28 namespace backend
29 {
30 namespace acl_common
31 {
32
33 template <typename T_FunctionWrapper, typename T_Tensor, typename T_ACLLayer,
34           typename T_TensorBuilder>
35 std::unique_ptr<exec::IFunction>
36 kernelGenLSTM(const ir::operation::LSTM &node, const ir::Operands &operands,
37               const std::shared_ptr<T_TensorBuilder> &tensor_builder)
38 {
39   // TODO Support dynamic rnn
40   // TODO Fix subtle error in the case of non-CIFG, non-peephole and No Projection.
41   const auto scratch_buffer_index{
42       node.getOutputs().at(ir::operation::LSTM::Output::SCRATCH_BUFFER)};
43   const auto output_state_out_index{
44       node.getOutputs().at(ir::operation::LSTM::Output::OUTPUT_STATE_OUT)};
45   const auto cell_state_out_index{
46       node.getOutputs().at(ir::operation::LSTM::Output::CELL_STATE_OUT)};
47   const auto output_index{node.getOutputs().at(ir::operation::LSTM::Output::OUTPUT)};
48
49   const auto input_index{node.getInputs().at(ir::operation::LSTM::Input::INPUT)};
50   const auto input_to_input_weights_index{
51       node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_INPUT_WEIGHTS)}; // optional
52   const auto input_to_forget_weights_index{
53       node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_FORGET_WEIGHTS)};
54   const auto input_to_cell_weights_index{
55       node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_CELL_WEIGHTS)};
56   const auto input_to_output_weights_index{
57       node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_OUTPUT_WEIGHTS)};
58   const auto recurrent_to_input_weights_index{
59       node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_INPUT_WEIGHTS)}; // optional
60   const auto recurrent_to_forget_weights_index{
61       node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_FORGET_WEIGHTS)};
62   const auto recurrent_to_cell_weights_index{
63       node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_CELL_WEIGHTS)};
64   const auto recurrent_to_output_weights_index{
65       node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_OUTPUT_WEIGHTS)};
66   const auto cell_to_input_weights_index{
67       node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_INPUT_WEIGHTS)}; // optional
68   const auto cell_to_forget_weights_index{
69       node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_FORGET_WEIGHTS)}; // optional
70   const auto cell_to_output_weights_index{
71       node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_OUTPUT_WEIGHTS)}; // optional
72   const auto input_gate_bias_index{
73       node.getInputs().at(ir::operation::LSTM::Input::INPUT_GATE_BIAS)};
74   const auto forget_gate_bias_index{
75       node.getInputs().at(ir::operation::LSTM::Input::FORGET_GATE_BIAS)};
76   const auto cell_bias_index{node.getInputs().at(ir::operation::LSTM::Input::CELL_BIAS)};
77   const auto output_gate_bias_index{
78       node.getInputs().at(ir::operation::LSTM::Input::OUTPUT_GATE_BIAS)};
79   const auto projection_weights_index{
80       node.getInputs().at(ir::operation::LSTM::Input::PROJECTION_WEIGHTS)}; // optional
81   const auto projection_bias_index{
82       node.getInputs().at(ir::operation::LSTM::Input::PROJECTION_BIAS)}; // optional
83   const auto output_state_in_index{
84       node.getInputs().at(ir::operation::LSTM::Input::OUTPUT_STATE_IN)};
85   const auto cell_state_in_index{node.getInputs().at(ir::operation::LSTM::Input::CELL_STATE_IN)};
86   const auto cell_threshold = node.param().cell_threshold;
87   const auto projection_threshold = node.param().projection_threshold;
88
89   bool has_input_to_input_weights = operands.at(input_to_input_weights_index).shape().dim(0) != 0 &&
90                                     operands.at(input_to_input_weights_index).shape().dim(1) != 0;
91   bool has_recurrent_to_input_weights =
92       operands.at(recurrent_to_input_weights_index).shape().dim(0) != 0 &&
93       operands.at(recurrent_to_input_weights_index).shape().dim(1) != 0;
94   bool has_cell_to_forget_weights = operands.at(cell_to_forget_weights_index).shape().dim(0) != 0;
95   bool has_cell_to_output_weights = operands.at(cell_to_output_weights_index).shape().dim(0) != 0;
96   bool has_projection_weights = operands.at(projection_weights_index).shape().dim(0) != 0 &&
97                                 operands.at(projection_weights_index).shape().dim(1) != 0;
98   bool has_projection_bias = operands.at(projection_bias_index).shape().dim(0);
99
100   // NOTE The input_to_input_weights and the recurrent_to_input_weights do not exist in CIFG.
101   // true: no CIFG
102   // false: CIFG
103   // NOTE The cell_to_input_weights does not exist in non-peephole although regular LSTM(non-CIFG).
104   bool has_cifg_param = has_input_to_input_weights && has_recurrent_to_input_weights;
105
106   // NOTE The cell_to_forget_weights and the cell_to_output_weights exist in peephole.
107   // But the cell_to_input_weights does not exist in regular CIFG although peephole.
108   // true: peephole
109   // false: no peephole
110   bool has_peephole_param = has_cell_to_forget_weights && has_cell_to_output_weights;
111
112   // NOTE Although the projection weights has data the projection bias may not have data.
113   bool has_projection_param = has_projection_weights;
114
115   const auto activation = node.param().activation;
116   const auto cell_clip = cell_threshold;
117   const auto projection_clip = projection_threshold;
118   assert(cell_clip >= 0.f && projection_clip >= 0.f);
119
120   auto scratch_buffer_tensor = tensor_builder->at(scratch_buffer_index).get();
121   auto output_state_out_tensor = tensor_builder->at(output_state_out_index).get();
122   auto cell_state_out_tensor = tensor_builder->at(cell_state_out_index).get();
123   auto output_tensor = tensor_builder->at(output_index).get();
124
125   auto input_tensor = tensor_builder->at(input_index).get();
126
127   auto input_to_forget_weights_tensor = tensor_builder->at(input_to_forget_weights_index).get();
128   auto input_to_cell_weights_tensor = tensor_builder->at(input_to_cell_weights_index).get();
129   auto input_to_output_weights_tensor = tensor_builder->at(input_to_output_weights_index).get();
130   auto recurrent_to_forget_weights_tensor =
131       tensor_builder->at(recurrent_to_forget_weights_index).get();
132   auto recurrent_to_cell_weights_tensor = tensor_builder->at(recurrent_to_cell_weights_index).get();
133   auto recurrent_to_output_weights_tensor =
134       tensor_builder->at(recurrent_to_output_weights_index).get();
135
136   auto forget_gate_bias_tensor = tensor_builder->at(forget_gate_bias_index).get();
137   auto cell_bias_tensor = tensor_builder->at(cell_bias_index).get();
138   auto output_gate_bias_tensor = tensor_builder->at(output_gate_bias_index).get();
139   auto output_state_in_tensor = tensor_builder->at(output_state_in_index).get();
140   auto cell_state_in_tensor = tensor_builder->at(cell_state_in_index).get();
141
142   auto act_info = ::onert::backend::acl_common::asActivationLayerInfo(activation);
143
144   auto fn = std::make_unique<T_ACLLayer>();
145
146   ::arm_compute::LSTMParams<T_Tensor> lstm_params{};
147   if (has_cifg_param)
148   {
149     auto input_to_input_weights_tensor =
150         tensor_builder->at(input_to_input_weights_index).get(); // optional
151     auto recurrent_to_input_weights_tensor =
152         tensor_builder->at(recurrent_to_input_weights_index).get(); // optional
153     auto cell_to_input_weights_handle =
154         has_peephole_param ? tensor_builder->at(cell_to_input_weights_index).get()->handle()
155                            : nullptr; // optional (non-cifg && peephole)
156     auto input_gate_bias_tensor = tensor_builder->at(input_gate_bias_index).get(); // optional
157     lstm_params.set_cifg_params(input_to_input_weights_tensor->handle(),
158                                 recurrent_to_input_weights_tensor->handle(),
159                                 cell_to_input_weights_handle, input_gate_bias_tensor->handle());
160   }
161   if (has_peephole_param)
162   {
163     auto cell_to_forget_weights_tensor =
164         tensor_builder->at(cell_to_forget_weights_index).get(); // optional
165     auto cell_to_output_weights_tensor =
166         tensor_builder->at(cell_to_output_weights_index).get(); // optional
167     lstm_params.set_peephole_params(cell_to_forget_weights_tensor->handle(),
168                                     cell_to_output_weights_tensor->handle());
169   }
170   if (has_projection_param)
171   {
172     auto projection_weights_tensor = tensor_builder->at(projection_weights_index).get(); // optional
173     auto projection_bias_handle = has_projection_bias
174                                       ? tensor_builder->at(projection_bias_index).get()->handle()
175                                       : nullptr; // optional
176     lstm_params.set_projection_params(projection_weights_tensor->handle(), projection_bias_handle);
177   }
178
179   fn->configure(input_tensor->handle(), input_to_forget_weights_tensor->handle(),
180                 input_to_cell_weights_tensor->handle(), input_to_output_weights_tensor->handle(),
181                 recurrent_to_forget_weights_tensor->handle(),
182                 recurrent_to_cell_weights_tensor->handle(),
183                 recurrent_to_output_weights_tensor->handle(), forget_gate_bias_tensor->handle(),
184                 cell_bias_tensor->handle(), output_gate_bias_tensor->handle(),
185                 output_state_in_tensor->handle(), cell_state_in_tensor->handle(),
186                 scratch_buffer_tensor->handle(), output_state_out_tensor->handle(),
187                 cell_state_out_tensor->handle(), output_tensor->handle(), lstm_params, act_info,
188                 cell_clip, projection_clip);
189
190   return std::make_unique<T_FunctionWrapper>(std::move(fn));
191 }
192
193 template <typename T_FunctionWrapper, typename T_Tensor, typename T_ACLLayer,
194           typename T_TensorBuilder>
195 std::unique_ptr<exec::IFunction>
196 kernelGenFullyConnected(const ir::operation::FullyConnected &node, const ir::Operands &operands,
197                         const std::shared_ptr<T_TensorBuilder> &tensor_builder, ir::Layout layout)
198 {
199   using ir::operation::FullyConnected;
200
201   const auto output_index{node.getOutputs().at(0)};
202   const auto input_index{node.getInputs().at(FullyConnected::Input::INPUT)};
203   const auto weight_index{node.getInputs().at(FullyConnected::Input::WEIGHT)};
204   const auto bias_index{node.getInputs().at(FullyConnected::Input::BIAS)};
205
206   const auto input_rank = operands.at(input_index).shape().rank();
207
208   const auto output_size =
209       operands.at(output_index).shape().dim(operands.at(output_index).shape().rank() - 1);
210   UNUSED_RELEASE(output_size);
211   assert(operands.at(bias_index).shape().dim(0) == output_size);
212   assert(operands.at(weight_index).shape().dim(0) == output_size);
213   const auto batch_size =
214       operands.at(output_index).shape().dim(operands.at(output_index).shape().rank() - 2);
215   const auto input_size =
216       operands.at(weight_index).shape().dim(operands.at(weight_index).shape().rank() - 1);
217
218   // Check for reshaping input's shape into rank-2
219   bool needs_reshape = false;
220   ir::Shape reshape(2);
221   if (input_rank == 3 || input_rank == 4)
222   {
223     const auto &ifm_shape = operands.at(input_index).shape();
224     auto feature_size = 1;
225     for (int i = 0; i < ifm_shape.rank(); ++i)
226     {
227       feature_size *= ifm_shape.dim(i);
228     }
229
230     UNUSED_RELEASE(feature_size);
231     assert(feature_size == batch_size * input_size);
232
233     // for reshaping
234     needs_reshape = true;
235     reshape.dim(0) = batch_size; /* H */
236     reshape.dim(1) = input_size; /* W */
237   }
238
239   auto output_tensor = tensor_builder->at(output_index).get();
240   const auto input_tensor = tensor_builder->at(input_index).get();
241   const auto weight_tensor = tensor_builder->at(weight_index).get();
242   const auto bias_tensor = tensor_builder->at(bias_index).get();
243   const auto frontend_layout = layout;
244   const auto acl_layout = output_tensor->handle()->info()->data_layout();
245
246   auto fn =
247       std::make_unique<T_ACLLayer>(tensor_builder->acl_tensor_manager()->internal_buffer_manager());
248
249   typename T_ACLLayer::KernelType kernel_type = T_ACLLayer::KernelType::GENERAL;
250   if (operands.at(weight_index).isConstant())
251   {
252     kernel_type = T_ACLLayer::KernelType::PREPROCESSED_WEIGHTS;
253     assert(operands.at(weight_index).data());
254   }
255
256   fn->configure(
257       input_tensor->handle(), weight_tensor->handle(), bias_tensor->handle(),
258       output_tensor->handle(), needs_reshape,
259       ::onert::backend::acl_common::asTensorShape(
260           reshape, frontend_layout, ::onert::backend::acl_common::asRuntimeLayout(acl_layout)),
261       kernel_type);
262
263   return std::make_unique<T_FunctionWrapper>(std::move(fn));
264 }
265
266 template <typename T_ACLLayer, typename T_PoolOp, typename T_TensorBuilder>
267 std::unique_ptr<::arm_compute::IFunction>
268 kernelGenPool2D(const T_PoolOp &node, const ir::Operands &operands,
269                 const std::shared_ptr<T_TensorBuilder> &tensor_builder, ir::Layout layout,
270                 ::arm_compute::PoolingType pooling_type)
271 {
272   const auto ofm_index{node.getOutputs().at(0)};
273   const auto ifm_index{node.getInputs().at(0)};
274
275   const auto ofm_shape = operands.at(ofm_index).shape().asFeature(layout);
276   const auto ifm_shape = operands.at(ifm_index).shape().asFeature(layout);
277
278   const auto kh = node.param().kh;
279   const auto kw = node.param().kw;
280   const auto stride = node.param().stride;
281   const auto padding =
282       ir::calculatePadding(node.param().padding, ifm_shape, ofm_shape, stride, kw, kh);
283
284   VERBOSE(Pool2DParam) << "IFM_H: " << ifm_shape.H << std::endl;
285   VERBOSE(Pool2DParam) << "IFM_W: " << ifm_shape.W << std::endl;
286   VERBOSE(Pool2DParam) << "OFM_H: " << ofm_shape.H << std::endl;
287   VERBOSE(Pool2DParam) << "OFM_W: " << ofm_shape.W << std::endl;
288   VERBOSE(Pool2DParam) << "KER_H: " << kh << std::endl;
289   VERBOSE(Pool2DParam) << "KER_W: " << kw << std::endl;
290   VERBOSE(Pool2DParam) << "STRIDE_H: " << stride.vertical << std::endl;
291   VERBOSE(Pool2DParam) << "STRIDE_W: " << stride.horizontal << std::endl;
292   VERBOSE(Pool2DParam) << "PAD(T): " << padding.top << std::endl;
293   VERBOSE(Pool2DParam) << "PAD(B): " << padding.bottom << std::endl;
294   VERBOSE(Pool2DParam) << "PAD(L): " << padding.left << std::endl;
295   VERBOSE(Pool2DParam) << "PAD(R): " << padding.right << std::endl;
296
297   auto ofm_tensor = tensor_builder->at(ofm_index).get();
298   auto ifm_tensor = tensor_builder->at(ifm_index).get();
299
300   ::arm_compute::PoolingLayerInfo info{
301       pooling_type, ::arm_compute::Size2D{kw, kh}, ifm_tensor->info()->data_layout(),
302       acl_common::asPadStrideInfo(padding, stride), true /* exclude_padding */};
303
304   auto fn = std::make_unique<T_ACLLayer>();
305
306   fn->configure(ifm_tensor->handle(), ofm_tensor->handle(), info);
307
308   return fn;
309 }
310
311 } // namespace acl_common
312 } // namespace backend
313 } // namespace onert
314
315 #endif // __ONERT_BACKEND_ACL_COMMON_ACL_KERNEL_GEN_H_