Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / circle-mpqsolver / src / core / TestHelper.h
1 /*
2  * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef __MPQSOLVER_TEST_HELPER_H__
18 #define __MPQSOLVER_TEST_HELPER_H__
19
20 #include <luci/IR/CircleNodes.h>
21 #include <luci/IR/Module.h>
22
23 class SimpleGraph
24 {
25 public:
26   SimpleGraph() : _g(loco::make_graph()) {}
27
28 public:
29   void init()
30   {
31     _input = _g->nodes()->create<luci::CircleInput>();
32     _output = _g->nodes()->create<luci::CircleOutput>();
33     _input->name("input");
34     _output->name("output");
35
36     auto graph_input = _g->inputs()->create();
37     _input->index(graph_input->index());
38     auto graph_output = _g->outputs()->create();
39     _output->index(graph_output->index());
40
41     graph_input->dtype(loco::DataType::FLOAT32);
42     _input->dtype(loco::DataType::FLOAT32);
43     _output->dtype(loco::DataType::FLOAT32);
44     graph_output->dtype(loco::DataType::FLOAT32);
45
46     graph_input->shape({1, _channel_size, _width, _height});
47     _input->shape({1, _channel_size, _width, _height});
48     _output->shape({1, _channel_size, _width, _height});
49     graph_output->shape({1, _channel_size, _width, _height});
50
51     auto graph_body = insertGraphBody(_input);
52     _output->from(graph_body);
53
54     initInput(_input);
55   }
56
57   virtual ~SimpleGraph() = default;
58   void transfer_to(luci::Module *module)
59   {
60     // WARNING: after g is transfered, _graph_inputs, _inputs
61     //          and _graph_outputs, _outputs in TestOsGraphlet will be invalid.
62     //          arrays are not cleared as this is just helpers to unit tests
63     module->add(std::move(_g));
64   }
65
66 protected:
67   virtual loco::Node *insertGraphBody(loco::Node *input) = 0;
68   virtual void initInput(loco::Node *input){};
69
70 public:
71   std::unique_ptr<loco::Graph> _g;
72   luci::CircleInput *_input = nullptr;
73   luci::CircleOutput *_output = nullptr;
74   uint32_t _channel_size = 16;
75   uint32_t _width = 4;
76   uint32_t _height = 4;
77 };
78
79 #endif //__MPQSOLVER_TEST_HELPER_H__