2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #ifndef LUCI_INTERPRETER_CORE_TENSOR_H
18 #define LUCI_INTERPRETER_CORE_TENSOR_H
20 #include "luci_interpreter/core/DataType.h"
21 #include "luci_interpreter/core/reader/CircleMicroReader.h"
30 namespace luci_interpreter
33 static constexpr int kMaxSmallSize = 5;
38 RuntimeShape(const RuntimeShape &other) : _size(other.dimensionsCount())
40 std::memcpy(dimsData(), other.dimsData(), sizeof(int32_t) * _size);
43 // Returns the total count of elements, that is the size when flattened into a
45 inline int flatSize() const
48 const int *dims_data = reinterpret_cast<const int *>(dimsData());
49 for (int i = 0; i < _size; i++)
51 buffer_size *= dims_data[i];
56 inline int32_t *dimsData() { return _dims; }
57 inline const int32_t *dimsData() const { return _dims; }
59 RuntimeShape() : _size(0) {}
61 explicit RuntimeShape(int dimensions_count) : _size(dimensions_count)
63 assert(dimensions_count <= kMaxSmallSize);
64 assert(dimensions_count >= 0);
67 RuntimeShape(int dimensions_count, const int32_t *dims_data) : _size(0)
69 resize(dimensions_count);
70 int32_t *dst_dims = dimsData();
71 std::memcpy(dst_dims, dims_data, dimensions_count * sizeof(int32_t));
74 RuntimeShape(int new_shape_size, const RuntimeShape &shape, int pad_value) : _size(0)
76 resize(new_shape_size);
77 const int size_increase = new_shape_size - shape.dimensionsCount();
78 for (int i = 0; i < size_increase; ++i)
82 std::memcpy(dimsData() + size_increase, shape.dimsData(),
83 sizeof(int32_t) * shape.dimensionsCount());
86 RuntimeShape(int shape_size, int32_t value) : _size(0)
89 for (int i = 0; i < shape_size; ++i)
95 inline static RuntimeShape extendedShape(int new_shape_size, const RuntimeShape &shape)
97 return RuntimeShape(new_shape_size, shape, 1);
100 bool operator==(const RuntimeShape &comp) const
102 return this->_size == comp._size &&
103 std::memcmp(dimsData(), comp.dimsData(), _size * sizeof(int32_t)) == 0;
106 inline int32_t dimensionsCount() const { return _size; }
108 inline int32_t dims(int i) const
114 inline void setDim(int i, int32_t val)
121 inline void resize(int dimensions_count)
123 assert(dimensions_count <= kMaxSmallSize);
124 assert(dimensions_count >= 0);
125 _size = dimensions_count;
130 int32_t _dims[kMaxSmallSize];
137 static float scale(const circle::Tensor *circle_tensor)
139 const auto *quant_params = circle_tensor->quantization();
140 if (quant_params == nullptr)
142 assert(false && "There is no quantization params");
146 return *quant_params->scale()->cbegin();
149 static int32_t zero_point(const circle::Tensor *circle_tensor)
151 const auto *quant_params = circle_tensor->quantization();
152 if (quant_params == nullptr)
154 assert(false && "There is no quantization params");
158 return *quant_params->zero_point()->cbegin();
161 static const std::vector<float> scales(const circle::Tensor *circle_tensor)
163 const auto *quant_params = circle_tensor->quantization();
164 if (quant_params == nullptr)
166 assert(false && "There is no quantization params");
169 assert(quant_params->scale() != nullptr);
170 std::vector<float> scales(quant_params->scale()->cbegin(), quant_params->scale()->cend());
175 static const std::vector<int32_t> zero_points(const circle::Tensor *circle_tensor)
177 const auto *quant_params = circle_tensor->quantization();
178 if (quant_params == nullptr)
180 assert(false && "There is no quantization params");
183 assert(quant_params->zero_point() != nullptr);
184 std::vector<int32_t> zero_points(quant_params->zero_point()->cbegin(),
185 quant_params->zero_point()->cend());
190 static int32_t quantized_dimension(const circle::Tensor *circle_tensor)
192 const auto *quant_params = circle_tensor->quantization();
193 if (quant_params == nullptr)
195 assert(false && "There is no quantization params");
198 return quant_params->quantized_dimension();
202 static bool is_constant_tensor(const luci_interpreter::CircleReader *reader,
203 const circle::Tensor *circle_tensor)
205 return reader->buffers()[circle_tensor->buffer()]->data() != nullptr;
208 static DataType element_type(const circle::Tensor *circle_tensor)
210 return luci_datatype(circle_tensor->type());
213 static VectorWrapper<int32_t> tensor_shape(const circle::Tensor *circle_tensor)
215 return wrap(circle_tensor->shape());
218 static int num_dims(const circle::Tensor *circle_tensor)
220 // TODO check removing of wrap
221 auto const const_dims = wrap(circle_tensor->shape());
222 return const_dims.size();
225 static int32_t dim(const circle::Tensor *circle_tensor, int i)
227 // TODO check removing of wrap
229 auto const const_dims = wrap(circle_tensor->shape());
230 assert(i < const_dims.size());
232 return const_dims[i];
235 static int32_t num_elements(const circle::Tensor *circle_tensor)
238 auto const const_dims = wrap(circle_tensor->shape());
239 for (const int32_t dim : const_dims)
247 } // namespace luci_interpreter
249 #endif // LUCI_INTERPRETER_CORE_TENSOR_H