Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / ReplaceNonConstFCWithBatchMatMulPass.test.cpp
1 /*
2  * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/ReplaceNonConstFCWithBatchMatMulPass.h"
18
19 #include "helpers/CreateCircleConst.h"
20
21 #include <luci/test/TestIOGraph.h>
22 #include <luci/IR/CircleNodes.h>
23
24 #include <gtest/gtest.h>
25
26 namespace
27 {
28
29 using namespace luci::test;
30
31 /**
32  *  Simple graph for test
33  *
34  *  BEFORE
35  *
36  *   [IFM1] [IFM2] [BIAS]
37  *        \   |   /
38  *          [FC]
39  *            |
40  *          [Res]
41  *
42  *  AFTER
43  *   [IFM1] [IFM2]
44  *        \   |
45  *      [BatchMatMul] [BIAS]
46  *              \      /
47  *               [Add]
48  *                 |
49  *               [Res]
50  *
51  */
52 struct FCGraphlet
53 {
54 public:
55   FCGraphlet() = default;
56   virtual ~FCGraphlet() = default;
57
58   void init(loco::Graph *g, const ShapeU32 r_shape, const float bv)
59   {
60     _tr_y = g->nodes()->create<luci::CircleTranspose>();
61     _tr_y->a(_y);
62     std::vector<int32_t> tr_val = {1, 0};
63     _tr_y->perm(luci::create_const_node(g, loco::DataType::S32, {2}, tr_val));
64
65     _fc = g->nodes()->create<luci::CircleFullyConnected>();
66     _fc->input(_x);
67     _fc->weights(_tr_y);
68     _fc->fusedActivationFunction(luci::FusedActFunc::NONE);
69     _fc->dtype(loco::DataType::FLOAT32);
70     _fc->shape(r_shape);
71     auto l = _fc->dim(_fc->rank() - 1).value();
72     std::vector<float> bias_val(l, bv);
73     _fc->bias(luci::create_const_node(g, loco::DataType::FLOAT32, {l}, bias_val));
74     _fc->name("fc");
75   }
76
77 public:
78   luci::CircleFullyConnected *fc() { return _fc; }
79
80 protected:
81   luci::CircleFullyConnected *_fc = nullptr;
82   luci::CircleTranspose *_tr_y = nullptr;
83   luci::CircleInput *_x = nullptr;
84   luci::CircleInput *_y = nullptr;
85 };
86
87 struct FCGraph : public TestIsGraphlet<2>, public TestOGraphlet, public FCGraphlet
88 {
89   FCGraph() = default;
90   virtual ~FCGraph() = default;
91   void init(const ShapeU32 x_shape, const ShapeU32 y_shape, const ShapeU32 r_shape, const float bv)
92   {
93     TestIsGraphlet<2>::init(g(), {x_shape, y_shape});
94     TestOGraphlet::init(g(), r_shape);
95     _x = input(0);
96     _y = input(1);
97     FCGraphlet::init(g(), r_shape, bv);
98     output()->from(_fc);
99   }
100 };
101
102 class ReplaceNonConstFCWithBatchMatMulPassTest : public ::testing::Test
103 {
104 public:
105   FCGraph g;
106   luci::ReplaceNonConstFCWithBatchMatMulPass pass;
107 };
108
109 } // namespace
110
111 TEST_F(ReplaceNonConstFCWithBatchMatMulPassTest, simple_test)
112 {
113   g.init({2, 3}, {2, 3}, {2, 2}, 0.0f);
114
115   auto ret = pass.run(g.g());
116   EXPECT_EQ(true, ret);
117
118   auto res = dynamic_cast<luci::CircleReshape *>(g.output()->from());
119   EXPECT_NE(nullptr, res);
120 }
121
122 TEST_F(ReplaceNonConstFCWithBatchMatMulPassTest, nonzero_bias_test)
123 {
124   g.init({2, 3}, {2, 3}, {2, 2}, 1.0f);
125
126   auto ret = pass.run(g.g());
127   EXPECT_EQ(true, ret);
128
129   auto mm = dynamic_cast<luci::CircleAdd *>(g.output()->from());
130   EXPECT_NE(nullptr, mm);
131 }
132
133 TEST_F(ReplaceNonConstFCWithBatchMatMulPassTest, wrong_op_NEG)
134 {
135   loco::Graph g;
136
137   auto inp = g.nodes()->create<luci::CircleInput>();
138   auto relu = g.nodes()->create<luci::CircleRelu>();
139   relu->features(inp);
140
141   luci::ReplaceNonConstFCWithBatchMatMulPass pass;
142   auto changed = pass.run(&g);
143
144   EXPECT_EQ(false, changed);
145 }