2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #ifndef __MPQSOLVER_TEST_HELPER_H__
18 #define __MPQSOLVER_TEST_HELPER_H__
20 #include <luci/IR/CircleNodes.h>
21 #include <luci/IR/Module.h>
26 SimpleGraph() : _g(loco::make_graph()) {}
31 _input = _g->nodes()->create<luci::CircleInput>();
32 _output = _g->nodes()->create<luci::CircleOutput>();
33 _input->name("input");
34 _output->name("output");
36 auto graph_input = _g->inputs()->create();
37 _input->index(graph_input->index());
38 auto graph_output = _g->outputs()->create();
39 _output->index(graph_output->index());
41 graph_input->dtype(loco::DataType::FLOAT32);
42 _input->dtype(loco::DataType::FLOAT32);
43 _output->dtype(loco::DataType::FLOAT32);
44 graph_output->dtype(loco::DataType::FLOAT32);
46 graph_input->shape({1, _channel_size, _width, _height});
47 _input->shape({1, _channel_size, _width, _height});
48 _output->shape({1, _channel_size, _width, _height});
49 graph_output->shape({1, _channel_size, _width, _height});
51 auto graph_body = insertGraphBody(_input);
52 _output->from(graph_body);
57 virtual ~SimpleGraph() = default;
58 void transfer_to(luci::Module *module)
60 // WARNING: after g is transfered, _graph_inputs, _inputs
61 // and _graph_outputs, _outputs in TestOsGraphlet will be invalid.
62 // arrays are not cleared as this is just helpers to unit tests
63 module->add(std::move(_g));
67 virtual loco::Node *insertGraphBody(loco::Node *input) = 0;
68 virtual void initInput(loco::Node *input){};
71 std::unique_ptr<loco::Graph> _g;
72 luci::CircleInput *_input = nullptr;
73 luci::CircleOutput *_output = nullptr;
74 uint32_t _channel_size = 16;
79 #endif //__MPQSOLVER_TEST_HELPER_H__