2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #ifndef __TEST_GRAPH_H__
18 #define __TEST_GRAPH_H__
20 #include <luci/IR/CircleNodes.h>
27 // TODO Change all Canonical nodes to Circle nodes
37 std::unique_ptr<loco::Graph> g;
38 luci::CircleInput *input_node = nullptr;
39 luci::CircleOutput *output_node = nullptr;
41 TestGraph() // creates Pull and Push
43 g = loco::make_graph();
45 input_node = g->nodes()->create<luci::CircleInput>();
47 output_node = g->nodes()->create<luci::CircleOutput>();
49 auto input = g->inputs()->create();
52 luci::link(input, input_node);
54 auto output = g->outputs()->create();
56 output->name("output");
57 luci::link(output, output_node);
60 _next_input = input_node;
63 loco::Graph *graph() { return g.get(); }
65 /// @brief Creates node with NO arg and appends it to graph
66 template <class T> T *append()
68 auto node = g->nodes()->create<T>();
74 /// @brief Creates op T (arity=1) with arg1 as an input and appends it to graph
75 template <class T> T *append(luci::CircleNode *arg1)
77 auto node = g->nodes()->create<T>();
84 /// @brief Creates op T (arity=2) with arg1, arg2 as inputs and appends it to graph
85 template <class T> T *append(luci::CircleNode *arg1, luci::CircleNode *arg2)
87 auto node = g->nodes()->create<T>();
88 setInput(node, arg1, arg2);
94 /// @brief Creates op T (arity=3) with arg1, arg2, arg3 as inputs and appends it to graph
96 T *append(luci::CircleNode *arg1, luci::CircleNode *arg2, luci::CircleNode *arg3)
98 auto node = g->nodes()->create<T>();
99 setInput(node, arg1, arg2, arg3);
105 // output will get the last appended node
106 void complete() { output_node->from(_next_input); }
108 void complete(luci::CircleNode *last_node) { output_node->from(last_node); }
112 void setInput(luci::CircleNode *, luci::CircleNode *) { assert(false && "NYI"); };
114 void setInput(luci::CircleAveragePool2D *node, luci::CircleNode *input) { node->value(input); };
115 void setInput(luci::CircleRelu *node, luci::CircleNode *input) { node->features(input); };
116 void setInput(luci::CircleSqueeze *node, luci::CircleNode *input) { node->input(input); };
118 void setInput(luci::CircleGatherNd *node, luci::CircleNode *params, luci::CircleNode *indices)
120 node->params(params);
121 node->indices(indices);
125 void setInput(luci::CircleNode *, luci::CircleNode *, luci::CircleNode *)
127 assert(false && "NYI");
130 void setInput(luci::CircleExpandDims *node, luci::CircleNode *arg1, luci::CircleNode *arg2)
136 void setInput(luci::CircleTranspose *node, luci::CircleNode *arg1, luci::CircleNode *arg2)
142 void setInput(luci::CircleResizeBilinear *node, luci::CircleNode *input, luci::CircleNode *size)
148 void setInput(luci::CircleResizeNearestNeighbor *node, luci::CircleNode *input,
149 luci::CircleNode *size)
156 void setInput(luci::CircleNode *, luci::CircleNode *, luci::CircleNode *, luci::CircleNode *)
158 assert(false && "NYI");
162 loco::Node *_next_input;
165 enum class ExampleGraphType
170 template <ExampleGraphType T> class ExampleGraph;
173 * @brief Class to create the following:
175 * CircleInput -- CircleTranspose -- CircleOutput
177 template <> class ExampleGraph<ExampleGraphType::CircleTranspose> : public TestGraph
180 luci::CircleConst *const_perm = nullptr;
181 luci::CircleTranspose *transpose_node = nullptr;
186 const_perm = append<luci::CircleConst>();
187 transpose_node = append<luci::CircleTranspose>(input_node, const_perm);
188 complete(transpose_node);
200 /// @brief This will set GraphInput shape from CircleInput shape
201 void graph_input_shape(luci::CircleInput *input);
203 /// @brief This will set GraphOutput shape from CircleOutput shape
204 void graph_output_shape(luci::CircleOutput *output);
206 /// @brief This will set GraphInput dtype from CircleInput dtype
207 void graph_input_dtype(luci::CircleInput *input);
209 /// @brief This will set GraphOutput dtype from CircleOutput dtype
210 void graph_output_dtype(luci::CircleOutput *output);
215 #endif // __TEST_GRAPH_H__