Imported Upstream version 1.12.0
[platform/core/ml/nnfw.git] / compute / cker / include / cker / operation / Transpose.h
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef __NNFW_CKER_TRANSPOSE_H__
19 #define __NNFW_CKER_TRANSPOSE_H__
20
21 #include "cker/Shape.h"
22 #include "cker/Types.h"
23 #include "cker/Utils.h"
24
25 namespace nnfw
26 {
27 namespace cker
28 {
29 namespace reference
30 {
31
32 template <typename T>
33 void TransposeImpl(const TransposeParams &params, const Shape &unextended_input_shape,
34                    const T *input_data, const Shape &unextended_output_shape, T *output_data)
35 {
36   const int unextended_output_size = unextended_output_shape.DimensionsCount();
37   assert(unextended_input_shape.DimensionsCount() <= 4);
38   assert(unextended_output_size <= 4);
39   assert(unextended_output_size == params.perm_count);
40   const Shape input_shape = Shape::ExtendedShape(4, unextended_input_shape);
41   const Shape output_shape = Shape::ExtendedShape(4, unextended_output_shape);
42   const int input_ext_size = 4 - unextended_input_shape.DimensionsCount();
43   const int output_ext_size = 4 - unextended_output_size;
44
45   // The perm data is extended to match the output, each index incremented by
46   // the amount of front padding of the input shape.
47   int extended_perm[4];
48   for (int i = 0; i < output_ext_size; ++i)
49   {
50     extended_perm[i] = i;
51   }
52   for (int i = 0; i < unextended_output_size; ++i)
53   {
54     extended_perm[i + output_ext_size] = params.perm[i] + input_ext_size;
55   }
56
57   int out_sizes[4];
58   // Compute the inverse permutation array so we can do an output centered
59   // transpose. Also, check to make sure output_dims is matching input_dims.
60   for (int k = 0; k < 4; k++)
61   {
62     out_sizes[k] = MatchingDim(input_shape, extended_perm[k], output_shape, k);
63   }
64
65   // Naive transpose loop (iterate on output index and compute input index).
66   int o[4]; // loop index (on output).
67   int i[4];
68   for (o[3] = 0; o[3] < out_sizes[3]; o[3]++)
69   {
70     i[extended_perm[3]] = o[3];
71     for (o[2] = 0; o[2] < out_sizes[2]; o[2]++)
72     {
73       i[extended_perm[2]] = o[2];
74       for (o[1] = 0; o[1] < out_sizes[1]; o[1]++)
75       {
76         i[extended_perm[1]] = o[1];
77         for (o[0] = 0; o[0] < out_sizes[0]; o[0]++)
78         {
79           i[extended_perm[0]] = o[0];
80           output_data[Offset(output_shape, o)] = input_data[Offset(input_shape, i)];
81         }
82       }
83     }
84   }
85 }
86
87 template <typename T>
88 void Transpose(const TransposeParams &params, const Shape &unextended_input_shape,
89                const T *input_data, const Shape &unextended_output_shape, T *output_data)
90 {
91   // Transpose kernel only does rearranging values not numeric evaluations on
92   // each cell. It's safe to implement per size of scalar type and this trick
93   // keeps the total code size in a reasonable range.
94   switch (sizeof(T))
95   {
96     case 1:
97       TransposeImpl<int8_t>(params, unextended_input_shape,
98                             reinterpret_cast<const int8_t *>(input_data), unextended_output_shape,
99                             reinterpret_cast<int8_t *>(output_data));
100       break;
101     case 2:
102       TransposeImpl<int16_t>(params, unextended_input_shape,
103                              reinterpret_cast<const int16_t *>(input_data), unextended_output_shape,
104                              reinterpret_cast<int16_t *>(output_data));
105       break;
106
107     case 4:
108       TransposeImpl<int32_t>(params, unextended_input_shape,
109                              reinterpret_cast<const int32_t *>(input_data), unextended_output_shape,
110                              reinterpret_cast<int32_t *>(output_data));
111       break;
112     case 8:
113       TransposeImpl<int64_t>(params, unextended_input_shape,
114                              reinterpret_cast<const int64_t *>(input_data), unextended_output_shape,
115                              reinterpret_cast<int64_t *>(output_data));
116       break;
117   }
118 }
119 } // namespace reference
120
121 namespace
122 {
123
124 bool IsTranspose2DApplicable(const TransposeParams &params, const Shape &input_shape, int *dim0,
125                              int *dim1)
126 {
127   const int dims_cnt = input_shape.DimensionsCount();
128
129   if (dims_cnt == 2)
130   {
131     *dim0 = input_shape.Dims(0);
132     *dim1 = input_shape.Dims(1);
133     return true;
134   }
135
136   const int first_perm = params.perm[0];
137   for (int i = 1; i < dims_cnt; ++i)
138   {
139     int rebased = params.perm[i] - first_perm;
140     if (rebased < 0)
141     {
142       rebased += dims_cnt;
143     }
144     if (rebased != i)
145     {
146       return false;
147     }
148   }
149   *dim0 = 1;
150   *dim1 = 1;
151   for (int i = 0; i < dims_cnt; ++i)
152   {
153     if (i < first_perm)
154     {
155       *dim0 *= input_shape.Dims(i);
156     }
157     else
158     {
159       *dim1 *= input_shape.Dims(i);
160     }
161   }
162   return true;
163 }
164
165 void RemoveOneSizeDimensions(Shape *input_shape, Shape *output_shape, TransposeParams *params)
166 {
167   const int dims_cnt = input_shape->DimensionsCount();
168   assert(params->perm_count == dims_cnt);
169
170   bool foundOneSizeDim = false;
171   for (int i = 0; i < dims_cnt; ++i)
172   {
173     if (input_shape->Dims(i) == 1)
174     {
175       foundOneSizeDim = true;
176       break;
177     }
178   }
179
180   // Return here if there is no one size dimension.
181   if (!foundOneSizeDim)
182     return;
183
184   // Handle the case where all the dimension size is one.
185   if (input_shape->FlatSize() == 1)
186   {
187     input_shape->Resize(1);
188     input_shape->SetDim(0, 1);
189     output_shape->Resize(1);
190     output_shape->SetDim(0, 1);
191     params->perm_count = 1;
192     params->perm[0] = 0;
193     return;
194   }
195
196   // Resize input shape.
197   int new_dims_cnt = 0;
198   for (int i = 0; i < dims_cnt; ++i)
199   {
200     if (input_shape->Dims(i) == 1)
201     {
202       continue;
203     }
204     input_shape->SetDim(new_dims_cnt, input_shape->Dims(i));
205     ++new_dims_cnt;
206   }
207   input_shape->Resize(new_dims_cnt);
208
209   // Resize output shape and re-calculate the perm parameter.
210   TransposeParams new_params;
211   new_dims_cnt = 0;
212   for (int i = 0; i < dims_cnt; ++i)
213   {
214     if (output_shape->Dims(i) == 1)
215     {
216       continue;
217     }
218     new_params.perm[new_dims_cnt] = params->perm[i];
219     output_shape->SetDim(new_dims_cnt, output_shape->Dims(i));
220     ++new_dims_cnt;
221   }
222   output_shape->Resize(new_dims_cnt);
223   new_params.perm_count = new_dims_cnt;
224
225   for (int i = 0; i < new_dims_cnt; ++i)
226   {
227     int min_val_idx = -1;
228     for (int j = 0; j < new_dims_cnt; ++j)
229     {
230       if (new_params.perm[j] >= i &&
231           (min_val_idx == -1 || new_params.perm[min_val_idx] > new_params.perm[j]))
232       {
233         min_val_idx = j;
234       }
235     }
236     new_params.perm[min_val_idx] = i;
237   }
238   *params = new_params;
239 }
240
241 size_t Flatten(const Shape &input_shape, const Shape &output_shape, const TransposeParams &params,
242                Shape *non_flatten_input_shape, Shape *non_flatten_output_shape,
243                TransposeParams *non_flatten_params)
244 {
245   // Calculate the total size of non-flatten dimensions.
246   int skip_dims_cnt = 0;
247   size_t flat_size = input_shape.FlatSize();
248   for (int i = 0; i < params.perm_count; ++i)
249   {
250     if (params.perm[i] == i)
251     {
252       flat_size /= input_shape.Dims(i);
253       ++skip_dims_cnt;
254     }
255     else
256     {
257       break;
258     }
259   }
260
261   // Shrink the shapes and re-calculate the perm parameter.
262   const int new_dims_cnt = params.perm_count - skip_dims_cnt;
263   non_flatten_input_shape->Resize(new_dims_cnt);
264   non_flatten_output_shape->Resize(new_dims_cnt);
265   non_flatten_params->perm_count = new_dims_cnt;
266
267   for (int i = skip_dims_cnt; i < params.perm_count; ++i)
268   {
269     non_flatten_input_shape->SetDim(i - skip_dims_cnt, input_shape.Dims(i));
270     non_flatten_output_shape->SetDim(i - skip_dims_cnt, output_shape.Dims(i));
271     non_flatten_params->perm[i - skip_dims_cnt] = params.perm[i];
272   }
273   for (int i = 0; i < new_dims_cnt; ++i)
274   {
275     int min_val_idx = -1;
276     for (int j = 0; j < new_dims_cnt; ++j)
277     {
278       if (non_flatten_params->perm[j] >= i &&
279           (min_val_idx == -1 ||
280            non_flatten_params->perm[min_val_idx] > non_flatten_params->perm[j]))
281       {
282         min_val_idx = j;
283       }
284     }
285     non_flatten_params->perm[min_val_idx] = i;
286   }
287
288   return flat_size;
289 }
290
291 } // namespace anonymous (util)
292
293 // Transpose2D only deals with typical 2D matrix transpose ops.
294 // Perform transpose by transposing 4x4 blocks of the input, proceeding from
295 // left to right (down the rows) of the input, and then from top to bottom.
296 template <typename T>
297 inline void Transpose2D(const Shape &input_shape, const T *input_data, const Shape &output_shape,
298                         T *output_data)
299 {
300   assert(input_shape.DimensionsCount() == 2);
301   assert(output_shape.DimensionsCount() == 2);
302   UNUSED_RELEASE(output_shape);
303
304   const int d0 = input_shape.DimsData()[0];
305   const int d1 = input_shape.DimsData()[1];
306   const int kLines = 4;
307   const int kSkipSize = (kLines - 1) * d1;
308
309   const T *input = input_data;
310
311   int i = 0;
312   for (; i <= d0 - kLines; i += kLines)
313   {
314     T *output = output_data + i;
315
316     const T *input_ptr = input;
317     optimized_ops_preload_l1_keep(input_ptr);
318     input_ptr += d1;
319     optimized_ops_preload_l1_keep(input_ptr);
320     input_ptr += d1;
321     optimized_ops_preload_l1_keep(input_ptr);
322     input_ptr += d1;
323     optimized_ops_preload_l1_keep(input_ptr);
324
325     int j = 0;
326     for (; j <= d1 - kLines; j += kLines)
327     {
328       input_ptr = input;
329       const T a00 = input_ptr[0];
330       const T a01 = input_ptr[1];
331       const T a02 = input_ptr[2];
332       const T a03 = input_ptr[3];
333       input_ptr += d1;
334       const T a10 = input_ptr[0];
335       const T a11 = input_ptr[1];
336       const T a12 = input_ptr[2];
337       const T a13 = input_ptr[3];
338       input_ptr += d1;
339       const T a20 = input_ptr[0];
340       const T a21 = input_ptr[1];
341       const T a22 = input_ptr[2];
342       const T a23 = input_ptr[3];
343       input_ptr += d1;
344       const T a30 = input_ptr[0];
345       const T a31 = input_ptr[1];
346       const T a32 = input_ptr[2];
347       const T a33 = input_ptr[3];
348
349       output[0] = a00;
350       output[1] = a10;
351       output[2] = a20;
352       output[3] = a30;
353       output += d0;
354
355       output[0] = a01;
356       output[1] = a11;
357       output[2] = a21;
358       output[3] = a31;
359       output += d0;
360
361       output[0] = a02;
362       output[1] = a12;
363       output[2] = a22;
364       output[3] = a32;
365       output += d0;
366
367       output[0] = a03;
368       output[1] = a13;
369       output[2] = a23;
370       output[3] = a33;
371       output += d0;
372
373       input += kLines;
374     }
375     if (j == d1)
376     {
377       input += kSkipSize;
378     }
379     else
380     {
381       for (int p = 0; p < kLines; ++p)
382       {
383         for (int q = 0; q < d1 - j; ++q)
384         {
385           *(output + q * d0 + p) = *(input + p * d1 + q);
386         }
387       }
388       input += (d1 - j) + kSkipSize;
389     }
390   }
391   for (; i < d0; ++i)
392   {
393     T *output = output_data + i;
394     for (int j = 0; j < d1; ++j)
395     {
396       *output = *input;
397       output += d0;
398       ++input;
399     }
400   }
401 }
402
403 // TODO(alanchiao): see if we can reduce the number
404 // of lines of code in branching without affecting latency.
405 template <typename T>
406 inline void Transpose3D(const TransposeParams &params, const Shape &input_shape,
407                         const T *input_data, const Shape &, T *output_data)
408 {
409   int s2, s3;
410   s2 = input_shape.Dims(1);
411   s3 = input_shape.Dims(2);
412
413   int p1 = 0;
414   int p2 = 0;
415   int p3 = 0;
416
417   if (params.perm[0] == 2)
418   {
419     p1 = 1;
420   }
421   else if (params.perm[1] == 2)
422   {
423     p2 = 1;
424   }
425   else
426   {
427     p3 = 1;
428   }
429
430   if (params.perm[0] == 1)
431   {
432     p1 = s3;
433   }
434   else if (params.perm[1] == 1)
435   {
436     p2 = s3;
437   }
438   else
439   {
440     p3 = s3;
441   }
442
443   if (params.perm[0] == 0)
444   {
445     p1 = s2 * s3;
446   }
447   else if (params.perm[1] == 0)
448   {
449     p2 = s2 * s3;
450   }
451   else
452   {
453     p3 = s2 * s3;
454   }
455
456   int o_s[3];
457   o_s[0] = input_shape.Dims(params.perm[0]);
458   o_s[1] = input_shape.Dims(params.perm[1]);
459   o_s[2] = input_shape.Dims(params.perm[2]);
460
461   for (int i1 = 0; i1 < o_s[0]; ++i1)
462   {
463     for (int i2 = 0; i2 < o_s[1]; ++i2)
464     {
465       for (int i3 = 0; i3 < o_s[2]; ++i3)
466       {
467         const int i = i1 * p1 + i2 * p2 + i3 * p3;
468         const int o = i1 * o_s[1] * o_s[2] + i2 * o_s[2] + i3;
469         output_data[o] = input_data[i];
470       }
471     }
472   }
473 }
474
475 template <typename T>
476 void TransposeImpl(const TransposeParams &params, const Shape &input_shape, const T *input_data,
477                    const Shape &output_shape, T *output_data)
478 {
479   const int dims_cnt = input_shape.DimensionsCount();
480
481   int dim0, dim1;
482   if (IsTranspose2DApplicable(params, input_shape, &dim0, &dim1))
483   {
484     Transpose2D(Shape({dim0, dim1}), input_data, Shape({dim1, dim0}), output_data);
485     return;
486   }
487
488   // TODO(b/141217325): notably Eigen is better suited for
489   // larger inputs whereas Transpose3D is generally
490   // better for smaller ones.
491   //
492   // E.g. on Nexus 5, Eigen is better for size 96^3 and up
493   // and Transpose3D is better for 72^3 and down.
494   //
495   // 96^3 is not mobile-friendly for certain usecases
496   // (e.g. model used in beam search for seq2seq) but is in others.
497   // Consider tradeoffs.
498   if (dims_cnt == 3)
499   {
500     Transpose3D(params, input_shape, input_data, output_shape, output_data);
501     return;
502   }
503
504   // Reroute to the reference version if an optimized method for the given data
505   // is not available.
506   reference::Transpose(params, input_shape, input_data, output_shape, output_data);
507 }
508
509 template <typename T>
510 void Transpose(const TransposeParams &unshrunk_params, const Shape &unshrunk_input_shape,
511                const T *input_data, const Shape &unshrunk_output_shape, T *output_data)
512 {
513   const int output_size = unshrunk_output_shape.DimensionsCount();
514   assert(unshrunk_input_shape.DimensionsCount() <= 4);
515   assert(output_size <= 4);
516   assert(output_size == unshrunk_params.perm_count);
517
518   Shape shrunk_input_shape = Shape(unshrunk_input_shape);
519
520   Shape shrunk_output_shape = Shape(unshrunk_output_shape);
521
522   TransposeParams shrunk_params = unshrunk_params;
523
524   // Reduce any dimensions that have one size. Lower transpose op usually
525   // performs better since memory access patterns will be improved.
526   RemoveOneSizeDimensions(&shrunk_input_shape, &shrunk_output_shape, &shrunk_params);
527
528   // Handle identity cases.
529   // TODO(b/140779653): Add an optimization pass in the conversion process to
530   // remove transpose op nodes where they do nothing like the below one.
531   bool identical = true;
532   for (int i = 0; i < shrunk_params.perm_count; ++i)
533
534   {
535     if (shrunk_params.perm[i] != i)
536
537     {
538       identical = false;
539       break;
540     }
541   }
542   if (identical)
543   {
544     memcpy(output_data, input_data, unshrunk_input_shape.FlatSize() * sizeof(T));
545     return;
546   }
547
548   // Reduce dimensions by flattening.
549   if (shrunk_params.perm[0] == 0 && output_size >= 3)
550
551   {
552     Shape non_flatten_input_shape;
553     Shape non_flatten_output_shape;
554     TransposeParams non_flatten_params;
555     const int total_size = shrunk_input_shape.FlatSize();
556
557     const int non_flatten_size =
558       Flatten(shrunk_input_shape, shrunk_output_shape, shrunk_params,
559
560               &non_flatten_input_shape, &non_flatten_output_shape, &non_flatten_params);
561     assert(non_flatten_params.perm[0] != 0);
562
563     for (int i = 0; i < total_size; i += non_flatten_size)
564     {
565       TransposeImpl(non_flatten_params, non_flatten_input_shape, input_data + i,
566                     non_flatten_output_shape, output_data + i);
567     }
568     return;
569   }
570
571   // Call non-flattened case.
572   TransposeImpl(shrunk_params, shrunk_input_shape, input_data, shrunk_output_shape,
573
574                 output_data);
575 }
576
577 } // namespace cker
578 } // namespace nnfw
579
580 #endif // __NNFW_CKER_TRANSPOSE_H__