2 * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #include "kernels/Cast.h"
19 #include "kernels/TestUtils.h"
20 #include "luci_interpreter/TestMemoryManager.h"
22 namespace luci_interpreter
29 using namespace testing;
31 template <typename T1, typename T2>
32 void Check(std::initializer_list<int32_t> shape, std::initializer_list<T1> input_data,
33 std::initializer_list<T2> output_data)
35 std::unique_ptr<IMemoryManager> memory_manager = std::make_unique<TestMemoryManager>();
36 constexpr DataType input_type = getElementType<T1>();
37 constexpr DataType output_type = getElementType<T2>();
39 Tensor input_tensor = makeInputTensor<input_type>(shape, input_data, memory_manager.get());
40 Tensor output_tensor = makeOutputTensor(output_type);
42 Cast kernel(&input_tensor, &output_tensor);
44 memory_manager->allocate_memory(output_tensor);
47 EXPECT_THAT(extractTensorData<T2>(output_tensor), ::testing::ElementsAreArray(output_data));
48 EXPECT_THAT(extractTensorShape(output_tensor), shape);
52 void CheckBoolTo(std::initializer_list<int32_t> shape, std::initializer_list<bool> input_data,
53 std::initializer_list<T> output_data)
55 std::unique_ptr<IMemoryManager> memory_manager = std::make_unique<TestMemoryManager>();
56 constexpr DataType input_type = loco::DataType::BOOL;
57 constexpr DataType output_type = getElementType<T>();
58 std::vector<typename DataTypeImpl<input_type>::Type> input_data_converted;
59 for (auto elem : input_data)
61 input_data_converted.push_back(elem);
65 makeInputTensor<input_type>(shape, input_data_converted, memory_manager.get());
66 Tensor output_tensor = makeOutputTensor(output_type);
68 Cast kernel(&input_tensor, &output_tensor);
70 memory_manager->allocate_memory(output_tensor);
73 EXPECT_THAT(extractTensorData<T>(output_tensor), ::testing::ElementsAreArray(output_data));
74 EXPECT_THAT(extractTensorShape(output_tensor), shape);
77 template <typename T> class CastTest : public ::testing::Test
82 ::testing::Types<uint8_t, uint16_t, uint32_t, uint64_t, int8_t, int16_t, int32_t, int64_t>;
83 TYPED_TEST_SUITE(CastTest, IntDataTypes);
85 TYPED_TEST(CastTest, FloatToInt)
87 Check<float, TypeParam>(/*shape=*/{1, 1, 1, 4},
90 1.0f, 9.0f, 7.0f, 3.0f, //
99 TYPED_TEST(CastTest, IntToFloat)
101 Check<TypeParam, float>(/*shape=*/{1, 1, 1, 4},
108 1.0f, 9.0f, 7.0f, 3.0f, //
113 template <typename T1, typename T2> void check_int()
115 Check<T1, T2>(/*shape=*/{1, 1, 1, 4},
127 TYPED_TEST(CastTest, IntToInt)
129 check_int<TypeParam, uint8_t>();
130 check_int<TypeParam, uint16_t>();
131 check_int<TypeParam, uint32_t>();
132 check_int<TypeParam, uint64_t>();
133 check_int<TypeParam, int8_t>();
134 check_int<TypeParam, int16_t>();
135 check_int<TypeParam, int32_t>();
136 check_int<TypeParam, int64_t>();
140 TYPED_TEST(CastTest, IntToBool)
142 Check<TypeParam, bool>(/*shape=*/{1, 1, 1, 4},
149 true, false, true, false, //
154 TYPED_TEST(CastTest, BoolToInt)
156 CheckBoolTo<TypeParam>(/*shape=*/{1, 1, 1, 4},
159 true, false, false, true, //
168 TEST(CastTest, FloatToBool)
170 Check<float, bool>(/*shape=*/{1, 1, 1, 4},
173 1.0f, 0.0f, 7.0f, 0.0f, //
177 true, false, true, false, //
182 TEST(CastTest, BoolToFloat)
184 CheckBoolTo<float>(/*shape=*/{1, 1, 1, 4},
187 true, false, false, true, //
191 1.0f, 0.0f, 0.0f, 1.0f, //
196 TEST(CastTest, FloatToFloat)
198 Check<float, float>(/*shape=*/{1, 1, 1, 4},
201 1.0f, 0.0f, 7.0f, 0.0f, //
205 1.0f, 0.0f, 7.0f, 0.0f, //
210 TEST(CastTest, BoolToBool)
212 CheckBoolTo<bool>(/*shape=*/{1, 1, 1, 4},
215 true, true, false, false, //
219 true, true, false, false, //
224 TEST(CastTest, UnsupportedType_NEG)
226 std::unique_ptr<IMemoryManager> memory_manager = std::make_unique<TestMemoryManager>();
227 Tensor input_tensor = makeInputTensor<DataType::FLOAT32>({1, 1, 2, 4},
232 memory_manager.get());
233 Tensor output_tensor = makeOutputTensor(DataType::Unknown);
235 Cast kernel(&input_tensor, &output_tensor);
236 EXPECT_ANY_THROW(kernel.configure());
241 } // namespace kernels
242 } // namespace luci_interpreter