2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "loco/IR/Graph.h"
19 #include <gtest/gtest.h>
24 /// @brief Mockup class for loco::NamedEntity
25 struct NamedElement final : private loco::NamedEntity
27 LOCO_NAMED_ENTITY_EXPOSE;
32 TEST(NamedTest, constructor)
36 ASSERT_EQ("", elem.name());
39 TEST(NamedTest, setter_and_getter)
44 ASSERT_EQ("name", elem.name());
47 TEST(DataTypedMixinTest, constructor)
49 loco::Mixin<loco::Trait::DataTyped> mixin;
51 ASSERT_EQ(loco::DataType::Unknown, mixin.dtype());
54 TEST(DataTypedMixinTest, setter_and_getter)
56 loco::Mixin<loco::Trait::DataTyped> mixin;
58 mixin.dtype(loco::DataType::FLOAT32);
59 ASSERT_EQ(loco::DataType::FLOAT32, mixin.dtype());
62 TEST(TensorShapedMixinTest, setter_and_getter)
64 loco::Mixin<loco::Trait::TensorShaped> mixin;
66 mixin.shape({1, 2, 3, 4});
67 ASSERT_NE(mixin.shape(), nullptr);
68 ASSERT_EQ(4, mixin.shape()->rank());
69 ASSERT_EQ(1, mixin.shape()->dim(0));
70 ASSERT_EQ(2, mixin.shape()->dim(1));
71 ASSERT_EQ(3, mixin.shape()->dim(2));
72 ASSERT_EQ(4, mixin.shape()->dim(3));
75 TEST(GraphTest, create_and_destroy_node)
77 auto g = loco::make_graph();
79 auto pull = g->nodes()->create<loco::Pull>();
81 ASSERT_NO_THROW(g->nodes()->destroy(pull));
82 ASSERT_THROW(g->nodes()->destroy(pull), std::invalid_argument);
85 TEST(GraphTest, create_input)
87 auto g = loco::make_graph();
89 auto input = g->inputs()->create();
91 // TODO Add more checks
92 ASSERT_EQ(nullptr, input->shape());
93 ASSERT_EQ(0, input->index());
96 TEST(GraphTest, create_output)
98 auto g = loco::make_graph();
100 auto output = g->outputs()->create();
102 // TODO Add more checks
103 ASSERT_EQ(nullptr, output->shape());
104 ASSERT_EQ(0, output->index());
109 // temp node with multple params for ctor. loco::CanonicalOpcode::ReLU is used for simplicity
111 : public loco::CanonicalNodeDef<loco::CanonicalOpcode::ReLU, loco::FixedArity<0>::Mixin>
114 ParamCtorNode(int i, float f)
120 int i() { return _i; }
121 float f() { return _f; }
129 TEST(GraphTest, consturctor_with_param_node)
131 auto g = loco::make_graph();
133 auto test_node = g->nodes()->create<ParamCtorNode>(22, 11.11);
135 ASSERT_EQ(g.get(), test_node->graph());
136 ASSERT_EQ(g.get(), const_cast<const ParamCtorNode *>(test_node)->graph());
138 ASSERT_EQ(22, test_node->i());
139 ASSERT_FLOAT_EQ(test_node->f(), 11.11);
141 ASSERT_NO_THROW(g->nodes()->destroy(test_node));
142 ASSERT_THROW(g->nodes()->destroy(test_node), std::invalid_argument);
145 TEST(GraphTest, getters_over_const_instance)
147 auto g = loco::make_graph();
149 auto pull = g->nodes()->create<loco::Pull>();
150 auto push = g->nodes()->create<loco::Push>();
152 loco::link(g->inputs()->create(), pull);
153 loco::link(g->outputs()->create(), push);
155 auto ptr = const_cast<const loco::Graph *>(g.get());
157 EXPECT_EQ(ptr->nodes()->size(), 2);
158 EXPECT_EQ(ptr->inputs()->size(), 1);
161 TEST(GraphTest, graph_node_enumeration)
163 auto g = loco::make_graph();
165 auto pull_1 = g->nodes()->create<loco::Pull>();
166 auto push_1 = g->nodes()->create<loco::Push>();
168 auto nodes = loco::all_nodes(g.get());
170 // Returns true if "nodes" includes a given node
171 auto member = [&nodes](loco::Node *node) { return nodes.find(node) != nodes.end(); };
173 ASSERT_EQ(2, nodes.size());
174 ASSERT_TRUE(member(pull_1));
175 ASSERT_TRUE(member(push_1));
178 TEST(GraphTest, graph_inout_enumeration)
180 auto g = loco::make_graph();
182 std::vector<loco::Pull *> pull_nodes;
184 auto pull_1 = g->nodes()->create<loco::Pull>();
185 auto pull_2 = g->nodes()->create<loco::Pull>();
186 auto pull_3 = g->nodes()->create<loco::Pull>();
188 auto push_1 = g->nodes()->create<loco::Push>();
189 auto push_2 = g->nodes()->create<loco::Push>();
190 auto push_3 = g->nodes()->create<loco::Push>();
192 loco::link(g->inputs()->create(), pull_2);
193 loco::link(g->inputs()->create(), pull_1);
195 loco::link(g->outputs()->create(), push_1);
196 loco::link(g->outputs()->create(), push_3);
198 auto output_nodes = loco::output_nodes(g.get());
200 ASSERT_EQ(2, output_nodes.size());
201 ASSERT_EQ(push_1, output_nodes.at(0));
202 ASSERT_EQ(push_3, output_nodes.at(1));
205 TEST(GraphTest, graph_name)
207 auto g = loco::make_graph();
209 g->name("HelloGraph");
210 ASSERT_TRUE(g->name() == "HelloGraph");
213 TEST(GraphTest, graph_name_nullptr_NEG)
215 auto g = loco::make_graph();
217 EXPECT_ANY_THROW(g->name(nullptr));