2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #ifndef __GRAPH_TESTCASE_H__
18 #define __GRAPH_TESTCASE_H__
20 #include "loco/IR/Graph.h"
21 #include "loco/IR/PermutingCodec.h"
23 #include "GraphBuilder.h"
25 #include <stdex/Memory.h>
46 template <loco::Domain D> loco::Permutation<D> make_NHWC_perm(void);
48 template <> loco::Permutation<loco::Domain::Feature> make_NHWC_perm(void)
50 loco::Permutation<loco::Domain::Feature> perm;
52 perm[loco::FeatureAxis::Count] = 0;
53 perm[loco::FeatureAxis::Height] = 1;
54 perm[loco::FeatureAxis::Width] = 2;
55 perm[loco::FeatureAxis::Depth] = 3;
60 template <loco::Domain D> loco::Permutation<D> make_HWCN_perm(void);
62 /// @note Also known as HWIO permutation
63 template <> loco::Permutation<loco::Domain::Filter> make_HWCN_perm(void)
65 loco::Permutation<loco::Domain::Filter> perm;
67 perm[loco::FilterAxis::Height] = 0;
68 perm[loco::FilterAxis::Width] = 1;
69 perm[loco::FilterAxis::Depth] = 2;
70 perm[loco::FilterAxis::Count] = 3;
75 template <loco::Domain D> loco::Permutation<D> make_HWCM_perm(void);
77 template <> loco::Permutation<loco::Domain::DepthwiseFilter> make_HWCM_perm(void)
79 loco::Permutation<loco::Domain::DepthwiseFilter> perm;
81 perm[loco::DepthwiseFilterAxis::Height] = 0;
82 perm[loco::DepthwiseFilterAxis::Width] = 1;
83 perm[loco::DepthwiseFilterAxis::Depth] = 2;
84 perm[loco::DepthwiseFilterAxis::Multiplier] = 3;
91 template <GraphCode Code> class GraphTestcase;
93 template <> class GraphTestcase<GraphCode::Identity> final
96 void init(std::initializer_list<uint32_t> dims)
98 // Create a sample network
99 _graph = loco::make_graph();
101 auto graph_builder = make_graph_builder(_graph.get());
103 pull_node = graph_builder->push<InputLayer>()->name("input")->shape(dims)->node();
104 push_node = graph_builder->push<OutputLayer>()->name("output")->node();
108 // NOTE This default constructor guarantees backward compatbility.
109 GraphTestcase() { init({1, 4, 8, 3}); }
110 GraphTestcase(std::initializer_list<uint32_t> dims) { init(dims); }
113 loco::Graph *graph() { return _graph.get(); }
115 loco::Pull *pull_node = nullptr;
116 loco::Push *push_node = nullptr;
119 std::unique_ptr<loco::Graph> _graph;
122 template <> class GraphTestcase<GraphCode::ConstGen> final
127 _graph = loco::make_graph();
129 auto graph_builder = make_graph_builder(_graph.get());
131 const_node = graph_builder->push<ConstGenLayer>()->node();
133 push_node = graph_builder->push<OutputLayer>()->name("output")->node();
137 loco::Graph *graph() { return _graph.get(); }
139 loco::ConstGen *const_node = nullptr;
140 loco::Push *push_node = nullptr;
143 std::unique_ptr<loco::Graph> _graph;
146 template <> class GraphTestcase<GraphCode::Relu> final
151 // Create a sample network
152 _graph = loco::make_graph();
154 auto graph_builder = make_graph_builder(_graph.get());
156 pull_node = graph_builder->push<InputLayer>()->name("input")->node();
157 relu_node = graph_builder->push<ReLULayer>()->node();
158 push_node = graph_builder->push<OutputLayer>()->name("output")->node();
162 loco::Graph *graph() { return _graph.get(); }
164 loco::Pull *pull_node = nullptr;
165 loco::ReLU *relu_node = nullptr;
166 loco::Push *push_node = nullptr;
169 std::unique_ptr<loco::Graph> _graph;
172 template <> class GraphTestcase<GraphCode::FeatureCodec> final
177 using namespace loco;
179 Permutation<Domain::Feature> perm;
181 perm[FeatureAxis::Count] = 0;
182 perm[FeatureAxis::Height] = 1;
183 perm[FeatureAxis::Width] = 2;
184 perm[FeatureAxis::Depth] = 3;
186 // Create a sample network
187 _graph = make_graph();
189 auto graph_builder = make_graph_builder(_graph.get());
191 pull_node = graph_builder->push<InputLayer>()->name("input")->node();
192 encode_node = graph_builder->push<FeatureEncodeLayer>()->perm(perm)->node();
193 decode_node = graph_builder->push<FeatureDecodeLayer>()->perm(perm)->node();
194 push_node = graph_builder->push<OutputLayer>()->name("output")->node();
198 loco::Graph *graph() { return _graph.get(); }
200 loco::Pull *pull_node = nullptr;
201 loco::FeatureEncode *encode_node = nullptr;
202 loco::FeatureDecode *decode_node = nullptr;
203 loco::Push *push_node = nullptr;
206 std::unique_ptr<loco::Graph> _graph;
209 template <> class GraphTestcase<GraphCode::AvgPool2D> final
214 using namespace loco;
216 // Create a sample network
217 _graph = make_graph();
219 // Create Graph Input/Output
220 auto graph_input = _graph->inputs()->create();
221 auto graph_output = _graph->outputs()->create();
223 graph_input->name("input");
224 graph_output->name("output");
226 // Create and connect nodes
227 pull_node = _graph->nodes()->create<Pull>();
230 encode_node = _graph->nodes()->create<FeatureEncode>();
231 encode_node->input(pull_node);
233 avgpool2d_node = _graph->nodes()->create<AvgPool2D>();
234 avgpool2d_node->ifm(encode_node);
236 decode_node = _graph->nodes()->create<FeatureDecode>();
237 decode_node->input(avgpool2d_node);
239 push_node = _graph->nodes()->create<loco::Push>();
241 push_node->from(decode_node);
243 // Create a link between input/output and corresponding nodes
244 loco::link(graph_input, pull_node);
245 loco::link(graph_output, push_node);
249 loco::Graph *graph() { return _graph.get(); }
251 loco::Pull *pull_node = nullptr;
252 loco::FeatureEncode *encode_node = nullptr;
253 loco::AvgPool2D *avgpool2d_node = nullptr;
254 loco::FeatureDecode *decode_node = nullptr;
255 loco::Push *push_node = nullptr;
258 std::unique_ptr<loco::Graph> _graph;
261 template <> class GraphTestcase<GraphCode::DepthwiseConv2D> final
266 using namespace loco;
268 _graph = make_graph();
270 auto graph_builder = make_graph_builder(_graph.get());
272 Permutation<Domain::Feature> perm = make_NHWC_perm<Domain::Feature>();
273 Permutation<Domain::DepthwiseFilter> filter_perm = make_HWCM_perm<Domain::DepthwiseFilter>();
275 pull_node = graph_builder->push<InputLayer>()->name("input")->node();
276 encode_node = graph_builder->push<FeatureEncodeLayer>()->perm(perm)->node();
278 const_node = graph_builder->push<ConstGenLayer>()->node();
281 graph_builder->push<DepthwiseFilterEncodeLayer>()->perm(filter_perm)->node();
283 depthwiseconv2d_node = graph_builder->push<DepthwiseConv2DLayer>()->node();
285 decode_node = graph_builder->push<FeatureDecodeLayer>()->perm(perm)->node();
286 push_node = graph_builder->push<OutputLayer>()->name("output")->node();
290 loco::Graph *graph() { return _graph.get(); }
292 loco::Pull *pull_node = nullptr;
293 loco::FeatureEncode *encode_node = nullptr;
294 loco::ConstGen *const_node = nullptr;
295 loco::DepthwiseFilterEncode *filter_encode_node = nullptr;
296 loco::DepthwiseConv2D *depthwiseconv2d_node = nullptr;
297 loco::FeatureDecode *decode_node = nullptr;
298 loco::Push *push_node = nullptr;
301 std::unique_ptr<loco::Graph> _graph;
304 template <> class GraphTestcase<GraphCode::TransposedConv2D> final
309 using namespace loco;
311 // Prepare permutations
312 Permutation<Domain::Feature> feature_perm = make_NHWC_perm<Domain::Feature>();
313 Permutation<Domain::Filter> filter_perm = make_HWCN_perm<Domain::Filter>();
316 _graph = make_graph();
317 auto graph_builder = make_graph_builder(_graph.get());
319 pull_node = graph_builder->push<InputLayer>()->name("input")->node();
320 encode_node = graph_builder->push<FeatureEncodeLayer>()->perm(feature_perm)->node();
321 const_node = graph_builder->push<ConstGenLayer>()->node();
322 filter_encode_node = graph_builder->push<FilterEncodeLayer>()->perm(filter_perm)->node();
323 tr_conv2d_node = graph_builder->push<TransposedConv2DLayer>()->node();
324 decode_node = graph_builder->push<FeatureDecodeLayer>()->perm(feature_perm)->node();
325 push_node = graph_builder->push<OutputLayer>()->name("output")->node();
329 loco::Graph *graph() { return _graph.get(); }
331 loco::Pull *pull_node = nullptr;
332 loco::FeatureEncode *encode_node = nullptr;
333 loco::ConstGen *const_node = nullptr;
334 loco::FilterEncode *filter_encode_node = nullptr;
335 loco::TransposedConv2D *tr_conv2d_node = nullptr;
336 loco::FeatureDecode *decode_node = nullptr;
337 loco::Push *push_node = nullptr;
340 std::unique_ptr<loco::Graph> _graph;
343 template <> class GraphTestcase<GraphCode::MaxPool2D> final
348 using namespace loco;
350 // Create a sample network
351 _graph = make_graph();
353 // Create Graph Input/Output
354 auto graph_input = _graph->inputs()->create();
355 auto graph_output = _graph->outputs()->create();
357 graph_input->name("input");
358 graph_output->name("output");
360 // Create and connect nodes
361 pull_node = _graph->nodes()->create<Pull>();
364 encode_node = _graph->nodes()->create<FeatureEncode>();
365 encode_node->input(pull_node);
367 maxpool2d_node = _graph->nodes()->create<MaxPool2D>();
368 maxpool2d_node->ifm(encode_node);
370 decode_node = _graph->nodes()->create<FeatureDecode>();
371 decode_node->input(maxpool2d_node);
373 push_node = _graph->nodes()->create<loco::Push>();
375 push_node->from(decode_node);
377 // Create a link between input/output and corresponding nodes
378 loco::link(graph_input, pull_node);
379 loco::link(graph_output, push_node);
383 loco::Graph *graph() { return _graph.get(); }
385 loco::Pull *pull_node = nullptr;
386 loco::FeatureEncode *encode_node = nullptr;
387 loco::MaxPool2D *maxpool2d_node = nullptr;
388 loco::FeatureDecode *decode_node = nullptr;
389 loco::Push *push_node = nullptr;
392 std::unique_ptr<loco::Graph> _graph;
395 template <> class GraphTestcase<GraphCode::TensorConcat> final
400 using namespace loco;
402 // Create a sample network
403 _graph = make_graph();
405 // Create Graph Input/Output
406 auto graph_lhs = _graph->inputs()->create();
407 auto graph_rhs = _graph->inputs()->create();
408 auto graph_out = _graph->outputs()->create();
410 graph_lhs->name("lhs");
411 graph_rhs->name("rhs");
412 graph_out->name("output");
414 // Create and connect nodes
415 lhs_node = _graph->nodes()->create<Pull>();
418 rhs_node = _graph->nodes()->create<Pull>();
421 concat_node = _graph->nodes()->create<TensorConcat>();
422 concat_node->lhs(lhs_node);
423 concat_node->rhs(rhs_node);
425 push_node = _graph->nodes()->create<loco::Push>();
427 push_node->from(concat_node);
429 // Create a link between input/output and corresponding nodes
430 loco::link(graph_lhs, lhs_node);
431 loco::link(graph_rhs, rhs_node);
432 loco::link(graph_out, push_node);
436 loco::Graph *graph() { return _graph.get(); }
438 loco::Pull *lhs_node = nullptr;
439 loco::Pull *rhs_node = nullptr;
440 loco::TensorConcat *concat_node = nullptr;
441 loco::Push *push_node = nullptr;
444 std::unique_ptr<loco::Graph> _graph;
447 template <> class GraphTestcase<GraphCode::FixedReshape> final
452 _graph = loco::make_graph();
454 auto graph_builder = make_graph_builder(_graph.get());
456 pull_node = graph_builder->push<InputLayer>()->name("input")->node();
457 reshape_node = graph_builder->push<FixedReshapeLayer>()->node();
458 push_node = graph_builder->push<OutputLayer>()->name("output")->node();
462 loco::Graph *graph() { return _graph.get(); }
464 loco::Pull *pull_node = nullptr;
465 loco::FixedReshape *reshape_node = nullptr;
466 loco::Push *push_node = nullptr;
469 std::unique_ptr<loco::Graph> _graph;
472 template <> class GraphTestcase<GraphCode::TensorBroadcast> final
475 GraphTestcase(std::initializer_list<uint32_t> dims)
477 _graph = loco::make_graph();
479 auto graph_builder = make_graph_builder(_graph.get());
481 pull_node = graph_builder->push<InputLayer>()->name("input")->shape(dims)->node();
482 broadcast_node = graph_builder->push<TensorBroadcastLayer>()->node();
483 push_node = graph_builder->push<OutputLayer>()->name("output")->node();
487 loco::Graph *graph(void) { return _graph.get(); }
489 loco::Pull *pull_node = nullptr;
490 loco::TensorBroadcast *broadcast_node = nullptr;
491 loco::Push *push_node = nullptr;
494 std::unique_ptr<loco::Graph> _graph;
497 template <> class GraphTestcase<GraphCode::TensorTranspose> final
502 using namespace loco;
504 // Create a sample network
505 _graph = make_graph();
507 // Create Graph Input/Output
508 auto graph_input = _graph->inputs()->create();
509 auto graph_output = _graph->outputs()->create();
511 graph_input->name("input");
512 graph_output->name("output");
514 // Create and connect nodes
515 pull_node = _graph->nodes()->create<Pull>();
518 transpose_node = _graph->nodes()->create<TensorTranspose>();
519 transpose_node->input(pull_node);
521 push_node = _graph->nodes()->create<loco::Push>();
523 push_node->from(transpose_node);
525 // Create a link between input/output and corresponding nodes
526 loco::link(graph_input, pull_node);
527 loco::link(graph_output, push_node);
531 loco::Graph *graph() { return _graph.get(); }
533 loco::Pull *pull_node = nullptr;
534 loco::TensorTranspose *transpose_node = nullptr;
535 loco::Push *push_node = nullptr;
538 std::unique_ptr<loco::Graph> _graph;
541 #endif // __GRAPH_TESTCASE_H__