Imported Upstream version 1.7.0
[platform/core/ml/nnfw.git] / compiler / loco / src / IR / Graph.test.cpp
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "loco/IR/Graph.h"
18
19 #include <gtest/gtest.h>
20
21 namespace
22 {
23
24 /// @brief Mockup class for loco::NamedEntity
25 struct NamedElement final : private loco::NamedEntity
26 {
27   LOCO_NAMED_ENTITY_EXPOSE;
28 };
29
30 } // namespace
31
32 TEST(NamedTest, constructor)
33 {
34   NamedElement elem;
35
36   ASSERT_EQ("", elem.name());
37 }
38
39 TEST(NamedTest, setter_and_getter)
40 {
41   NamedElement elem;
42
43   elem.name("name");
44   ASSERT_EQ("name", elem.name());
45 }
46
47 TEST(DataTypedMixinTest, constructor)
48 {
49   loco::Mixin<loco::Trait::DataTyped> mixin;
50
51   ASSERT_EQ(loco::DataType::Unknown, mixin.dtype());
52 }
53
54 TEST(DataTypedMixinTest, setter_and_getter)
55 {
56   loco::Mixin<loco::Trait::DataTyped> mixin;
57
58   mixin.dtype(loco::DataType::FLOAT32);
59   ASSERT_EQ(loco::DataType::FLOAT32, mixin.dtype());
60 }
61
62 TEST(TensorShapedMixinTest, setter_and_getter)
63 {
64   loco::Mixin<loco::Trait::TensorShaped> mixin;
65
66   mixin.shape({1, 2, 3, 4});
67   ASSERT_NE(mixin.shape(), nullptr);
68   ASSERT_EQ(4, mixin.shape()->rank());
69   ASSERT_EQ(1, mixin.shape()->dim(0));
70   ASSERT_EQ(2, mixin.shape()->dim(1));
71   ASSERT_EQ(3, mixin.shape()->dim(2));
72   ASSERT_EQ(4, mixin.shape()->dim(3));
73 }
74
75 TEST(GraphTest, create_and_destroy_node)
76 {
77   auto g = loco::make_graph();
78
79   auto pull = g->nodes()->create<loco::Pull>();
80
81   ASSERT_NO_THROW(g->nodes()->destroy(pull));
82   ASSERT_THROW(g->nodes()->destroy(pull), std::invalid_argument);
83 }
84
85 TEST(GraphTest, create_input)
86 {
87   auto g = loco::make_graph();
88
89   auto input = g->inputs()->create();
90
91   // TODO Add more checks
92   ASSERT_EQ(nullptr, input->shape());
93   ASSERT_EQ(0, input->index());
94 }
95
96 TEST(GraphTest, create_output)
97 {
98   auto g = loco::make_graph();
99
100   auto output = g->outputs()->create();
101
102   // TODO Add more checks
103   ASSERT_EQ(nullptr, output->shape());
104   ASSERT_EQ(0, output->index());
105 }
106
107 namespace
108 {
109 // temp node with multple params for ctor. loco::CanonicalOpcode::ReLU is used for simplicity
110 class ParamCtorNode
111     : public loco::CanonicalNodeDef<loco::CanonicalOpcode::ReLU, loco::FixedArity<0>::Mixin>
112 {
113 public:
114   ParamCtorNode(int i, float f)
115   {
116     _i = i;
117     _f = f;
118   }
119
120   int i() { return _i; }
121   float f() { return _f; }
122
123 private:
124   int _i;
125   float _f;
126 };
127 } // namespace
128
129 TEST(GraphTest, consturctor_with_param_node)
130 {
131   auto g = loco::make_graph();
132
133   auto test_node = g->nodes()->create<ParamCtorNode>(22, 11.11);
134
135   ASSERT_EQ(g.get(), test_node->graph());
136   ASSERT_EQ(g.get(), const_cast<const ParamCtorNode *>(test_node)->graph());
137
138   ASSERT_EQ(22, test_node->i());
139   ASSERT_FLOAT_EQ(test_node->f(), 11.11);
140
141   ASSERT_NO_THROW(g->nodes()->destroy(test_node));
142   ASSERT_THROW(g->nodes()->destroy(test_node), std::invalid_argument);
143 }
144
145 TEST(GraphTest, getters_over_const_instance)
146 {
147   auto g = loco::make_graph();
148
149   auto pull = g->nodes()->create<loco::Pull>();
150   auto push = g->nodes()->create<loco::Push>();
151
152   loco::link(g->inputs()->create(), pull);
153   loco::link(g->outputs()->create(), push);
154
155   auto ptr = const_cast<const loco::Graph *>(g.get());
156
157   EXPECT_EQ(ptr->nodes()->size(), 2);
158   EXPECT_EQ(ptr->inputs()->size(), 1);
159 }
160
161 TEST(GraphTest, graph_node_enumeration)
162 {
163   auto g = loco::make_graph();
164
165   auto pull_1 = g->nodes()->create<loco::Pull>();
166   auto push_1 = g->nodes()->create<loco::Push>();
167
168   auto nodes = loco::all_nodes(g.get());
169
170   // Returns true if "nodes" includes a given node
171   auto member = [&nodes](loco::Node *node) { return nodes.find(node) != nodes.end(); };
172
173   ASSERT_EQ(2, nodes.size());
174   ASSERT_TRUE(member(pull_1));
175   ASSERT_TRUE(member(push_1));
176 }
177
178 TEST(GraphTest, graph_inout_enumeration)
179 {
180   auto g = loco::make_graph();
181
182   std::vector<loco::Pull *> pull_nodes;
183
184   auto pull_1 = g->nodes()->create<loco::Pull>();
185   auto pull_2 = g->nodes()->create<loco::Pull>();
186   auto pull_3 = g->nodes()->create<loco::Pull>();
187
188   auto push_1 = g->nodes()->create<loco::Push>();
189   auto push_2 = g->nodes()->create<loco::Push>();
190   auto push_3 = g->nodes()->create<loco::Push>();
191
192   loco::link(g->inputs()->create(), pull_2);
193   loco::link(g->inputs()->create(), pull_1);
194
195   loco::link(g->outputs()->create(), push_1);
196   loco::link(g->outputs()->create(), push_3);
197
198   auto output_nodes = loco::output_nodes(g.get());
199
200   ASSERT_EQ(2, output_nodes.size());
201   ASSERT_EQ(push_1, output_nodes.at(0));
202   ASSERT_EQ(push_3, output_nodes.at(1));
203 }
204
205 TEST(GraphTest, graph_name)
206 {
207   auto g = loco::make_graph();
208
209   g->name("HelloGraph");
210   ASSERT_TRUE(g->name() == "HelloGraph");
211 }
212
213 TEST(GraphTest, graph_name_nullptr_NEG)
214 {
215   auto g = loco::make_graph();
216
217   EXPECT_ANY_THROW(g->name(nullptr));
218 }