2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #ifndef __NNFW_CKER_SHAPE_H__
19 #define __NNFW_CKER_SHAPE_H__
26 #define UNUSED_RELEASE(a) (void)(a)
36 // Shapes with dimensions up to 5 are stored directly in the structure, while
37 // larger shapes are separately allocated.
38 static constexpr int kMaxSmallSize = 5;
40 Shape &operator=(Shape const &) = delete;
44 explicit Shape(int dimensions_count) : _size(dimensions_count)
46 if (dimensions_count > kMaxSmallSize)
48 _dims_pointer = new int32_t[dimensions_count];
52 Shape(int shape_size, int32_t value) : _size(0)
55 for (int i = 0; i < shape_size; ++i)
61 Shape(int dimensions_count, const int32_t *dims_data) : _size(0)
63 ReplaceWith(dimensions_count, dims_data);
66 Shape(const std::initializer_list<int> init_list) : _size(0) { BuildFrom(init_list); }
68 // Avoid using this constructor. We should be able to delete it when C++17
70 Shape(Shape const &other) : _size(other.DimensionsCount())
72 if (_size > kMaxSmallSize)
74 _dims_pointer = new int32_t[_size];
76 std::memcpy(DimsData(), other.DimsData(), sizeof(int32_t) * _size);
79 bool operator==(const Shape &comp) const
81 return this->_size == comp._size &&
82 std::memcmp(DimsData(), comp.DimsData(), _size * sizeof(int32_t)) == 0;
87 if (_size > kMaxSmallSize)
89 delete[] _dims_pointer;
93 inline int32_t DimensionsCount() const { return _size; }
94 inline int32_t Dims(int i) const
98 return _size > kMaxSmallSize ? _dims_pointer[i] : _dims[i];
100 inline void SetDim(int i, int32_t val)
104 if (_size > kMaxSmallSize)
106 _dims_pointer[i] = val;
114 inline int32_t *DimsData() { return _size > kMaxSmallSize ? _dims_pointer : _dims; }
115 inline const int32_t *DimsData() const { return _size > kMaxSmallSize ? _dims_pointer : _dims; }
116 // The caller must ensure that the shape is no bigger than 4-D.
117 inline const int32_t *DimsDataUpTo4D() const { return _dims; }
119 inline void Resize(int dimensions_count)
121 if (_size > kMaxSmallSize)
123 delete[] _dims_pointer;
125 _size = dimensions_count;
126 if (dimensions_count > kMaxSmallSize)
128 _dims_pointer = new int32_t[dimensions_count];
132 inline void ReplaceWith(int dimensions_count, const int32_t *dims_data)
134 Resize(dimensions_count);
135 int32_t *dst_dims = DimsData();
136 std::memcpy(dst_dims, dims_data, dimensions_count * sizeof(int32_t));
139 template <typename T> inline void BuildFrom(const T &src_iterable)
141 const int dimensions_count = std::distance(src_iterable.begin(), src_iterable.end());
142 Resize(dimensions_count);
143 int32_t *data = DimsData();
144 for (auto it : src_iterable)
151 // This will probably be factored out. Old code made substantial use of 4-D
152 // shapes, and so this function is used to extend smaller shapes. Note that
153 // (a) as Dims<4>-dependent code is eliminated, the reliance on this should be
154 // reduced, and (b) some kernels are stricly 4-D, but then the shapes of their
155 // inputs should already be 4-D, so this function should not be needed.
156 inline static Shape ExtendedShape(int new_shape_size, const Shape &shape)
158 return Shape(new_shape_size, shape, 1);
161 inline void BuildFrom(const std::initializer_list<int> init_list)
163 BuildFrom<const std::initializer_list<int>>(init_list);
166 // Returns the total count of elements, that is the size when flattened into a
168 inline int FlatSize() const
171 const int *dims_data = DimsData();
172 for (int i = 0; i < _size; i++)
174 const int dim = dims_data[i];
181 bool operator!=(const Shape &comp) const { return !((*this) == comp); }
184 // For use only by ExtendedShape(), written to guarantee (return-value) copy
186 // This creates a shape padded to the desired size with the specified value.
187 Shape(int new_shape_size, const Shape &shape, int pad_value) : _size(0)
189 assert(new_shape_size >= shape.DimensionsCount());
190 assert(new_shape_size <= kMaxSmallSize);
191 Resize(new_shape_size);
192 const int size_increase = new_shape_size - shape.DimensionsCount();
193 for (int i = 0; i < size_increase; ++i)
195 SetDim(i, pad_value);
197 std::memcpy(DimsData() + size_increase, shape.DimsData(),
198 sizeof(int32_t) * shape.DimensionsCount());
203 int32_t _dims[kMaxSmallSize];
204 int32_t *_dims_pointer{nullptr};
208 inline int MatchingDim(const Shape &shape1, int index1, const Shape &shape2, int index2)
210 UNUSED_RELEASE(shape2);
211 UNUSED_RELEASE(index2);
212 assert(shape1.Dims(index1) == shape2.Dims(index2));
213 return shape1.Dims(index1);
216 template <typename... Args>
217 int MatchingDim(const Shape &shape1, int index1, const Shape &shape2, int index2, Args... args)
219 assert(shape1.Dims(index1) == shape2.Dims(index2));
220 UNUSED_RELEASE(shape2);
221 UNUSED_RELEASE(index2);
222 return MatchingDim(shape1, index1, args...);
225 inline Shape GetShape(const std::vector<int32_t> &data) { return Shape(data.size(), data.data()); }
227 inline int Offset(const Shape &shape, int i0, int i1, int i2, int i3)
229 assert(shape.DimensionsCount() == 4);
230 const int *dims_data = shape.DimsDataUpTo4D();
231 assert(i0 >= 0 && i0 < dims_data[0]);
232 assert(i1 >= 0 && i1 < dims_data[1]);
233 assert(i2 >= 0 && i2 < dims_data[2]);
234 assert(i3 >= 0 && i3 < dims_data[3]);
235 return ((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3;
238 inline int Offset(const Shape &shape, int *index)
240 return Offset(shape, index[0], index[1], index[2], index[3]);
243 inline int FlatSizeSkipDim(const Shape &shape, int skip_dim)
245 const int dims_count = shape.DimensionsCount();
246 assert(skip_dim >= 0 && skip_dim < dims_count);
247 const auto *dims_data = shape.DimsData();
249 for (int i = 0; i < dims_count; ++i)
251 flat_size *= (i == skip_dim) ? 1 : dims_data[i];
256 // Flat size calculation, checking that dimensions match with one or more other
258 template <typename... Ts> inline bool checkMatching(const Shape &shape, Ts... check_shapes)
260 const Shape check_shapes_array[sizeof...(Ts)] = {std::forward<Ts>(check_shapes)...};
261 for (const auto &check_shape : check_shapes_array)
263 // Check matching of shapes except the case of that two shapes can be scalar
264 if (shape.DimensionsCount() > 1 || check_shape.DimensionsCount() > 1 || shape.FlatSize() != 1 ||
265 check_shape.FlatSize() != 1)
267 if (shape.DimensionsCount() != check_shape.DimensionsCount())
271 for (int i = 0; i < shape.DimensionsCount(); ++i)
273 if (shape.Dims(i) != check_shape.Dims(i))
285 template <typename... Args> UNUSED_ALL(Args const &...) {}
287 template <typename... Ts> inline int MatchingFlatSize(const Shape &shape, Ts... check_shapes)
289 UNUSED_ALL{check_shapes...};
290 assert(checkMatching(shape, std::forward<Ts>(check_shapes)...));
291 return shape.FlatSize();
294 inline int MatchingFlatSizeSkipDim(const Shape &shape, int skip_dim, const Shape &check_shape_0)
296 UNUSED_RELEASE(check_shape_0);
297 const int dims_count = shape.DimensionsCount();
298 for (int i = 0; i < dims_count; ++i)
302 assert(shape.Dims(i) == check_shape_0.Dims(i));
305 return FlatSizeSkipDim(shape, skip_dim);
308 inline int MatchingFlatSizeSkipDim(const Shape &shape, int skip_dim, const Shape &check_shape_0,
309 const Shape &check_shape_1)
311 UNUSED_RELEASE(check_shape_0);
312 const int dims_count = shape.DimensionsCount();
313 for (int i = 0; i < dims_count; ++i)
317 assert(shape.Dims(i) == check_shape_0.Dims(i));
320 return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1);
323 inline int MatchingElementsSize(const Shape &shape, const Shape &check_shape_0,
324 const Shape &check_shape_1)
326 const int size_1 = shape.FlatSize();
327 const int size_2 = check_shape_0.FlatSize();
328 const int size_3 = check_shape_1.FlatSize();
329 assert(size_1 == size_2);
330 assert(size_2 == size_3);
331 UNUSED_RELEASE(size_2);
332 UNUSED_RELEASE(size_3);
339 #endif // __NNFW_CKER_SHAPE_H__