cc8cf3a8954a63c0225b2266a7c03b49c9a0cfa0
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / src / kernels / Unpack.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #include "kernels/Unpack.h"
19
20 #include "kernels/Utils.h"
21
22 #include <tensorflow/lite/kernels/internal/reference/reference_ops.h>
23
24 namespace luci_interpreter
25 {
26
27 namespace kernels
28 {
29
30 Unpack::Unpack(const Tensor *input, std::vector<Tensor *> outputs, const UnpackParams &params)
31   : KernelWithParams<UnpackParams>({input}, std::move(outputs), params)
32 {
33 }
34
35 void Unpack::configure()
36 {
37   const Shape &input_shape = input()->shape();
38
39   int axis = _params.axis;
40   if (axis < 0)
41     axis += input()->shape().num_dims();
42   assert(axis >= 0 && axis < input_shape.num_dims());
43
44   Shape output_shape(input_shape.num_dims() - 1);
45   int out_index = 0;
46   for (int in_index = 0; in_index < input_shape.num_dims(); ++in_index)
47   {
48     if (in_index != axis)
49       output_shape.dim(out_index++) = input_shape.dim(in_index);
50   }
51
52   // TODO: enable it only if kernel with dynamic shapes
53   for (Tensor *output : _outputs)
54   {
55     assert(output->element_type() == input()->element_type());
56     output->resize(output_shape);
57   }
58 }
59
60 template <typename T> void Unpack::executeImpl() const
61 {
62   tflite::UnpackParams params{};
63   params.axis = _params.axis;
64   params.num_split = _outputs.size();
65   VectorOfTensors<T, false> all_outputs(_outputs);
66   tflite::reference_ops::Unpack<T>(params, getTensorShape(input()), getTensorData<T>(input()),
67                                    **all_outputs.shapes(), all_outputs.data());
68 }
69
70 void Unpack::execute() const
71 {
72   switch (input()->element_type())
73   {
74     case DataType::FLOAT32:
75       return executeImpl<float>();
76     case DataType::U8:
77       return executeImpl<uint8_t>();
78     default:
79       assert(false && "Unsupported type.");
80   }
81 }
82
83 } // namespace kernels
84 } // namespace luci_interpreter