Imported Upstream version 1.7.0
[platform/core/ml/nnfw.git] / compiler / luci / service / src / TestGraph.h
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef __TEST_GRAPH_H__
18 #define __TEST_GRAPH_H__
19
20 #include <luci/IR/CircleNodes.h>
21
22 #include <loco.h>
23
24 #include <cassert>
25 #include <memory>
26
27 // TODO Change all Canonical nodes to Circle nodes
28
29 namespace luci
30 {
31 namespace test
32 {
33
34 class TestGraph
35 {
36 public:
37   std::unique_ptr<loco::Graph> g;
38   luci::CircleInput *input_node = nullptr;
39   luci::CircleOutput *output_node = nullptr;
40
41   TestGraph() // creates Pull and Push
42   {
43     g = loco::make_graph();
44
45     input_node = g->nodes()->create<luci::CircleInput>();
46
47     output_node = g->nodes()->create<luci::CircleOutput>();
48
49     auto input = g->inputs()->create();
50     {
51       input->name("input");
52       luci::link(input, input_node);
53     }
54     auto output = g->outputs()->create();
55     {
56       output->name("output");
57       luci::link(output, output_node);
58     }
59
60     _next_input = input_node;
61   }
62
63   loco::Graph *graph() { return g.get(); }
64
65   /// @brief Creates node with NO arg and appends it to graph
66   template <class T> T *append()
67   {
68     auto node = g->nodes()->create<T>();
69     _next_input = node;
70
71     return node;
72   }
73
74   /// @brief Creates op T (arity=1) with arg1 as an input and appends it to graph
75   template <class T> T *append(luci::CircleNode *arg1)
76   {
77     auto node = g->nodes()->create<T>();
78     setInput(node, arg1);
79     _next_input = node;
80
81     return node;
82   }
83
84   /// @brief Creates op T (arity=2) with arg1, arg2 as inputs and appends it to graph
85   template <class T> T *append(luci::CircleNode *arg1, luci::CircleNode *arg2)
86   {
87     auto node = g->nodes()->create<T>();
88     setInput(node, arg1, arg2);
89     _next_input = node;
90
91     return node;
92   }
93
94   /// @brief Creates op T (arity=3) with arg1, arg2, arg3 as inputs and appends it to graph
95   template <class T>
96   T *append(luci::CircleNode *arg1, luci::CircleNode *arg2, luci::CircleNode *arg3)
97   {
98     auto node = g->nodes()->create<T>();
99     setInput(node, arg1, arg2, arg3);
100     _next_input = node;
101
102     return node;
103   }
104
105   // output will get the last appended node
106   void complete() { output_node->from(_next_input); }
107
108   void complete(luci::CircleNode *last_node) { output_node->from(last_node); }
109
110 private:
111   // arity 1
112   void setInput(luci::CircleNode *, luci::CircleNode *) { assert(false && "NYI"); };
113
114   void setInput(luci::CircleAveragePool2D *node, luci::CircleNode *input) { node->value(input); };
115   void setInput(luci::CircleRelu *node, luci::CircleNode *input) { node->features(input); };
116   void setInput(luci::CircleSqueeze *node, luci::CircleNode *input) { node->input(input); };
117
118   void setInput(luci::CircleGatherNd *node, luci::CircleNode *params, luci::CircleNode *indices)
119   {
120     node->params(params);
121     node->indices(indices);
122   };
123
124   // arity 2
125   void setInput(luci::CircleNode *, luci::CircleNode *, luci::CircleNode *)
126   {
127     assert(false && "NYI");
128   };
129
130   void setInput(luci::CircleExpandDims *node, luci::CircleNode *arg1, luci::CircleNode *arg2)
131   {
132     node->input(arg1);
133     node->axis(arg2);
134   };
135
136   void setInput(luci::CircleTranspose *node, luci::CircleNode *arg1, luci::CircleNode *arg2)
137   {
138     node->a(arg1);
139     node->perm(arg2);
140   };
141
142   void setInput(luci::CircleResizeBilinear *node, luci::CircleNode *input, luci::CircleNode *size)
143   {
144     node->input(input);
145     node->size(size);
146   };
147
148   void setInput(luci::CircleResizeNearestNeighbor *node, luci::CircleNode *input,
149                 luci::CircleNode *size)
150   {
151     node->input(input);
152     node->size(size);
153   };
154
155   // arity 3
156   void setInput(luci::CircleNode *, luci::CircleNode *, luci::CircleNode *, luci::CircleNode *)
157   {
158     assert(false && "NYI");
159   };
160
161 private:
162   loco::Node *_next_input;
163 };
164
165 enum class ExampleGraphType
166 {
167   CircleTranspose,
168 };
169
170 template <ExampleGraphType T> class ExampleGraph;
171
172 /**
173  * @brief Class to create the following:
174  *
175  *     CircleInput -- CircleTranspose -- CircleOutput
176  */
177 template <> class ExampleGraph<ExampleGraphType::CircleTranspose> : public TestGraph
178 {
179 public:
180   luci::CircleConst *const_perm = nullptr;
181   luci::CircleTranspose *transpose_node = nullptr;
182
183 public:
184   ExampleGraph()
185   {
186     const_perm = append<luci::CircleConst>();
187     transpose_node = append<luci::CircleTranspose>(input_node, const_perm);
188     complete(transpose_node);
189   }
190 };
191
192 } // namespace test
193 } // namespace luci
194
195 namespace luci
196 {
197 namespace test
198 {
199
200 /// @brief This will set GraphInput shape from CircleInput shape
201 void graph_input_shape(luci::CircleInput *input);
202
203 /// @brief This will set GraphOutput shape from CircleOutput shape
204 void graph_output_shape(luci::CircleOutput *output);
205
206 /// @brief This will set GraphInput dtype from CircleInput dtype
207 void graph_input_dtype(luci::CircleInput *input);
208
209 /// @brief This will set GraphOutput dtype from CircleOutput dtype
210 void graph_output_dtype(luci::CircleOutput *output);
211
212 } // namespace test
213 } // namespace luci
214
215 #endif // __TEST_GRAPH_H__