2 * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "kernels/Cast.h"
18 #include "kernels/TestUtils.h"
19 #include "luci_interpreter/TestMemoryManager.h"
21 namespace luci_interpreter
28 using namespace testing;
30 template <typename T1, typename T2>
31 void Check(std::initializer_list<int32_t> shape, std::initializer_list<T1> input_data,
32 std::initializer_list<T2> output_data)
34 std::unique_ptr<IMemoryManager> memory_manager = std::make_unique<TestMemoryManager>();
35 constexpr DataType input_type = getElementType<T1>();
36 constexpr DataType output_type = getElementType<T2>();
38 Tensor input_tensor = makeInputTensor<input_type>(shape, input_data, memory_manager.get());
39 Tensor output_tensor = makeOutputTensor(output_type);
41 Cast kernel(&input_tensor, &output_tensor);
43 memory_manager->allocate_memory(output_tensor);
46 EXPECT_THAT(extractTensorData<T2>(output_tensor), ::testing::ElementsAreArray(output_data));
47 EXPECT_THAT(extractTensorShape(output_tensor), shape);
51 void CheckBoolTo(std::initializer_list<int32_t> shape, std::initializer_list<bool> input_data,
52 std::initializer_list<T> output_data)
54 std::unique_ptr<IMemoryManager> memory_manager = std::make_unique<TestMemoryManager>();
55 constexpr DataType input_type = loco::DataType::BOOL;
56 constexpr DataType output_type = getElementType<T>();
57 std::vector<typename DataTypeImpl<input_type>::Type> input_data_converted;
58 for (auto elem : input_data)
60 input_data_converted.push_back(elem);
64 makeInputTensor<input_type>(shape, input_data_converted, memory_manager.get());
65 Tensor output_tensor = makeOutputTensor(output_type);
67 Cast kernel(&input_tensor, &output_tensor);
69 memory_manager->allocate_memory(output_tensor);
72 EXPECT_THAT(extractTensorData<T>(output_tensor), ::testing::ElementsAreArray(output_data));
73 EXPECT_THAT(extractTensorShape(output_tensor), shape);
76 template <typename T> class CastTest : public ::testing::Test
81 ::testing::Types<uint8_t, uint16_t, uint32_t, uint64_t, int8_t, int16_t, int32_t, int64_t>;
82 TYPED_TEST_CASE(CastTest, IntDataTypes);
84 TYPED_TEST(CastTest, FloatToInt)
86 Check<float, TypeParam>(/*shape=*/{1, 1, 1, 4},
89 1.0f, 9.0f, 7.0f, 3.0f, //
98 TYPED_TEST(CastTest, IntToFloat)
100 Check<TypeParam, float>(/*shape=*/{1, 1, 1, 4},
107 1.0f, 9.0f, 7.0f, 3.0f, //
112 template <typename T1, typename T2> void check_int()
114 Check<T1, T2>(/*shape=*/{1, 1, 1, 4},
126 TYPED_TEST(CastTest, IntToInt)
128 check_int<TypeParam, uint8_t>();
129 check_int<TypeParam, uint16_t>();
130 check_int<TypeParam, uint32_t>();
131 check_int<TypeParam, uint64_t>();
132 check_int<TypeParam, int8_t>();
133 check_int<TypeParam, int16_t>();
134 check_int<TypeParam, int32_t>();
135 check_int<TypeParam, int64_t>();
139 TYPED_TEST(CastTest, IntToBool)
141 Check<TypeParam, bool>(/*shape=*/{1, 1, 1, 4},
148 true, false, true, false, //
153 TYPED_TEST(CastTest, BoolToInt)
155 CheckBoolTo<TypeParam>(/*shape=*/{1, 1, 1, 4},
158 true, false, false, true, //
167 TEST(CastTest, FloatToBool)
169 Check<float, bool>(/*shape=*/{1, 1, 1, 4},
172 1.0f, 0.0f, 7.0f, 0.0f, //
176 true, false, true, false, //
181 TEST(CastTest, BoolToFloat)
183 CheckBoolTo<float>(/*shape=*/{1, 1, 1, 4},
186 true, false, false, true, //
190 1.0f, 0.0f, 0.0f, 1.0f, //
195 TEST(CastTest, FloatToFloat)
197 Check<float, float>(/*shape=*/{1, 1, 1, 4},
200 1.0f, 0.0f, 7.0f, 0.0f, //
204 1.0f, 0.0f, 7.0f, 0.0f, //
209 TEST(CastTest, BoolToBool)
211 CheckBoolTo<bool>(/*shape=*/{1, 1, 1, 4},
214 true, true, false, false, //
218 true, true, false, false, //
223 TEST(CastTest, UnsupportedType_NEG)
225 std::unique_ptr<IMemoryManager> memory_manager = std::make_unique<TestMemoryManager>();
226 Tensor input_tensor = makeInputTensor<DataType::FLOAT32>({1, 1, 2, 4},
231 memory_manager.get());
232 Tensor output_tensor = makeOutputTensor(DataType::Unknown);
234 Cast kernel(&input_tensor, &output_tensor);
235 EXPECT_ANY_THROW(kernel.configure());
240 } // namespace kernels
241 } // namespace luci_interpreter