Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / src / kernels / Unpack.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "kernels/Unpack.h"
18
19 #include "kernels/Utils.h"
20
21 #include <tensorflow/lite/kernels/internal/reference/reference_ops.h>
22
23 namespace luci_interpreter
24 {
25
26 namespace kernels
27 {
28
29 Unpack::Unpack(const Tensor *input, std::vector<Tensor *> outputs, const UnpackParams &params)
30   : KernelWithParams<UnpackParams>({input}, std::move(outputs), params)
31 {
32 }
33
34 void Unpack::configure()
35 {
36   const Shape &input_shape = input()->shape();
37
38   int axis = _params.axis;
39   if (axis < 0)
40     axis += input()->shape().num_dims();
41   assert(axis >= 0 && axis < input_shape.num_dims());
42
43   Shape output_shape(input_shape.num_dims() - 1);
44   int out_index = 0;
45   for (int in_index = 0; in_index < input_shape.num_dims(); ++in_index)
46   {
47     if (in_index != axis)
48       output_shape.dim(out_index++) = input_shape.dim(in_index);
49   }
50
51   // TODO: enable it only if kernel with dynamic shapes
52   for (Tensor *output : _outputs)
53   {
54     assert(output->element_type() == input()->element_type());
55     output->resize(output_shape);
56   }
57 }
58
59 template <typename T> void Unpack::executeImpl() const
60 {
61   tflite::UnpackParams params{};
62   params.axis = _params.axis;
63   params.num_split = _outputs.size();
64   VectorOfTensors<T, false> all_outputs(_outputs);
65   tflite::reference_ops::Unpack<T>(params, getTensorShape(input()), getTensorData<T>(input()),
66                                    **all_outputs.shapes(), all_outputs.data());
67 }
68
69 void Unpack::execute() const
70 {
71   switch (input()->element_type())
72   {
73     case DataType::FLOAT32:
74       return executeImpl<float>();
75     case DataType::U8:
76       return executeImpl<uint8_t>();
77     default:
78       assert(false && "Unsupported type.");
79   }
80 }
81
82 } // namespace kernels
83 } // namespace luci_interpreter