Imported Upstream version 1.12.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / ReplaceMulAddWithDepthwiseConvPass.test.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/ReplaceMulAddWithDepthwiseConvPass.h"
18
19 #include <luci/IR/CircleNodes.h>
20
21 #include <gtest/gtest.h>
22
23 namespace
24 {
25
26 /**
27  *  Simple graph for test
28  *
29  *  BEFORE
30  *
31  *             [Node] [gamma]
32  *                |  /
33  *              [Mul]  [beta]
34  *                |   /
35  *               [Add]
36  *
37  *  AFTER
38  *
39  *              [Node]  [weights]  [bias]
40  *                  \      /       /
41  *                [DepthwiseConv2D]
42  */
43 class SimpleGraph
44 {
45 public:
46   SimpleGraph()
47   {
48     input = g.nodes()->create<luci::CircleInput>();
49     mul = g.nodes()->create<luci::CircleMul>();
50     gamma = g.nodes()->create<luci::CircleConst>();
51     add = g.nodes()->create<luci::CircleAdd>();
52     beta = g.nodes()->create<luci::CircleConst>();
53     output = g.nodes()->create<luci::CircleOutput>();
54
55     auto graph_input = g.inputs()->create();
56     input->index(graph_input->index());
57     auto graph_output = g.outputs()->create();
58     output->index(graph_output->index());
59
60     input->dtype(loco::DataType::FLOAT32);
61     mul->dtype(loco::DataType::FLOAT32);
62     gamma->dtype(loco::DataType::FLOAT32);
63     add->dtype(loco::DataType::FLOAT32);
64     beta->dtype(loco::DataType::FLOAT32);
65     output->dtype(loco::DataType::FLOAT32);
66
67     uint32_t channel_size = 16;
68     input->shape({1, 4, 4, channel_size});
69     mul->shape({1, 4, 4, channel_size});
70     gamma->shape({channel_size});
71     add->shape({1, 4, 4, channel_size});
72     beta->shape({channel_size});
73     output->shape({1, 4, 4, channel_size});
74
75     gamma->size<loco::DataType::FLOAT32>(channel_size);
76     beta->size<loco::DataType::FLOAT32>(channel_size);
77     for (uint32_t i = 0; i < channel_size; i++)
78     {
79       gamma->at<loco::DataType::FLOAT32>(i) = i;
80       beta->at<loco::DataType::FLOAT32>(i) = i;
81     }
82
83     mul->x(input);
84     mul->y(gamma);
85     add->x(mul);
86     add->y(beta);
87     output->from(add);
88   }
89
90 public:
91   loco::Graph g;
92   luci::CircleInput *input = nullptr;
93   luci::CircleMul *mul = nullptr;
94   luci::CircleConst *gamma = nullptr;
95   luci::CircleAdd *add = nullptr;
96   luci::CircleConst *beta = nullptr;
97   luci::CircleOutput *output = nullptr;
98 };
99
100 } // namespace
101
102 TEST(ReplaceMulAddWithDepthwiseConv, simple)
103 {
104   SimpleGraph g;
105
106   luci::ReplaceMulAddWithDepthwiseConvPass pass;
107   while (pass.run(&g.g))
108     ;
109
110   auto dwconv = dynamic_cast<luci::CircleDepthwiseConv2D *>(g.output->from());
111   EXPECT_NE(nullptr, dwconv);
112
113   uint32_t channel_size = 16;
114   auto weights = dynamic_cast<luci::CircleConst *>(dwconv->filter());
115   auto bias = dynamic_cast<luci::CircleConst *>(dwconv->bias());
116   EXPECT_NE(nullptr, weights);
117   EXPECT_EQ(4, weights->rank());
118   EXPECT_EQ(channel_size, weights->dim(3).value());
119   EXPECT_NE(nullptr, bias);
120   EXPECT_EQ(1, bias->rank());
121   EXPECT_EQ(channel_size, bias->dim(0).value());
122
123   for (int i = 0; i < channel_size; i++)
124   {
125     EXPECT_FLOAT_EQ(i, weights->at<loco::DataType::FLOAT32>(i));
126     EXPECT_FLOAT_EQ(i, bias->at<loco::DataType::FLOAT32>(i));
127   }
128 }
129
130 TEST(ReplaceMulAddWithDepthwiseConv, wrong_op_NEG)
131 {
132   SimpleGraph g;
133   // swap mul/add (changed to add->mul)
134   g.add->x(g.input);
135   loco::replace(g.add).with(g.mul);
136   g.mul->x(g.add);
137
138   luci::ReplaceMulAddWithDepthwiseConvPass pass;
139   auto changed = pass.run(&g.g);
140
141   EXPECT_EQ(false, changed);
142 }