2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "loader/GraphLoader.h"
19 #include "loader/ModuleLoader.h"
20 #include "loader/KernelBuilder.h"
22 #include <loco/IR/Algorithm.h>
24 namespace luci_interpreter
29 template <typename NodeT> Shape getNodeShape(const NodeT *node)
31 Shape shape(node->rank());
32 for (uint32_t i = 0; i < node->rank(); ++i)
34 shape.dim(i) = node->dim(i).value();
39 template <DataType DT> const void *getNodeDataImpl(const luci::CircleConst *node, size_t *data_size)
41 const size_t element_size = getDataTypeSize(DT);
42 const int32_t num_elements = node->size<DT>();
44 *data_size = num_elements * element_size;
47 // FIXME There is no good way to get the pointer to the data currently.
48 return &node->at<DT>(0);
53 const void *getNodeData(const luci::CircleConst *node, size_t *data_size)
55 switch (node->dtype())
58 return getNodeDataImpl<DataType::U8>(node, data_size);
59 case DataType::FLOAT32:
60 return getNodeDataImpl<DataType::FLOAT32>(node, data_size);
62 return getNodeDataImpl<DataType::S32>(node, data_size);
64 throw std::runtime_error("Unsupported type.");
68 bool isExecutableNode(const luci::CircleNode *node)
70 switch (node->opcode())
72 // These nodes denote inputs / outputs of a graph.
73 case luci::CircleOpcode::CONST:
74 case luci::CircleOpcode::CIRCLEINPUT:
75 case luci::CircleOpcode::CIRCLEOUTPUT:
76 // The following nodes denote outputs of multiple-output nodes.
77 case luci::CircleOpcode::CIRCLEIFOUT:
78 case luci::CircleOpcode::CIRCLESPLITOUT:
79 case luci::CircleOpcode::CIRCLEUNPACKOUT:
86 bool isTensorProducingNode(const luci::CircleNode *node)
88 switch (node->opcode())
90 // Output nodes do not produce tensors.
91 case luci::CircleOpcode::CIRCLEOUTPUT:
92 // The following nodes are multiple-output nodes. They do not produce tensors, the tensors
93 // are produced by the corresponding *Out nodes instead.
94 case luci::CircleOpcode::IF:
95 case luci::CircleOpcode::SPLIT:
96 case luci::CircleOpcode::UNPACK:
105 GraphLoader::GraphLoader(const ModuleLoader &module_loader, const loco::Graph *graph,
106 RuntimeGraph *runtime_graph, RuntimeToIR &runtime_to_ir,
107 std::unordered_map<const loco::Node *, Tensor *> &node_to_tensor)
108 : _module_loader(module_loader), _graph(graph), _runtime_graph(runtime_graph),
109 _runtime_to_ir(runtime_to_ir), _node_to_tensor(node_to_tensor)
113 void GraphLoader::loadTensors()
115 for (uint32_t i = 0; i < _graph->nodes()->size(); ++i)
117 const auto *node = loco::must_cast<const luci::CircleNode *>(_graph->nodes()->at(i));
119 if (!isTensorProducingNode(node))
122 // Only Input and Const nodes have shapes. Shapes of intermediate tensors will be inferred.
124 if (const auto *input_node = dynamic_cast<const luci::CircleInput *>(node))
126 shape = getNodeShape(input_node);
128 else if (const auto *const_node = dynamic_cast<const luci::CircleConst *>(node))
130 shape = getNodeShape(const_node);
133 AffineQuantization quantization;
134 if (node->quantparam() != nullptr)
136 const luci::CircleQuantParam *params = node->quantparam();
137 quantization.scale.assign(params->scale.cbegin(), params->scale.cend());
138 quantization.zero_point.assign(params->zerop.cbegin(), params->zerop.cend());
141 auto tensor = std::make_unique<Tensor>(node->dtype(), std::move(shape), std::move(quantization),
144 if (const auto *const_node = dynamic_cast<const luci::CircleConst *>(node))
147 const void *const_data = getNodeData(const_node, &data_size);
148 if (const_data != nullptr)
149 tensor->writeData(const_data, data_size);
152 _node_to_tensor.emplace(node, tensor.get());
153 _runtime_to_ir.tensor_to_node.emplace(tensor.get(), node);
155 _runtime_graph->addTensor(std::move(tensor));
159 void GraphLoader::initInputOutputTensors() const
161 auto input_nodes = loco::input_nodes(_graph);
162 std::vector<Tensor *> input_tensors(input_nodes.size());
163 for (size_t i = 0; i < input_nodes.size(); ++i)
165 input_tensors[i] = _node_to_tensor.at(input_nodes[i]);
167 _runtime_graph->setInputTensors(input_tensors);
169 auto output_nodes = loco::output_nodes(const_cast<loco::Graph *>(_graph));
170 std::vector<Tensor *> output_tensors(output_nodes.size());
171 for (size_t i = 0; i < output_nodes.size(); ++i)
173 const auto *node = loco::must_cast<const luci::CircleOutput *>(output_nodes[i]);
174 output_tensors[i] = _node_to_tensor.at(node->from());
176 _runtime_graph->setOutputTensors(output_tensors);
179 void GraphLoader::loadOperators()
181 KernelBuilder kernel_builder(_module_loader, *this);
183 // Create kernels for executable nodes. This has to be done in execution order.
184 for (const loco::Node *loco_node :
185 loco::postorder_traversal(loco::output_nodes(const_cast<loco::Graph *>(_graph))))
187 const auto *node = loco::must_cast<const luci::CircleNode *>(loco_node);
189 if (isExecutableNode(node))
191 std::unique_ptr<Kernel> kernel = node->accept(&kernel_builder);
192 _runtime_to_ir.kernel_to_node.emplace(kernel.get(), node);
193 _runtime_graph->addKernel(std::move(kernel));
198 void GraphLoader::load()
201 initInputOutputTensors();
205 } // namespace luci_interpreter