2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Importer.h"
18 #include "CircleImportMetadata.h"
19 #include "PostImport.h"
21 #include "luci/Import/GraphBuilder.h"
22 #include "luci/Import/GraphBuilderContext.h"
23 #include "luci/Import/GraphBuilderRegistry.h"
24 #include "luci/Import/CircleReader.h"
25 #include "luci/Import/Nodes/CircleConst.h"
27 #include <luci/IR/Module.h>
28 #include <luci/IR/CircleNodes.h>
29 #include <luci/Profile/CircleNodeID.h>
30 #include <luci/Profile/CircleNodeOrigin.h>
31 #include <luci/Plan/CircleNodeExecutionPlan.h>
33 #include <luci/LogHelper.h>
35 #include <oops/InternalExn.h>
36 #include <oops/UserExn.h>
43 void convert_graph(const luci::GraphBuilderSource &source, luci::CircleReader &reader,
48 auto nodefinder = std::make_unique<luci::IndexNodeFinder>();
49 auto tensoroutputs = std::make_unique<luci::IndexTensorOutputs>();
51 luci::GraphBuilderContext gb_context(graph, &reader, nodefinder.get(), tensoroutputs.get());
53 const auto &operators = reader.operators();
54 const auto &tensors = reader.tensors();
55 auto tensors_ptr = reader.tensors_ptr();
56 assert(tensors_ptr != nullptr);
57 auto circle_metadata = std::make_unique<luci::CircleImportMetadata>(reader);
59 // build a cache to identify if a tensor is output of an operator
60 // if this is set, we should not create a CircleConst for this tensor
61 for (uint32_t i = 0; i < operators.size(); ++i)
63 const circle::OperatorT &op = *operators[i];
64 const auto &outputs = op.outputs;
66 for (uint32_t j = 0; j < outputs.size(); ++j)
68 auto tidx = outputs[j];
69 tensoroutputs->enroll(tidx);
73 // graph inputs; there are no input nodes in TFlite but just Tensors
74 // creating virtual input nodes will make possible to connect nodes that uses them
75 // all attributes of tensor should be copied to CircleInput node
76 for (const auto input : reader.inputs())
78 auto input_node = graph->nodes()->create<luci::CircleInput>();
79 assert(input_node != nullptr);
80 const circle::TensorT &tensor = *tensors[input];
82 luci::copy_tensor_attributes(tensor, input_node);
83 if (tensors_ptr->Get(input)->shape() == nullptr)
84 input_node->shape_status(luci::ShapeStatus::NOSHAPE);
86 input_node->shape_status(luci::ShapeStatus::VALID);
88 INFO(l) << "[luci] NodeFinder INPUT(" << input << ") = " << input_node << std::endl;
89 nodefinder->enroll(input, input_node);
91 // input_node is also an output to a tensor
92 tensoroutputs->enroll(input);
95 auto graph_input = graph->inputs()->create();
96 graph_input->name(input_node->name());
98 // Set GraphInputOutputIndex for graph
99 input_node->index(graph_input->index());
102 graph_input->dtype(input_node->dtype());
104 assert(tensor.shape_signature.size() == 0 ||
105 tensor.shape_signature.size() == tensor.shape.size());
107 // Shape of GraphInput
108 auto input_shape = std::make_unique<loco::TensorShape>();
109 const std::vector<int32_t> &input_dims = tensor.shape; // in NHWC
110 input_shape->rank(input_dims.size());
111 for (uint32_t r = 0; r < input_dims.size(); ++r)
113 if (tensor.shape_signature.size() > 0 && tensor.shape_signature.at(r) == -1)
114 input_shape->dim(r).unset();
116 input_shape->dim(r).set(input_dims[r]);
118 graph_input->shape(std::move(input_shape));
121 // Create CircleConst nodes for constant tensors.
122 for (uint32_t i = 0; i < tensors.size(); ++i)
124 luci::CircleConst *const_node = luci::create_circleconst(&gb_context, i);
125 if (const_node != nullptr)
126 nodefinder->enroll(i, const_node);
129 // Import the operators.
130 // Note that operators in model are stored in execution order. This means that when importing
131 // an operator, its input operators have already been imported. We exploit this fact to set up
132 // node's inputs right after creating the node.
133 auto origin_table = circle_metadata->origin_table();
134 for (uint32_t i = 0; i < operators.size(); ++i)
136 const circle::OperatorT &op = *operators[i];
137 circle::BuiltinOperator builtincode = reader.builtin_code(op);
139 if (const auto *builder = source.lookup(builtincode))
141 luci::GraphBuilder::ValidateArgs args(op, reader);
142 if (!builder->validate(args))
144 throw oops::UserExn("Invalid operator", reader.opcode_name(op));
147 auto built_op = builder->build(op, &gb_context);
148 set_node_id(built_op, i);
149 if (origin_table.find(i) != origin_table.end())
150 add_origin(built_op, origin_table.at(i));
152 add_origin(built_op, luci::single_origin(i, built_op->name()));
156 throw oops::UserExn("Not supported", reader.opcode_name(op));
161 for (auto output : reader.outputs())
163 const circle::TensorT &tensor = *tensors[output];
165 auto output_node = graph->nodes()->create<luci::CircleOutput>();
166 assert(output_node != nullptr);
167 auto output_from = nodefinder->node(output);
168 if (output_from != nullptr)
169 output_node->from(output_from);
172 // NOTE loco::Graph requires all input node(s) to a node should exist.
173 // Here, CircleOutput needs an input node.
174 // We add a dummy node to make it happy.
175 auto output_dummy = graph->nodes()->create<luci::CircleOutputDummy>();
176 assert(output_dummy != nullptr);
177 output_node->from(output_dummy);
179 luci::copy_tensor_attributes(tensor, output_dummy);
180 if (tensors_ptr->Get(output)->shape() == nullptr)
181 output_dummy->shape_status(luci::ShapeStatus::NOSHAPE);
183 output_dummy->shape_status(luci::ShapeStatus::VALID);
186 INFO(l) << "[luci] NodeFinder OUTPUT(" << output << ") = " << output_node << std::endl;
188 // set the graph output name and node object
189 auto graph_output = graph->outputs()->create();
190 std::string tname = luci::tensor_name(tensor);
191 assert(tname.length() > 0);
192 graph_output->name(tname);
194 luci::copy_tensor_attributes(tensor, output_node);
196 // Set GraphInputOutputIndex for graph
197 output_node->index(graph_output->index());
199 assert(tensor.shape_signature.size() == 0 ||
200 tensor.shape_signature.size() == tensor.shape.size());
203 auto output_shape = std::make_unique<loco::TensorShape>();
204 const std::vector<int32_t> &output_dims = tensor.shape; // in NHWC
205 output_shape->rank(output_dims.size());
206 for (uint32_t r = 0; r < output_dims.size(); ++r)
208 if (tensor.shape_signature.size() > 0 && tensor.shape_signature.at(r) == -1)
209 output_shape->dim(r).unset();
211 output_shape->dim(r).set(output_dims[r]);
213 graph_output->shape(std::move(output_shape));
216 auto dtype = luci::luci_datatype(tensor.type);
217 graph_output->dtype(dtype);
221 class ValidateCollector final : public loco::ErrorListener
224 void notify(const loco::ErrorDetail<loco::ErrorCategory::MissingArgument> &d) override
227 INFO(l) << "[luci] GraphValidate error " << d.node() << "(" << d.index() << ")" << std::endl;
241 std::unique_ptr<loco::Graph> Importer::import(const circle::Model *model) const
243 auto graph = loco::make_graph();
245 const GraphBuilderSource *source_ptr = &GraphBuilderRegistry::get();
247 if (_source != nullptr)
249 // Use user-defined GraphBuilderSource
250 source_ptr = _source;
254 if (!reader.parse(model))
257 if (reader.num_subgraph() != 1)
259 INTERNAL_EXN("Use 'importModule()' for multiple subgraphs");
261 if (!reader.select_subgraph(0))
264 // Convert circle::Model to loco::Graph
265 convert_graph(*source_ptr, reader, graph.get());
268 VERBOSE(l, 3) << "--- graph dump begin -------------------------------------------";
269 VERBOSE(l, 3) << "Name: " << graph->name();
270 VERBOSE(l, 3) << fmt(graph.get());
271 VERBOSE(l, 3) << "--- graph dump end ---------------------------------------------";
273 assert(loco::valid(graph.get(), std::make_unique<ValidateCollector>()));
278 std::unique_ptr<Module> Importer::importModule(const circle::Model *model) const
280 auto module = make_module();
282 const GraphBuilderSource *source_ptr = &GraphBuilderRegistry::get();
284 if (_source != nullptr)
286 // Use user-defined GraphBuilderSource
287 source_ptr = _source;
291 if (!reader.parse(model))
294 for (uint32_t g = 0; g < reader.num_subgraph(); ++g)
296 auto graph = loco::make_graph();
298 if (!reader.select_subgraph(g))
301 graph->name(reader.name());
303 // Convert circle::Model to loco::Graph
304 convert_graph(*source_ptr, reader, graph.get());
307 VERBOSE(l, 3) << "--- graph dump begin -------------------------------------------";
308 VERBOSE(l, 3) << "Name: " << graph->name();
309 VERBOSE(l, 3) << fmt(graph.get());
310 VERBOSE(l, 3) << "--- graph dump end ---------------------------------------------";
312 assert(loco::valid(graph.get(), std::make_unique<ValidateCollector>()));
314 module->add(std::move(graph));
317 post_import_graph(module.get(), reader);
319 // Initialize 'source_table'
320 auto circle_metadata = std::make_unique<luci::CircleImportMetadata>(reader);
321 if (circle_metadata->source_table().size() > 0)
323 // If there is 'source_table' metadata in circle model, copy the table.
324 module->source_table(circle_metadata->source_table());
328 // If there is no 'source_table' metadata in circle model,
329 // create new table with circle nodes.
330 std::map<uint32_t, std::string> table;
332 // NOTE Only first subgraph is considered
333 for (auto node : loco::all_nodes(module->graph(0)))
335 auto circle_node = loco::must_cast<luci::CircleNode *>(node);
337 // Virtual nodes may not have id
338 if (!has_node_id(circle_node))
341 assert(table.find(get_node_id(circle_node)) == table.end());
342 table.insert({get_node_id(circle_node), circle_node->name()});
345 module->source_table(table);
348 // Add execution_plan annotations
349 if (circle_metadata->execution_plan_table().size() > 0)
351 auto execution_plan_table = circle_metadata->execution_plan_table();
352 auto node_position = 0;
353 for (auto node : loco::postorder_traversal(loco::output_nodes(module->graph())))
355 if (auto circle_node = dynamic_cast<luci::CircleNode *>(node))
357 auto node_plan = execution_plan_table[node_position];
358 luci::add_execution_plan(
360 luci::CircleNodeExecutionPlan(
361 node_plan[0], std::vector<uint32_t>(node_plan.begin() + 1, node_plan.end())));