Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / circle-mpqsolver / src / bisection / DepthParameterizer.test.cpp
1 /*
2  * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <gtest/gtest.h>
18
19 #include "DepthParameterizer.h"
20 #include <core/TestHelper.h>
21
22 #include <luci/IR/CircleNodes.h>
23
24 namespace
25 {
26
27 class NConvGraph final : public SimpleGraph
28 {
29 protected:
30   loco::Node *insertGraphBody(loco::Node *input) override
31   {
32     _filter = _g->nodes()->create<luci::CircleConst>();
33     _filter->dtype(loco::DataType::FLOAT32);
34     _filter->shape({_channel_size, 1, 1, _channel_size});
35     _filter->name("conv_filter");
36
37     _bias = _g->nodes()->create<luci::CircleConst>();
38     _bias->dtype(loco::DataType::FLOAT32);
39     _bias->shape({_channel_size});
40     _bias->name("conv_bias");
41
42     _conv = _g->nodes()->create<luci::CircleConv2D>();
43     _conv->padding(luci::Padding::SAME);
44     _conv->fusedActivationFunction(luci::FusedActFunc::NONE);
45     _conv->dtype(loco::DataType::FLOAT32);
46     _conv->shape({1, _width, _height, _channel_size});
47     _conv->name("conv");
48     _conv->filter(_filter);
49     _conv->bias(_bias);
50     _conv->input(input);
51
52     return _conv;
53   }
54
55 public:
56   luci::CircleConv2D *_conv = nullptr;
57   luci::CircleConst *_filter = nullptr;
58   luci::CircleConst *_bias = nullptr;
59 };
60
61 } // namespace
62
63 TEST(CircleMPQSolverDepthParameteriserTest, verifyResultsTest)
64 {
65   auto m = luci::make_module();
66   NConvGraph g;
67   g.init();
68   auto conv = g._conv;
69   auto input = g._input;
70   auto output = g._output;
71
72   g.transfer_to(m.get());
73
74   mpqsolver::bisection::NodeDepthType nodes_depth;
75   float min_depth = std::numeric_limits<float>().max();
76   float max_depth = -std::numeric_limits<float>().max();
77   auto status = mpqsolver::bisection::compute_depth(m.get(), nodes_depth, min_depth, max_depth);
78
79   EXPECT_TRUE(status == mpqsolver::bisection::ParameterizerResult::SUCCESS);
80   EXPECT_TRUE(max_depth == 2 && min_depth == 0);
81   EXPECT_TRUE(nodes_depth[input] == min_depth);
82   EXPECT_TRUE(nodes_depth[conv] == 1);
83   EXPECT_TRUE(nodes_depth[output] == max_depth);
84 }
85
86 TEST(CircleMPQSolverDepthParameteriserTest, verifyResultsTest_NEG)
87 {
88   auto m = luci::make_module();
89   mpqsolver::bisection::NodeDepthType nodes_depth;
90   float min_depth = std::numeric_limits<float>().max();
91   float max_depth = -std::numeric_limits<float>().max();
92   auto status = mpqsolver::bisection::compute_depth(m.get(), nodes_depth, min_depth, max_depth);
93
94   EXPECT_TRUE(status == mpqsolver::bisection::ParameterizerResult::FAILURE);
95 }