Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / FuseAddWithFullyConnectedPass.test.cpp
1 /*
2  * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/FuseAddWithFullyConnectedPass.h"
18
19 #include "helpers/CreateCircleConst.h"
20
21 #include <luci/IR/CircleNodes.h>
22
23 #include <luci/test/TestIOGraph.h>
24
25 #include <gtest/gtest.h>
26
27 namespace
28 {
29
30 using namespace luci::test;
31
32 /**
33  *  Simple graph for test
34  *
35  *  BEFORE
36  *
37  *         [FC]
38  *           |
39  *     [Add w/ Relu]
40  *
41  *  AFTER
42  *
43  *      [FC w/ Relu] (bias updated)
44  *
45  */
46 class FCAddGraphlet
47 {
48 public:
49   FCAddGraphlet() = default;
50
51   void init(loco::Graph *g)
52   {
53     std::vector<float> weights_val(16 * 4);
54     _fc_f = luci::create_const_node(g, loco::DataType::FLOAT32, {16, 4}, weights_val);
55
56     std::vector<float> bias_val(16);
57     _fc_b = luci::create_const_node(g, loco::DataType::FLOAT32, {1, 16}, bias_val);
58
59     _fc = g->nodes()->create<luci::CircleFullyConnected>();
60     _fc->weights(_fc_f);
61     _fc->bias(_fc_b);
62     _fc->fusedActivationFunction(luci::FusedActFunc::NONE);
63     _fc->dtype(loco::DataType::FLOAT32);
64     _fc->shape({1, 16});
65     _fc->name("fc");
66
67     std::vector<float> addition_val;
68     for (uint32_t i = 0; i < 16; i++)
69       addition_val.push_back(static_cast<float>(i));
70     _add_c = luci::create_const_node(g, loco::DataType::FLOAT32, {1, 16}, addition_val);
71
72     _add = g->nodes()->create<luci::CircleAdd>();
73     _add->x(_fc);
74     _add->y(_add_c);
75     _add->fusedActivationFunction(luci::FusedActFunc::RELU);
76     _add->dtype(loco::DataType::FLOAT32);
77     _add->shape({1, 16});
78     _add->name("add");
79   }
80
81 public:
82   luci::CircleFullyConnected *fc() { return _fc; }
83
84 public:
85   void to_fm_bias(void)
86   {
87     assert(_fc != nullptr); // FIX_ME_UNLESS
88
89     auto new_fc = _fc->graph()->nodes()->create<luci::CircleFullyConnected>();
90     _fc->bias(new_fc);
91   }
92
93 protected:
94   luci::CircleFullyConnected *_fc = nullptr;
95   luci::CircleAdd *_add = nullptr;
96   luci::CircleConst *_fc_f = nullptr;
97   luci::CircleConst *_fc_b = nullptr;
98   luci::CircleConst *_add_c = nullptr;
99 };
100
101 class FuseAddWithFCTestGraph : public TestIOGraph, public FCAddGraphlet
102 {
103 public:
104   FuseAddWithFCTestGraph() = default;
105
106   void init(void)
107   {
108     TestIOGraph::init({1, 4}, {1, 16});
109     FCAddGraphlet::init(g());
110
111     _fc->input(input());
112
113     output()->from(_add);
114   }
115 };
116
117 class FuseAddWithFullyConnectedPassTest : public ::testing::Test
118 {
119 public:
120   FuseAddWithFCTestGraph g;
121   luci::FuseAddWithFullyConnectedPass pass;
122 };
123
124 } // namespace
125
126 TEST_F(FuseAddWithFullyConnectedPassTest, simple_test)
127 {
128   g.init();
129
130   auto ret = pass.run(g.g());
131   EXPECT_EQ(true, ret);
132
133   auto fc = dynamic_cast<luci::CircleFullyConnected *>(g.output()->from());
134   EXPECT_NE(nullptr, fc);
135
136   auto bias = loco::must_cast<luci::CircleConst *>(g.fc()->bias());
137   for (uint32_t i = 0; i < bias->size<loco::DataType::FLOAT32>(); i++)
138   {
139     EXPECT_EQ(i, bias->at<loco::DataType::FLOAT32>(i));
140   }
141 }
142
143 TEST_F(FuseAddWithFullyConnectedPassTest, fm_bias_NEG)
144 {
145   g.init();
146
147   // Bias is a feature map. Add is not fused.
148   g.to_fm_bias();
149
150   auto ret = pass.run(g.g());
151   EXPECT_EQ(false, ret);
152 }