2 * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #include "kernels/Pack.h"
19 #include "kernels/Utils.h"
21 #include <tensorflow/lite/kernels/internal/reference/reference_ops.h>
23 namespace luci_interpreter
28 Pack::Pack(std::vector<const Tensor *> inputs, Tensor *output, const PackParams ¶ms)
29 : KernelWithParams<PackParams>(std::move(inputs), {output}, params)
33 void Pack::configure()
35 LUCI_INTERPRETER_CHECK(_inputs.size() == static_cast<uint32_t>(params().values_count));
36 const Tensor *t0 = _inputs[0];
37 const int dimension_size = t0->shape().num_dims() + 1;
38 int axis = params().axis;
41 axis += dimension_size;
43 LUCI_INTERPRETER_CHECK(axis >= 0 && axis <= t0->shape().num_dims());
45 if (t0->element_type() != DataType::S32 && t0->element_type() != DataType::FLOAT32 &&
46 t0->element_type() != DataType::U8 && t0->element_type() != DataType::S8 &&
47 t0->element_type() != DataType::S16 && t0->element_type() != DataType::S64)
49 assert(false && "Unsupported type.");
52 for (uint32_t i = 1; i < _inputs.size(); ++i)
54 const Tensor *tensor = _inputs[i];
55 LUCI_INTERPRETER_CHECK(tensor->element_type() == t0->element_type());
56 LUCI_INTERPRETER_CHECK(tensor->shape().num_dims() == t0->shape().num_dims());
57 for (int d = 0; d < t0->shape().num_dims(); ++d)
59 LUCI_INTERPRETER_CHECK(tensor->shape().dim(d) == t0->shape().dim(d));
63 Shape output_shape(dimension_size);
65 for (int index = 0; index < dimension_size; ++index)
69 output_shape.dim(index) = params().values_count;
73 output_shape.dim(index) = t0->shape().dim(i++);
77 if (t0->element_type() == DataType::U8 || t0->element_type() == DataType::S8 ||
78 t0->element_type() == DataType::S16)
80 LUCI_INTERPRETER_CHECK(output()->zero_point() == t0->zero_point());
81 LUCI_INTERPRETER_CHECK(output()->scale() == t0->scale());
82 // Guarantee input/output quantization params match as we do not support
83 // packing quantized tensors.
84 for (int i = 0; i < params().values_count; i++)
86 LUCI_INTERPRETER_CHECK(_inputs[i]->zero_point() == t0->zero_point());
87 LUCI_INTERPRETER_CHECK(_inputs[i]->scale() == t0->scale());
90 // TODO: enable it only if kernel with dynamic shapes
91 output()->resize(output_shape);
94 void Pack::execute() const
96 switch (_inputs[0]->element_type())
98 case DataType::FLOAT32:
102 evalGeneric<uint8_t>();
105 evalGeneric<int8_t>();
108 evalGeneric<int16_t>();
111 evalGeneric<int32_t>();
114 evalGeneric<int64_t>();
117 assert(false && "Unsupported type.");
121 template <typename T> void Pack::evalGeneric() const
123 const Tensor *t0 = _inputs[0];
124 const int dimension_size = t0->shape().num_dims() + 1;
125 int axis = params().axis;
128 axis += dimension_size;
131 VectorOfTensors<T, true> inputs(_inputs);
132 tflite::PackParams params{};
134 params.inputs_count = _inputs.size();
135 tflite::reference_ops::Pack<T>(params, inputs.shapes(), inputs.data(), getTensorShape(output()),
136 getTensorData<T>(output()));
139 } // namespace kernels
140 } // namespace luci_interpreter