2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Pass/PropagateQuantParamPass.h"
19 #include <luci/IR/CircleNodes.h>
21 #include <gtest/gtest.h>
26 void addQuantParam(luci::CircleNode *node, const std::vector<float> &scale,
27 const std::vector<int64_t> &zp)
29 assert(node->quantparam() == nullptr);
31 auto quantparam = std::make_unique<luci::CircleQuantParam>();
32 quantparam->scale = scale;
33 quantparam->zerop = zp;
34 node->quantparam(std::move(quantparam));
38 * Simple graph for test
44 * [Reshape] (qparam 2)
50 * [Reshape] (qparam 2)
58 input = g.nodes()->create<luci::CircleInput>();
59 conv = g.nodes()->create<luci::CircleConv2D>();
60 reshape = g.nodes()->create<luci::CircleReshape>();
61 output = g.nodes()->create<luci::CircleOutput>();
63 auto graph_input = g.inputs()->create();
64 input->index(graph_input->index());
65 auto graph_output = g.outputs()->create();
66 output->index(graph_output->index());
68 addQuantParam(conv, {0.1, 0.2, 0.3}, {0, 10, 20});
69 addQuantParam(reshape, {0.2, 0.4, 0.6}, {-10, 0, 10});
72 reshape->tensor(conv);
73 output->from(reshape);
78 luci::CircleInput *input;
79 luci::CircleConv2D *conv;
80 luci::CircleReshape *reshape;
81 luci::CircleOutput *output;
86 TEST(PropagateQuantParam, simple)
90 luci::PropagateQuantParamPass pass;
91 while (pass.run(&g.g))
94 EXPECT_FLOAT_EQ(0.2, g.conv->quantparam()->scale[0]);
95 EXPECT_FLOAT_EQ(0.4, g.conv->quantparam()->scale[1]);
96 EXPECT_FLOAT_EQ(0.6, g.conv->quantparam()->scale[2]);
97 EXPECT_EQ(-10, g.conv->quantparam()->zerop[0]);
98 EXPECT_EQ(0, g.conv->quantparam()->zerop[1]);
99 EXPECT_EQ(10, g.conv->quantparam()->zerop[2]);
102 TEST(PropagateQuantParam, wrong_op_NEG)
105 g.output->from(g.conv);
108 luci::PropagateQuantParamPass pass;
109 while (pass.run(&g.g))
112 EXPECT_FLOAT_EQ(0.1, g.conv->quantparam()->scale[0]);
113 EXPECT_FLOAT_EQ(0.2, g.conv->quantparam()->scale[1]);
114 EXPECT_FLOAT_EQ(0.3, g.conv->quantparam()->scale[2]);
115 EXPECT_EQ(0, g.conv->quantparam()->zerop[0]);
116 EXPECT_EQ(10, g.conv->quantparam()->zerop[1]);
117 EXPECT_EQ(20, g.conv->quantparam()->zerop[2]);