Imported Upstream version 1.12.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / PropagateQuantParamPass.test.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/PropagateQuantParamPass.h"
18
19 #include <luci/IR/CircleNodes.h>
20
21 #include <gtest/gtest.h>
22
23 namespace
24 {
25
26 void addQuantParam(luci::CircleNode *node, const std::vector<float> &scale,
27                    const std::vector<int64_t> &zp)
28 {
29   assert(node->quantparam() == nullptr);
30
31   auto quantparam = std::make_unique<luci::CircleQuantParam>();
32   quantparam->scale = scale;
33   quantparam->zerop = zp;
34   node->quantparam(std::move(quantparam));
35 }
36
37 /**
38  *  Simple graph for test
39  *
40  *  BEFORE
41  *
42  *        [Conv] (qparam 1)
43  *           |
44  *       [Reshape] (qparam 2)
45  *
46  *  AFTER
47  *
48  *        [Conv] (qparam 2)
49  *           |
50  *       [Reshape] (qparam 2)
51  *
52  */
53 class SimpleGraph
54 {
55 public:
56   SimpleGraph()
57   {
58     input = g.nodes()->create<luci::CircleInput>();
59     conv = g.nodes()->create<luci::CircleConv2D>();
60     reshape = g.nodes()->create<luci::CircleReshape>();
61     output = g.nodes()->create<luci::CircleOutput>();
62
63     auto graph_input = g.inputs()->create();
64     input->index(graph_input->index());
65     auto graph_output = g.outputs()->create();
66     output->index(graph_output->index());
67
68     addQuantParam(conv, {0.1, 0.2, 0.3}, {0, 10, 20});
69     addQuantParam(reshape, {0.2, 0.4, 0.6}, {-10, 0, 10});
70
71     conv->input(input);
72     reshape->tensor(conv);
73     output->from(reshape);
74   }
75
76 public:
77   loco::Graph g;
78   luci::CircleInput *input;
79   luci::CircleConv2D *conv;
80   luci::CircleReshape *reshape;
81   luci::CircleOutput *output;
82 };
83
84 } // namespace
85
86 TEST(PropagateQuantParam, simple)
87 {
88   SimpleGraph g;
89
90   luci::PropagateQuantParamPass pass;
91   while (pass.run(&g.g))
92     ;
93
94   EXPECT_FLOAT_EQ(0.2, g.conv->quantparam()->scale[0]);
95   EXPECT_FLOAT_EQ(0.4, g.conv->quantparam()->scale[1]);
96   EXPECT_FLOAT_EQ(0.6, g.conv->quantparam()->scale[2]);
97   EXPECT_EQ(-10, g.conv->quantparam()->zerop[0]);
98   EXPECT_EQ(0, g.conv->quantparam()->zerop[1]);
99   EXPECT_EQ(10, g.conv->quantparam()->zerop[2]);
100 }
101
102 TEST(PropagateQuantParam, wrong_op_NEG)
103 {
104   SimpleGraph g;
105   g.output->from(g.conv);
106   g.reshape->drop();
107
108   luci::PropagateQuantParamPass pass;
109   while (pass.run(&g.g))
110     ;
111
112   EXPECT_FLOAT_EQ(0.1, g.conv->quantparam()->scale[0]);
113   EXPECT_FLOAT_EQ(0.2, g.conv->quantparam()->scale[1]);
114   EXPECT_FLOAT_EQ(0.3, g.conv->quantparam()->scale[2]);
115   EXPECT_EQ(0, g.conv->quantparam()->zerop[0]);
116   EXPECT_EQ(10, g.conv->quantparam()->zerop[1]);
117   EXPECT_EQ(20, g.conv->quantparam()->zerop[2]);
118 }