2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "SelectLayer.h"
19 #include "OperationUtils.h"
21 #include <cker/operation/Select.h>
32 SelectLayer::SelectLayer()
33 : _cond(nullptr), _input_true(nullptr), _input_false(nullptr), _output(nullptr)
38 void SelectLayer::configure(const IPortableTensor *cond, const IPortableTensor *input_true,
39 const IPortableTensor *input_false, IPortableTensor *output)
42 _input_true = input_true;
43 _input_false = input_false;
47 void SelectLayer::run()
50 #define KERNEL_SELECT(type, op) \
51 nnfw::cker::op(getShape(_cond), getBuffer<uint8_t>(_cond), getShape(_input_true), \
52 getBuffer<type>(_input_true), getShape(_input_false), \
53 getBuffer<type>(_input_false), getShape(_output), getBuffer<type>(_output));
55 #define KERNEL_SWITCH(type, op) \
59 case OperandType::FLOAT32: \
60 KERNEL_SELECT(float, op); \
63 throw std::runtime_error{"Select: unsupported data type"}; \
66 auto input_type = _input_true->data_type();
67 bool require_broadcast =
68 !HaveSameShapes(_input_true, _cond) || !HaveSameShapes(_input_false, _cond);
69 bool rank_one_select = ((_input_true->getShape().rank() == 1) && !require_broadcast);
73 KERNEL_SWITCH(input_type, RankOneSelect);
75 else if (require_broadcast)
77 KERNEL_SWITCH(input_type, BroadcastSelect4DSlow);
81 KERNEL_SWITCH(input_type, Select);
87 } // namespace backend