Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compute / cker / src / train / Relu.test.cc
1 /*
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <cker/operation/ReLU.h>
18 #include <cker/train/operation/ReLU.h>
19
20 #include <gtest/gtest.h>
21 #include <vector>
22
23 namespace
24 {
25
26 template <typename T> class ReluOpVerifier
27 {
28 public:
29   ReluOpVerifier(const std::vector<T> &input, const std::vector<T> &expected_output,
30                  const std::vector<T> &backprop_output,
31                  const std::vector<T> &expected_backprop_input)
32     : _input{input}, _expected_output{expected_output}, _backprop_output{backprop_output},
33       _expected_backprop_input{expected_backprop_input}
34   {
35     EXPECT_TRUE(input.size() == expected_output.size());
36     _output.resize(_expected_output.size());
37     _backprop_input.resize(_expected_backprop_input.size());
38   }
39
40 public:
41   void verifyExpected()
42   {
43     nnfw::cker::ReLU(nnfw::cker::Shape{static_cast<int>(_input.size())}, _input.data(),
44                      nnfw::cker::Shape{static_cast<int>(_output.size())}, _output.data());
45
46     for (size_t i = 0; i < _output.size(); ++i)
47       ASSERT_EQ(_output[i], _expected_output[i]);
48
49     if (_backprop_output.size() > 0)
50     {
51       nnfw::cker::train::ReLUGrad(
52         nnfw::cker::Shape{static_cast<int>(_output.size())}, _output.data(),
53         nnfw::cker::Shape{static_cast<int>(_backprop_output.size())}, _backprop_output.data(),
54         nnfw::cker::Shape{static_cast<int>(_backprop_input.size())}, _backprop_input.data());
55
56       for (size_t i = 0; i < _backprop_input.size(); ++i)
57         ASSERT_EQ(_backprop_input[i], _expected_backprop_input[i]);
58     }
59   }
60
61 private:
62   std::vector<T> _input;
63   std::vector<T> _output;
64   std::vector<T> _expected_output;
65   std::vector<T> _backprop_output;
66   std::vector<T> _backprop_input;
67   std::vector<T> _expected_backprop_input;
68 };
69
70 } // namespace
71
72 TEST(CKer_Operation, ReLU)
73 {
74   {
75     std::vector<float> input_forward = {-1, 2, 3, -4};
76     std::vector<float> expected_forward = {0, 2, 3, 0};
77     std::vector<float> incoming_backward = {-5, 6, -7, 8};
78     std::vector<float> expected_backward = {0, 6, -7, 0};
79     ReluOpVerifier<float> verifier{input_forward, expected_forward, incoming_backward,
80                                    expected_backward};
81     verifier.verifyExpected();
82   }
83
84   {
85     std::vector<float> input_forward = {0, -1, 2, 3, -4, 5, 6, -7};
86     std::vector<float> expected_forward = {0, 0, 2, 3, 0, 5, 6, 0};
87     std::vector<float> incoming_backward = {8, -9, 10, 11, -12, -13, 14, -15};
88     std::vector<float> expected_backward = {0, 0, 10, 11, 0, -13, 14, 0};
89     ReluOpVerifier<float> verifier{input_forward, expected_forward, incoming_backward,
90                                    expected_backward};
91     verifier.verifyExpected();
92   }
93 }
94
95 TEST(CKer_Operation, neg_ReLU)
96 {
97   {
98     // Unmatched shape
99     std::vector<float> input_forward = {0, -1, 2, 3, -4};
100     std::vector<float> expected_forward = {0, 0, 2, 3, 0};
101     std::vector<float> incoming_backward = {-5, 6, -7, 8};
102     std::vector<float> expected_backward = {0, 6, -7, 0};
103     ReluOpVerifier<float> verifier{input_forward, expected_forward, incoming_backward,
104                                    expected_backward};
105     EXPECT_ANY_THROW(verifier.verifyExpected());
106   }
107 }