Imported Upstream version 1.7.0
[platform/core/ml/nnfw.git] / compiler / luci-interpreter / src / kernels / Split.test.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #include "kernels/Split.h"
19 #include "kernels/TestUtils.h"
20
21 namespace luci_interpreter
22 {
23 namespace kernels
24 {
25 namespace
26 {
27
28 using namespace testing;
29
30 template <typename T>
31 void Check(int axis, int num_splits, std::initializer_list<int32_t> input_shape,
32            std::initializer_list<int32_t> output_shape, std::initializer_list<T> input_data,
33            std::vector<std::vector<T>> output_data, DataType element_type)
34 {
35   Tensor axis_tensor = makeInputTensor<DataType::S32>({}, {axis});
36   Tensor input_tensor{element_type, input_shape, {}, ""};
37   input_tensor.writeData(input_data.begin(), input_data.size() * sizeof(T));
38
39   std::vector<Tensor> output_tensors;
40   output_tensors.reserve(num_splits);
41   for (int i = 0; i < num_splits; ++i)
42   {
43     output_tensors.emplace_back(makeOutputTensor(element_type));
44   }
45
46   std::vector<Tensor *> output_tensor_ptrs(num_splits);
47   for (int i = 0; i < num_splits; ++i)
48   {
49     output_tensor_ptrs[i] = &output_tensors[i];
50   }
51
52   Split kernel(&axis_tensor, &input_tensor, std::move(output_tensor_ptrs));
53   kernel.configure();
54   kernel.execute();
55
56   for (int i = 0; i < num_splits; ++i)
57   {
58     EXPECT_THAT(extractTensorData<T>(output_tensors[i]),
59                 ::testing::ElementsAreArray(output_data[i]));
60   }
61 }
62
63 template <typename T> class SplitTest : public ::testing::Test
64 {
65 };
66
67 using DataTypes = ::testing::Types<float, uint8_t>;
68 TYPED_TEST_CASE(SplitTest, DataTypes);
69
70 TYPED_TEST(SplitTest, FourDimensional)
71 {
72   Check<TypeParam>(/*axis=*/0, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2},
73                    {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
74                    {
75                        {1, 2, 3, 4, 5, 6, 7, 8},        //
76                        {9, 10, 11, 12, 13, 14, 15, 16}, //
77                    },
78                    getElementType<TypeParam>());
79   Check<TypeParam>(
80       /*axis=*/1, /*num_splits=*/2, {2, 2, 2, 2}, {2, 1, 2, 2},
81       {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
82       {
83           {1, 2, 3, 4, 9, 10, 11, 12},  //
84           {5, 6, 7, 8, 13, 14, 15, 16}, //
85       },
86       getElementType<TypeParam>());
87   Check<TypeParam>(
88       /*axis=*/2, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 1, 2},
89       {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
90       {
91           {1, 2, 5, 6, 9, 10, 13, 14},  //
92           {3, 4, 7, 8, 11, 12, 15, 16}, //
93       },
94       getElementType<TypeParam>());
95   Check<TypeParam>(
96       /*axis=*/3, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 2, 1},
97       {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
98       {
99           {1, 3, 5, 7, 9, 11, 13, 15},  //
100           {2, 4, 6, 8, 10, 12, 14, 16}, //
101       },
102       getElementType<TypeParam>());
103 }
104
105 TYPED_TEST(SplitTest, OneDimensional)
106 {
107   Check<TypeParam>(
108       /*axis=*/0, /*num_splits=*/8, {8}, {1}, {1, 2, 3, 4, 5, 6, 7, 8},
109       {{1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}}, getElementType<TypeParam>());
110 }
111
112 TYPED_TEST(SplitTest, NegativeAxis)
113 {
114   Check<TypeParam>(
115       /*axis=*/-4, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2},
116       {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
117       {
118           {1, 2, 3, 4, 5, 6, 7, 8}, //
119           {9, 10, 11, 12, 13, 14, 15, 16},
120       },
121       getElementType<TypeParam>());
122 }
123
124 } // namespace
125 } // namespace kernels
126 } // namespace luci_interpreter