2 * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/CircleExporter.h"
19 #include <luci/Plan/CircleNodeExecutionPlan.h>
20 #include <luci/IR/Nodes/CircleInput.h>
21 #include <luci/IR/Nodes/CircleOutput.h>
22 #include <luci/IR/Nodes/CircleRelu.h>
23 #include <luci/UserSettings.h>
25 #include <mio/circle/schema_generated.h>
26 #include <flatbuffers/flatbuffers.h>
28 #include <gtest/gtest.h>
30 class SampleGraphContract : public luci::CircleExporter::Contract
33 SampleGraphContract() : luci::CircleExporter::Contract(), _buffer(new std::vector<char>)
35 // create needed entities
36 _g = loco::make_graph();
37 auto graph_input = _g->inputs()->create();
38 auto graph_output = _g->outputs()->create();
39 input_node = _g->nodes()->create<luci::CircleInput>();
40 output_node = _g->nodes()->create<luci::CircleOutput>();
41 relu_node = _g->nodes()->create<luci::CircleRelu>();
43 // link nodes and link them to graph
44 relu_node->features(input_node);
45 output_node->from(relu_node);
46 input_node->index(graph_input->index());
47 output_node->index(graph_output->index());
49 // Set needed properties
50 input_node->name("input");
51 output_node->name("output");
52 relu_node->name("relu");
53 input_node->dtype(loco::DataType::FLOAT32);
55 graph_input->shape({1, 2, 3, 4});
56 graph_input->dtype(loco::DataType::FLOAT32);
58 graph_output->shape({1, 2, 3, 4});
59 graph_output->dtype(loco::DataType::FLOAT32);
62 loco::Graph *graph(void) const override { return _g.get(); }
65 bool store(const char *ptr, const size_t size) const override
67 _buffer->resize(size);
68 std::copy(ptr, ptr + size, _buffer->begin());
72 const std::vector<char> &get_buffer() { return *_buffer; }
75 luci::CircleInput *input_node;
76 luci::CircleOutput *output_node;
77 luci::CircleRelu *relu_node;
80 std::unique_ptr<loco::Graph> _g;
81 std::unique_ptr<std::vector<char>> _buffer;
84 TEST(CircleExport, export_execution_plan)
86 SampleGraphContract contract;
87 uint32_t reference_order = 1;
88 uint32_t reference_offset = 100u;
89 luci::add_execution_plan(contract.relu_node,
90 luci::CircleNodeExecutionPlan(reference_order, {reference_offset}));
92 luci::UserSettings::settings()->set(luci::UserSettings::ExecutionPlanGen, true);
93 luci::CircleExporter exporter;
95 exporter.invoke(&contract);
97 ASSERT_FALSE(contract.get_buffer().empty());
98 std::unique_ptr<circle::ModelT> model(circle::GetModel(contract.get_buffer().data())->UnPack());
99 ASSERT_NE(model.get(), nullptr);
100 ASSERT_EQ(model->metadata[0]->name, "ONE_execution_plan_table");
101 auto metadata_buffer = model->metadata[0]->buffer;
102 auto &buffer = model->buffers[metadata_buffer]->data;
103 ASSERT_EQ(buffer.size(), 20);
104 uint32_t *raw_table_contents = reinterpret_cast<uint32_t *>(buffer.data());
106 auto num_entries = raw_table_contents[0];
107 ASSERT_EQ(num_entries, 1);
108 auto node_id = raw_table_contents[1];
109 ASSERT_EQ(node_id, 1); // relu node is second (aka id 1) in tological sort in exporter
110 auto node_plan_size = raw_table_contents[2];
111 ASSERT_EQ(node_plan_size, 2); // 1 for execution order, 1 for memory offset value
112 auto node_plan_order = raw_table_contents[3];
113 ASSERT_EQ(node_plan_order,
114 reference_order); // this value goes from CircleNodeExecutionPlan initialization
115 auto node_plan_offset = raw_table_contents[4];
116 ASSERT_EQ(node_plan_offset,
117 reference_offset); // this value goes from CircleNodeExecutionPlan initialization
120 TEST(CircleExport, export_execution_plan_nosetting_NEG)
122 SampleGraphContract contract;
123 uint32_t reference_order = 1;
124 uint32_t reference_offset = 100u;
125 luci::add_execution_plan(contract.relu_node,
126 luci::CircleNodeExecutionPlan(reference_order, {reference_offset}));
128 luci::UserSettings::settings()->set(luci::UserSettings::ExecutionPlanGen, false);
129 luci::CircleExporter exporter;
131 exporter.invoke(&contract);
133 ASSERT_FALSE(contract.get_buffer().empty());
134 std::unique_ptr<circle::ModelT> model(circle::GetModel(contract.get_buffer().data())->UnPack());
135 ASSERT_NE(model.get(), nullptr);
136 ASSERT_EQ(model->metadata.size(), 0);