2 * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #include "kernels/SplitV.h"
19 #include "kernels/TestUtils.h"
20 #include "luci_interpreter/TestMemoryManager.h"
22 namespace luci_interpreter
29 using namespace testing;
32 void Check(int axis, std::initializer_list<int32_t> splits_size,
33 std::initializer_list<int32_t> input_shape, std::initializer_list<T> input_data,
34 std::vector<std::vector<T>> output_data)
36 std::unique_ptr<IMemoryManager> memory_manager = std::make_unique<TestMemoryManager>();
37 constexpr DataType element_type = getElementType<T>();
39 auto num_splits = static_cast<int32_t>(splits_size.size());
41 makeInputTensor<element_type>(input_shape, input_data, memory_manager.get());
43 makeInputTensor<DataType::S32>({num_splits}, splits_size, memory_manager.get());
44 Tensor axis_tensor = makeInputTensor<DataType::S32>({}, {axis}, memory_manager.get());
46 std::vector<Tensor> output_tensors;
47 output_tensors.reserve(num_splits);
48 for (int i = 0; i < num_splits; ++i)
50 output_tensors.emplace_back(makeOutputTensor(element_type));
53 std::vector<Tensor *> output_tensor_ptrs(num_splits);
54 for (int i = 0; i < num_splits; ++i)
56 output_tensor_ptrs[i] = &output_tensors[i];
59 SplitV kernel(&input_tensor, &sizes_tensor, &axis_tensor, std::move(output_tensor_ptrs));
61 for (int i = 0; i < num_splits; ++i)
63 memory_manager->allocate_memory(output_tensors[i]);
67 for (int i = 0; i < num_splits; ++i)
69 auto tmp = extractTensorData<T>(output_tensors[i]);
70 EXPECT_THAT(extractTensorData<T>(output_tensors[i]),
71 ::testing::ElementsAreArray(output_data[i]));
75 template <typename T> class SplitVTest : public ::testing::Test
79 using DataTypes = ::testing::Types<float, uint8_t, int16_t>;
80 TYPED_TEST_CASE(SplitVTest, DataTypes);
82 TYPED_TEST(SplitVTest, ThreeDimensional)
85 /*axis=*/0, /*splits_size=*/{1, 2}, {3, 3, 3},
86 {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
87 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27},
89 {1, 2, 3, 4, 5, 6, 7, 8, 9}, //
90 {10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27} //
93 /*axis=*/1, /*splits_size=*/{1, 2}, {3, 3, 3},
94 {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
95 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27},
97 {1, 2, 3, 10, 11, 12, 19, 20, 21}, //
98 {4, 5, 6, 7, 8, 9, 13, 14, 15, 16, 17, 18, 22, 23, 24, 25, 26, 27} //
101 /*axis=*/2, /*splits_size=*/{1, 2}, {3, 3, 3},
102 {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
103 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27},
105 {1, 4, 7, 10, 13, 16, 19, 22, 25}, //
106 {2, 3, 5, 6, 8, 9, 11, 12, 14, 15, 17, 18, 20, 21, 23, 24, 26, 27} //
111 } // namespace kernels
112 } // namespace luci_interpreter