Imported Upstream version 1.18.0
[platform/core/ml/nnfw.git] / compiler / luci-interpreter / src / kernels / Split.test.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #include "kernels/Split.h"
19 #include "kernels/TestUtils.h"
20 #include "luci_interpreter/TestMemoryManager.h"
21
22 namespace luci_interpreter
23 {
24 namespace kernels
25 {
26 namespace
27 {
28
29 using namespace testing;
30
31 template <typename T>
32 void Check(int axis, int num_splits, std::initializer_list<int32_t> input_shape,
33            std::initializer_list<int32_t> output_shape, std::initializer_list<T> input_data,
34            std::vector<std::vector<T>> output_data)
35 {
36   std::unique_ptr<IMemoryManager> memory_manager = std::make_unique<TestMemoryManager>();
37
38   constexpr DataType element_type = getElementType<T>();
39   Tensor axis_tensor = makeInputTensor<DataType::S32>({}, {axis}, memory_manager.get());
40   Tensor input_tensor =
41     makeInputTensor<element_type>(input_shape, input_data, memory_manager.get());
42
43   std::vector<Tensor> output_tensors;
44   output_tensors.reserve(num_splits);
45   for (int i = 0; i < num_splits; ++i)
46   {
47     output_tensors.emplace_back(makeOutputTensor(element_type));
48   }
49
50   std::vector<Tensor *> output_tensor_ptrs(num_splits);
51   for (int i = 0; i < num_splits; ++i)
52   {
53     output_tensor_ptrs[i] = &output_tensors[i];
54   }
55
56   Split kernel(&axis_tensor, &input_tensor, std::move(output_tensor_ptrs));
57   kernel.configure();
58   for (int i = 0; i < num_splits; ++i)
59   {
60     memory_manager->allocate_memory(output_tensors[i]);
61   }
62   kernel.execute();
63
64   for (int i = 0; i < num_splits; ++i)
65   {
66     EXPECT_THAT(extractTensorData<T>(output_tensors[i]),
67                 ::testing::ElementsAreArray(output_data[i]));
68   }
69 }
70
71 template <typename T> class SplitTest : public ::testing::Test
72 {
73 };
74
75 using DataTypes = ::testing::Types<float, uint8_t>;
76 TYPED_TEST_CASE(SplitTest, DataTypes);
77
78 TYPED_TEST(SplitTest, FourDimensional)
79 {
80   Check<TypeParam>(/*axis=*/0, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2},
81                    {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
82                    {
83                      {1, 2, 3, 4, 5, 6, 7, 8},        //
84                      {9, 10, 11, 12, 13, 14, 15, 16}, //
85                    });
86   Check<TypeParam>(
87     /*axis=*/1, /*num_splits=*/2, {2, 2, 2, 2}, {2, 1, 2, 2},
88     {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
89     {
90       {1, 2, 3, 4, 9, 10, 11, 12},  //
91       {5, 6, 7, 8, 13, 14, 15, 16}, //
92     });
93   Check<TypeParam>(
94     /*axis=*/2, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 1, 2},
95     {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
96     {
97       {1, 2, 5, 6, 9, 10, 13, 14},  //
98       {3, 4, 7, 8, 11, 12, 15, 16}, //
99     });
100   Check<TypeParam>(
101     /*axis=*/3, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 2, 1},
102     {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
103     {
104       {1, 3, 5, 7, 9, 11, 13, 15},  //
105       {2, 4, 6, 8, 10, 12, 14, 16}, //
106     });
107 }
108
109 TYPED_TEST(SplitTest, OneDimensional)
110 {
111   Check<TypeParam>(
112     /*axis=*/0, /*num_splits=*/8, {8}, {1}, {1, 2, 3, 4, 5, 6, 7, 8},
113     {{1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}});
114 }
115
116 TYPED_TEST(SplitTest, NegativeAxis)
117 {
118   Check<TypeParam>(
119     /*axis=*/-4, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2},
120     {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
121     {
122       {1, 2, 3, 4, 5, 6, 7, 8}, //
123       {9, 10, 11, 12, 13, 14, 15, 16},
124     });
125 }
126
127 } // namespace
128 } // namespace kernels
129 } // namespace luci_interpreter