2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include <gtest/gtest.h>
19 #include "ErrorApproximator.h"
20 #include "TestHelper.h"
22 #include <luci/IR/CircleNodeDecl.h>
29 inline uint32_t cal_offset(uint32_t shape[4], uint32_t *indices)
31 return indices[0] * shape[1] * shape[2] * shape[3] + indices[1] * shape[2] * shape[3] +
32 indices[2] * shape[3] + indices[3];
35 class NConvGraph final : public SimpleGraph
38 void initInput(loco::Node *input) override
40 auto ci_input = loco::must_cast<luci::CircleNode *>(input);
41 ci_input->shape_status(luci::ShapeStatus::VALID);
42 auto qparam = std::make_unique<luci::CircleQuantParam>();
43 qparam->min.assign(_channel_size, _a_min);
44 qparam->max.assign(_channel_size, _a_max);
45 ci_input->quantparam(std::move(qparam));
48 loco::Node *insertGraphBody(loco::Node *input) override
50 _filter = _g->nodes()->create<luci::CircleConst>();
51 _filter->dtype(loco::DataType::FLOAT32);
52 _filter->shape({_channel_size, _f_w, _f_h, _channel_size});
53 _filter->shape_status(luci::ShapeStatus::VALID);
54 _filter->name("conv_filter");
55 uint32_t indices[4] = {
59 uint32_t w_shape[4] = {_filter->dim(0).value(), _filter->dim(1).value(),
60 _filter->dim(2).value(), _filter->dim(3).value()};
62 _filter->size<loco::DataType::FLOAT32>(w_shape[0] * w_shape[1] * w_shape[2] * w_shape[3]);
64 for (indices[0] = 0; indices[0] < w_shape[0]; ++indices[0])
66 for (indices[1] = 0; indices[1] < w_shape[1]; ++indices[1])
68 for (indices[2] = 0; indices[2] < w_shape[2]; ++indices[2])
70 for (indices[3] = 0; indices[3] < w_shape[3]; ++indices[3])
72 uint32_t offset = cal_offset(w_shape, indices);
73 _filter->at<loco::DataType::FLOAT32>(offset) = (offset % 2 == 0) ? _w_max : _w_min;
79 _bias = _g->nodes()->create<luci::CircleConst>();
80 _bias->dtype(loco::DataType::FLOAT32);
81 _bias->shape({_channel_size});
82 _bias->name("conv_bias");
84 _conv = _g->nodes()->create<luci::CircleConv2D>();
85 _conv->padding(luci::Padding::SAME);
86 _conv->fusedActivationFunction(luci::FusedActFunc::NONE);
87 _conv->dtype(loco::DataType::FLOAT32);
88 _conv->shape({1, _width, _height, _channel_size});
89 _conv->shape_status(luci::ShapeStatus::VALID);
91 _conv->filter(_filter);
99 luci::CircleConv2D *_conv = nullptr;
100 luci::CircleConst *_filter = nullptr;
101 luci::CircleConst *_bias = nullptr;
112 TEST(CircleMPQSolverErrorApproximatorTest, verifyResultsTest)
117 auto value = mpqsolver::bisection::approximate(g._conv);
118 float expected = ((g._w_max - g._w_min) * g._channel_size * std::max(g._a_max, g._a_min) +
119 (g._a_max - g._a_min) * g._channel_size * std::max(g._w_max, g._w_min)) *
120 g._f_h * g._f_w * g._height * g._width * g._channel_size / 1.e+6f;
121 EXPECT_FLOAT_EQ(expected, value);
124 TEST(CircleMPQSolverErrorApproximatorTest, verifyResultsTest_NEG)
129 auto value = mpqsolver::bisection::approximate(g._input);
130 float expected = 0.f;
131 EXPECT_FLOAT_EQ(expected, value);
133 value = mpqsolver::bisection::approximate(g._output);
135 EXPECT_FLOAT_EQ(expected, value);