2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "loader/GraphLoader.h"
19 #include "loader/KernelBuilder.h"
21 #include <loco/IR/Algorithm.h>
23 namespace luci_interpreter
28 template <typename NodeT> Shape getNodeShape(const NodeT *node)
30 Shape shape(node->rank());
31 for (uint32_t i = 0; i < node->rank(); ++i)
33 shape.dim(i) = node->dim(i).value();
38 template <DataType DT> const void *getNodeDataImpl(const luci::CircleConst *node, size_t *data_size)
40 const size_t element_size = getDataTypeSize(DT);
41 const int32_t num_elements = node->size<DT>();
43 *data_size = num_elements * element_size;
46 // FIXME There is no good way to get the pointer to the data currently.
47 return &node->at<DT>(0);
52 const void *getNodeData(const luci::CircleConst *node, size_t *data_size)
54 switch (node->dtype())
57 return getNodeDataImpl<DataType::U8>(node, data_size);
58 case DataType::FLOAT32:
59 return getNodeDataImpl<DataType::FLOAT32>(node, data_size);
61 return getNodeDataImpl<DataType::S32>(node, data_size);
63 throw std::runtime_error("Unsupported type.");
67 bool isExecutableNode(const luci::CircleNode *node)
69 switch (node->opcode())
71 // These nodes denote inputs / outputs of a graph.
72 case luci::CircleOpcode::CIRCLECONST:
73 case luci::CircleOpcode::CIRCLEINPUT:
74 case luci::CircleOpcode::CIRCLEOUTPUT:
75 case luci::CircleOpcode::CIRCLEOUTPUTEXCLUDE:
76 // The following nodes denote outputs of multiple-output nodes.
77 case luci::CircleOpcode::CIRCLEIFOUT:
78 case luci::CircleOpcode::CIRCLESPLITOUT:
79 case luci::CircleOpcode::CIRCLEUNPACKOUT:
86 bool isTensorProducingNode(const luci::CircleNode *node)
88 switch (node->opcode())
90 // Output nodes do not produce tensors.
91 case luci::CircleOpcode::CIRCLEOUTPUT:
92 // The following nodes are multiple-output nodes. They do not produce tensors, the tensors
93 // are produced by the corresponding *Out nodes instead.
94 case luci::CircleOpcode::IF:
95 case luci::CircleOpcode::SPLIT:
96 case luci::CircleOpcode::UNPACK:
105 GraphLoader::GraphLoader(
106 const loco::Graph *graph, RuntimeGraph *runtime_graph, RuntimeToIR &runtime_to_ir,
107 const std::unordered_map<const loco::Graph *, RuntimeGraph *> &graph_to_runtime_graph,
108 std::unordered_map<const loco::Node *, Tensor *> &node_to_tensor)
109 : _graph(graph), _runtime_graph(runtime_graph), _runtime_to_ir(runtime_to_ir),
110 _graph_to_runtime_graph(graph_to_runtime_graph), _node_to_tensor(node_to_tensor)
114 void GraphLoader::loadTensors()
116 for (uint32_t i = 0; i < _graph->nodes()->size(); ++i)
118 const auto *node = loco::must_cast<const luci::CircleNode *>(_graph->nodes()->at(i));
120 if (!isTensorProducingNode(node))
123 // Only Input and Const nodes have shapes. Shapes of intermediate tensors will be inferred.
125 if (const auto *input_node = dynamic_cast<const luci::CircleInput *>(node))
127 shape = getNodeShape(input_node);
129 else if (const auto *const_node = dynamic_cast<const luci::CircleConst *>(node))
131 shape = getNodeShape(const_node);
134 AffineQuantization quantization;
135 if (node->quantparam() != nullptr)
137 const luci::CircleQuantParam *params = node->quantparam();
138 quantization.scale.assign(params->scale.cbegin(), params->scale.cend());
139 quantization.zero_point.assign(params->zerop.cbegin(), params->zerop.cend());
140 quantization.quantized_dimension = params->quantized_dimension;
143 auto tensor = std::make_unique<Tensor>(node->dtype(), std::move(shape), std::move(quantization),
146 if (const auto *const_node = dynamic_cast<const luci::CircleConst *>(node))
149 const void *const_data = getNodeData(const_node, &data_size);
150 if (const_data != nullptr)
151 tensor->writeData(const_data, data_size);
154 _node_to_tensor.emplace(node, tensor.get());
155 _runtime_to_ir.tensor_to_node.emplace(tensor.get(), node);
157 _runtime_graph->addTensor(std::move(tensor));
161 void GraphLoader::initInputOutputTensors() const
163 auto input_nodes = loco::input_nodes(_graph);
164 std::vector<Tensor *> input_tensors(input_nodes.size());
165 for (size_t i = 0; i < input_nodes.size(); ++i)
167 input_tensors[i] = _node_to_tensor.at(input_nodes[i]);
169 _runtime_graph->setInputTensors(input_tensors);
171 auto output_nodes = loco::output_nodes(const_cast<loco::Graph *>(_graph));
172 std::vector<Tensor *> output_tensors(output_nodes.size());
173 for (size_t i = 0; i < output_nodes.size(); ++i)
175 const auto *node = loco::must_cast<const luci::CircleOutput *>(output_nodes[i]);
176 output_tensors[i] = _node_to_tensor.at(node->from());
178 _runtime_graph->setOutputTensors(output_tensors);
181 void GraphLoader::loadOperators()
183 KernelBuilder kernel_builder(_graph_to_runtime_graph, _node_to_tensor);
185 // Create kernels for executable nodes. This has to be done in execution order.
186 for (const loco::Node *loco_node :
187 loco::postorder_traversal(loco::output_nodes(const_cast<loco::Graph *>(_graph))))
189 const auto *node = loco::must_cast<const luci::CircleNode *>(loco_node);
191 if (isExecutableNode(node))
193 std::unique_ptr<Kernel> kernel = node->accept(&kernel_builder);
194 _runtime_to_ir.kernel_to_node.emplace(kernel.get(), node);
195 _runtime_graph->addKernel(std::move(kernel));
200 } // namespace luci_interpreter