Imported Upstream version 1.8.0
[platform/core/ml/nnfw.git] / runtime / onert / backend / cpu / ops / SplitVLayer.cc
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "SplitVLayer.h"
18
19 #include "OperationUtils.h"
20
21 #include <cker/operation/SplitV.h>
22
23 namespace onert
24 {
25 namespace backend
26 {
27 namespace cpu
28 {
29 namespace ops
30 {
31
32 SplitVLayer::SplitVLayer()
33     : _input(nullptr), _size_splits(nullptr), _split_dim(nullptr), _num_splits(0), _outputs()
34 {
35   // DO NOTHING
36 }
37
38 template <typename T> void SplitVLayer::splitV(void)
39 {
40   nnfw::cker::SplitVParams op_params;
41   op_params.axis = *(reinterpret_cast<const int32_t *>(_split_dim->buffer()));
42   op_params.num_split = _num_splits;
43
44   std::vector<T *> outputPtrs;
45   std::vector<nnfw::cker::Shape> outshape;
46
47   for (const auto output : _outputs)
48   {
49     assert(output->total_size() == sizeOfData(output->data_type(), output->getShape().dims()));
50     outputPtrs.emplace_back(reinterpret_cast<T *>(output->buffer()));
51     outshape.emplace_back(getTensorShape(output));
52   }
53
54   assert(_input->total_size() == sizeOfData(_input->data_type(), _input->getShape().dims()));
55   nnfw::cker::SplitV<T>(op_params, getTensorShape(_input), reinterpret_cast<T *>(_input->buffer()),
56                         outshape, outputPtrs.data());
57 }
58
59 void SplitVLayer::configure(const IPortableTensor *input, const IPortableTensor *size_splits,
60                             const IPortableTensor *split_dim, uint16_t num_splits,
61                             std::vector<IPortableTensor *> &outputs)
62 {
63   assert(input != nullptr);
64
65   _num_splits = num_splits;
66   _size_splits = size_splits;
67   _input = input;
68   _split_dim = split_dim;
69   _outputs = outputs;
70 }
71
72 void SplitVLayer::run()
73 {
74   if (_input->data_type() == OperandType::FLOAT32)
75   {
76     splitV<float>();
77   }
78   else if (_input->data_type() == OperandType::QUANT_UINT8_ASYMM)
79   {
80     splitV<uint8_t>();
81   }
82   else if (_input->data_type() == OperandType::INT32)
83   {
84     splitV<int32_t>();
85   }
86   else if (_input->data_type() == OperandType::INT64)
87   {
88     splitV<int64_t>();
89   }
90   else
91   {
92     throw std::runtime_error{"SplitV: unsupported input type"};
93   }
94 }
95
96 } // namespace ops
97 } // namespace cpu
98 } // namespace backend
99 } // namespace onert