2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #include "kernels/Split.h"
19 #include "kernels/TestUtils.h"
21 namespace luci_interpreter
28 using namespace testing;
31 void Check(int axis, int num_splits, std::initializer_list<int32_t> input_shape,
32 std::initializer_list<int32_t> output_shape, std::initializer_list<T> input_data,
33 std::vector<std::vector<T>> output_data, DataType element_type)
35 Tensor axis_tensor = makeInputTensor<DataType::S32>({}, {axis});
36 Tensor input_tensor{element_type, input_shape, {}, ""};
37 input_tensor.writeData(input_data.begin(), input_data.size() * sizeof(T));
39 std::vector<Tensor> output_tensors;
40 output_tensors.reserve(num_splits);
41 for (int i = 0; i < num_splits; ++i)
43 output_tensors.emplace_back(makeOutputTensor(element_type));
46 std::vector<Tensor *> output_tensor_ptrs(num_splits);
47 for (int i = 0; i < num_splits; ++i)
49 output_tensor_ptrs[i] = &output_tensors[i];
52 Split kernel(&axis_tensor, &input_tensor, std::move(output_tensor_ptrs));
56 for (int i = 0; i < num_splits; ++i)
58 EXPECT_THAT(extractTensorData<T>(output_tensors[i]),
59 ::testing::ElementsAreArray(output_data[i]));
63 template <typename T> class SplitTest : public ::testing::Test
67 using DataTypes = ::testing::Types<float, uint8_t>;
68 TYPED_TEST_CASE(SplitTest, DataTypes);
70 TYPED_TEST(SplitTest, FourDimensional)
72 Check<TypeParam>(/*axis=*/0, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2},
73 {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
75 {1, 2, 3, 4, 5, 6, 7, 8}, //
76 {9, 10, 11, 12, 13, 14, 15, 16}, //
78 getElementType<TypeParam>());
80 /*axis=*/1, /*num_splits=*/2, {2, 2, 2, 2}, {2, 1, 2, 2},
81 {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
83 {1, 2, 3, 4, 9, 10, 11, 12}, //
84 {5, 6, 7, 8, 13, 14, 15, 16}, //
86 getElementType<TypeParam>());
88 /*axis=*/2, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 1, 2},
89 {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
91 {1, 2, 5, 6, 9, 10, 13, 14}, //
92 {3, 4, 7, 8, 11, 12, 15, 16}, //
94 getElementType<TypeParam>());
96 /*axis=*/3, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 2, 1},
97 {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
99 {1, 3, 5, 7, 9, 11, 13, 15}, //
100 {2, 4, 6, 8, 10, 12, 14, 16}, //
102 getElementType<TypeParam>());
105 TYPED_TEST(SplitTest, OneDimensional)
108 /*axis=*/0, /*num_splits=*/8, {8}, {1}, {1, 2, 3, 4, 5, 6, 7, 8},
109 {{1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}}, getElementType<TypeParam>());
112 TYPED_TEST(SplitTest, NegativeAxis)
115 /*axis=*/-4, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2},
116 {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
118 {1, 2, 3, 4, 5, 6, 7, 8}, //
119 {9, 10, 11, 12, 13, 14, 15, 16},
121 getElementType<TypeParam>());
125 } // namespace kernels
126 } // namespace luci_interpreter