2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #include "kernels/Unpack.h"
19 #include "kernels/TestUtils.h"
20 #include "luci_interpreter/TestMemoryManager.h"
22 namespace luci_interpreter
29 using namespace testing;
32 void Check(int axis, Shape input_shape, std::initializer_list<T> input_data,
33 const std::vector<std::initializer_list<int32_t>> &exp_output_shape,
34 std::vector<std::initializer_list<T>> exp_output_data)
36 std::unique_ptr<IMemoryManager> memory_manager = std::make_unique<TestMemoryManager>();
37 constexpr DataType element_type = getElementType<T>();
38 const int num_outputs = input_shape.dim(axis < 0 ? axis + input_shape.num_dims() : axis);
41 makeInputTensor<element_type>(input_shape, input_data, memory_manager.get());
42 std::vector<Tensor> output_tensors;
43 output_tensors.reserve(num_outputs);
44 for (int i = 0; i < num_outputs; ++i)
46 output_tensors.push_back(makeOutputTensor(element_type));
49 std::vector<Tensor *> output_tensor_ptrs(num_outputs);
50 for (int i = 0; i < num_outputs; ++i)
52 output_tensor_ptrs[i] = &output_tensors[i];
55 UnpackParams params{};
58 Unpack kernel(&input_tensor, std::move(output_tensor_ptrs), params);
60 for (int i = 0; i < num_outputs; i++)
62 memory_manager->allocate_memory(output_tensors[i]);
66 for (int i = 0; i < num_outputs; ++i)
68 EXPECT_THAT(extractTensorData<T>(output_tensors[i]),
69 ::testing::ElementsAreArray(exp_output_data[i]));
73 template <typename T> class UnpackTest : public ::testing::Test
77 using DataTypes = ::testing::Types<float, uint8_t>;
78 TYPED_TEST_CASE(UnpackTest, DataTypes);
80 TYPED_TEST(UnpackTest, ThreeOutputs)
82 Check<TypeParam>(/*axis=*/0, /*input_shape=*/{3, 2},
83 /*input_data=*/{1, 2, 3, 4, 5, 6},
84 /*exp_output_shape=*/{{2}, {2}, {2}},
85 /*exp_output_data=*/{{1, 2}, {3, 4}, {5, 6}});
88 TYPED_TEST(UnpackTest, ThreeOutputsAxisOne)
90 Check<TypeParam>(/*axis=*/1, /*input_shape=*/{3, 2},
91 /*input_data=*/{1, 2, 3, 4, 5, 6},
92 /*exp_output_shape=*/{{3}, {3}},
93 /*exp_output_data=*/{{1, 3, 5}, {2, 4, 6}});
96 TYPED_TEST(UnpackTest, ThreeOutputsNegativeAxisOne)
98 Check<TypeParam>(/*axis=*/-1, /*input_shape=*/{3, 2},
99 /*input_data=*/{1, 2, 3, 4, 5, 6},
100 /*exp_output_shape=*/{{3}, {3}},
101 /*exp_output_data=*/{{1, 3, 5}, {2, 4, 6}});
104 TYPED_TEST(UnpackTest, ThreeOutputsNegativeAxisTwo)
106 Check<TypeParam>(/*axis=*/-2, /*input_shape=*/{3, 2},
107 /*input_data=*/{1, 2, 3, 4, 5, 6},
108 /*exp_output_shape=*/{{2}, {2}, {2}},
109 /*exp_output_data=*/{{1, 2}, {3, 4}, {5, 6}});
112 TYPED_TEST(UnpackTest, OneOutput)
114 Check<TypeParam>(/*axis=*/0, /*input_shape=*/{1, 6},
115 /*input_data=*/{1, 2, 3, 4, 5, 6},
116 /*exp_output_shape=*/{{6}},
117 /*exp_output_data=*/{{1, 2, 3, 4, 5, 6}});
120 TYPED_TEST(UnpackTest, ThreeDimensionsTwoOutputs)
122 Check<TypeParam>(/*axis=*/2, /*input_shape=*/{2, 2, 2},
123 /*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8},
124 /*exp_output_shape=*/{{2, 2}, {2, 2}},
125 /*exp_output_data=*/{{1, 3, 5, 7}, {2, 4, 6, 8}});
128 TYPED_TEST(UnpackTest, FiveDimensionsTwoOutputs)
131 /*axis=*/2, /*input_shape=*/{2, 2, 2, 2, 1},
132 /*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
133 /*exp_output_shape=*/{{2, 2, 2, 1}, {2, 2, 2, 1}},
135 {{1, 2, 5, 6, 9, 10, 13, 14}, {3, 4, 7, 8, 11, 12, 15, 16}});
138 TYPED_TEST(UnpackTest, VectorToScalar)
140 Check<TypeParam>(/*axis=*/0, /*input_shape=*/{5},
141 /*input_data=*/{1, 2, 3, 4, 5},
142 /*exp_output_shape=*/{{}, {}, {}, {}, {}},
143 /*exp_output_data=*/{{1}, {2}, {3}, {4}, {5}});
147 } // namespace kernels
148 } // namespace luci_interpreter