6a463d29ebadf527954326dd8116c560aa3054d1
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / src / kernels / Cast.test.cpp
1 /*
2  * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #include "kernels/Cast.h"
19 #include "kernels/TestUtils.h"
20 #include "luci_interpreter/TestMemoryManager.h"
21
22 namespace luci_interpreter
23 {
24 namespace kernels
25 {
26 namespace
27 {
28
29 using namespace testing;
30
31 template <typename T1, typename T2>
32 void Check(std::initializer_list<int32_t> shape, std::initializer_list<T1> input_data,
33            std::initializer_list<T2> output_data)
34 {
35   std::unique_ptr<IMemoryManager> memory_manager = std::make_unique<TestMemoryManager>();
36   constexpr DataType input_type = getElementType<T1>();
37   constexpr DataType output_type = getElementType<T2>();
38
39   Tensor input_tensor = makeInputTensor<input_type>(shape, input_data, memory_manager.get());
40   Tensor output_tensor = makeOutputTensor(output_type);
41
42   Cast kernel(&input_tensor, &output_tensor);
43   kernel.configure();
44   memory_manager->allocate_memory(output_tensor);
45   kernel.execute();
46
47   EXPECT_THAT(extractTensorData<T2>(output_tensor), ::testing::ElementsAreArray(output_data));
48   EXPECT_THAT(extractTensorShape(output_tensor), shape);
49 }
50
51 template <typename T>
52 void CheckBoolTo(std::initializer_list<int32_t> shape, std::initializer_list<bool> input_data,
53                  std::initializer_list<T> output_data)
54 {
55   std::unique_ptr<IMemoryManager> memory_manager = std::make_unique<TestMemoryManager>();
56   constexpr DataType input_type = loco::DataType::BOOL;
57   constexpr DataType output_type = getElementType<T>();
58   std::vector<typename DataTypeImpl<input_type>::Type> input_data_converted;
59   for (auto elem : input_data)
60   {
61     input_data_converted.push_back(elem);
62   }
63
64   Tensor input_tensor =
65     makeInputTensor<input_type>(shape, input_data_converted, memory_manager.get());
66   Tensor output_tensor = makeOutputTensor(output_type);
67
68   Cast kernel(&input_tensor, &output_tensor);
69   kernel.configure();
70   memory_manager->allocate_memory(output_tensor);
71   kernel.execute();
72
73   EXPECT_THAT(extractTensorData<T>(output_tensor), ::testing::ElementsAreArray(output_data));
74   EXPECT_THAT(extractTensorShape(output_tensor), shape);
75 }
76
77 template <typename T> class CastTest : public ::testing::Test
78 {
79 };
80
81 using IntDataTypes =
82   ::testing::Types<uint8_t, uint16_t, uint32_t, uint64_t, int8_t, int16_t, int32_t, int64_t>;
83 TYPED_TEST_SUITE(CastTest, IntDataTypes);
84
85 TYPED_TEST(CastTest, FloatToInt)
86 {
87   Check<float, TypeParam>(/*shape=*/{1, 1, 1, 4},
88                           /*input_data=*/
89                           {
90                             1.0f, 9.0f, 7.0f, 3.0f, //
91                           },
92                           /*output_data=*/
93                           {
94                             1, 9, 7, 3, //
95                           });
96   SUCCEED();
97 }
98
99 TYPED_TEST(CastTest, IntToFloat)
100 {
101   Check<TypeParam, float>(/*shape=*/{1, 1, 1, 4},
102                           /*input_data=*/
103                           {
104                             1, 9, 7, 3, //
105                           },
106                           /*output_data=*/
107                           {
108                             1.0f, 9.0f, 7.0f, 3.0f, //
109                           });
110   SUCCEED();
111 }
112
113 template <typename T1, typename T2> void check_int()
114 {
115   Check<T1, T2>(/*shape=*/{1, 1, 1, 4},
116                 /*input_data=*/
117                 {
118                   1, 9, 7, 3, //
119                 },
120                 /*output_data=*/
121                 {
122                   1, 9, 7, 3, //
123                 });
124   SUCCEED();
125 }
126
127 TYPED_TEST(CastTest, IntToInt)
128 {
129   check_int<TypeParam, uint8_t>();
130   check_int<TypeParam, uint16_t>();
131   check_int<TypeParam, uint32_t>();
132   check_int<TypeParam, uint64_t>();
133   check_int<TypeParam, int8_t>();
134   check_int<TypeParam, int16_t>();
135   check_int<TypeParam, int32_t>();
136   check_int<TypeParam, int64_t>();
137   SUCCEED();
138 }
139
140 TYPED_TEST(CastTest, IntToBool)
141 {
142   Check<TypeParam, bool>(/*shape=*/{1, 1, 1, 4},
143                          /*input_data=*/
144                          {
145                            1, 0, 7, 0, //
146                          },
147                          /*output_data=*/
148                          {
149                            true, false, true, false, //
150                          });
151   SUCCEED();
152 }
153
154 TYPED_TEST(CastTest, BoolToInt)
155 {
156   CheckBoolTo<TypeParam>(/*shape=*/{1, 1, 1, 4},
157                          /*input_data=*/
158                          {
159                            true, false, false, true, //
160                          },
161                          /*output_data=*/
162                          {
163                            1, 0, 0, 1, //
164                          });
165   SUCCEED();
166 }
167
168 TEST(CastTest, FloatToBool)
169 {
170   Check<float, bool>(/*shape=*/{1, 1, 1, 4},
171                      /*input_data=*/
172                      {
173                        1.0f, 0.0f, 7.0f, 0.0f, //
174                      },
175                      /*output_data=*/
176                      {
177                        true, false, true, false, //
178                      });
179   SUCCEED();
180 }
181
182 TEST(CastTest, BoolToFloat)
183 {
184   CheckBoolTo<float>(/*shape=*/{1, 1, 1, 4},
185                      /*input_data=*/
186                      {
187                        true, false, false, true, //
188                      },
189                      /*output_data=*/
190                      {
191                        1.0f, 0.0f, 0.0f, 1.0f, //
192                      });
193   SUCCEED();
194 }
195
196 TEST(CastTest, FloatToFloat)
197 {
198   Check<float, float>(/*shape=*/{1, 1, 1, 4},
199                       /*input_data=*/
200                       {
201                         1.0f, 0.0f, 7.0f, 0.0f, //
202                       },
203                       /*output_data=*/
204                       {
205                         1.0f, 0.0f, 7.0f, 0.0f, //
206                       });
207   SUCCEED();
208 }
209
210 TEST(CastTest, BoolToBool)
211 {
212   CheckBoolTo<bool>(/*shape=*/{1, 1, 1, 4},
213                     /*input_data=*/
214                     {
215                       true, true, false, false, //
216                     },
217                     /*output_data=*/
218                     {
219                       true, true, false, false, //
220                     });
221   SUCCEED();
222 }
223
224 TEST(CastTest, UnsupportedType_NEG)
225 {
226   std::unique_ptr<IMemoryManager> memory_manager = std::make_unique<TestMemoryManager>();
227   Tensor input_tensor = makeInputTensor<DataType::FLOAT32>({1, 1, 2, 4},
228                                                            {
229                                                              1, 2, 7, 8, //
230                                                              1, 9, 7, 3, //
231                                                            },
232                                                            memory_manager.get());
233   Tensor output_tensor = makeOutputTensor(DataType::Unknown);
234
235   Cast kernel(&input_tensor, &output_tensor);
236   EXPECT_ANY_THROW(kernel.configure());
237   SUCCEED();
238 }
239
240 } // namespace
241 } // namespace kernels
242 } // namespace luci_interpreter