Imported Upstream version 1.18.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / FoldDepthwiseConv2DPass.test.cpp
1 /*
2  * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/FoldDepthwiseConv2DPass.h"
18 #include "PassTestGraphs.h"
19
20 #include <luci/IR/CircleNodes.h>
21
22 #include <gtest/gtest.h>
23
24 namespace
25 {
26
27 /**
28  *  Graph has an DepthwiseConv2D Op with constant inputs
29  *
30  *    BEFORE
31  *
32  *    [CircleConst] [CircleConst]
33  *               |   |
34  *       [CircleDepthwiseConv2D]
35  *
36  *    AFTER
37  *
38  *           [CircleConst]
39  */
40 class FoldDepthwiseConv2DTest : public luci::ConstantFoldingTestGraph, public ::testing::Test
41 {
42 public:
43   FoldDepthwiseConv2DTest() : luci::ConstantFoldingTestGraph({1, 4, 4, 1}, loco::DataType::FLOAT32)
44   {
45     _dconv = _g.nodes()->create<luci::CircleDepthwiseConv2D>();
46     _dconv_input = _g.nodes()->create<luci::CircleConst>();
47     _dconv_filter = _g.nodes()->create<luci::CircleConst>();
48     _dconv_bias = _g.nodes()->create<luci::CircleConst>();
49
50     _dconv->dtype(loco::DataType::FLOAT32);
51     _dconv->padding(luci::Padding::VALID);
52     _dconv->fusedActivationFunction(luci::FusedActFunc::NONE);
53     _dconv->input(_dconv_input);
54     _dconv->filter(_dconv_filter);
55     _dconv->bias(_dconv_bias);
56     _dconv->shape({1, 4, 4, 1});
57     _dconv->stride()->h(1);
58     _dconv->stride()->w(1);
59     _dconv->depthMultiplier(1);
60
61     _dconv_input->dtype(loco::DataType::FLOAT32);
62     _dconv_input->shape({1, 4, 4, 1});
63     _dconv_input->size<loco::DataType::FLOAT32>(16);
64
65     _dconv_filter->dtype(loco::DataType::FLOAT32);
66     _dconv_filter->shape({1, 1, 1, 1});
67     _dconv_filter->size<loco::DataType::FLOAT32>(1);
68
69     _dconv_bias->dtype(loco::DataType::FLOAT32);
70     _dconv_bias->shape({1});
71     _dconv_bias->size<loco::DataType::FLOAT32>(1);
72
73     _output->from(_dconv);
74   }
75
76 protected:
77   void init() final {}
78
79 protected:
80   loco::Node *createFoldedPattern() final { return nullptr; }
81
82 protected:
83   luci::CircleConst *getFoldedPattern() final
84   {
85     return loco::must_cast<luci::CircleConst *>(_output->from());
86   }
87
88 protected:
89   luci::CircleDepthwiseConv2D *_dconv = nullptr;
90   luci::CircleConst *_dconv_input = nullptr;
91   luci::CircleConst *_dconv_filter = nullptr;
92   luci::CircleConst *_dconv_bias = nullptr;
93 };
94
95 } // namespace
96
97 TEST(FoldDepthwiseConv2DPass, name)
98 {
99   luci::FoldDepthwiseConv2DPass pass;
100   auto const name = pass.name();
101   ASSERT_NE(nullptr, name);
102 }
103
104 TEST_F(FoldDepthwiseConv2DTest, fold_depthwise_conv2d)
105 {
106   for (uint32_t i = 0; i < 16; ++i)
107     _dconv_input->at<loco::DataType::FLOAT32>(i) = 0.5;
108   _dconv_filter->at<loco::DataType::FLOAT32>(0) = 0.5;
109
110   luci::FoldDepthwiseConv2DPass pass;
111   ASSERT_TRUE(pass.run(&_g));
112
113   auto folded_const = getFoldedPattern();
114   EXPECT_EQ(folded_const->dtype(), loco::DataType::FLOAT32);
115   EXPECT_NEAR(folded_const->at<loco::DataType::FLOAT32>(0), 0.25,
116               std::numeric_limits<float>::min());
117   EXPECT_NEAR(folded_const->at<loco::DataType::FLOAT32>(15), 0.25,
118               std::numeric_limits<float>::min());
119 }
120
121 TEST_F(FoldDepthwiseConv2DTest, fold_non_constant_NEG)
122 {
123   _dconv->input(_input);
124
125   luci::FoldDepthwiseConv2DPass pass;
126   ASSERT_FALSE(pass.run(&_g));
127 }