2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #ifndef __NNFW_CKER_SHAPE_H__
19 #define __NNFW_CKER_SHAPE_H__
26 #define UNUSED_RELEASE(a) (void)(a)
36 // Shapes with dimensions up to 5 are stored directly in the structure, while
37 // larger shapes are separately allocated.
38 static constexpr int kMaxSmallSize = 5;
40 Shape &operator=(Shape const &) = delete;
44 explicit Shape(int dimensions_count) : _size(dimensions_count)
46 if (dimensions_count > kMaxSmallSize)
48 _dims_pointer = new int32_t[dimensions_count];
52 Shape(int shape_size, int32_t value) : _size(0)
55 for (int i = 0; i < shape_size; ++i)
61 Shape(int dimensions_count, const int32_t *dims_data) : _size(0)
63 ReplaceWith(dimensions_count, dims_data);
66 Shape(const std::initializer_list<int> init_list) : _size(0) { BuildFrom(init_list); }
68 // Avoid using this constructor. We should be able to delete it when C++17
70 Shape(Shape const &other) : _size(other.DimensionsCount())
72 if (_size > kMaxSmallSize)
74 _dims_pointer = new int32_t[_size];
76 std::memcpy(DimsData(), other.DimsData(), sizeof(int32_t) * _size);
79 bool operator==(const Shape &comp) const
81 return this->_size == comp._size &&
82 std::memcmp(DimsData(), comp.DimsData(), _size * sizeof(int32_t)) == 0;
87 if (_size > kMaxSmallSize)
89 delete[] _dims_pointer;
93 inline int32_t DimensionsCount() const { return _size; }
94 inline int32_t Dims(int i) const
98 return _size > kMaxSmallSize ? _dims_pointer[i] : _dims[i];
100 inline void SetDim(int i, int32_t val)
104 if (_size > kMaxSmallSize)
106 _dims_pointer[i] = val;
114 inline int32_t *DimsData() { return _size > kMaxSmallSize ? _dims_pointer : _dims; }
115 inline const int32_t *DimsData() const { return _size > kMaxSmallSize ? _dims_pointer : _dims; }
116 // The caller must ensure that the shape is no bigger than 4-D.
117 inline const int32_t *DimsDataUpTo4D() const { return _dims; }
119 inline void Resize(int dimensions_count)
121 if (_size > kMaxSmallSize)
123 delete[] _dims_pointer;
125 _size = dimensions_count;
126 if (dimensions_count > kMaxSmallSize)
128 _dims_pointer = new int32_t[dimensions_count];
132 inline void ReplaceWith(int dimensions_count, const int32_t *dims_data)
134 Resize(dimensions_count);
135 int32_t *dst_dims = DimsData();
136 std::memcpy(dst_dims, dims_data, dimensions_count * sizeof(int32_t));
139 inline void ReplaceWith(const Shape &other)
141 ReplaceWith(other.DimensionsCount(), other.DimsData());
144 inline void ReplaceWith(Shape &&other)
147 std::swap(_size, other._size);
148 if (_size <= kMaxSmallSize)
149 std::copy(other._dims, other._dims + kMaxSmallSize, _dims);
151 _dims_pointer = other._dims_pointer;
154 template <typename T> inline void BuildFrom(const T &src_iterable)
156 const int dimensions_count = std::distance(src_iterable.begin(), src_iterable.end());
157 Resize(dimensions_count);
158 int32_t *data = DimsData();
159 for (auto &&it : src_iterable)
166 // This will probably be factored out. Old code made substantial use of 4-D
167 // shapes, and so this function is used to extend smaller shapes. Note that
168 // (a) as Dims<4>-dependent code is eliminated, the reliance on this should be
169 // reduced, and (b) some kernels are stricly 4-D, but then the shapes of their
170 // inputs should already be 4-D, so this function should not be needed.
171 inline static Shape ExtendedShape(int new_shape_size, const Shape &shape)
173 return Shape(new_shape_size, shape, 1);
176 inline void BuildFrom(const std::initializer_list<int> init_list)
178 BuildFrom<const std::initializer_list<int>>(init_list);
181 // Returns the total count of elements, that is the size when flattened into a
183 inline int FlatSize() const
186 const int *dims_data = DimsData();
187 for (int i = 0; i < _size; i++)
189 const int dim = dims_data[i];
195 bool operator!=(const Shape &comp) const { return !((*this) == comp); }
198 // For use only by ExtendedShape(), written to guarantee (return-value) copy
200 // This creates a shape padded to the desired size with the specified value.
201 Shape(int new_shape_size, const Shape &shape, int pad_value) : _size(0)
203 assert(new_shape_size >= shape.DimensionsCount());
204 assert(new_shape_size <= kMaxSmallSize);
205 Resize(new_shape_size);
206 const int size_increase = new_shape_size - shape.DimensionsCount();
207 for (int i = 0; i < size_increase; ++i)
209 SetDim(i, pad_value);
211 std::memcpy(DimsData() + size_increase, shape.DimsData(),
212 sizeof(int32_t) * shape.DimensionsCount());
217 int32_t _dims[kMaxSmallSize];
218 int32_t *_dims_pointer{nullptr};
222 inline int MatchingDim(const Shape &shape1, int index1, const Shape &shape2, int index2)
224 UNUSED_RELEASE(shape2);
225 UNUSED_RELEASE(index2);
226 assert(shape1.Dims(index1) == shape2.Dims(index2));
227 return shape1.Dims(index1);
230 template <typename... Args>
231 int MatchingDim(const Shape &shape1, int index1, const Shape &shape2, int index2, Args... args)
233 assert(shape1.Dims(index1) == shape2.Dims(index2));
234 UNUSED_RELEASE(shape2);
235 UNUSED_RELEASE(index2);
236 return MatchingDim(shape1, index1, args...);
239 inline Shape GetShape(const std::vector<int32_t> &data) { return Shape(data.size(), data.data()); }
241 inline int Offset(const Shape &shape, int i0, int i1, int i2, int i3)
243 assert(shape.DimensionsCount() == 4);
244 const int *dims_data = shape.DimsDataUpTo4D();
245 assert(i0 >= 0 && i0 < dims_data[0]);
246 assert(i1 >= 0 && i1 < dims_data[1]);
247 assert(i2 >= 0 && i2 < dims_data[2]);
248 assert(i3 >= 0 && i3 < dims_data[3]);
249 return ((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3;
252 inline int Offset(const Shape &shape, int *index)
254 return Offset(shape, index[0], index[1], index[2], index[3]);
257 inline int FlatSizeSkipDim(const Shape &shape, int skip_dim)
259 const int dims_count = shape.DimensionsCount();
260 assert(skip_dim >= 0 && skip_dim < dims_count);
261 const auto *dims_data = shape.DimsData();
263 for (int i = 0; i < dims_count; ++i)
265 flat_size *= (i == skip_dim) ? 1 : dims_data[i];
270 // Flat size calculation, checking that dimensions match with one or more other
272 template <typename... Ts> inline bool checkMatching(const Shape &shape, Ts... check_shapes)
274 const Shape check_shapes_array[sizeof...(Ts)] = {std::forward<Ts>(check_shapes)...};
275 for (const auto &check_shape : check_shapes_array)
277 // Check matching of shapes except the case of that two shapes can be scalar
278 if (shape.DimensionsCount() > 1 || check_shape.DimensionsCount() > 1 || shape.FlatSize() != 1 ||
279 check_shape.FlatSize() != 1)
281 if (shape.DimensionsCount() != check_shape.DimensionsCount())
285 for (int i = 0; i < shape.DimensionsCount(); ++i)
287 if (shape.Dims(i) != check_shape.Dims(i))
299 template <typename... Args> UNUSED_ALL(Args const &...) {}
301 template <typename... Ts> inline int MatchingFlatSize(const Shape &shape, Ts... check_shapes)
303 UNUSED_ALL{check_shapes...};
304 assert(checkMatching(shape, std::forward<Ts>(check_shapes)...));
305 return shape.FlatSize();
308 inline int MatchingFlatSizeSkipDim(const Shape &shape, int skip_dim, const Shape &check_shape_0)
310 UNUSED_RELEASE(check_shape_0);
311 const int dims_count = shape.DimensionsCount();
312 for (int i = 0; i < dims_count; ++i)
316 assert(shape.Dims(i) == check_shape_0.Dims(i));
319 return FlatSizeSkipDim(shape, skip_dim);
322 inline int MatchingFlatSizeSkipDim(const Shape &shape, int skip_dim, const Shape &check_shape_0,
323 const Shape &check_shape_1)
325 UNUSED_RELEASE(check_shape_0);
326 const int dims_count = shape.DimensionsCount();
327 for (int i = 0; i < dims_count; ++i)
331 assert(shape.Dims(i) == check_shape_0.Dims(i));
334 return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1);
337 inline int MatchingElementsSize(const Shape &shape, const Shape &check_shape_0,
338 const Shape &check_shape_1)
340 const int size_1 = shape.FlatSize();
341 const int size_2 = check_shape_0.FlatSize();
342 const int size_3 = check_shape_1.FlatSize();
343 assert(size_1 == size_2);
344 assert(size_2 == size_3);
345 UNUSED_RELEASE(size_2);
346 UNUSED_RELEASE(size_3);
353 #endif // __NNFW_CKER_SHAPE_H__