Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / luci-interpreter / src / kernels / Sum.test.cpp
1 /*
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #include "kernels/Sum.h"
19 #include "kernels/TestUtils.h"
20 #include "luci_interpreter/TestMemoryManager.h"
21
22 namespace luci_interpreter
23 {
24 namespace kernels
25 {
26 namespace
27 {
28
29 using namespace testing;
30
31 class SumTest : public ::testing::Test
32 {
33 protected:
34   void SetUp() override { _memory_manager = std::make_unique<TestMemoryManager>(); }
35
36   std::unique_ptr<IMemoryManager> _memory_manager;
37 };
38
39 TEST_F(SumTest, FloatNotKeepDims)
40 {
41   std::vector<float> input_data = {1.0,  2.0,  3.0,  4.0,  5.0,  6.0,  7.0,  8.0,
42                                    9.0,  10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
43                                    17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
44
45   std::vector<int32_t> axis_data{1, 0};
46   Tensor input_tensor =
47     makeInputTensor<DataType::FLOAT32>({4, 3, 2}, input_data, _memory_manager.get());
48   Tensor axis_tensor = makeInputTensor<DataType::S32>({2}, axis_data, _memory_manager.get());
49   Tensor temp_index(DataType::S32, Shape({}), {}, "");
50   Tensor resolved_axes(DataType::S32, Shape({}), {}, "");
51   Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
52
53   ReducerParams params{};
54   params.keep_dims = false;
55
56   Sum kernel(&input_tensor, &axis_tensor, &output_tensor, &temp_index, &resolved_axes, params);
57   kernel.configure();
58   _memory_manager->allocate_memory(temp_index);
59   _memory_manager->allocate_memory(resolved_axes);
60   _memory_manager->allocate_memory(output_tensor);
61   kernel.execute();
62
63   std::vector<float> ref_output_data{144, 156};
64   std::initializer_list<int32_t> ref_output_shape{2};
65   EXPECT_THAT(extractTensorData<float>(output_tensor), FloatArrayNear(ref_output_data));
66   EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(ref_output_shape));
67 }
68
69 TEST_F(SumTest, FloatKeepDims)
70 {
71   std::vector<float> input_data = {1.0,  2.0,  3.0,  4.0,  5.0,  6.0,  7.0,  8.0,
72                                    9.0,  10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
73                                    17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
74
75   std::vector<int32_t> axis_data{0, 2};
76   Tensor input_tensor =
77     makeInputTensor<DataType::FLOAT32>({4, 3, 2}, input_data, _memory_manager.get());
78   Tensor axis_tensor = makeInputTensor<DataType::S32>({2}, axis_data, _memory_manager.get());
79   Tensor temp_index(DataType::S32, Shape({}), {}, "");
80   Tensor resolved_axes(DataType::S32, Shape({}), {}, "");
81   Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
82
83   ReducerParams params{};
84   params.keep_dims = true;
85
86   Sum kernel(&input_tensor, &axis_tensor, &output_tensor, &temp_index, &resolved_axes, params);
87   kernel.configure();
88   _memory_manager->allocate_memory(temp_index);
89   _memory_manager->allocate_memory(resolved_axes);
90   _memory_manager->allocate_memory(output_tensor);
91   kernel.execute();
92
93   std::vector<float> ref_output_data{84, 100, 116};
94   std::initializer_list<int32_t> ref_output_shape{1, 3, 1};
95   EXPECT_THAT(extractTensorData<float>(output_tensor), FloatArrayNear(ref_output_data));
96   EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(ref_output_shape));
97 }
98
99 TEST_F(SumTest, Input_Output_Type_NEG)
100 {
101   std::vector<float> input_data = {1.0,  2.0,  3.0,  4.0,  5.0,  6.0,  7.0,  8.0,
102                                    9.0,  10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
103                                    17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
104
105   std::vector<int32_t> axis_data{0, 2};
106   Tensor input_tensor =
107     makeInputTensor<DataType::FLOAT32>({4, 3, 2}, input_data, _memory_manager.get());
108   Tensor axis_tensor = makeInputTensor<DataType::S32>({2}, axis_data, _memory_manager.get());
109   Tensor temp_index(DataType::S32, Shape({}), {}, "");
110   Tensor resolved_axes(DataType::S32, Shape({}), {}, "");
111   Tensor output_tensor = makeOutputTensor(DataType::U8);
112
113   ReducerParams params{};
114   params.keep_dims = true;
115
116   Sum kernel(&input_tensor, &axis_tensor, &output_tensor, &temp_index, &resolved_axes, params);
117
118   EXPECT_ANY_THROW(kernel.configure());
119 }
120
121 TEST_F(SumTest, Invalid_Axes_Type_NEG)
122 {
123   std::vector<float> input_data = {1.0,  2.0,  3.0,  4.0,  5.0,  6.0,  7.0,  8.0,
124                                    9.0,  10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
125                                    17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
126
127   std::vector<int64_t> axis_data{0, 2};
128   Tensor input_tensor =
129     makeInputTensor<DataType::FLOAT32>({4, 3, 2}, input_data, _memory_manager.get());
130   Tensor axis_tensor = makeInputTensor<DataType::S64>({2}, axis_data, _memory_manager.get());
131   Tensor temp_index(DataType::S32, Shape({}), {}, "");
132   Tensor resolved_axes(DataType::S32, Shape({}), {}, "");
133   Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
134
135   ReducerParams params{};
136   params.keep_dims = true;
137
138   Sum kernel(&input_tensor, &axis_tensor, &output_tensor, &temp_index, &resolved_axes, params);
139
140   EXPECT_ANY_THROW(kernel.configure());
141 }
142
143 } // namespace
144 } // namespace kernels
145 } // namespace luci_interpreter