Imported Upstream version 1.9.0
[platform/core/ml/nnfw.git] / compiler / luci-interpreter / src / kernels / Tanh.test.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #include "kernels/Tanh.h"
19 #include "kernels/TestUtils.h"
20
21 namespace luci_interpreter
22 {
23 namespace kernels
24 {
25 namespace
26 {
27
28 using namespace testing;
29
30 TEST(TanhTest, Float)
31 {
32   Shape input_shape{1, 2, 4, 1};
33   std::vector<float> input_data{
34       0, -6, 2,  4, //
35       3, -2, 10, 1, //
36   };
37   Tensor input_tensor = makeInputTensor<DataType::FLOAT32>(input_shape, input_data);
38   Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
39
40   Tanh kernel(&input_tensor, &output_tensor);
41   kernel.configure();
42   kernel.execute();
43
44   std::vector<float> ref_output_data{
45       0,          -0.9999877, 0.9640275, 0.999329,  //
46       0.99505475, -0.9640275, 1,         0.7615941, //
47   };
48   EXPECT_THAT(extractTensorData<float>(output_tensor),
49               ElementsAreArray(ArrayFloatNear(ref_output_data)));
50 }
51
52 TEST(TanhTest, Uint8)
53 {
54   float kMin = -1;
55   float kMax = 127.f / 128.f;
56   float kTanhTolerance = 2 * (1. / 256);
57   std::pair<float, int32_t> input_quant_param = quantizationParams<uint8_t>(8 * kMin, 8 * kMax);
58   std::pair<float, int32_t> output_quant_param = quantizationParams<uint8_t>(kMin, kMax);
59   std::vector<float> input_data{
60       0,  -6, 2, 4, //
61       -4, -2, 8, 1, //
62       0,  -6, 2, 4, //
63       -4, -2, 8, 1, //
64       0,  -6, 2, 4, //
65       -4, -2, 8, 1, //
66       0,  -6, 2, 4, //
67       -4, -2, 8, 1, //
68       0,  -6, 2, 4, //
69       -4, -2, 8, 1, //
70       0,  -6, 2, 4, //
71       -4, -2, 8, 1, //
72   };
73   Tensor input_tensor{
74       DataType::U8, {2, 6, 4, 1}, {{input_quant_param.first}, {input_quant_param.second}}, ""};
75   Tensor output_tensor =
76       makeOutputTensor(DataType::U8, output_quant_param.first, output_quant_param.second);
77   std::vector<uint8_t> quantize_input =
78       quantize<uint8_t>(input_data, input_quant_param.first, input_quant_param.second);
79   input_tensor.writeData(quantize_input.data(), quantize_input.size() * sizeof(uint8_t));
80
81   Tanh kernel(&input_tensor, &output_tensor);
82   kernel.configure();
83   kernel.execute();
84
85   std::vector<float> ref_output_data{
86       0.0,       -0.999987, 0.964027, 0.999329, //
87       -0.999329, -0.96402,  0.99999,  0.76159,  //
88       0.0,       -0.999987, 0.964027, 0.999329, //
89       -0.999329, -0.96402,  0.99999,  0.76159,  //
90       0.0,       -0.999987, 0.964027, 0.999329, //
91       -0.999329, -0.96402,  0.99999,  0.76159,  //
92       0.0,       -0.999987, 0.964027, 0.999329, //
93       -0.999329, -0.96402,  0.99999,  0.76159,  //
94       0.0,       -0.999987, 0.964027, 0.999329, //
95       -0.999329, -0.96402,  0.99999,  0.76159,  //
96       0.0,       -0.999987, 0.964027, 0.999329, //
97       -0.999329, -0.96402,  0.99999,  0.76159,  //
98   };
99   std::vector<int32_t> ref_output_shape{2, 6, 4, 1};
100   EXPECT_THAT(dequantize<uint8_t>(extractTensorData<uint8_t>(output_tensor), output_tensor.scale(),
101                                   output_tensor.zero_point()),
102               ElementsAreArray(ArrayFloatNear(ref_output_data, kTanhTolerance)));
103   EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(ref_output_shape));
104 }
105
106 } // namespace
107 } // namespace kernels
108 } // namespace luci_interpreter