Imported Upstream version 1.22.1
[platform/core/ml/nnfw.git] / compiler / circle-mpqsolver / src / bisection / ErrorApproximator.test.cpp
1 /*
2  * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <gtest/gtest.h>
18
19 #include "ErrorApproximator.h"
20 #include "TestHelper.h"
21
22 #include <luci/IR/CircleNodeDecl.h>
23
24 #include <cmath>
25
26 namespace
27 {
28
29 inline uint32_t cal_offset(uint32_t shape[4], uint32_t *indices)
30 {
31   return indices[0] * shape[1] * shape[2] * shape[3] + indices[1] * shape[2] * shape[3] +
32          indices[2] * shape[3] + indices[3];
33 }
34
35 class NConvGraph final : public SimpleGraph
36 {
37 protected:
38   void initInput(loco::Node *input) override
39   {
40     auto ci_input = loco::must_cast<luci::CircleNode *>(input);
41     ci_input->shape_status(luci::ShapeStatus::VALID);
42     auto qparam = std::make_unique<luci::CircleQuantParam>();
43     qparam->min.assign(_channel_size, _a_min);
44     qparam->max.assign(_channel_size, _a_max);
45     ci_input->quantparam(std::move(qparam));
46   }
47
48   loco::Node *insertGraphBody(loco::Node *input) override
49   {
50     _filter = _g->nodes()->create<luci::CircleConst>();
51     _filter->dtype(loco::DataType::FLOAT32);
52     _filter->shape({_channel_size, _f_w, _f_h, _channel_size});
53     _filter->shape_status(luci::ShapeStatus::VALID);
54     _filter->name("conv_filter");
55     uint32_t indices[4] = {
56       0,
57     };
58
59     uint32_t w_shape[4] = {_filter->dim(0).value(), _filter->dim(1).value(),
60                            _filter->dim(2).value(), _filter->dim(3).value()};
61
62     _filter->size<loco::DataType::FLOAT32>(w_shape[0] * w_shape[1] * w_shape[2] * w_shape[3]);
63
64     for (indices[0] = 0; indices[0] < w_shape[0]; ++indices[0])
65     {
66       for (indices[1] = 0; indices[1] < w_shape[1]; ++indices[1])
67       {
68         for (indices[2] = 0; indices[2] < w_shape[2]; ++indices[2])
69         {
70           for (indices[3] = 0; indices[3] < w_shape[3]; ++indices[3])
71           {
72             uint32_t offset = cal_offset(w_shape, indices);
73             _filter->at<loco::DataType::FLOAT32>(offset) = (offset % 2 == 0) ? _w_max : _w_min;
74           }
75         }
76       }
77     }
78
79     _bias = _g->nodes()->create<luci::CircleConst>();
80     _bias->dtype(loco::DataType::FLOAT32);
81     _bias->shape({_channel_size});
82     _bias->name("conv_bias");
83
84     _conv = _g->nodes()->create<luci::CircleConv2D>();
85     _conv->padding(luci::Padding::SAME);
86     _conv->fusedActivationFunction(luci::FusedActFunc::NONE);
87     _conv->dtype(loco::DataType::FLOAT32);
88     _conv->shape({1, _width, _height, _channel_size});
89     _conv->shape_status(luci::ShapeStatus::VALID);
90     _conv->name("conv");
91     _conv->filter(_filter);
92     _conv->bias(_bias);
93     _conv->input(input);
94
95     return _conv;
96   }
97
98 public:
99   luci::CircleConv2D *_conv = nullptr;
100   luci::CircleConst *_filter = nullptr;
101   luci::CircleConst *_bias = nullptr;
102   uint32_t _f_w = 1;
103   uint32_t _f_h = 1;
104   float _w_min = -1.f;
105   float _w_max = 1.f;
106   float _a_min = -1.f;
107   float _a_max = 1.f;
108 };
109
110 } // namespace
111
112 TEST(CircleMPQSolverErrorApproximatorTest, verifyResultsTest)
113 {
114   NConvGraph g;
115   g.init();
116
117   auto value = mpqsolver::bisection::approximate(g._conv);
118   float expected = ((g._w_max - g._w_min) * g._channel_size * std::max(g._a_max, g._a_min) +
119                     (g._a_max - g._a_min) * g._channel_size * std::max(g._w_max, g._w_min)) *
120                    g._f_h * g._f_w * g._height * g._width * g._channel_size / 1.e+6f;
121   EXPECT_FLOAT_EQ(expected, value);
122 }
123
124 TEST(CircleMPQSolverErrorApproximatorTest, verifyResultsTest_NEG)
125 {
126   NConvGraph g;
127   g.init();
128
129   auto value = mpqsolver::bisection::approximate(g._input);
130   float expected = 0.f;
131   EXPECT_FLOAT_EQ(expected, value);
132
133   value = mpqsolver::bisection::approximate(g._output);
134   expected = 0.f;
135   EXPECT_FLOAT_EQ(expected, value);
136 }