2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "loader/GraphLoader.h"
19 #include "loader/KernelBuilder.h"
21 #include <loco/IR/Algorithm.h>
23 namespace luci_interpreter
28 template <typename NodeT> Shape getNodeShape(const NodeT *node)
30 Shape shape(node->rank());
31 for (uint32_t i = 0; i < node->rank(); ++i)
33 shape.dim(i) = node->dim(i).value();
38 template <DataType DT> const void *getNodeDataImpl(const luci::CircleConst *node, size_t *data_size)
40 const size_t element_size = getDataTypeSize(DT);
41 const int32_t num_elements = node->size<DT>();
43 *data_size = num_elements * element_size;
46 // FIXME There is no good way to get the pointer to the data currently.
47 return &node->at<DT>(0);
52 const void *getNodeData(const luci::CircleConst *node, size_t *data_size)
54 switch (node->dtype())
57 return getNodeDataImpl<DataType::U8>(node, data_size);
58 case DataType::FLOAT32:
59 return getNodeDataImpl<DataType::FLOAT32>(node, data_size);
61 return getNodeDataImpl<DataType::S16>(node, data_size);
63 return getNodeDataImpl<DataType::S32>(node, data_size);
65 return getNodeDataImpl<DataType::S64>(node, data_size);
67 throw std::runtime_error("Unsupported type.");
71 bool isExecutableNode(const luci::CircleNode *node)
73 switch (node->opcode())
75 // These nodes denote inputs / outputs of a graph.
76 case luci::CircleOpcode::CIRCLECONST:
77 case luci::CircleOpcode::CIRCLEINPUT:
78 case luci::CircleOpcode::CIRCLEOUTPUT:
79 case luci::CircleOpcode::CIRCLEOUTPUTEXCLUDE:
80 // The following nodes denote outputs of multiple-output nodes.
81 case luci::CircleOpcode::CIRCLEIFOUT:
82 case luci::CircleOpcode::CIRCLESPLITOUT:
83 case luci::CircleOpcode::CIRCLEUNPACKOUT:
90 bool isTensorProducingNode(const luci::CircleNode *node)
92 switch (node->opcode())
94 // Output nodes do not produce tensors.
95 case luci::CircleOpcode::CIRCLEOUTPUT:
96 // The following nodes are multiple-output nodes. They do not produce tensors, the tensors
97 // are produced by the corresponding *Out nodes instead.
98 case luci::CircleOpcode::IF:
99 case luci::CircleOpcode::SPLIT:
100 case luci::CircleOpcode::UNPACK:
109 GraphLoader::GraphLoader(
110 const loco::Graph *graph, RuntimeGraph *runtime_graph, RuntimeToIR &runtime_to_ir,
111 const std::unordered_map<const loco::Graph *, RuntimeGraph *> &graph_to_runtime_graph,
112 std::unordered_map<const loco::Node *, Tensor *> &node_to_tensor)
113 : _graph(graph), _runtime_graph(runtime_graph), _runtime_to_ir(runtime_to_ir),
114 _graph_to_runtime_graph(graph_to_runtime_graph), _node_to_tensor(node_to_tensor)
118 void GraphLoader::loadTensors()
120 for (uint32_t i = 0; i < _graph->nodes()->size(); ++i)
122 const auto *node = loco::must_cast<const luci::CircleNode *>(_graph->nodes()->at(i));
124 if (!isTensorProducingNode(node))
127 // Only Input and Const nodes have shapes. Shapes of intermediate tensors will be inferred.
129 if (const auto *input_node = dynamic_cast<const luci::CircleInput *>(node))
131 shape = getNodeShape(input_node);
133 else if (const auto *const_node = dynamic_cast<const luci::CircleConst *>(node))
135 shape = getNodeShape(const_node);
138 AffineQuantization quantization;
139 if (node->quantparam() != nullptr)
141 const luci::CircleQuantParam *params = node->quantparam();
142 assert(params->scale.size() == params->zerop.size());
143 quantization.scale.assign(params->scale.cbegin(), params->scale.cend());
144 quantization.zero_point.assign(params->zerop.cbegin(), params->zerop.cend());
145 quantization.quantized_dimension = params->quantized_dimension;
148 auto tensor = std::make_unique<Tensor>(node->dtype(), std::move(shape), std::move(quantization),
151 if (const auto *const_node = dynamic_cast<const luci::CircleConst *>(node))
154 const void *const_data = getNodeData(const_node, &data_size);
155 if (const_data != nullptr)
156 tensor->writeData(const_data, data_size);
159 _node_to_tensor.emplace(node, tensor.get());
160 _runtime_to_ir.tensor_to_node.emplace(tensor.get(), node);
162 _runtime_graph->addTensor(std::move(tensor));
166 void GraphLoader::initInputOutputTensors() const
168 auto input_nodes = loco::input_nodes(_graph);
169 std::vector<Tensor *> input_tensors(input_nodes.size());
170 for (size_t i = 0; i < input_nodes.size(); ++i)
172 input_tensors[i] = _node_to_tensor.at(input_nodes[i]);
174 _runtime_graph->setInputTensors(input_tensors);
176 auto output_nodes = loco::output_nodes(const_cast<loco::Graph *>(_graph));
177 std::vector<Tensor *> output_tensors(output_nodes.size());
178 for (size_t i = 0; i < output_nodes.size(); ++i)
180 const auto *node = loco::must_cast<const luci::CircleOutput *>(output_nodes[i]);
181 output_tensors[i] = _node_to_tensor.at(node->from());
183 _runtime_graph->setOutputTensors(output_tensors);
186 void GraphLoader::loadOperators()
188 KernelBuilder kernel_builder(_graph_to_runtime_graph, _node_to_tensor);
190 // Create kernels for executable nodes. This has to be done in execution order.
191 for (const loco::Node *loco_node :
192 loco::postorder_traversal(loco::output_nodes(const_cast<loco::Graph *>(_graph))))
194 const auto *node = loco::must_cast<const luci::CircleNode *>(loco_node);
196 if (isExecutableNode(node))
198 std::unique_ptr<Kernel> kernel = node->accept(&kernel_builder);
199 _runtime_to_ir.kernel_to_node.emplace(kernel.get(), node);
200 _runtime_graph->addKernel(std::move(kernel));
205 } // namespace luci_interpreter