Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / src / kernels / Cast.test.cpp
1 /*
2  * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "kernels/Cast.h"
18 #include "kernels/TestUtils.h"
19 #include "luci_interpreter/TestMemoryManager.h"
20
21 namespace luci_interpreter
22 {
23 namespace kernels
24 {
25 namespace
26 {
27
28 using namespace testing;
29
30 template <typename T1, typename T2>
31 void Check(std::initializer_list<int32_t> shape, std::initializer_list<T1> input_data,
32            std::initializer_list<T2> output_data)
33 {
34   std::unique_ptr<IMemoryManager> memory_manager = std::make_unique<TestMemoryManager>();
35   constexpr DataType input_type = getElementType<T1>();
36   constexpr DataType output_type = getElementType<T2>();
37
38   Tensor input_tensor = makeInputTensor<input_type>(shape, input_data, memory_manager.get());
39   Tensor output_tensor = makeOutputTensor(output_type);
40
41   Cast kernel(&input_tensor, &output_tensor);
42   kernel.configure();
43   memory_manager->allocate_memory(output_tensor);
44   kernel.execute();
45
46   EXPECT_THAT(extractTensorData<T2>(output_tensor), ::testing::ElementsAreArray(output_data));
47   EXPECT_THAT(extractTensorShape(output_tensor), shape);
48 }
49
50 template <typename T>
51 void CheckBoolTo(std::initializer_list<int32_t> shape, std::initializer_list<bool> input_data,
52                  std::initializer_list<T> output_data)
53 {
54   std::unique_ptr<IMemoryManager> memory_manager = std::make_unique<TestMemoryManager>();
55   constexpr DataType input_type = loco::DataType::BOOL;
56   constexpr DataType output_type = getElementType<T>();
57   std::vector<typename DataTypeImpl<input_type>::Type> input_data_converted;
58   for (auto elem : input_data)
59   {
60     input_data_converted.push_back(elem);
61   }
62
63   Tensor input_tensor =
64     makeInputTensor<input_type>(shape, input_data_converted, memory_manager.get());
65   Tensor output_tensor = makeOutputTensor(output_type);
66
67   Cast kernel(&input_tensor, &output_tensor);
68   kernel.configure();
69   memory_manager->allocate_memory(output_tensor);
70   kernel.execute();
71
72   EXPECT_THAT(extractTensorData<T>(output_tensor), ::testing::ElementsAreArray(output_data));
73   EXPECT_THAT(extractTensorShape(output_tensor), shape);
74 }
75
76 template <typename T> class CastTest : public ::testing::Test
77 {
78 };
79
80 using IntDataTypes =
81   ::testing::Types<uint8_t, uint16_t, uint32_t, uint64_t, int8_t, int16_t, int32_t, int64_t>;
82 TYPED_TEST_SUITE(CastTest, IntDataTypes);
83
84 TYPED_TEST(CastTest, FloatToInt)
85 {
86   Check<float, TypeParam>(/*shape=*/{1, 1, 1, 4},
87                           /*input_data=*/
88                           {
89                             1.0f, 9.0f, 7.0f, 3.0f, //
90                           },
91                           /*output_data=*/
92                           {
93                             1, 9, 7, 3, //
94                           });
95   SUCCEED();
96 }
97
98 TYPED_TEST(CastTest, IntToFloat)
99 {
100   Check<TypeParam, float>(/*shape=*/{1, 1, 1, 4},
101                           /*input_data=*/
102                           {
103                             1, 9, 7, 3, //
104                           },
105                           /*output_data=*/
106                           {
107                             1.0f, 9.0f, 7.0f, 3.0f, //
108                           });
109   SUCCEED();
110 }
111
112 template <typename T1, typename T2> void check_int()
113 {
114   Check<T1, T2>(/*shape=*/{1, 1, 1, 4},
115                 /*input_data=*/
116                 {
117                   1, 9, 7, 3, //
118                 },
119                 /*output_data=*/
120                 {
121                   1, 9, 7, 3, //
122                 });
123   SUCCEED();
124 }
125
126 TYPED_TEST(CastTest, IntToInt)
127 {
128   check_int<TypeParam, uint8_t>();
129   check_int<TypeParam, uint16_t>();
130   check_int<TypeParam, uint32_t>();
131   check_int<TypeParam, uint64_t>();
132   check_int<TypeParam, int8_t>();
133   check_int<TypeParam, int16_t>();
134   check_int<TypeParam, int32_t>();
135   check_int<TypeParam, int64_t>();
136   SUCCEED();
137 }
138
139 TYPED_TEST(CastTest, IntToBool)
140 {
141   Check<TypeParam, bool>(/*shape=*/{1, 1, 1, 4},
142                          /*input_data=*/
143                          {
144                            1, 0, 7, 0, //
145                          },
146                          /*output_data=*/
147                          {
148                            true, false, true, false, //
149                          });
150   SUCCEED();
151 }
152
153 TYPED_TEST(CastTest, BoolToInt)
154 {
155   CheckBoolTo<TypeParam>(/*shape=*/{1, 1, 1, 4},
156                          /*input_data=*/
157                          {
158                            true, false, false, true, //
159                          },
160                          /*output_data=*/
161                          {
162                            1, 0, 0, 1, //
163                          });
164   SUCCEED();
165 }
166
167 TEST(CastTest, FloatToBool)
168 {
169   Check<float, bool>(/*shape=*/{1, 1, 1, 4},
170                      /*input_data=*/
171                      {
172                        1.0f, 0.0f, 7.0f, 0.0f, //
173                      },
174                      /*output_data=*/
175                      {
176                        true, false, true, false, //
177                      });
178   SUCCEED();
179 }
180
181 TEST(CastTest, BoolToFloat)
182 {
183   CheckBoolTo<float>(/*shape=*/{1, 1, 1, 4},
184                      /*input_data=*/
185                      {
186                        true, false, false, true, //
187                      },
188                      /*output_data=*/
189                      {
190                        1.0f, 0.0f, 0.0f, 1.0f, //
191                      });
192   SUCCEED();
193 }
194
195 TEST(CastTest, FloatToFloat)
196 {
197   Check<float, float>(/*shape=*/{1, 1, 1, 4},
198                       /*input_data=*/
199                       {
200                         1.0f, 0.0f, 7.0f, 0.0f, //
201                       },
202                       /*output_data=*/
203                       {
204                         1.0f, 0.0f, 7.0f, 0.0f, //
205                       });
206   SUCCEED();
207 }
208
209 TEST(CastTest, BoolToBool)
210 {
211   CheckBoolTo<bool>(/*shape=*/{1, 1, 1, 4},
212                     /*input_data=*/
213                     {
214                       true, true, false, false, //
215                     },
216                     /*output_data=*/
217                     {
218                       true, true, false, false, //
219                     });
220   SUCCEED();
221 }
222
223 TEST(CastTest, UnsupportedType_NEG)
224 {
225   std::unique_ptr<IMemoryManager> memory_manager = std::make_unique<TestMemoryManager>();
226   Tensor input_tensor = makeInputTensor<DataType::FLOAT32>({1, 1, 2, 4},
227                                                            {
228                                                              1, 2, 7, 8, //
229                                                              1, 9, 7, 3, //
230                                                            },
231                                                            memory_manager.get());
232   Tensor output_tensor = makeOutputTensor(DataType::Unknown);
233
234   Cast kernel(&input_tensor, &output_tensor);
235   EXPECT_ANY_THROW(kernel.configure());
236   SUCCEED();
237 }
238
239 } // namespace
240 } // namespace kernels
241 } // namespace luci_interpreter