Imported Upstream version 1.12.0
[platform/core/ml/nnfw.git] / compute / ruy / include / ruy / Shape.h
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef __NNFW_RUY_SHAPE_H__
19 #define __NNFW_RUY_SHAPE_H__
20
21 #include <algorithm>
22 #include <cstring>
23 #include <cassert>
24 #include <vector>
25
26 #define UNUSED_RELEASE(a) (void)(a)
27
28 namespace nnfw
29 {
30 namespace ruy
31 {
32
33 class Shape
34 {
35 public:
36   // Shapes with dimensions up to 5 are stored directly in the structure, while
37   // larger shapes are separately allocated.
38   static constexpr int kMaxSmallSize = 5;
39
40   Shape &operator=(Shape const &) = delete;
41
42   Shape() : _size(0) {}
43
44   explicit Shape(int dimensions_count) : _size(dimensions_count)
45   {
46     if (dimensions_count > kMaxSmallSize)
47     {
48       _dims_pointer = new int32_t[dimensions_count];
49     }
50   }
51
52   Shape(int shape_size, int32_t value) : _size(0)
53   {
54     Resize(shape_size);
55     for (int i = 0; i < shape_size; ++i)
56     {
57       SetDim(i, value);
58     }
59   }
60
61   Shape(int dimensions_count, const int32_t *dims_data) : _size(0)
62   {
63     ReplaceWith(dimensions_count, dims_data);
64   }
65
66   Shape(const std::initializer_list<int> init_list) : _size(0) { BuildFrom(init_list); }
67
68   // Avoid using this constructor.  We should be able to delete it when C++17
69   // rolls out.
70   Shape(Shape const &other) : _size(other.DimensionsCount())
71   {
72     if (_size > kMaxSmallSize)
73     {
74       _dims_pointer = new int32_t[_size];
75     }
76     std::memcpy(DimsData(), other.DimsData(), sizeof(int32_t) * _size);
77   }
78
79   bool operator==(const Shape &comp) const
80   {
81     return this->_size == comp._size &&
82            std::memcmp(DimsData(), comp.DimsData(), _size * sizeof(int32_t)) == 0;
83   }
84
85   ~Shape()
86   {
87     if (_size > kMaxSmallSize)
88     {
89       delete[] _dims_pointer;
90     }
91   }
92
93   inline int32_t DimensionsCount() const { return _size; }
94   inline int32_t Dims(int i) const
95   {
96     assert(i >= 0);
97     assert(i < _size);
98     return _size > kMaxSmallSize ? _dims_pointer[i] : _dims[i];
99   }
100   inline void SetDim(int i, int32_t val)
101   {
102     assert(i >= 0);
103     assert(i < _size);
104     if (_size > kMaxSmallSize)
105     {
106       _dims_pointer[i] = val;
107     }
108     else
109     {
110       _dims[i] = val;
111     }
112   }
113
114   inline int32_t *DimsData() { return _size > kMaxSmallSize ? _dims_pointer : _dims; }
115   inline const int32_t *DimsData() const { return _size > kMaxSmallSize ? _dims_pointer : _dims; }
116   // The caller must ensure that the shape is no bigger than 4-D.
117   inline const int32_t *DimsDataUpTo4D() const { return _dims; }
118
119   inline void Resize(int dimensions_count)
120   {
121     if (_size > kMaxSmallSize)
122     {
123       delete[] _dims_pointer;
124     }
125     _size = dimensions_count;
126     if (dimensions_count > kMaxSmallSize)
127     {
128       _dims_pointer = new int32_t[dimensions_count];
129     }
130   }
131
132   inline void ReplaceWith(int dimensions_count, const int32_t *dims_data)
133   {
134     Resize(dimensions_count);
135     int32_t *dst_dims = DimsData();
136     std::memcpy(dst_dims, dims_data, dimensions_count * sizeof(int32_t));
137   }
138
139   inline void ReplaceWith(const Shape &other)
140   {
141     ReplaceWith(other.DimensionsCount(), other.DimsData());
142   }
143
144   inline void ReplaceWith(Shape &&other)
145   {
146     Resize(0);
147     std::swap(_size, other._size);
148     if (_size <= kMaxSmallSize)
149       std::copy(other._dims, other._dims + kMaxSmallSize, _dims);
150     else
151       _dims_pointer = other._dims_pointer;
152   }
153
154   template <typename T> inline void BuildFrom(const T &src_iterable)
155   {
156     const int dimensions_count = std::distance(src_iterable.begin(), src_iterable.end());
157     Resize(dimensions_count);
158     int32_t *data = DimsData();
159     for (auto it : src_iterable)
160     {
161       *data = it;
162       ++data;
163     }
164   }
165
166   // This will probably be factored out. Old code made substantial use of 4-D
167   // shapes, and so this function is used to extend smaller shapes. Note that
168   // (a) as Dims<4>-dependent code is eliminated, the reliance on this should be
169   // reduced, and (b) some kernels are stricly 4-D, but then the shapes of their
170   // inputs should already be 4-D, so this function should not be needed.
171   inline static Shape ExtendedShape(int new_shape_size, const Shape &shape)
172   {
173     return Shape(new_shape_size, shape, 1);
174   }
175
176   inline void BuildFrom(const std::initializer_list<int> init_list)
177   {
178     BuildFrom<const std::initializer_list<int>>(init_list);
179   }
180
181   // Returns the total count of elements, that is the size when flattened into a
182   // vector.
183   inline int FlatSize() const
184   {
185     int buffer_size = 1;
186     const int *dims_data = DimsData();
187     for (int i = 0; i < _size; i++)
188     {
189       const int dim = dims_data[i];
190       assert(dim >= 1);
191       buffer_size *= dim;
192     }
193     return buffer_size;
194   }
195
196   bool operator!=(const Shape &comp) const { return !((*this) == comp); }
197
198 private:
199   // For use only by ExtendedShape(), written to guarantee (return-value) copy
200   // elision in C++17.
201   // This creates a shape padded to the desired size with the specified value.
202   Shape(int new_shape_size, const Shape &shape, int pad_value) : _size(0)
203   {
204     assert(new_shape_size >= shape.DimensionsCount());
205     assert(new_shape_size <= kMaxSmallSize);
206     Resize(new_shape_size);
207     const int size_increase = new_shape_size - shape.DimensionsCount();
208     for (int i = 0; i < size_increase; ++i)
209     {
210       SetDim(i, pad_value);
211     }
212     std::memcpy(DimsData() + size_increase, shape.DimsData(),
213                 sizeof(int32_t) * shape.DimensionsCount());
214   }
215
216   int32_t _size;
217   union {
218     int32_t _dims[kMaxSmallSize];
219     int32_t *_dims_pointer{nullptr};
220   };
221 };
222
223 inline int MatchingDim(const Shape &shape1, int index1, const Shape &shape2, int index2)
224 {
225   UNUSED_RELEASE(shape2);
226   UNUSED_RELEASE(index2);
227   assert(shape1.Dims(index1) == shape2.Dims(index2));
228   return shape1.Dims(index1);
229 }
230
231 template <typename... Args>
232 int MatchingDim(const Shape &shape1, int index1, const Shape &shape2, int index2, Args... args)
233 {
234   assert(shape1.Dims(index1) == shape2.Dims(index2));
235   UNUSED_RELEASE(shape2);
236   UNUSED_RELEASE(index2);
237   return MatchingDim(shape1, index1, args...);
238 }
239
240 inline Shape GetShape(const std::vector<int32_t> &data) { return Shape(data.size(), data.data()); }
241
242 inline int Offset(const Shape &shape, int i0, int i1, int i2, int i3)
243 {
244   assert(shape.DimensionsCount() == 4);
245   const int *dims_data = shape.DimsDataUpTo4D();
246   assert(i0 >= 0 && i0 < dims_data[0]);
247   assert(i1 >= 0 && i1 < dims_data[1]);
248   assert(i2 >= 0 && i2 < dims_data[2]);
249   assert(i3 >= 0 && i3 < dims_data[3]);
250   return ((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3;
251 }
252
253 inline int Offset(const Shape &shape, int *index)
254 {
255   return Offset(shape, index[0], index[1], index[2], index[3]);
256 }
257
258 inline int FlatSizeSkipDim(const Shape &shape, int skip_dim)
259 {
260   const int dims_count = shape.DimensionsCount();
261   assert(skip_dim >= 0 && skip_dim < dims_count);
262   const auto *dims_data = shape.DimsData();
263   int flat_size = 1;
264   for (int i = 0; i < dims_count; ++i)
265   {
266     flat_size *= (i == skip_dim) ? 1 : dims_data[i];
267   }
268   return flat_size;
269 }
270
271 // Flat size calculation, checking that dimensions match with one or more other
272 // arrays.
273 template <typename... Ts> inline bool checkMatching(const Shape &shape, Ts... check_shapes)
274 {
275   const Shape check_shapes_array[sizeof...(Ts)] = {std::forward<Ts>(check_shapes)...};
276   for (const auto &check_shape : check_shapes_array)
277   {
278     // Check matching of shapes except the case of that two shapes can be scalar
279     if (shape.DimensionsCount() > 1 || check_shape.DimensionsCount() > 1 || shape.FlatSize() != 1 ||
280         check_shape.FlatSize() != 1)
281     {
282       if (shape.DimensionsCount() != check_shape.DimensionsCount())
283       {
284         return false;
285       }
286       for (int i = 0; i < shape.DimensionsCount(); ++i)
287       {
288         if (shape.Dims(i) != check_shape.Dims(i))
289         {
290           return false;
291         }
292       }
293     }
294   }
295   return true;
296 }
297
298 struct UNUSED_ALL
299 {
300   template <typename... Args> UNUSED_ALL(Args const &...) {}
301 };
302 template <typename... Ts> inline int MatchingFlatSize(const Shape &shape, Ts... check_shapes)
303 {
304   UNUSED_ALL{check_shapes...};
305   assert(checkMatching(shape, std::forward<Ts>(check_shapes)...));
306   return shape.FlatSize();
307 }
308
309 inline int MatchingFlatSizeSkipDim(const Shape &shape, int skip_dim, const Shape &check_shape_0)
310 {
311   UNUSED_RELEASE(check_shape_0);
312   const int dims_count = shape.DimensionsCount();
313   for (int i = 0; i < dims_count; ++i)
314   {
315     if (i != skip_dim)
316     {
317       assert(shape.Dims(i) == check_shape_0.Dims(i));
318     }
319   }
320   return FlatSizeSkipDim(shape, skip_dim);
321 }
322
323 inline int MatchingFlatSizeSkipDim(const Shape &shape, int skip_dim, const Shape &check_shape_0,
324                                    const Shape &check_shape_1)
325 {
326   UNUSED_RELEASE(check_shape_0);
327   const int dims_count = shape.DimensionsCount();
328   for (int i = 0; i < dims_count; ++i)
329   {
330     if (i != skip_dim)
331     {
332       assert(shape.Dims(i) == check_shape_0.Dims(i));
333     }
334   }
335   return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1);
336 }
337
338 inline int MatchingElementsSize(const Shape &shape, const Shape &check_shape_0,
339                                 const Shape &check_shape_1)
340 {
341   const int size_1 = shape.FlatSize();
342   const int size_2 = check_shape_0.FlatSize();
343   const int size_3 = check_shape_1.FlatSize();
344   assert(size_1 == size_2);
345   assert(size_2 == size_3);
346   UNUSED_RELEASE(size_2);
347   UNUSED_RELEASE(size_3);
348   return size_1;
349 }
350
351 } // namespace ruy
352 } // namespace nnfw
353
354 #endif // __NNFW_RUY_SHAPE_H__