Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / runtime / onert / backend / cpu / ops / SelectLayer.cc
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "SelectLayer.h"
18
19 #include "OperationUtils.h"
20
21 #include <cker/operation/Select.h>
22
23 namespace onert
24 {
25 namespace backend
26 {
27 namespace cpu
28 {
29 namespace ops
30 {
31
32 SelectLayer::SelectLayer()
33   : _cond(nullptr), _input_true(nullptr), _input_false(nullptr), _output(nullptr)
34 {
35   // DO NOTHING
36 }
37
38 void SelectLayer::configure(const IPortableTensor *cond, const IPortableTensor *input_true,
39                             const IPortableTensor *input_false, IPortableTensor *output)
40 {
41   _cond = cond;
42   _input_true = input_true;
43   _input_false = input_false;
44   _output = output;
45 }
46
47 void SelectLayer::run()
48 {
49
50 #define KERNEL_SELECT(type, op)                                                     \
51   nnfw::cker::op(getShape(_cond), getBuffer<uint8_t>(_cond), getShape(_input_true), \
52                  getBuffer<type>(_input_true), getShape(_input_false),              \
53                  getBuffer<type>(_input_false), getShape(_output), getBuffer<type>(_output));
54
55 #define KERNEL_SWITCH(type, op)                                  \
56   switch (type)                                                  \
57   {                                                              \
58     break;                                                       \
59     case OperandType::FLOAT32:                                   \
60       KERNEL_SELECT(float, op);                                  \
61       break;                                                     \
62     default:                                                     \
63       throw std::runtime_error{"Select: unsupported data type"}; \
64   }
65
66   auto input_type = _input_true->data_type();
67   bool require_broadcast =
68     !HaveSameShapes(_input_true, _cond) || !HaveSameShapes(_input_false, _cond);
69   bool rank_one_select = ((_input_true->getShape().rank() == 1) && !require_broadcast);
70
71   if (rank_one_select)
72   {
73     KERNEL_SWITCH(input_type, RankOneSelect);
74   }
75   else if (require_broadcast)
76   {
77     KERNEL_SWITCH(input_type, BroadcastSelect4DSlow);
78   }
79   else
80   {
81     KERNEL_SWITCH(input_type, Select);
82   }
83 }
84
85 } // namespace ops
86 } // namespace cpu
87 } // namespace backend
88 } // namespace onert