035bc2122fbf02ec2d27659d84c524245fa41667
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / src / kernels / SplitV.test.cpp
1 /*
2  * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #include "kernels/SplitV.h"
19 #include "kernels/TestUtils.h"
20 #include "luci_interpreter/TestMemoryManager.h"
21
22 namespace luci_interpreter
23 {
24 namespace kernels
25 {
26 namespace
27 {
28
29 using namespace testing;
30
31 template <typename T>
32 void Check(int axis, std::initializer_list<int32_t> splits_size,
33            std::initializer_list<int32_t> input_shape, std::initializer_list<T> input_data,
34            std::vector<std::vector<T>> output_data)
35 {
36   std::unique_ptr<IMemoryManager> memory_manager = std::make_unique<TestMemoryManager>();
37   constexpr DataType element_type = getElementType<T>();
38
39   auto num_splits = static_cast<int32_t>(splits_size.size());
40   Tensor input_tensor =
41     makeInputTensor<element_type>(input_shape, input_data, memory_manager.get());
42   Tensor sizes_tensor =
43     makeInputTensor<DataType::S32>({num_splits}, splits_size, memory_manager.get());
44   Tensor axis_tensor = makeInputTensor<DataType::S32>({}, {axis}, memory_manager.get());
45
46   std::vector<Tensor> output_tensors;
47   output_tensors.reserve(num_splits);
48   for (int i = 0; i < num_splits; ++i)
49   {
50     output_tensors.emplace_back(makeOutputTensor(element_type));
51   }
52
53   std::vector<Tensor *> output_tensor_ptrs(num_splits);
54   for (int i = 0; i < num_splits; ++i)
55   {
56     output_tensor_ptrs[i] = &output_tensors[i];
57   }
58
59   SplitV kernel(&input_tensor, &sizes_tensor, &axis_tensor, std::move(output_tensor_ptrs));
60   kernel.configure();
61   for (int i = 0; i < num_splits; ++i)
62   {
63     memory_manager->allocate_memory(output_tensors[i]);
64   }
65   kernel.execute();
66
67   for (int i = 0; i < num_splits; ++i)
68   {
69     auto tmp = extractTensorData<T>(output_tensors[i]);
70     EXPECT_THAT(extractTensorData<T>(output_tensors[i]),
71                 ::testing::ElementsAreArray(output_data[i]));
72   }
73 }
74
75 template <typename T> class SplitVTest : public ::testing::Test
76 {
77 };
78
79 using DataTypes = ::testing::Types<float, uint8_t, int16_t>;
80 TYPED_TEST_SUITE(SplitVTest, DataTypes);
81
82 TYPED_TEST(SplitVTest, ThreeDimensional)
83 {
84   Check<TypeParam>(
85     /*axis=*/0, /*splits_size=*/{1, 2}, {3, 3, 3},
86     {1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 14,
87      15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27},
88     {
89       {1, 2, 3, 4, 5, 6, 7, 8, 9},                                             //
90       {10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27} //
91     });
92   Check<TypeParam>(
93     /*axis=*/1, /*splits_size=*/{1, 2}, {3, 3, 3},
94     {1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 14,
95      15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27},
96     {
97       {1, 2, 3, 10, 11, 12, 19, 20, 21},                                 //
98       {4, 5, 6, 7, 8, 9, 13, 14, 15, 16, 17, 18, 22, 23, 24, 25, 26, 27} //
99     });
100   Check<TypeParam>(
101     /*axis=*/2, /*splits_size=*/{1, 2}, {3, 3, 3},
102     {1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 14,
103      15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27},
104     {
105       {1, 4, 7, 10, 13, 16, 19, 22, 25},                                 //
106       {2, 3, 5, 6, 8, 9, 11, 12, 14, 15, 17, 18, 20, 21, 23, 24, 26, 27} //
107     });
108 }
109
110 } // namespace
111 } // namespace kernels
112 } // namespace luci_interpreter