Imported Upstream version 1.8.0
[platform/core/ml/nnfw.git] / compute / cker / include / cker / operation / SplitV.h
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef __NNFW_CKER_SPLIT_V_H__
19 #define __NNFW_CKER_SPLIT_V_H__
20
21 #include "cker/Shape.h"
22 #include "cker/Types.h"
23
24 namespace nnfw
25 {
26 namespace cker
27 {
28
29 template <typename Scalar>
30 void SplitV(const SplitVParams &params, const Shape &input_shape, const Scalar *input_data,
31             std::vector<nnfw::cker::Shape> &output_shapes, Scalar *const *output_data)
32 {
33   const int split_dimensions = input_shape.DimensionsCount();
34   int axis = params.axis < 0 ? params.axis + split_dimensions : params.axis;
35   int outputs_count = params.num_split;
36
37   int64_t split_size = 0;
38
39   for (int i = 0; i < outputs_count; i++)
40   {
41     // TFLITE_DCHECK_EQ(output_shapes[i]->DimensionsCount(), split_dimensions);
42     for (int j = 0; j < split_dimensions; j++)
43     {
44       if (j != axis)
45       {
46         MatchingDim(output_shapes[i], j, input_shape, j);
47       }
48     }
49     split_size += output_shapes[i].Dims(axis);
50   }
51
52   int64_t outer_size = 1;
53   for (int i = 0; i < axis; ++i)
54   {
55     outer_size *= input_shape.Dims(i);
56   }
57   // For all output arrays,
58   // FlatSize() = outer_size * Dims(axis) * base_inner_size;
59   int64_t base_inner_size = 1;
60   for (int i = axis + 1; i < split_dimensions; ++i)
61   {
62     base_inner_size *= input_shape.Dims(i);
63   }
64
65   const Scalar *input_ptr = input_data;
66   int copy_size = 0;
67   for (int k = 0; k < outer_size; k++)
68   {
69     for (int i = 0; i < outputs_count; ++i)
70     {
71       copy_size = output_shapes[i].Dims(axis) * base_inner_size;
72       memcpy(output_data[i] + k * copy_size, input_ptr, copy_size * sizeof(Scalar));
73       input_ptr += copy_size;
74     }
75   }
76 }
77
78 } // namespace cker
79 } // namespace nnfw
80
81 #endif // __NNFW_CKER_SPLIT_V_H__