2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #ifndef __NNFW_CKER_TRANSPOSE_H__
19 #define __NNFW_CKER_TRANSPOSE_H__
21 #include "cker/Shape.h"
22 #include "cker/Types.h"
23 #include "cker/Utils.h"
33 void TransposeImpl(const TransposeParams ¶ms, const Shape &unextended_input_shape,
34 const T *input_data, const Shape &unextended_output_shape, T *output_data)
36 const int unextended_output_size = unextended_output_shape.DimensionsCount();
37 assert(unextended_input_shape.DimensionsCount() <= 4);
38 assert(unextended_output_size <= 4);
39 assert(unextended_output_size == params.perm_count);
40 const Shape input_shape = Shape::ExtendedShape(4, unextended_input_shape);
41 const Shape output_shape = Shape::ExtendedShape(4, unextended_output_shape);
42 const int input_ext_size = 4 - unextended_input_shape.DimensionsCount();
43 const int output_ext_size = 4 - unextended_output_size;
45 // The perm data is extended to match the output, each index incremented by
46 // the amount of front padding of the input shape.
48 for (int i = 0; i < output_ext_size; ++i)
52 for (int i = 0; i < unextended_output_size; ++i)
54 extended_perm[i + output_ext_size] = params.perm[i] + input_ext_size;
58 // Compute the inverse permutation array so we can do an output centered
59 // transpose. Also, check to make sure output_dims is matching input_dims.
60 for (int k = 0; k < 4; k++)
62 out_sizes[k] = MatchingDim(input_shape, extended_perm[k], output_shape, k);
65 // Naive transpose loop (iterate on output index and compute input index).
66 int o[4]; // loop index (on output).
68 for (o[3] = 0; o[3] < out_sizes[3]; o[3]++)
70 i[extended_perm[3]] = o[3];
71 for (o[2] = 0; o[2] < out_sizes[2]; o[2]++)
73 i[extended_perm[2]] = o[2];
74 for (o[1] = 0; o[1] < out_sizes[1]; o[1]++)
76 i[extended_perm[1]] = o[1];
77 for (o[0] = 0; o[0] < out_sizes[0]; o[0]++)
79 i[extended_perm[0]] = o[0];
80 output_data[Offset(output_shape, o)] = input_data[Offset(input_shape, i)];
88 void Transpose(const TransposeParams ¶ms, const Shape &unextended_input_shape,
89 const T *input_data, const Shape &unextended_output_shape, T *output_data)
91 // Transpose kernel only does rearranging values not numeric evaluations on
92 // each cell. It's safe to implement per size of scalar type and this trick
93 // keeps the total code size in a reasonable range.
97 TransposeImpl<int8_t>(params, unextended_input_shape,
98 reinterpret_cast<const int8_t *>(input_data), unextended_output_shape,
99 reinterpret_cast<int8_t *>(output_data));
102 TransposeImpl<int16_t>(params, unextended_input_shape,
103 reinterpret_cast<const int16_t *>(input_data), unextended_output_shape,
104 reinterpret_cast<int16_t *>(output_data));
108 TransposeImpl<int32_t>(params, unextended_input_shape,
109 reinterpret_cast<const int32_t *>(input_data), unextended_output_shape,
110 reinterpret_cast<int32_t *>(output_data));
113 TransposeImpl<int64_t>(params, unextended_input_shape,
114 reinterpret_cast<const int64_t *>(input_data), unextended_output_shape,
115 reinterpret_cast<int64_t *>(output_data));
119 } // namespace reference
124 bool IsTranspose2DApplicable(const TransposeParams ¶ms, const Shape &input_shape, int *dim0,
127 const int dims_cnt = input_shape.DimensionsCount();
131 *dim0 = input_shape.Dims(0);
132 *dim1 = input_shape.Dims(1);
136 const int first_perm = params.perm[0];
137 for (int i = 1; i < dims_cnt; ++i)
139 int rebased = params.perm[i] - first_perm;
151 for (int i = 0; i < dims_cnt; ++i)
155 *dim0 *= input_shape.Dims(i);
159 *dim1 *= input_shape.Dims(i);
165 void RemoveOneSizeDimensions(Shape *input_shape, Shape *output_shape, TransposeParams *params)
167 const int dims_cnt = input_shape->DimensionsCount();
168 assert(params->perm_count == dims_cnt);
170 bool foundOneSizeDim = false;
171 for (int i = 0; i < dims_cnt; ++i)
173 if (input_shape->Dims(i) == 1)
175 foundOneSizeDim = true;
180 // Return here if there is no one size dimension.
181 if (!foundOneSizeDim)
184 // Handle the case where all the dimension size is one.
185 if (input_shape->FlatSize() == 1)
187 input_shape->Resize(1);
188 input_shape->SetDim(0, 1);
189 output_shape->Resize(1);
190 output_shape->SetDim(0, 1);
191 params->perm_count = 1;
196 // Resize input shape.
197 int new_dims_cnt = 0;
198 for (int i = 0; i < dims_cnt; ++i)
200 if (input_shape->Dims(i) == 1)
204 input_shape->SetDim(new_dims_cnt, input_shape->Dims(i));
207 input_shape->Resize(new_dims_cnt);
209 // Resize output shape and re-calculate the perm parameter.
210 TransposeParams new_params;
212 for (int i = 0; i < dims_cnt; ++i)
214 if (output_shape->Dims(i) == 1)
218 new_params.perm[new_dims_cnt] = params->perm[i];
219 output_shape->SetDim(new_dims_cnt, output_shape->Dims(i));
222 output_shape->Resize(new_dims_cnt);
223 new_params.perm_count = new_dims_cnt;
225 for (int i = 0; i < new_dims_cnt; ++i)
227 int min_val_idx = -1;
228 for (int j = 0; j < new_dims_cnt; ++j)
230 if (new_params.perm[j] >= i &&
231 (min_val_idx == -1 || new_params.perm[min_val_idx] > new_params.perm[j]))
236 new_params.perm[min_val_idx] = i;
238 *params = new_params;
241 size_t Flatten(const Shape &input_shape, const Shape &output_shape, const TransposeParams ¶ms,
242 Shape *non_flatten_input_shape, Shape *non_flatten_output_shape,
243 TransposeParams *non_flatten_params)
245 // Calculate the total size of non-flatten dimensions.
246 int skip_dims_cnt = 0;
247 size_t flat_size = input_shape.FlatSize();
248 for (int i = 0; i < params.perm_count; ++i)
250 if (params.perm[i] == i)
252 flat_size /= input_shape.Dims(i);
261 // Shrink the shapes and re-calculate the perm parameter.
262 const int new_dims_cnt = params.perm_count - skip_dims_cnt;
263 non_flatten_input_shape->Resize(new_dims_cnt);
264 non_flatten_output_shape->Resize(new_dims_cnt);
265 non_flatten_params->perm_count = new_dims_cnt;
267 for (int i = skip_dims_cnt; i < params.perm_count; ++i)
269 non_flatten_input_shape->SetDim(i - skip_dims_cnt, input_shape.Dims(i));
270 non_flatten_output_shape->SetDim(i - skip_dims_cnt, output_shape.Dims(i));
271 non_flatten_params->perm[i - skip_dims_cnt] = params.perm[i];
273 for (int i = 0; i < new_dims_cnt; ++i)
275 int min_val_idx = -1;
276 for (int j = 0; j < new_dims_cnt; ++j)
278 if (non_flatten_params->perm[j] >= i &&
279 (min_val_idx == -1 ||
280 non_flatten_params->perm[min_val_idx] > non_flatten_params->perm[j]))
285 non_flatten_params->perm[min_val_idx] = i;
291 } // namespace anonymous (util)
293 // Transpose2D only deals with typical 2D matrix transpose ops.
294 // Perform transpose by transposing 4x4 blocks of the input, proceeding from
295 // left to right (down the rows) of the input, and then from top to bottom.
296 template <typename T>
297 inline void Transpose2D(const Shape &input_shape, const T *input_data, const Shape &output_shape,
300 assert(input_shape.DimensionsCount() == 2);
301 assert(output_shape.DimensionsCount() == 2);
302 UNUSED_RELEASE(output_shape);
304 const int d0 = input_shape.DimsData()[0];
305 const int d1 = input_shape.DimsData()[1];
306 const int kLines = 4;
307 const int kSkipSize = (kLines - 1) * d1;
309 const T *input = input_data;
312 for (; i <= d0 - kLines; i += kLines)
314 T *output = output_data + i;
316 const T *input_ptr = input;
317 optimized_ops_preload_l1_keep(input_ptr);
319 optimized_ops_preload_l1_keep(input_ptr);
321 optimized_ops_preload_l1_keep(input_ptr);
323 optimized_ops_preload_l1_keep(input_ptr);
326 for (; j <= d1 - kLines; j += kLines)
329 const T a00 = input_ptr[0];
330 const T a01 = input_ptr[1];
331 const T a02 = input_ptr[2];
332 const T a03 = input_ptr[3];
334 const T a10 = input_ptr[0];
335 const T a11 = input_ptr[1];
336 const T a12 = input_ptr[2];
337 const T a13 = input_ptr[3];
339 const T a20 = input_ptr[0];
340 const T a21 = input_ptr[1];
341 const T a22 = input_ptr[2];
342 const T a23 = input_ptr[3];
344 const T a30 = input_ptr[0];
345 const T a31 = input_ptr[1];
346 const T a32 = input_ptr[2];
347 const T a33 = input_ptr[3];
381 for (int p = 0; p < kLines; ++p)
383 for (int q = 0; q < d1 - j; ++q)
385 *(output + q * d0 + p) = *(input + p * d1 + q);
388 input += (d1 - j) + kSkipSize;
393 T *output = output_data + i;
394 for (int j = 0; j < d1; ++j)
403 // TODO(alanchiao): see if we can reduce the number
404 // of lines of code in branching without affecting latency.
405 template <typename T>
406 inline void Transpose3D(const TransposeParams ¶ms, const Shape &input_shape,
407 const T *input_data, const Shape &, T *output_data)
410 s2 = input_shape.Dims(1);
411 s3 = input_shape.Dims(2);
417 if (params.perm[0] == 2)
421 else if (params.perm[1] == 2)
430 if (params.perm[0] == 1)
434 else if (params.perm[1] == 1)
443 if (params.perm[0] == 0)
447 else if (params.perm[1] == 0)
457 o_s[0] = input_shape.Dims(params.perm[0]);
458 o_s[1] = input_shape.Dims(params.perm[1]);
459 o_s[2] = input_shape.Dims(params.perm[2]);
461 for (int i1 = 0; i1 < o_s[0]; ++i1)
463 for (int i2 = 0; i2 < o_s[1]; ++i2)
465 for (int i3 = 0; i3 < o_s[2]; ++i3)
467 const int i = i1 * p1 + i2 * p2 + i3 * p3;
468 const int o = i1 * o_s[1] * o_s[2] + i2 * o_s[2] + i3;
469 output_data[o] = input_data[i];
475 template <typename T>
476 void TransposeImpl(const TransposeParams ¶ms, const Shape &input_shape, const T *input_data,
477 const Shape &output_shape, T *output_data)
479 const int dims_cnt = input_shape.DimensionsCount();
482 if (IsTranspose2DApplicable(params, input_shape, &dim0, &dim1))
484 Transpose2D(Shape({dim0, dim1}), input_data, Shape({dim1, dim0}), output_data);
488 // TODO(b/141217325): notably Eigen is better suited for
489 // larger inputs whereas Transpose3D is generally
490 // better for smaller ones.
492 // E.g. on Nexus 5, Eigen is better for size 96^3 and up
493 // and Transpose3D is better for 72^3 and down.
495 // 96^3 is not mobile-friendly for certain usecases
496 // (e.g. model used in beam search for seq2seq) but is in others.
497 // Consider tradeoffs.
500 Transpose3D(params, input_shape, input_data, output_shape, output_data);
504 // Reroute to the reference version if an optimized method for the given data
506 reference::Transpose(params, input_shape, input_data, output_shape, output_data);
509 template <typename T>
510 void Transpose(const TransposeParams &unshrunk_params, const Shape &unshrunk_input_shape,
511 const T *input_data, const Shape &unshrunk_output_shape, T *output_data)
513 const int output_size = unshrunk_output_shape.DimensionsCount();
514 assert(unshrunk_input_shape.DimensionsCount() <= 4);
515 assert(output_size <= 4);
516 assert(output_size == unshrunk_params.perm_count);
518 Shape shrunk_input_shape = Shape(unshrunk_input_shape);
520 Shape shrunk_output_shape = Shape(unshrunk_output_shape);
522 TransposeParams shrunk_params = unshrunk_params;
524 // Reduce any dimensions that have one size. Lower transpose op usually
525 // performs better since memory access patterns will be improved.
526 RemoveOneSizeDimensions(&shrunk_input_shape, &shrunk_output_shape, &shrunk_params);
528 // Handle identity cases.
529 // TODO(b/140779653): Add an optimization pass in the conversion process to
530 // remove transpose op nodes where they do nothing like the below one.
531 bool identical = true;
532 for (int i = 0; i < shrunk_params.perm_count; ++i)
535 if (shrunk_params.perm[i] != i)
544 memcpy(output_data, input_data, unshrunk_input_shape.FlatSize() * sizeof(T));
548 // Reduce dimensions by flattening.
549 if (shrunk_params.perm[0] == 0 && output_size >= 3)
552 Shape non_flatten_input_shape;
553 Shape non_flatten_output_shape;
554 TransposeParams non_flatten_params;
555 const int total_size = shrunk_input_shape.FlatSize();
557 const int non_flatten_size =
558 Flatten(shrunk_input_shape, shrunk_output_shape, shrunk_params,
560 &non_flatten_input_shape, &non_flatten_output_shape, &non_flatten_params);
561 assert(non_flatten_params.perm[0] != 0);
563 for (int i = 0; i < total_size; i += non_flatten_size)
565 TransposeImpl(non_flatten_params, non_flatten_input_shape, input_data + i,
566 non_flatten_output_shape, output_data + i);
571 // Call non-flattened case.
572 TransposeImpl(shrunk_params, shrunk_input_shape, input_data, shrunk_output_shape,
580 #endif // __NNFW_CKER_TRANSPOSE_H__