2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #include "kernels/Gelu.h"
19 #include "kernels/TestUtils.h"
20 #include "luci_interpreter/TestMemoryManager.h"
22 namespace luci_interpreter
29 using namespace testing;
31 void Check(std::initializer_list<int32_t> input_shape, std::initializer_list<int32_t> output_shape,
32 std::initializer_list<float> input_data, std::initializer_list<float> output_data,
35 std::unique_ptr<IMemoryManager> memory_manager = std::make_unique<TestMemoryManager>();
36 constexpr DataType element_type = getElementType<float>();
38 makeInputTensor<element_type>(input_shape, input_data, memory_manager.get());
39 Tensor output_tensor = makeOutputTensor(element_type);
42 params.approximate = approximate;
44 Gelu kernel(&input_tensor, &output_tensor, params);
47 memory_manager->allocate_memory(output_tensor);
50 EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(output_shape));
51 EXPECT_THAT(extractTensorData<float>(output_tensor), FloatArrayNear(output_data));
54 class GeluTest : public ::testing::Test
58 TEST_F(GeluTest, Simple)
60 Check(/*input_shape=*/{2, 3}, /*output_shape=*/{2, 3},
63 0.0f, 1.0f, 3.0f, // Row 1
64 1.0f, -1.0f, -2.0f, // Row 2
68 0.0f, 0.841345f, 2.99595f, // Row 1
69 0.841345f, -0.158655f, -0.0455003f, // Row 2
71 /*approximate=*/false);
76 TEST_F(GeluTest, Approximate)
78 Check(/*input_shape=*/{2, 3}, /*output_shape=*/{2, 3},
81 0.0f, 1.0f, 3.0f, // Row 1
82 1.0f, -1.0f, -2.0f, // Row 2
86 0.0f, 0.841192f, 2.99636f, // Row 1
87 0.841192f, -0.158808f, -0.0454023f, // Row 2
89 /*approximate=*/true);
94 TEST_F(GeluTest, DifferentInOutType_NEG)
96 std::unique_ptr<IMemoryManager> memory_manager = std::make_unique<TestMemoryManager>();
97 Tensor input_tensor = makeInputTensor<DataType::FLOAT32>({2, 3},
99 0.0f, 1.0f, 3.0f, // Row 1
100 1.0f, -1.0f, -2.0f, // Row 2
102 memory_manager.get());
103 Tensor output_tensor = makeOutputTensor(DataType::U8);
106 params.approximate = false;
108 Gelu kernel(&input_tensor, &output_tensor, params);
110 EXPECT_ANY_THROW(kernel.configure());
114 } // namespace kernels
115 } // namespace luci_interpreter