Imported Upstream version 1.8.0
[platform/core/ml/nnfw.git] / runtime / onert / backend / cpu / KernelGenerator.cc
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "KernelGenerator.h"
18
19 #include "ops/AbsLayer.h"
20 #include "ops/AddLayer.h"
21 #include "ops/ArgMinMaxLayer.h"
22 #include "ops/AvgPoolLayer.h"
23 #include "ops/BatchToSpaceNDLayer.h"
24 #include "ops/CastLayer.h"
25 #include "ops/CompareLayer.h"
26 #include "ops/ConcatLayer.h"
27 #include "ops/ConvolutionLayer.h"
28 #include "ops/CosLayer.h"
29 #include "ops/DepthwiseConvolutionLayer.h"
30 #include "ops/DivLayer.h"
31 #include "ops/EinsumLayer.h"
32 #include "ops/ExpLayer.h"
33 #include "ops/ExpandDimsLayer.h"
34 #include "ops/FillLayer.h"
35 #include "ops/FullyConnectedLayer.h"
36 #include "ops/GatherLayer.h"
37 #include "ops/LogLayer.h"
38 #include "ops/LogisticLayer.h"
39 #include "ops/MaxLayer.h"
40 #include "ops/MaxPoolLayer.h"
41 #include "ops/MeanLayer.h"
42 #include "ops/MinLayer.h"
43 #include "ops/MulLayer.h"
44 #include "ops/NegLayer.h"
45 #include "ops/OneHotLayer.h"
46 #include "ops/OperationUtils.h"
47 #include "ops/PackLayer.h"
48 #include "ops/PadLayer.h"
49 #include "ops/PowLayer.h"
50 #include "ops/RangeLayer.h"
51 #include "ops/ReduceLayer.h"
52 #include "ops/ReLULayer.h"
53 #include "ops/ReLU6Layer.h"
54 #include "ops/ReshapeLayer.h"
55 #include "ops/ResizeBilinearLayer.h"
56 #include "ops/ReverseLayer.h"
57 #include "ops/RoundLayer.h"
58 #include "ops/RsqrtLayer.h"
59 #include "ops/SelectLayer.h"
60 #include "ops/ShapeLayer.h"
61 #include "ops/SinLayer.h"
62 #include "ops/SliceLayer.h"
63 #include "ops/SoftMaxLayer.h"
64 #include "ops/StridedSliceLayer.h"
65 #include "ops/SpaceToBatchNDLayer.h"
66 #include "ops/SpaceToDepthLayer.h"
67 #include "ops/SplitLayer.h"
68 #include "ops/SplitVLayer.h"
69 #include "ops/SubLayer.h"
70 #include "ops/TanhLayer.h"
71 #include "ops/TileLayer.h"
72 #include "ops/TransposeLayer.h"
73 #include "ops/UnpackLayer.h"
74 #include "ops/LogicalNotLayer.h"
75 #include "ops/ZerosLikeLayer.h"
76 #include "ops/SquaredDiffLayer.h"
77 #include "ops/LogicalOrLayer.h"
78 #include "ops/L2NormLayer.h"
79 #include "ops/MatrixBandPartLayer.h"
80 #include "ops/BatchMatMulLayer.h"
81 #include "ops/BroadcastToLayer.h"
82 #include "ops/FusedBatchNormLayer.h"
83 #include "ops/LogSoftMaxLayer.h"
84 #include "ops/QuantizeLayer.h"
85 #include "ops/StatelessRandomUniformLayer.h"
86
87 #include <backend/Backend.h>
88 #include <backend/IConfig.h>
89 #include <memory>
90 #include <util/Utils.h>
91 #include <util/logging.h>
92 #include <exec/DynamicShapeInference.h>
93
94 #include <stdexcept>
95
96 namespace onert
97 {
98 namespace backend
99 {
100 namespace cpu
101 {
102
103 namespace
104 {
105 ops::ReduceType convertReduceType(ir::operation::Reduce::ReduceType reduce_type_ir)
106 {
107   switch (reduce_type_ir)
108   {
109     case ir::operation::Reduce::ReduceType::ALL:
110       return ops::ReduceType::kAll;
111     case ir::operation::Reduce::ReduceType::ANY:
112       return ops::ReduceType::kAny;
113     case ir::operation::Reduce::ReduceType::MAX:
114       return ops::ReduceType::kMax;
115     case ir::operation::Reduce::ReduceType::MIN:
116       return ops::ReduceType::kMin;
117     case ir::operation::Reduce::ReduceType::PROD:
118       return ops::ReduceType::kProd;
119     case ir::operation::Reduce::ReduceType::SUM:
120       return ops::ReduceType::kSum;
121     default:
122       throw std::runtime_error("cpu KernelGenerator : Not supported operation yet");
123   }
124 }
125 } // namespace
126
127 KernelGenerator::KernelGenerator(
128     const ir::Operands &operands_ctx, const ir::Operations &operations_ctx,
129     const std::shared_ptr<TensorBuilder> &tensor_builder,
130     const std::shared_ptr<backend::custom::IKernelBuilder> &kernel_builder,
131     const std::shared_ptr<ExternalContext> &external_context)
132     : _ctx(operands_ctx), _operations_ctx{operations_ctx}, _tensor_builder(tensor_builder),
133       _kernel_builder(kernel_builder), _current_op_seq_layout(ir::Layout::UNKNOWN),
134       _external_context(external_context)
135 {
136   // DO NOTHING
137 }
138
139 void KernelGenerator::visit(const ir::OpSequence &op_seq)
140 {
141   assert(!_return_fn_seq);
142   assert(_tensor_builder->dynamicTensorManager());
143   assert(_tensor_builder->tensorRegistry());
144
145   auto dyn_tensor_manager = _tensor_builder->dynamicTensorManager();
146   auto dyn_shape_inferer = std::make_shared<exec::DynamicShapeInferer>(
147       _ctx, dyn_tensor_manager, _tensor_builder->tensorRegistry());
148
149   _return_fn_seq = std::make_unique<exec::FunctionSequence>();
150
151   // Prepare to handle dynamic tensors later
152   auto dyn_ctx = std::make_shared<exec::FunctionSequence::DynamicTensorCtx>();
153   {
154     dyn_ctx->op_seq = &op_seq;
155     dyn_ctx->operations = &_operations_ctx;
156     dyn_ctx->dynamic_shape_inferer = std::move(dyn_shape_inferer);
157     dyn_ctx->tensor_registry = _tensor_builder->tensorRegistry();
158     dyn_ctx->dynamic_tensor_manager = _tensor_builder->dynamicTensorManager();
159
160     _return_fn_seq->dynamic_tensor_ctx(dyn_ctx);
161   }
162   _return_fn_seq->enableDynamicShapeInferer(true);
163
164   _current_op_seq_layout = op_seq.getLayout();
165   for (const auto &operation_idx : op_seq.operations())
166   {
167     const auto &node = _operations_ctx.at(operation_idx);
168     node.accept(*this);
169     _return_fn_seq->append(releaseFunction());
170
171     for (const auto &ind : (node.getInputs() | ir::Remove::UNDEFINED) + node.getOutputs())
172     {
173       auto portable_tensor = _tensor_builder->portableAt(ind);
174       if (portable_tensor)
175       {
176         assert(portable_tensor->layout() == ir::Layout::NHWC);
177       }
178
179       auto tensor = _tensor_builder->at(ind);
180       if (tensor)
181       {
182         tensor->increase_ref();
183       }
184     }
185   }
186 }
187
188 void KernelGenerator::visit(const ir::operation::Conv2D &node)
189 {
190   using ir::operation::Conv2D;
191
192   const auto ofm_index{node.getOutputs().at(0)};
193   const auto ifm_index{node.getInputs().at(Conv2D::Input::INPUT)};
194   const auto ker_index{node.getInputs().at(Conv2D::Input::KERNEL)};
195   const auto bias_index{node.getInputs().at(Conv2D::Input::BIAS)};
196
197   auto ofm_tensor = _tensor_builder->portableAt(ofm_index).get();
198   auto ifm_tensor = _tensor_builder->portableAt(ifm_index).get();
199   auto ker_tensor = _tensor_builder->portableAt(ker_index).get();
200   auto bias_tensor = _tensor_builder->portableAt(bias_index).get();
201
202   const auto stride = node.param().stride;
203   const auto activation = node.param().activation;
204   const auto param_padding = node.param().padding;
205   auto fn = std::make_unique<ops::ConvolutionLayer>();
206
207   if (_ctx.at(ifm_index).info().isDynamic() || _ctx.at(ker_index).info().isDynamic())
208   {
209     fn->configure(ifm_tensor, ker_tensor, bias_tensor, param_padding.type, param_padding.param.left,
210                   param_padding.param.right, param_padding.param.top, param_padding.param.bottom,
211                   stride.horizontal, stride.vertical, activation, ofm_tensor);
212
213     _return_fn = std::move(fn);
214     return;
215   }
216   const auto ifm_shape = _ctx.at(ifm_index).shape().asFeature(_current_op_seq_layout);
217   const auto ofm_shape = _ctx.at(ofm_index).shape().asFeature(_current_op_seq_layout);
218   // Kernel format is [depth_out, kernel_height, kernel_width, depth_in].
219   const auto &ker_shape = _ctx.at(ker_index).shape();
220   const auto ker_height = ker_shape.dim(1);
221   const auto ker_width = ker_shape.dim(2);
222
223   const auto padding =
224       ir::calculatePadding(param_padding, ifm_shape, ofm_shape, stride, ker_width, ker_height);
225
226   fn->configure(ifm_tensor, ker_tensor, bias_tensor, param_padding.type, padding.left,
227                 padding.right, padding.top, padding.bottom, stride.horizontal, stride.vertical,
228                 activation, ofm_tensor);
229
230   _return_fn = std::move(fn);
231 }
232
233 void KernelGenerator::visit(const ir::operation::DepthwiseConv2D &node)
234 {
235   using ir::operation::DepthwiseConv2D;
236
237   const auto ofm_index{node.getOutputs().at(0)};
238   const auto ifm_index{node.getInputs().at(DepthwiseConv2D::Input::INPUT)};
239   const auto ker_index{node.getInputs().at(DepthwiseConv2D::Input::KERNEL)};
240   const auto bias_index{node.getInputs().at(DepthwiseConv2D::Input::BIAS)};
241
242   const auto stride = node.param().stride;
243   const auto ifm_shape = _ctx.at(ifm_index).shape().asFeature(_current_op_seq_layout);
244   const auto ofm_shape = _ctx.at(ofm_index).shape().asFeature(_current_op_seq_layout);
245   // Kernel format is [1, kernel_height, kernel_width, depth_out].
246   const auto &ker_shape = _ctx.at(ker_index).shape();
247   const auto ker_height = ker_shape.dim(1);
248   const auto ker_width = ker_shape.dim(2);
249   const auto padding = ir::calculatePadding(node.param().padding, ifm_shape, ofm_shape, stride,
250                                             ker_width, ker_height);
251   const auto multiplier = node.param().multiplier;
252   const auto activation = node.param().activation;
253
254   auto ofm_tensor = _tensor_builder->portableAt(ofm_index).get();
255   auto ifm_tensor = _tensor_builder->portableAt(ifm_index).get();
256   auto ker_tensor = _tensor_builder->portableAt(ker_index).get();
257   auto bias_tensor = _tensor_builder->portableAt(bias_index).get();
258
259   auto fn = std::make_unique<ops::DepthwiseConvolutionLayer>();
260
261   fn->configure(ifm_tensor, ker_tensor, bias_tensor, padding.left, padding.right, padding.top,
262                 padding.bottom, stride.horizontal, stride.vertical, multiplier, activation,
263                 ofm_tensor);
264
265   _return_fn = std::move(fn);
266 }
267
268 void KernelGenerator::visit(const ir::operation::MaxPool2D &node)
269 {
270   const auto ofm_index{node.getOutputs().at(0)};
271   const auto ifm_index{node.getInputs().at(ir::operation::MaxPool2D::Input::INPUT)};
272
273   const auto kh = node.param().kh;
274   const auto kw = node.param().kw;
275
276   const auto stride = node.param().stride;
277   const auto ifm_shape = _ctx.at(ifm_index).shape().asFeature(_current_op_seq_layout);
278   const auto ofm_shape = _ctx.at(ofm_index).shape().asFeature(_current_op_seq_layout);
279   const auto padding =
280       ir::calculatePadding(node.param().padding, ifm_shape, ofm_shape, stride, kw, kh);
281   const auto activation = node.param().activation;
282
283   auto ofm_tensor = _tensor_builder->portableAt(ofm_index).get();
284   auto ifm_tensor = _tensor_builder->portableAt(ifm_index).get();
285
286   auto fn = std::make_unique<ops::MaxPoolLayer>();
287
288   fn->configure(ifm_tensor, padding.left, padding.right, padding.top, padding.bottom,
289                 stride.horizontal, stride.vertical, kw, kh, activation, ofm_tensor);
290
291   _return_fn = std::move(fn);
292 }
293
294 void KernelGenerator::visit(const ir::operation::AvgPool2D &node)
295 {
296   const auto ofm_index{node.getOutputs().at(0)};
297   const auto ifm_index{node.getInputs().at(ir::operation::AvgPool2D::Input::INPUT)};
298
299   const auto kh = node.param().kh;
300   const auto kw = node.param().kw;
301   const auto stride = node.param().stride;
302   const auto ifm_shape = _ctx.at(ifm_index).shape().asFeature(_current_op_seq_layout);
303   const auto ofm_shape = _ctx.at(ofm_index).shape().asFeature(_current_op_seq_layout);
304   const auto padding =
305       ir::calculatePadding(node.param().padding, ifm_shape, ofm_shape, stride, kw, kh);
306   const auto activation = node.param().activation;
307
308   auto ofm_tensor = _tensor_builder->portableAt(ofm_index).get();
309   auto ifm_tensor = _tensor_builder->portableAt(ifm_index).get();
310
311   auto fn = std::make_unique<ops::AvgPoolLayer>();
312
313   fn->configure(ifm_tensor, padding.left, padding.right, padding.top, padding.bottom,
314                 stride.horizontal, stride.vertical, kw, kh, activation, ofm_tensor);
315
316   _return_fn = std::move(fn);
317 }
318
319 void KernelGenerator::visit(const ir::operation::Concat &node)
320 {
321   const auto ofm_index{node.getOutputs().at(0)};
322
323   const auto rank = _ctx.at(ofm_index).shape().rank();
324   const auto axis = ops::getAxis(rank, node.param().axis, _current_op_seq_layout);
325
326   auto output_tensor = _tensor_builder->portableAt(ofm_index).get();
327
328   std::vector<const IPortableTensor *> input_tensors;
329   for (auto &ifm_idx : node.getInputs())
330     input_tensors.emplace_back(_tensor_builder->portableAt(ifm_idx).get());
331
332   auto fn = std::make_unique<ops::ConcatLayer>();
333
334   fn->configure(input_tensors, axis, output_tensor);
335
336   _return_fn = std::move(fn);
337 }
338
339 void KernelGenerator::visit(const ir::operation::BatchToSpaceND &node)
340 {
341   const auto output_index{node.getOutputs().at(0)};
342   const auto input_index{node.getInputs().at(ir::operation::BatchToSpaceND::INPUT)};
343   const auto block_size_index{node.getInputs().at(ir::operation::BatchToSpaceND::BLOCK_SIZE)};
344
345   auto output_alloc = _tensor_builder->portableAt(output_index).get();
346   auto input_alloc = _tensor_builder->portableAt(input_index).get();
347   auto block_size_alloc = _tensor_builder->portableAt(block_size_index).get();
348
349   auto fn = std::make_unique<ops::BatchToSpaceNDLayer>();
350
351   IPortableTensor *crops_alloc = nullptr;
352   const auto NNApiInputs = 2;
353
354   if (node.getInputs().size() != NNApiInputs)
355   {
356     const auto crops_data_index{node.getInputs().at(ir::operation::BatchToSpaceND::CROPS_DATA)};
357     crops_alloc = _tensor_builder->portableAt(crops_data_index).get();
358   }
359
360   fn->configure(input_alloc, output_alloc, block_size_alloc, crops_alloc);
361
362   _return_fn = std::move(fn);
363 }
364
365 void KernelGenerator::visit(const ir::operation::Fill &node)
366 {
367   const auto output_index{node.getOutputs().at(0)};
368   const auto input_index{node.getInputs().at(ir::operation::Fill::Input::INPUT)};
369   const auto value_index{node.getInputs().at(ir::operation::Fill::Input::VALUE)};
370
371   auto output_tensor = _tensor_builder->portableAt(output_index).get();
372   auto input_tensor = _tensor_builder->portableAt(input_index).get();
373   auto value_tensor = _tensor_builder->portableAt(value_index).get();
374
375   auto fn = std::make_unique<ops::FillLayer>();
376
377   fn->configure(input_tensor, value_tensor, output_tensor);
378
379   _return_fn = std::move(fn);
380 }
381
382 void KernelGenerator::visit(const ir::operation::FullyConnected &node)
383 {
384   using ir::operation::FullyConnected;
385
386   const auto output_index{node.getOutputs().at(0)};
387   const auto input_index{node.getInputs().at(FullyConnected::Input::INPUT)};
388   const auto weight_index{node.getInputs().at(FullyConnected::Input::WEIGHT)};
389   const auto bias_index{node.getInputs().at(FullyConnected::Input::BIAS)};
390   const auto activation = node.param().activation;
391
392   auto output_tensor = _tensor_builder->portableAt(output_index).get();
393   auto input_tensor = _tensor_builder->portableAt(input_index).get();
394   auto weight_tensor = _tensor_builder->portableAt(weight_index).get();
395   auto bias_tensor =
396       bias_index.undefined() ? nullptr : _tensor_builder->portableAt(bias_index).get();
397
398   auto fn = std::make_unique<ops::FullyConnectedLayer>();
399
400   fn->configure(input_tensor, weight_tensor, bias_tensor, activation, output_tensor,
401                 _external_context);
402
403   _return_fn = std::move(fn);
404 }
405
406 void KernelGenerator::visit(const ir::operation::Reshape &node)
407 {
408   const auto output_index{node.getOutputs().at(0)};
409   const auto input_index{node.getInputs().at(ir::operation::Reshape::Input::INPUT)};
410
411   auto output_tensor = _tensor_builder->portableAt(output_index).get();
412   auto input_tensor = _tensor_builder->portableAt(input_index).get();
413
414   // optional 2nd input
415   IPortableTensor *shape_tensor = nullptr;
416
417   if (node.getInputs().size() == 2)
418   {
419     const auto shape_index{node.getInputs().at(ir::operation::Reshape::Input::SHAPE)};
420     shape_tensor = _tensor_builder->portableAt(shape_index).get();
421   }
422
423   auto fn = std::make_unique<ops::ReshapeLayer>();
424
425   fn->configure(input_tensor, shape_tensor, output_tensor);
426   _return_fn = std::move(fn);
427 }
428
429 void KernelGenerator::visit(const ir::operation::Squeeze &node)
430 {
431   const auto output_index{node.getOutputs().at(0)};
432   const auto input_index{node.getInputs().at(ir::operation::Squeeze::Input::INPUT)};
433
434   auto output_tensor = _tensor_builder->portableAt(output_index).get();
435   auto input_tensor = _tensor_builder->portableAt(input_index).get();
436
437   // Squeeze can share same kernel with reshape
438   auto fn = std::make_unique<ops::ReshapeLayer>();
439
440   fn->configure(input_tensor, nullptr, output_tensor);
441
442   _return_fn = std::move(fn);
443 }
444
445 void KernelGenerator::visit(const ir::operation::Softmax &node)
446 {
447   const auto output_index{node.getOutputs().at(0)};
448   const auto input_index{node.getInputs().at(ir::operation::Softmax::Input::INPUT)};
449
450   const auto beta = node.param().beta;
451
452   auto output_tensor = _tensor_builder->portableAt(output_index).get();
453   auto input_tensor = _tensor_builder->portableAt(input_index).get();
454
455   auto fn = std::make_unique<ops::SoftMaxLayer>();
456
457   fn->configure(input_tensor, beta, output_tensor);
458
459   _return_fn = std::move(fn);
460 }
461
462 void KernelGenerator::visit(const ir::operation::Add &node)
463 {
464   const auto ofm_index{node.getOutputs().at(0)};
465   const auto lhs_index{node.getInputs().at(ir::operation::Add::Input::LHS)};
466   const auto rhs_index{node.getInputs().at(ir::operation::Add::Input::RHS)};
467
468   const auto activation = node.param().activation;
469
470   auto ofm_tensor = _tensor_builder->portableAt(ofm_index).get();
471   auto lhs_tensor = _tensor_builder->portableAt(lhs_index).get();
472   auto rhs_tensor = _tensor_builder->portableAt(rhs_index).get();
473
474   auto fn = std::make_unique<ops::AddLayer>();
475
476   fn->configure(lhs_tensor, rhs_tensor, activation, ofm_tensor);
477
478   _return_fn = std::move(fn);
479 }
480
481 void KernelGenerator::visit(const ir::operation::Comparison &node)
482 {
483   const auto ofm_index{node.getOutputs().at(0)};
484   const auto lhs_index{node.getInputs().at(ir::operation::Comparison::Input::INPUT0)};
485   const auto rhs_index{node.getInputs().at(ir::operation::Comparison::Input::INPUT1)};
486
487   auto ofm_tensor = _tensor_builder->portableAt(ofm_index).get();
488   auto lhs_tensor = _tensor_builder->portableAt(lhs_index).get();
489   auto rhs_tensor = _tensor_builder->portableAt(rhs_index).get();
490
491   auto comparison_type = node.param().comparison_type;
492
493   auto fn = std::make_unique<ops::CompareLayer>();
494
495   fn->configure(lhs_tensor, rhs_tensor, comparison_type, ofm_tensor);
496
497   _return_fn = std::move(fn);
498 }
499
500 void KernelGenerator::visit(const ir::operation::Gather &node)
501 {
502   const auto output_index{node.getOutputs().at(0)};
503   const auto input_index{node.getInputs().at(ir::operation::Gather::Input::INPUT)};
504   const auto indices_index{node.getInputs().at(ir::operation::Gather::Input::INDICES)};
505
506   auto output_tensor = _tensor_builder->portableAt(output_index).get();
507   auto input_tensor = _tensor_builder->portableAt(input_index).get();
508   auto indices_tensor = _tensor_builder->portableAt(indices_index).get();
509
510   const auto backend_layout = output_tensor->layout();
511   UNUSED_RELEASE(backend_layout);
512
513   // NOTE The frontend layout and backend layout must be the same for this operation.
514   //      If not the same, we have to add a stage(?) to perform permutation of output tensor. It
515   //      is not not efficient even if it works well. If so, it would be better to set the
516   //      layout of these backend tensors to the same layout.
517   //      There is also one thing we have to think about. This operation depends on the layout of
518   //      a model. For example, if a model in NHWC has this operation as output rank == 4, indices
519   //      rank == 2 and axis == 2, this operation should work as the axis W and C, but the axis W
520   //      and C are not sequential in NCHW. So the backend in NCHW cannot handle this case.
521   assert(backend_layout == input_tensor->layout());
522   assert(backend_layout == indices_tensor->layout());
523   const auto &input_shape = _ctx.at(input_index).shape();
524   UNUSED_RELEASE(input_shape);
525   assert(input_shape.rank() < 4 || _current_op_seq_layout == backend_layout);
526
527   const auto axis_raw = node.param().axis;
528   const auto axis_value = (axis_raw < 0 ? (input_shape.rank() + axis_raw) : axis_raw);
529
530   auto fn = std::make_unique<ops::GatherLayer>();
531
532   fn->configure(input_tensor, indices_tensor, output_tensor, axis_value);
533
534   _return_fn = std::move(fn);
535 }
536
537 void KernelGenerator::visit(const ir::operation::Sub &node)
538 {
539   // The same as Add
540   const auto ofm_index{node.getOutputs().at(0)};
541   const auto lhs_index{node.getInputs().at(ir::operation::Sub::Input::LHS)};
542   const auto rhs_index{node.getInputs().at(ir::operation::Sub::Input::RHS)};
543
544   const auto activation = node.param().activation;
545
546   auto ofm_tensor = _tensor_builder->portableAt(ofm_index).get();
547   auto lhs_tensor = _tensor_builder->portableAt(lhs_index).get();
548   auto rhs_tensor = _tensor_builder->portableAt(rhs_index).get();
549
550   auto fn = std::make_unique<ops::SubLayer>();
551
552   fn->configure(lhs_tensor, rhs_tensor, activation, ofm_tensor);
553
554   _return_fn = std::move(fn);
555 }
556
557 void KernelGenerator::visit(const ir::operation::Mul &node)
558 {
559   // The same as Add
560   const auto ofm_index{node.getOutputs().at(0)};
561   const auto lhs_index{node.getInputs().at(ir::operation::Mul::Input::LHS)};
562   const auto rhs_index{node.getInputs().at(ir::operation::Mul::Input::RHS)};
563
564   const auto activation = node.param().activation;
565
566   auto ofm_tensor = _tensor_builder->portableAt(ofm_index).get();
567   auto lhs_tensor = _tensor_builder->portableAt(lhs_index).get();
568   auto rhs_tensor = _tensor_builder->portableAt(rhs_index).get();
569
570   auto fn = std::make_unique<ops::MulLayer>();
571
572   fn->configure(lhs_tensor, rhs_tensor, activation, ofm_tensor);
573
574   _return_fn = std::move(fn);
575 }
576
577 void KernelGenerator::visit(const ir::operation::OneHot &node)
578 {
579   const auto output_index{node.getOutputs().at(0)};
580   const auto indices_index{node.getInputs().at(ir::operation::OneHot::INDICES)};
581   const auto depth_index{node.getInputs().at(ir::operation::OneHot::Input::DEPTH)};
582   const auto onvalue_index{node.getInputs().at(ir::operation::OneHot::Input::ON_VALUE)};
583   const auto offvalue_index{node.getInputs().at(ir::operation::OneHot::Input::OFF_VALUE)};
584
585   const auto axis = node.param().axis;
586
587   auto output_tensor = _tensor_builder->portableAt(output_index).get();
588   auto indices_tensor = _tensor_builder->portableAt(indices_index).get();
589   auto depth_tensor = _tensor_builder->portableAt(depth_index).get();
590   auto onvalue_tensor = _tensor_builder->portableAt(onvalue_index).get();
591   auto offvalue_tensor = _tensor_builder->portableAt(offvalue_index).get();
592
593   assert(indices_tensor->data_type() == OperandType::INT32);
594   assert(axis <= static_cast<int>(indices_tensor->num_dimensions()));
595
596   auto fn = std::make_unique<ops::OneHotLayer>();
597
598   fn->configure(indices_tensor, depth_tensor, onvalue_tensor, offvalue_tensor, output_tensor, axis);
599
600   _return_fn = std::move(fn);
601 }
602
603 void KernelGenerator::visit(const ir::operation::Div &node)
604 {
605   // The same as Add
606   const auto ofm_index{node.getOutputs().at(0)};
607   const auto lhs_index{node.getInputs().at(ir::operation::Div::Input::LHS)};
608   const auto rhs_index{node.getInputs().at(ir::operation::Div::Input::RHS)};
609
610   const auto activation = node.param().activation;
611
612   auto ofm_tensor = _tensor_builder->portableAt(ofm_index).get();
613   auto lhs_tensor = _tensor_builder->portableAt(lhs_index).get();
614   auto rhs_tensor = _tensor_builder->portableAt(rhs_index).get();
615
616   auto fn = std::make_unique<ops::DivLayer>();
617
618   fn->configure(lhs_tensor, rhs_tensor, activation, ofm_tensor);
619
620   _return_fn = std::move(fn);
621 }
622
623 void KernelGenerator::visit(const ir::operation::Einsum &node)
624 {
625   const auto ofm_index{node.getOutputs().at(0)};
626
627   auto output_tensor = _tensor_builder->portableAt(ofm_index).get();
628   std::vector<const IPortableTensor *> input_tensors;
629   for (auto &ifm_idx : node.getInputs())
630     input_tensors.emplace_back(_tensor_builder->portableAt(ifm_idx).get());
631
632   const auto equation = node.param().equation;
633
634   auto fn = std::make_unique<ops::EinsumLayer>();
635
636   fn->configure(input_tensors, equation, output_tensor);
637
638   _return_fn = std::move(fn);
639 }
640
641 void KernelGenerator::visit(const ir::operation::Custom &node)
642 {
643   auto fill_op_info = [&](const ir::OperandIndexSequence &opSeq,
644                           std::vector<custom::TypeInfo> &types,
645                           std::vector<std::shared_ptr<IPortableTensor>> &tensors) {
646     for (auto &idx : opSeq)
647     {
648       const auto &operand = _ctx.at(idx);
649       // TODO make sure using `_current_op_seq_layout` is correct for custom operations
650       types.emplace_back(custom::TypeInfo{operand.shape(), operand.typeInfo().type()});
651       auto in_tensor = _tensor_builder->portableAt(idx);
652       tensors.emplace_back(in_tensor);
653     }
654   };
655
656   backend::custom::CustomKernelConfigParams params{};
657
658   fill_op_info(node.getInputs(), params.input_types, params.input_tensors);
659   fill_op_info(node.getOutputs(), params.output_types, params.output_tensors);
660
661   params.userdata = node.userdata().data;
662   params.userdata_size = node.userdata().size;
663
664   auto fn = _kernel_builder->buildKernel(node.id(), std::move(params));
665
666   _return_fn = std::move(fn);
667 }
668
669 void KernelGenerator::visit(const ir::operation::Exp &node)
670 {
671   const auto output_index{node.getOutputs().at(0)};
672   const auto input_index{node.getInputs().at(ir::operation::Exp::Input::INPUT)};
673
674   auto output_tensor = _tensor_builder->portableAt(output_index).get();
675   auto input_tensor = _tensor_builder->portableAt(input_index).get();
676
677   auto fn = std::make_unique<ops::ExpLayer>();
678
679   fn->configure(input_tensor, output_tensor);
680
681   _return_fn = std::move(fn);
682 }
683
684 void KernelGenerator::visit(const ir::operation::ExpandDims &node)
685 {
686   const auto output_index{node.getOutputs().at(0)};
687   const auto input_index{node.getInputs().at(ir::operation::ExpandDims::Input::INPUT)};
688   const auto axis_index{node.getInputs().at(ir::operation::ExpandDims::Input::AXIS)};
689
690   auto output_tensor = _tensor_builder->portableAt(output_index).get();
691   auto input_tensor = _tensor_builder->portableAt(input_index).get();
692   auto axis_tensor = _tensor_builder->portableAt(axis_index).get();
693
694   auto fn = std::make_unique<ops::ExpandDimsLayer>();
695
696   fn->configure(input_tensor, axis_tensor, output_tensor);
697
698   _return_fn = std::move(fn);
699 }
700
701 void KernelGenerator::visit(const ir::operation::Logistic &node)
702 {
703   const auto output_index{node.getOutputs().at(0)};
704   const auto input_index{node.getInputs().at(ir::operation::Logistic::Input::INPUT)};
705
706   auto output_tensor = _tensor_builder->portableAt(output_index).get();
707   auto input_tensor = _tensor_builder->portableAt(input_index).get();
708
709   auto fn = std::make_unique<ops::LogisticLayer>();
710
711   fn->configure(input_tensor, output_tensor);
712
713   _return_fn = std::move(fn);
714 }
715
716 void KernelGenerator::visit(const ir::operation::Tanh &node)
717 {
718   const auto output_index{node.getOutputs().at(0)};
719   const auto input_index{node.getInputs().at(ir::operation::Tanh::Input::INPUT)};
720
721   auto output_tensor = _tensor_builder->portableAt(output_index).get();
722   auto input_tensor = _tensor_builder->portableAt(input_index).get();
723
724   auto fn = std::make_unique<ops::TanhLayer>();
725
726   fn->configure(input_tensor, output_tensor);
727
728   _return_fn = std::move(fn);
729 }
730
731 void KernelGenerator::visit(const ir::operation::Pack &node)
732 {
733   const auto ofm_index{node.getOutputs().at(0)};
734
735   const auto rank = _ctx.at(ofm_index).shape().rank();
736   const auto axis = ops::getAxis(rank, node.param().axis, _current_op_seq_layout);
737
738   assert(-rank <= axis && axis < rank);
739
740   auto output_tensor = _tensor_builder->portableAt(ofm_index).get();
741
742   std::vector<const IPortableTensor *> input_tensors;
743   for (auto &ifm_idx : node.getInputs())
744     input_tensors.emplace_back(_tensor_builder->portableAt(ifm_idx).get());
745
746   auto fn = std::make_unique<ops::PackLayer>();
747
748   fn->configure(input_tensors, axis, output_tensor);
749
750   _return_fn = std::move(fn);
751 }
752
753 void KernelGenerator::visit(const ir::operation::Unpack &node)
754 {
755   const auto input_index{node.getInputs().at(0)};
756
757   const auto rank = _ctx.at(input_index).shape().rank();
758   const auto axis = ops::getAxis(rank, node.param().axis, _current_op_seq_layout);
759
760   assert(rank == 0 || (-rank <= axis && axis < rank));
761
762   auto input_tensor = _tensor_builder->portableAt(input_index).get();
763
764   std::vector<IPortableTensor *> output_tensors;
765   for (auto &output_idx : node.getOutputs())
766     output_tensors.emplace_back(_tensor_builder->portableAt(output_idx).get());
767
768   auto fn = std::make_unique<ops::UnpackLayer>();
769
770   uint32_t axis_resolved = (axis < 0 ? axis + rank : axis);
771
772   fn->configure(input_tensor, axis_resolved, node.param().num, output_tensors);
773
774   _return_fn = std::move(fn);
775 }
776
777 void KernelGenerator::visit(const ir::operation::Pad &node)
778 {
779   const auto input_index{node.getInputs().at(ir::operation::Pad::Input::INPUT)};
780   const auto pad_index{node.getInputs().at(ir::operation::Pad::Input::PAD)};
781   const auto output_index{node.getOutputs().at(0)};
782   assert(_ctx.at(pad_index).data());
783
784   auto input = _tensor_builder->portableAt(input_index).get();
785   auto output = _tensor_builder->portableAt(output_index).get();
786   auto pad_rank = _ctx.at(pad_index).shape().dim(0);
787   auto pad_base = reinterpret_cast<const int32_t *>(_ctx.at(pad_index).data()->base());
788
789   auto fn = std::make_unique<ops::PadLayer>();
790
791   bool isPadV2 = node.getInputs().size() == 3 ? true : false;
792   const void *value = nullptr;
793
794   if (isPadV2)
795   {
796     const auto value_index{node.getInputs().at(ir::operation::Pad::Input::VALUE)};
797     value = reinterpret_cast<const void *>(_ctx.at(value_index).data()->base());
798   }
799
800   fn->configure(input, output, pad_base, pad_rank, value);
801   _return_fn = std::move(fn);
802 }
803
804 void KernelGenerator::visit(const ir::operation::Max &node)
805 {
806   const auto ofm_index{node.getOutputs().at(0)};
807   const auto lhs_index{node.getInputs().at(ir::operation::Max::Input::LHS)};
808   const auto rhs_index{node.getInputs().at(ir::operation::Max::Input::RHS)};
809
810   auto ofm_tensor = _tensor_builder->portableAt(ofm_index).get();
811   auto lhs_tensor = _tensor_builder->portableAt(lhs_index).get();
812   auto rhs_tensor = _tensor_builder->portableAt(rhs_index).get();
813
814   auto fn = std::make_unique<ops::MaxLayer>();
815
816   fn->configure(lhs_tensor, rhs_tensor, ofm_tensor);
817
818   _return_fn = std::move(fn);
819 }
820
821 void KernelGenerator::visit(const ir::operation::Min &node)
822 {
823   const auto ofm_index{node.getOutputs().at(0)};
824   const auto lhs_index{node.getInputs().at(ir::operation::Min::Input::LHS)};
825   const auto rhs_index{node.getInputs().at(ir::operation::Min::Input::RHS)};
826
827   auto ofm_tensor = _tensor_builder->portableAt(ofm_index).get();
828   auto lhs_tensor = _tensor_builder->portableAt(lhs_index).get();
829   auto rhs_tensor = _tensor_builder->portableAt(rhs_index).get();
830
831   auto fn = std::make_unique<ops::MinLayer>();
832
833   fn->configure(lhs_tensor, rhs_tensor, ofm_tensor);
834
835   _return_fn = std::move(fn);
836 }
837
838 void KernelGenerator::visit(const ir::operation::Cast &node)
839 {
840   const auto ofm_index{node.getOutputs().at(0)};
841   const auto ifm_index{node.getInputs().at(ir::operation::Cast::Input::INPUT)};
842
843   auto ofm_tensor = _tensor_builder->portableAt(ofm_index).get();
844   auto ifm_tensor = _tensor_builder->portableAt(ifm_index).get();
845
846   auto fn = std::make_unique<ops::CastLayer>();
847
848   fn->configure(ifm_tensor, ofm_tensor);
849
850   _return_fn = std::move(fn);
851 }
852
853 void KernelGenerator::visit(const ir::operation::Transpose &node)
854 {
855   const auto output_index{node.getOutputs().at(0)};
856   const auto input_index{node.getInputs().at(ir::operation::Transpose::Input::INPUT)};
857
858   auto output_tensor = _tensor_builder->portableAt(output_index).get();
859   auto input_tensor = _tensor_builder->portableAt(input_index).get();
860
861   auto fn = std::make_unique<ops::TransposeLayer>();
862
863   fn->configure(input_tensor, output_tensor, node.param().perm);
864
865   _return_fn = std::move(fn);
866 }
867
868 void KernelGenerator::visit(const ir::operation::Reduce &node)
869 {
870   const auto output_index{node.getOutputs().at(0)};
871   const auto input_index{node.getInputs().at(ir::operation::Reduce::Input::INPUT)};
872   const auto axes_index{node.getInputs().at(ir::operation::Reduce::Input::AXES)};
873
874   const auto keep_dims = node.param().keep_dims;
875   auto output_tensor = _tensor_builder->portableAt(output_index).get();
876   auto input_tensor = _tensor_builder->portableAt(input_index).get();
877   auto axes_tensor = _tensor_builder->portableAt(axes_index).get();
878
879   if (node.param().reduce_type == ir::operation::Reduce::ReduceType::MEAN)
880   {
881     auto fn = std::make_unique<ops::MeanLayer>();
882
883     fn->configure(input_tensor, axes_tensor, output_tensor, keep_dims);
884
885     _return_fn = std::move(fn);
886   }
887   else
888   {
889     auto fn = std::make_unique<ops::ReduceLayer>();
890
891     const auto reduce_type = convertReduceType(node.param().reduce_type);
892     fn->configure(input_tensor, axes_tensor, output_tensor, reduce_type, keep_dims);
893
894     _return_fn = std::move(fn);
895   }
896 }
897
898 void KernelGenerator::visit(const ir::operation::ReLU &node)
899 {
900   const auto output_index{node.getOutputs().at(0)};
901   const auto input_index{node.getInputs().at(0)};
902
903   auto output_tensor = _tensor_builder->portableAt(output_index).get();
904   auto input_tensor = _tensor_builder->portableAt(input_index).get();
905
906   auto fn = std::make_unique<ops::ReLULayer>();
907
908   fn->configure(input_tensor, output_tensor);
909
910   _return_fn = std::move(fn);
911 }
912
913 void KernelGenerator::visit(const ir::operation::ReLU6 &node)
914 {
915   const auto output_index{node.getOutputs().at(0)};
916   const auto input_index{node.getInputs().at(0)};
917
918   auto output_tensor = _tensor_builder->portableAt(output_index).get();
919   auto input_tensor = _tensor_builder->portableAt(input_index).get();
920
921   auto fn = std::make_unique<ops::ReLU6Layer>();
922
923   fn->configure(input_tensor, output_tensor);
924
925   _return_fn = std::move(fn);
926 }
927
928 void KernelGenerator::visit(const ir::operation::Select &node)
929 {
930   const auto output_index{node.getOutputs().at(0)};
931   const auto condition_index{node.getInputs().at(ir::operation::Select::Input::CONDITION)};
932   const auto true_index{node.getInputs().at(ir::operation::Select::Input::INPUT_TRUE)};
933   const auto false_index{node.getInputs().at(ir::operation::Select::Input::INPUT_FALSE)};
934
935   auto output_tensor = _tensor_builder->portableAt(output_index).get();
936   auto condition_tensor = _tensor_builder->portableAt(condition_index).get();
937   auto true_tensor = _tensor_builder->portableAt(true_index).get();
938   auto false_tensor = _tensor_builder->portableAt(false_index).get();
939
940   auto fn = std::make_unique<ops::SelectLayer>();
941
942   fn->configure(condition_tensor, true_tensor, false_tensor, output_tensor);
943
944   _return_fn = std::move(fn);
945 }
946
947 void KernelGenerator::visit(const ir::operation::Slice &node)
948 {
949   const auto output_index{node.getOutputs().at(0)};
950   const auto input_index{node.getInputs().at(ir::operation::Slice::Input::INPUT)};
951   const auto begins_index{node.getInputs().at(ir::operation::Slice::Input::BEGINS)};
952   const auto sizes_index{node.getInputs().at(ir::operation::Slice::Input::SIZES)};
953
954   auto output_tensor = _tensor_builder->portableAt(output_index).get();
955   auto input_tensor = _tensor_builder->portableAt(input_index).get();
956   auto begins_tensor = _tensor_builder->portableAt(begins_index).get();
957   auto sizes_tensor = _tensor_builder->portableAt(sizes_index).get();
958
959   auto fn = std::make_unique<ops::SliceLayer>();
960
961   fn->configure(input_tensor, begins_tensor, sizes_tensor, output_tensor);
962
963   _return_fn = std::move(fn);
964 }
965
966 void KernelGenerator::visit(const ir::operation::StridedSlice &node)
967 {
968   const auto output_index{node.getOutputs().at(0)};
969   const auto input_index{node.getInputs().at(ir::operation::StridedSlice::Input::INPUT)};
970   const auto starts_index{node.getInputs().at(ir::operation::StridedSlice::Input::STARTS)};
971   const auto ends_index{node.getInputs().at(ir::operation::StridedSlice::Input::ENDS)};
972   const auto strides_index{node.getInputs().at(ir::operation::StridedSlice::Input::STRIDES)};
973
974   auto output_tensor = _tensor_builder->portableAt(output_index).get();
975   auto input_tensor = _tensor_builder->portableAt(input_index).get();
976   auto starts_tensor = _tensor_builder->portableAt(starts_index).get();
977   auto ends_tensor = _tensor_builder->portableAt(ends_index).get();
978   auto strides_tensor = _tensor_builder->portableAt(strides_index).get();
979
980   auto begin_mask = node.param().begin_mask;
981   auto end_mask = node.param().end_mask;
982   auto shrink_axis_mask = node.param().shrink_axis_mask;
983
984   auto fn = std::make_unique<ops::StridedSliceLayer>();
985
986   fn->configure(input_tensor, starts_tensor, ends_tensor, strides_tensor, output_tensor, begin_mask,
987                 end_mask, shrink_axis_mask);
988
989   _return_fn = std::move(fn);
990 }
991
992 void KernelGenerator::visit(const ir::operation::Split &node)
993 {
994   const auto num_splits = node.param().num_splits;
995   assert(num_splits == static_cast<int>(node.getOutputs().size()));
996
997   const auto input_idx{node.getInputs().at(ir::operation::Split::Input::INPUT)};
998   const auto rank = _ctx.at(input_idx).shape().rank();
999   const auto axis = ops::getAxis(rank, node.param().axis, _current_op_seq_layout);
1000   auto axis_resolved = axis < 0 ? axis + rank : axis;
1001
1002   auto in_tensor = _tensor_builder->portableAt(input_idx).get();
1003
1004   std::vector<IPortableTensor *> out_tensors;
1005   for (auto &output_idx : node.getOutputs())
1006     out_tensors.emplace_back(_tensor_builder->portableAt(output_idx).get());
1007
1008   auto fn = std::make_unique<ops::SplitLayer>();
1009
1010   fn->configure(in_tensor, num_splits, axis_resolved, out_tensors);
1011
1012   _return_fn = std::move(fn);
1013 }
1014
1015 void KernelGenerator::visit(const ir::operation::Abs &node)
1016 {
1017   const auto ofm_index{node.getOutputs().at(0)};
1018   const auto ifm_index{node.getInputs().at(ir::operation::Abs::Input::INPUT)};
1019
1020   auto ofm_tensor = _tensor_builder->portableAt(ofm_index).get();
1021   auto ifm_tensor = _tensor_builder->portableAt(ifm_index).get();
1022
1023   auto fn = std::make_unique<ops::AbsLayer>();
1024
1025   fn->configure(ifm_tensor, ofm_tensor);
1026
1027   _return_fn = std::move(fn);
1028 }
1029
1030 void KernelGenerator::visit(const ir::operation::Sin &node)
1031 {
1032   const auto ofm_index{node.getOutputs().at(0)};
1033   const auto ifm_index{node.getInputs().at(ir::operation::Sin::Input::INPUT)};
1034
1035   auto ofm_tensor = _tensor_builder->portableAt(ofm_index).get();
1036   auto ifm_tensor = _tensor_builder->portableAt(ifm_index).get();
1037
1038   auto fn = std::make_unique<ops::SinLayer>();
1039
1040   fn->configure(ifm_tensor, ofm_tensor);
1041
1042   _return_fn = std::move(fn);
1043 }
1044
1045 void KernelGenerator::visit(const ir::operation::Cos &node)
1046 {
1047   const auto ofm_index{node.getOutputs().at(0)};
1048   const auto ifm_index{node.getInputs().at(ir::operation::Cos::Input::INPUT)};
1049
1050   auto ofm_tensor = _tensor_builder->portableAt(ofm_index).get();
1051   auto ifm_tensor = _tensor_builder->portableAt(ifm_index).get();
1052
1053   auto fn = std::make_unique<ops::CosLayer>();
1054
1055   fn->configure(ifm_tensor, ofm_tensor);
1056
1057   _return_fn = std::move(fn);
1058 }
1059
1060 void KernelGenerator::visit(const ir::operation::RSQRT &node)
1061 {
1062   const auto ofm_index{node.getOutputs().at(0)};
1063   const auto ifm_index{node.getInputs().at(ir::operation::RSQRT::Input::INPUT)};
1064
1065   auto ofm_tensor = _tensor_builder->portableAt(ofm_index).get();
1066   auto ifm_tensor = _tensor_builder->portableAt(ifm_index).get();
1067
1068   auto fn = std::make_unique<ops::RsqrtLayer>();
1069
1070   fn->configure(ifm_tensor, ofm_tensor);
1071
1072   _return_fn = std::move(fn);
1073 }
1074
1075 void KernelGenerator::visit(const ir::operation::Shape &node)
1076 {
1077   const auto ofm_index{node.getOutputs().at(0)};
1078   const auto ifm_index{node.getInputs().at(ir::operation::Shape::Input::INPUT)};
1079
1080   auto ofm_tensor = _tensor_builder->portableAt(ofm_index).get();
1081   auto ifm_tensor = _tensor_builder->portableAt(ifm_index).get();
1082
1083   auto fn = std::make_unique<ops::ShapeLayer>();
1084
1085   fn->configure(ifm_tensor, ofm_tensor);
1086
1087   _return_fn = std::move(fn);
1088 }
1089
1090 void KernelGenerator::visit(const ir::operation::ResizeBilinear &node)
1091 {
1092   const auto output_index{node.getOutputs().at(0)};
1093   const auto input_index{node.getInputs().at(ir::operation::ResizeBilinear::INPUT)};
1094
1095   auto output_height = node.param().height_out;
1096   auto output_width = node.param().width_out;
1097   auto align_corners = node.param().align_corners;
1098   auto half_pixel_centers = node.param().half_pixel_centers;
1099
1100   auto output_tensor = _tensor_builder->portableAt(output_index).get();
1101   auto input_tensor = _tensor_builder->portableAt(input_index).get();
1102
1103   auto fn = std::make_unique<ops::ResizeBilinearLayer>();
1104
1105   fn->configure(input_tensor, output_tensor, output_height, output_width, align_corners,
1106                 half_pixel_centers);
1107
1108   _return_fn = std::move(fn);
1109 }
1110
1111 void KernelGenerator::visit(const ir::operation::Reverse &node)
1112 {
1113   const auto output_index{node.getOutputs().at(0)};
1114   const auto input_index{node.getInputs().at(ir::operation::Reverse::INPUT)};
1115   const auto axis_index{node.getInputs().at(ir::operation::Reverse::AXIS)};
1116
1117   auto output_tensor = _tensor_builder->portableAt(output_index).get();
1118   auto input_tensor = _tensor_builder->portableAt(input_index).get();
1119   auto axis_tensor = _tensor_builder->portableAt(axis_index).get();
1120
1121   auto fn = std::make_unique<ops::ReverseLayer>();
1122
1123   fn->configure(input_tensor, axis_tensor, output_tensor);
1124
1125   _return_fn = std::move(fn);
1126 }
1127
1128 void KernelGenerator::visit(const ir::operation::Neg &node)
1129 {
1130   const auto ofm_index{node.getOutputs().at(0)};
1131   const auto ifm_index{node.getInputs().at(ir::operation::Neg::Input::INPUT)};
1132
1133   auto ofm_tensor = _tensor_builder->portableAt(ofm_index).get();
1134   auto ifm_tensor = _tensor_builder->portableAt(ifm_index).get();
1135
1136   auto fn = std::make_unique<ops::NegLayer>();
1137
1138   fn->configure(ifm_tensor, ofm_tensor);
1139
1140   _return_fn = std::move(fn);
1141 }
1142
1143 void KernelGenerator::visit(const ir::operation::ArgMax &node)
1144 {
1145   const auto output_index{node.getOutputs().at(0)};
1146   const auto input_index{node.getInputs().at(ir::operation::ArgMax::INPUT)};
1147
1148   const auto axis = node.param().axis;
1149
1150   auto output_tensor = _tensor_builder->portableAt(output_index).get();
1151   auto input_tensor = _tensor_builder->portableAt(input_index).get();
1152
1153   auto fn = std::make_unique<ops::ArgMinMaxLayer>();
1154
1155   fn->configure(input_tensor, output_tensor, axis, /* is_arg_max */ true);
1156
1157   _return_fn = std::move(fn);
1158 }
1159
1160 void KernelGenerator::visit(const ir::operation::Pow &node)
1161 {
1162   const auto output_index{node.getOutputs().at(0)};
1163   const auto lhs_index{node.getInputs().at(ir::operation::Pow::LHS)};
1164   const auto rhs_index{node.getInputs().at(ir::operation::Pow::RHS)};
1165
1166   auto output_tensor = _tensor_builder->portableAt(output_index).get();
1167   auto lhs_tensor = _tensor_builder->portableAt(lhs_index).get();
1168   auto rhs_tensor = _tensor_builder->portableAt(rhs_index).get();
1169
1170   auto fn = std::make_unique<ops::PowLayer>();
1171
1172   fn->configure(lhs_tensor, rhs_tensor, ir::Activation::NONE, output_tensor);
1173
1174   _return_fn = std::move(fn);
1175 }
1176
1177 void KernelGenerator::visit(const ir::operation::Log &node)
1178 {
1179   const auto ofm_index{node.getOutputs().at(0)};
1180   const auto ifm_index{node.getInputs().at(ir::operation::Log::Input::INPUT)};
1181
1182   auto ofm_tensor = _tensor_builder->portableAt(ofm_index).get();
1183   auto ifm_tensor = _tensor_builder->portableAt(ifm_index).get();
1184
1185   auto fn = std::make_unique<ops::LogLayer>();
1186
1187   fn->configure(ifm_tensor, ofm_tensor);
1188
1189   _return_fn = std::move(fn);
1190 }
1191
1192 void KernelGenerator::visit(const ir::operation::Round &node)
1193 {
1194   const auto output_index{node.getOutputs().at(0)};
1195   const auto input_index{node.getInputs().at(ir::operation::Round::INPUT)};
1196
1197   auto output_tensor = _tensor_builder->portableAt(output_index).get();
1198   auto input_tensor = _tensor_builder->portableAt(input_index).get();
1199
1200   auto fn = std::make_unique<ops::RoundLayer>();
1201
1202   fn->configure(input_tensor, output_tensor);
1203
1204   _return_fn = std::move(fn);
1205 }
1206
1207 void KernelGenerator::visit(const ir::operation::LogicalNot &node)
1208 {
1209   const auto output_index{node.getOutputs().at(0)};
1210   const auto input_index{node.getInputs().at(ir::operation::LogicalNot::INPUT)};
1211
1212   auto output_tensor = _tensor_builder->portableAt(output_index).get();
1213   auto input_tensor = _tensor_builder->portableAt(input_index).get();
1214
1215   auto fn = std::make_unique<ops::LogicalNotLayer>();
1216
1217   fn->configure(input_tensor, output_tensor);
1218
1219   _return_fn = std::move(fn);
1220 }
1221
1222 void KernelGenerator::visit(const ir::operation::LogicalOr &node)
1223 {
1224   const auto ofm_index{node.getOutputs().at(0)};
1225   const auto lhs_index{node.getInputs().at(0)};
1226   const auto rhs_index{node.getInputs().at(1)};
1227
1228   auto ofm_tensor = _tensor_builder->portableAt(ofm_index).get();
1229   auto lhs_tensor = _tensor_builder->portableAt(lhs_index).get();
1230   auto rhs_tensor = _tensor_builder->portableAt(rhs_index).get();
1231
1232   auto fn = std::make_unique<ops::LogicalOrLayer>();
1233
1234   fn->configure(lhs_tensor, rhs_tensor, ofm_tensor);
1235
1236   _return_fn = std::move(fn);
1237 }
1238
1239 void KernelGenerator::visit(const ir::operation::L2Normalization &node)
1240 {
1241   const auto output_index{node.getOutputs().at(0)};
1242   const auto input_index{node.getInputs().at(0)};
1243
1244   auto output_alloc = _tensor_builder->portableAt(output_index).get();
1245   auto input_alloc = _tensor_builder->portableAt(input_index).get();
1246
1247   auto fn = std::make_unique<ops::L2NormLayer>();
1248
1249   fn->configure(input_alloc, output_alloc);
1250
1251   _return_fn = std::move(fn);
1252 }
1253
1254 void KernelGenerator::visit(const ir::operation::ZerosLike &node)
1255 {
1256   const auto output_index{node.getOutputs().at(0)};
1257   const auto input_index{node.getInputs().at(ir::operation::ZerosLike::INPUT)};
1258
1259   auto output_tensor = _tensor_builder->portableAt(output_index).get();
1260   auto input_tensor = _tensor_builder->portableAt(input_index).get();
1261
1262   auto fn = std::make_unique<ops::ZerosLikeLayer>();
1263
1264   fn->configure(input_tensor, output_tensor);
1265   _return_fn = std::move(fn);
1266 }
1267
1268 void KernelGenerator::visit(const ir::operation::Range &node)
1269 {
1270   const auto output_index{node.getOutputs().at(0)};
1271   const auto start_index{node.getInputs().at(ir::operation::Range::START)};
1272   const auto limit_index{node.getInputs().at(ir::operation::Range::LIMIT)};
1273   const auto delta_index{node.getInputs().at(ir::operation::Range::DELTA)};
1274
1275   auto output_tensor = _tensor_builder->portableAt(output_index).get();
1276   auto start_tensor = _tensor_builder->portableAt(start_index).get();
1277   auto limit_tensor = _tensor_builder->portableAt(limit_index).get();
1278   auto delta_tensor = _tensor_builder->portableAt(delta_index).get();
1279
1280   auto fn = std::make_unique<ops::RangeLayer>();
1281
1282   fn->configure(start_tensor, limit_tensor, delta_tensor, output_tensor);
1283   _return_fn = std::move(fn);
1284 }
1285
1286 void KernelGenerator::visit(const ir::operation::SquaredDifference &node)
1287 {
1288   const auto ofm_index{node.getOutputs().at(0)};
1289   const auto lhs_index{node.getInputs().at(ir::operation::SquaredDifference::Input::LHS)};
1290   const auto rhs_index{node.getInputs().at(ir::operation::SquaredDifference::Input::RHS)};
1291
1292   auto ofm_tensor = _tensor_builder->portableAt(ofm_index).get();
1293   auto lhs_tensor = _tensor_builder->portableAt(lhs_index).get();
1294   auto rhs_tensor = _tensor_builder->portableAt(rhs_index).get();
1295
1296   auto fn = std::make_unique<ops::SqDiffLayer>();
1297
1298   fn->configure(lhs_tensor, rhs_tensor, ofm_tensor);
1299   _return_fn = std::move(fn);
1300 }
1301
1302 void KernelGenerator::visit(const ir::operation::Tile &node)
1303 {
1304   const auto output_index{node.getOutputs().at(0)};
1305   const auto input_index{node.getInputs().at(ir::operation::Tile::INPUT)};
1306   const auto multiples_index{node.getInputs().at(ir::operation::Tile::MULTIPLES)};
1307
1308   auto output_tensor = _tensor_builder->portableAt(output_index).get();
1309   auto input_tensor = _tensor_builder->portableAt(input_index).get();
1310   auto multiples_tensor = _tensor_builder->portableAt(multiples_index).get();
1311
1312   auto fn = std::make_unique<ops::TileLayer>();
1313
1314   fn->configure(input_tensor, multiples_tensor, output_tensor);
1315   _return_fn = std::move(fn);
1316 }
1317
1318 void KernelGenerator::visit(const ir::operation::MatrixBandPart &node)
1319 {
1320   const auto output_index{node.getOutputs().at(0)};
1321   const auto input_index{node.getInputs().at(ir::operation::MatrixBandPart::INPUT)};
1322   const auto num_lower_index{node.getInputs().at(ir::operation::MatrixBandPart::NUM_LOWER_DIAG)};
1323   const auto num_upper_index{node.getInputs().at(ir::operation::MatrixBandPart::NUM_UPPER_DIAG)};
1324
1325   auto output_tensor = _tensor_builder->portableAt(output_index).get();
1326   auto input_tensor = _tensor_builder->portableAt(input_index).get();
1327   auto num_lower_tensor = _tensor_builder->portableAt(num_lower_index).get();
1328   auto num_upper_tensor = _tensor_builder->portableAt(num_upper_index).get();
1329
1330   auto fn = std::make_unique<ops::MatrixBandPartLayer>();
1331
1332   fn->configure(input_tensor, num_lower_tensor, num_upper_tensor, output_tensor);
1333   _return_fn = std::move(fn);
1334 }
1335
1336 void KernelGenerator::visit(const ir::operation::BatchMatMul &node)
1337 {
1338   const auto output_index{node.getOutputs().at(0)};
1339   const auto lhs_index{node.getInputs().at(ir::operation::BatchMatMul::LHS)};
1340   const auto rhs_index{node.getInputs().at(ir::operation::BatchMatMul::RHS)};
1341
1342   auto output_tensor = _tensor_builder->portableAt(output_index).get();
1343   auto lhs_tensor = _tensor_builder->portableAt(lhs_index).get();
1344   auto rhs_tensor = _tensor_builder->portableAt(rhs_index).get();
1345
1346   const auto adj_x = node.param().adj_x;
1347   const auto adj_y = node.param().adj_y;
1348
1349   auto fn = std::make_unique<ops::BatchMatMulLayer>();
1350
1351   fn->configure(lhs_tensor, rhs_tensor, adj_x, adj_y, output_tensor);
1352   _return_fn = std::move(fn);
1353 }
1354
1355 void KernelGenerator::visit(const ir::operation::BroadcastTo &node)
1356 {
1357   const auto output_index{node.getOutputs().at(0)};
1358   const auto input_index{node.getInputs().at(ir::operation::BroadcastTo::INPUT)};
1359   const auto shape_index{node.getInputs().at(ir::operation::BroadcastTo::SHAPE)};
1360
1361   auto output_tensor = _tensor_builder->portableAt(output_index).get();
1362   auto input_tensor = _tensor_builder->portableAt(input_index).get();
1363   auto shape_tensor = _tensor_builder->portableAt(shape_index).get();
1364
1365   auto fn = std::make_unique<ops::BroadcastToLayer>();
1366
1367   fn->configure(input_tensor, shape_tensor, output_tensor);
1368
1369   _return_fn = std::move(fn);
1370 }
1371
1372 void KernelGenerator::visit(const ir::operation::FusedBatchNorm &node)
1373 {
1374   const auto ofm_index{node.getOutputs().at(0)};
1375
1376   auto output_tensor = _tensor_builder->portableAt(ofm_index).get();
1377   std::vector<const IPortableTensor *> input_tensors;
1378   for (auto &ifm_idx : node.getInputs())
1379     input_tensors.emplace_back(_tensor_builder->portableAt(ifm_idx).get());
1380
1381   const auto epsilon = node.param().epsilon;
1382   const auto is_training = node.param().is_training;
1383   const auto data_format = node.param().data_format;
1384
1385   auto fn = std::make_unique<ops::FusedBatchNormLayer>();
1386
1387   fn->configure(input_tensors, epsilon, is_training, data_format, output_tensor);
1388
1389   _return_fn = std::move(fn);
1390 }
1391
1392 void KernelGenerator::visit(const ir::operation::LogSoftmax &node)
1393 {
1394   const auto output_index{node.getOutputs().at(0)};
1395   const auto input_index{node.getInputs().at(ir::operation::LogSoftmax::Input::INPUT)};
1396
1397   const auto beta = node.param().beta;
1398   const auto axis = node.param().axis;
1399
1400   auto output_tensor = _tensor_builder->portableAt(output_index).get();
1401   auto input_tensor = _tensor_builder->portableAt(input_index).get();
1402
1403   auto fn = std::make_unique<ops::LogSoftMaxLayer>();
1404
1405   fn->configure(input_tensor, beta, axis, output_tensor);
1406
1407   _return_fn = std::move(fn);
1408 }
1409
1410 void KernelGenerator::visit(const ir::operation::SpaceToBatchND &node)
1411 {
1412   const auto output_index{node.getOutputs().at(0)};
1413   const auto input_index{node.getInputs().at(ir::operation::SpaceToBatchND::INPUT)};
1414   const auto block_shape_index{node.getInputs().at(ir::operation::SpaceToBatchND::BLOCK_SIZE)};
1415   const auto padding_index{node.getInputs().at(ir::operation::SpaceToBatchND::PADDINGS)};
1416
1417   auto output_tensor = _tensor_builder->portableAt(output_index).get();
1418   auto input_tensor = _tensor_builder->portableAt(input_index).get();
1419   auto block_shape_tensor = _tensor_builder->portableAt(block_shape_index).get();
1420   auto padding_tensor = _tensor_builder->portableAt(padding_index).get();
1421
1422   auto fn = std::make_unique<ops::SpaceToBatchNDLayer>();
1423
1424   fn->configure(input_tensor, block_shape_tensor, padding_tensor, output_tensor);
1425
1426   _return_fn = std::move(fn);
1427 }
1428
1429 void KernelGenerator::visit(const ir::operation::Quantize &node)
1430 {
1431   const auto input_index{node.getInputs().at(ir::operation::Quantize::Input::INPUT)};
1432   const auto output_index{node.getOutputs().at(0)};
1433
1434   auto input_tensor = _tensor_builder->portableAt(input_index).get();
1435   auto output_tensor = _tensor_builder->portableAt(output_index).get();
1436
1437   auto fn = std::make_unique<ops::QuantizeLayer>();
1438
1439   fn->configure(input_tensor, output_tensor);
1440
1441   _return_fn = std::move(fn);
1442 }
1443
1444 void KernelGenerator::visit(const ir::operation::SpaceToDepth &node)
1445 {
1446   const auto input_index{node.getInputs().at(ir::operation::SpaceToDepth::Input::INPUT)};
1447   const auto output_index{node.getOutputs().at(0)};
1448   auto block_size = node.param().block_size;
1449
1450   auto input_tensor = _tensor_builder->portableAt(input_index).get();
1451   auto output_tensor = _tensor_builder->portableAt(output_index).get();
1452
1453   auto fn = std::make_unique<ops::SpaceToDepthLayer>();
1454
1455   fn->configure(input_tensor, block_size, output_tensor);
1456   _return_fn = std::move(fn);
1457 }
1458
1459 void KernelGenerator::visit(const ir::operation::StatelessRandomUniform &node)
1460 {
1461   const auto output_index{node.getOutputs().at(0)};
1462   const auto shape_index{node.getInputs().at(ir::operation::StatelessRandomUniform::SHAPE)};
1463   const auto seed_index{node.getInputs().at(ir::operation::StatelessRandomUniform::SEED)};
1464
1465   auto output_alloc = _tensor_builder->portableAt(output_index).get();
1466   auto shape_alloc = _tensor_builder->portableAt(shape_index).get();
1467   auto seed_alloc = _tensor_builder->portableAt(seed_index).get();
1468
1469   auto fn = std::make_unique<ops::StatelessRandomUniformLayer>();
1470
1471   fn->configure(shape_alloc, seed_alloc, output_alloc);
1472   _return_fn = std::move(fn);
1473 }
1474
1475 void KernelGenerator::visit(const ir::operation::SplitV &node)
1476 {
1477   const auto num_splits = node.param().num_splits;
1478   assert(num_splits == static_cast<int>(node.getOutputs().size()));
1479
1480   const auto input_idx{node.getInputs().at(ir::operation::SplitV::Input::INPUT)};
1481   const auto size_splits{node.getInputs().at(ir::operation::SplitV::Input::SIZE_SPLITS)};
1482   const auto split_dim{node.getInputs().at(ir::operation::SplitV::Input::SPLIT_DIM)};
1483
1484   auto in_tensor = _tensor_builder->portableAt(input_idx).get();
1485   auto in_size_splits = _tensor_builder->portableAt(size_splits).get();
1486   auto in_split_dim = _tensor_builder->portableAt(split_dim).get();
1487
1488   std::vector<IPortableTensor *> out_tensors;
1489   for (auto &output_idx : node.getOutputs())
1490     out_tensors.emplace_back(_tensor_builder->portableAt(output_idx).get());
1491
1492   auto fn = std::make_unique<ops::SplitVLayer>();
1493
1494   fn->configure(in_tensor, in_size_splits, in_split_dim, num_splits, out_tensors);
1495
1496   _return_fn = std::move(fn);
1497 }
1498
1499 } // namespace cpu
1500 } // namespace backend
1501 } // namespace onert