62228f16d00dc9e0c7ffcb7a3e6075c2723a07cd
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / src / kernels / Pack.cpp
1 /*
2  * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #include "kernels/Pack.h"
19 #include "kernels/Utils.h"
20
21 #include <tensorflow/lite/kernels/internal/reference/reference_ops.h>
22
23 namespace luci_interpreter
24 {
25 namespace kernels
26 {
27
28 Pack::Pack(std::vector<const Tensor *> inputs, Tensor *output, const PackParams &params)
29   : KernelWithParams<PackParams>(std::move(inputs), {output}, params)
30 {
31 }
32
33 void Pack::configure()
34 {
35   LUCI_INTERPRETER_CHECK(_inputs.size() == static_cast<uint32_t>(params().values_count));
36   const Tensor *t0 = _inputs[0];
37   const int dimension_size = t0->shape().num_dims() + 1;
38   int axis = params().axis;
39   if (axis < 0)
40   {
41     axis += dimension_size;
42   }
43   LUCI_INTERPRETER_CHECK(axis >= 0 && axis <= t0->shape().num_dims());
44
45   if (t0->element_type() != DataType::S32 && t0->element_type() != DataType::FLOAT32 &&
46       t0->element_type() != DataType::U8 && t0->element_type() != DataType::S8 &&
47       t0->element_type() != DataType::S16 && t0->element_type() != DataType::S64)
48   {
49     assert(false && "Unsupported type.");
50   }
51
52   for (uint32_t i = 1; i < _inputs.size(); ++i)
53   {
54     const Tensor *tensor = _inputs[i];
55     LUCI_INTERPRETER_CHECK(tensor->element_type() == t0->element_type());
56     LUCI_INTERPRETER_CHECK(tensor->shape().num_dims() == t0->shape().num_dims());
57     for (int d = 0; d < t0->shape().num_dims(); ++d)
58     {
59       LUCI_INTERPRETER_CHECK(tensor->shape().dim(d) == t0->shape().dim(d));
60     }
61   }
62
63   Shape output_shape(dimension_size);
64   int i = 0;
65   for (int index = 0; index < dimension_size; ++index)
66   {
67     if (index == axis)
68     {
69       output_shape.dim(index) = params().values_count;
70     }
71     else
72     {
73       output_shape.dim(index) = t0->shape().dim(i++);
74     }
75   }
76
77   if (t0->element_type() == DataType::U8 || t0->element_type() == DataType::S8 ||
78       t0->element_type() == DataType::S16)
79   {
80     LUCI_INTERPRETER_CHECK(output()->zero_point() == t0->zero_point());
81     LUCI_INTERPRETER_CHECK(output()->scale() == t0->scale());
82     // Guarantee input/output quantization params match as we do not support
83     // packing quantized tensors.
84     for (int i = 0; i < params().values_count; i++)
85     {
86       LUCI_INTERPRETER_CHECK(_inputs[i]->zero_point() == t0->zero_point());
87       LUCI_INTERPRETER_CHECK(_inputs[i]->scale() == t0->scale());
88     }
89   }
90   // TODO: enable it only if kernel with dynamic shapes
91   output()->resize(output_shape);
92 }
93
94 void Pack::execute() const
95 {
96   switch (_inputs[0]->element_type())
97   {
98     case DataType::FLOAT32:
99       evalGeneric<float>();
100       break;
101     case DataType::U8:
102       evalGeneric<uint8_t>();
103       break;
104     case DataType::S8:
105       evalGeneric<int8_t>();
106       break;
107     case DataType::S16:
108       evalGeneric<int16_t>();
109       break;
110     case DataType::S32:
111       evalGeneric<int32_t>();
112       break;
113     case DataType::S64:
114       evalGeneric<int64_t>();
115       break;
116     default:
117       assert(false && "Unsupported type.");
118   }
119 }
120
121 template <typename T> void Pack::evalGeneric() const
122 {
123   const Tensor *t0 = _inputs[0];
124   const int dimension_size = t0->shape().num_dims() + 1;
125   int axis = params().axis;
126   if (axis < 0)
127   {
128     axis += dimension_size;
129   }
130
131   VectorOfTensors<T, true> inputs(_inputs);
132   tflite::PackParams params{};
133   params.axis = axis;
134   params.inputs_count = _inputs.size();
135   tflite::reference_ops::Pack<T>(params, inputs.shapes(), inputs.data(), getTensorShape(output()),
136                                  getTensorData<T>(output()));
137 }
138
139 } // namespace kernels
140 } // namespace luci_interpreter