2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #include "kernels/Split.h"
19 #include "kernels/TestUtils.h"
20 #include "luci_interpreter/TestMemoryManager.h"
22 namespace luci_interpreter
29 using namespace testing;
32 void Check(int axis, int num_splits, std::initializer_list<int32_t> input_shape,
33 std::initializer_list<int32_t> output_shape, std::initializer_list<T> input_data,
34 std::vector<std::vector<T>> output_data)
36 std::unique_ptr<IMemoryManager> memory_manager = std::make_unique<TestMemoryManager>();
38 constexpr DataType element_type = getElementType<T>();
39 Tensor axis_tensor = makeInputTensor<DataType::S32>({}, {axis}, memory_manager.get());
41 makeInputTensor<element_type>(input_shape, input_data, memory_manager.get());
43 std::vector<Tensor> output_tensors;
44 output_tensors.reserve(num_splits);
45 for (int i = 0; i < num_splits; ++i)
47 output_tensors.emplace_back(makeOutputTensor(element_type));
50 std::vector<Tensor *> output_tensor_ptrs(num_splits);
51 for (int i = 0; i < num_splits; ++i)
53 output_tensor_ptrs[i] = &output_tensors[i];
56 Split kernel(&axis_tensor, &input_tensor, std::move(output_tensor_ptrs));
58 for (int i = 0; i < num_splits; ++i)
60 memory_manager->allocate_memory(output_tensors[i]);
64 for (int i = 0; i < num_splits; ++i)
66 EXPECT_THAT(extractTensorData<T>(output_tensors[i]),
67 ::testing::ElementsAreArray(output_data[i]));
71 template <typename T> class SplitTest : public ::testing::Test
75 using DataTypes = ::testing::Types<float, uint8_t>;
76 TYPED_TEST_CASE(SplitTest, DataTypes);
78 TYPED_TEST(SplitTest, FourDimensional)
80 Check<TypeParam>(/*axis=*/0, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2},
81 {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
83 {1, 2, 3, 4, 5, 6, 7, 8}, //
84 {9, 10, 11, 12, 13, 14, 15, 16}, //
87 /*axis=*/1, /*num_splits=*/2, {2, 2, 2, 2}, {2, 1, 2, 2},
88 {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
90 {1, 2, 3, 4, 9, 10, 11, 12}, //
91 {5, 6, 7, 8, 13, 14, 15, 16}, //
94 /*axis=*/2, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 1, 2},
95 {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
97 {1, 2, 5, 6, 9, 10, 13, 14}, //
98 {3, 4, 7, 8, 11, 12, 15, 16}, //
101 /*axis=*/3, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 2, 1},
102 {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
104 {1, 3, 5, 7, 9, 11, 13, 15}, //
105 {2, 4, 6, 8, 10, 12, 14, 16}, //
109 TYPED_TEST(SplitTest, OneDimensional)
112 /*axis=*/0, /*num_splits=*/8, {8}, {1}, {1, 2, 3, 4, 5, 6, 7, 8},
113 {{1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}});
116 TYPED_TEST(SplitTest, NegativeAxis)
119 /*axis=*/-4, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2},
120 {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
122 {1, 2, 3, 4, 5, 6, 7, 8}, //
123 {9, 10, 11, 12, 13, 14, 15, 16},
128 } // namespace kernels
129 } // namespace luci_interpreter