2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include <gtest/gtest.h>
19 #include "DepthParameterizer.h"
20 #include "TestHelper.h"
22 #include <luci/IR/CircleNodes.h>
27 class NConvGraph final : public SimpleGraph
30 loco::Node *insertGraphBody(loco::Node *input) override
32 _filter = _g->nodes()->create<luci::CircleConst>();
33 _filter->dtype(loco::DataType::FLOAT32);
34 _filter->shape({_channel_size, 1, 1, _channel_size});
35 _filter->name("conv_filter");
37 _bias = _g->nodes()->create<luci::CircleConst>();
38 _bias->dtype(loco::DataType::FLOAT32);
39 _bias->shape({_channel_size});
40 _bias->name("conv_bias");
42 _conv = _g->nodes()->create<luci::CircleConv2D>();
43 _conv->padding(luci::Padding::SAME);
44 _conv->fusedActivationFunction(luci::FusedActFunc::NONE);
45 _conv->dtype(loco::DataType::FLOAT32);
46 _conv->shape({1, _width, _height, _channel_size});
48 _conv->filter(_filter);
56 luci::CircleConv2D *_conv = nullptr;
57 luci::CircleConst *_filter = nullptr;
58 luci::CircleConst *_bias = nullptr;
63 TEST(CircleMPQSolverDepthParameteriserTest, verifyResultsTest)
65 auto m = luci::make_module();
69 auto input = g._input;
70 auto output = g._output;
72 g.transfer_to(m.get());
74 mpqsolver::bisection::NodeDepthType nodes_depth;
75 float min_depth = std::numeric_limits<float>().max();
76 float max_depth = -std::numeric_limits<float>().max();
77 auto status = mpqsolver::bisection::compute_depth(m.get(), nodes_depth, min_depth, max_depth);
79 EXPECT_TRUE(status == mpqsolver::bisection::ParameterizerResult::SUCCESS);
80 EXPECT_TRUE(max_depth == 2 && min_depth == 0);
81 EXPECT_TRUE(nodes_depth[input] == min_depth);
82 EXPECT_TRUE(nodes_depth[conv] == 1);
83 EXPECT_TRUE(nodes_depth[output] == max_depth);
86 TEST(CircleMPQSolverDepthParameteriserTest, verifyResultsTest_NEG)
88 auto m = luci::make_module();
89 mpqsolver::bisection::NodeDepthType nodes_depth;
90 float min_depth = std::numeric_limits<float>().max();
91 float max_depth = -std::numeric_limits<float>().max();
92 auto status = mpqsolver::bisection::compute_depth(m.get(), nodes_depth, min_depth, max_depth);
94 EXPECT_TRUE(status == mpqsolver::bisection::ParameterizerResult::FAILURE);