2 * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
22 #include <tensorflow/lite/kernels/internal/optimized/optimized_ops.h>
24 namespace luci_interpreter
29 SplitV::SplitV(const Tensor *input, const Tensor *size_splits, const Tensor *axis,
30 std::vector<Tensor *> outputs)
31 : Kernel({input, size_splits, axis}, std::move(outputs))
35 void SplitV::configure()
37 assert(axis()->shape().num_elements() == 1);
38 _axis_value = getTensorData<int32_t>(axis())[0];
40 _axis_value += input()->shape().num_dims();
41 assert(_axis_value >= 0 && _axis_value < input()->shape().num_dims());
43 auto num_split = static_cast<int32_t>(_outputs.size());
44 auto sizes_data = getTensorData<int32_t>(size_splits());
46 assert(size_splits()->shape().num_dims() == 1);
49 const auto num_dims_size_spits = size_splits()->shape().dim(0);
50 int32_t count_neg_dim = 0;
52 for (int32_t i = 0; i < num_dims_size_spits - 1; ++i)
54 if (sizes_data[i] != -1)
63 assert(count_neg_dim < 2);
64 assert(size_splits()->shape().num_elements() == num_split);
66 // TODO: enable it only if kernel with dynamic shapes
67 auto output_shape = input()->shape();
68 for (int32_t i = 0; i < num_split; ++i)
70 if (sizes_data[i] == -1)
72 output_shape.dim(_axis_value) = input()->shape().dim(_axis_value) - sum;
76 output_shape.dim(_axis_value) = sizes_data[i];
78 _outputs[i]->resize(output_shape);
82 void SplitV::execute() const
84 tflite::SplitParams params{};
85 params.num_split = _outputs.size();
86 params.axis = _axis_value;
88 #define TF_LITE_SPLIT(scalar) \
90 VectorOfTensors<scalar, false> all_outputs(_outputs); \
91 tflite::optimized_ops::Split(params, getTensorShape(input()), getTensorData<scalar>(input()), \
92 all_outputs.shapes(), all_outputs.data()); \
95 switch (input()->element_type())
97 case DataType::FLOAT32:
101 TF_LITE_SPLIT(uint8_t);
104 TF_LITE_SPLIT(int16_t);
107 assert(false && "Unsupported type.");
112 } // namespace kernels
113 } // namespace luci_interpreter