2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "core/RuntimeGraph.h"
18 #include "kernels/KernelBuilder.h"
23 namespace luci_interpreter
27 RuntimeGraph::RuntimeGraph(SimpleMemoryManager *memory_manager, CircleReader *circle_reader)
28 : _memory_manager(memory_manager),
29 _tensor_to_data(std::unordered_map<const circle::Tensor *, uint8_t *>{}),
30 _reader(circle_reader), _inplace_op_indexes(std::unordered_set<uint32_t>{})
34 RuntimeGraph::~RuntimeGraph()
36 for (auto &idx_to_tensor : _tensor_to_data)
38 auto *data = idx_to_tensor.second;
40 _memory_manager->release_memory(data);
45 void RuntimeGraph::buildAllocDeallocPlan()
48 using Lifetime = std::pair<int32_t, int32_t>;
49 std::map<const circle::Tensor *, Lifetime> lifetimes;
50 const size_t num_kernels = _reader->operators().size();
52 for (const auto input_ind : _reader->inputs())
54 const auto raw_tensor = _reader->tensors()[input_ind];
56 assert(lifetimes.count(raw_tensor) == 0);
57 lifetimes[raw_tensor] = Lifetime(-1, 0);
60 for (int32_t index = 0; index < num_kernels; ++index)
62 const auto kernel = _reader->operators().at(index);
63 assert(kernel != nullptr);
65 for (int32_t j = 0; j < kernel->inputs()->size(); ++j)
67 const auto input_index = kernel->inputs()->operator[](j);
69 if (input_index == -1)
72 const auto raw_tensor = _reader->tensors()[input_index];
74 // Pass constant tensors
75 auto const &buffer = wrap(_reader->buffers()[raw_tensor->buffer()]->data());
76 if (not buffer.empty())
78 // unknown shape tensor and scalar tensor
82 if (lifetimes.count(raw_tensor) > 0)
84 if (_inplace_op_indexes.find(index) != _inplace_op_indexes.end())
85 lifetimes.at(raw_tensor).second = -1;
87 lifetimes.at(raw_tensor).second = index;
91 for (int32_t j = 0; j < kernel->outputs()->size(); ++j)
93 const auto output_index = kernel->outputs()->operator[](j);
94 const auto raw_tensor = _reader->tensors()[output_index];
96 assert(lifetimes.count(raw_tensor) == 0);
97 if (_inplace_op_indexes.find(index) != _inplace_op_indexes.end())
98 lifetimes[raw_tensor] = Lifetime(-1, index);
100 lifetimes[raw_tensor] = Lifetime(index, index);
104 for (const auto output_ind : _reader->outputs())
106 const auto raw_tensor = _reader->tensors()[output_ind];
108 if (lifetimes.count(raw_tensor) > 0)
109 lifetimes.at(raw_tensor).second = num_kernels;
112 _alloc_plan.assign(num_kernels, std::vector<const circle::Tensor *>());
113 _dealloc_plan.assign(num_kernels + 1, std::vector<const circle::Tensor *>());
114 for (const auto &item : lifetimes)
116 if (item.second.first != -1)
117 _alloc_plan[item.second.first].push_back(item.first);
118 if (item.second.second != -1)
119 _dealloc_plan[item.second.second].push_back(item.first);
124 void RuntimeGraph::allocate(size_t kernel_index)
126 assert(_is_valid && kernel_index < _alloc_plan.size());
127 for (const circle::Tensor *tensor : _alloc_plan[kernel_index])
129 if (_tensor_to_data.find(tensor) != _tensor_to_data.end())
131 auto *data = _tensor_to_data.at(tensor);
132 _memory_manager->release_memory(data);
134 auto *data = _memory_manager->allocate_memory(tensor);
135 _tensor_to_data[tensor] = data;
139 void RuntimeGraph::deallocate(size_t kernel_index)
141 assert(_is_valid && kernel_index < _dealloc_plan.size());
142 for (const circle::Tensor *tensor : _dealloc_plan[kernel_index])
144 const auto it = _tensor_to_data.find(tensor);
145 assert(it != _tensor_to_data.end());
147 auto *data = _tensor_to_data.at(tensor);
148 _memory_manager->release_memory(data);
150 _tensor_to_data.erase(it);
154 void RuntimeGraph::resetOutputTensorsData()
156 for (int i = 0; i < _reader->outputs().size(); ++i)
158 const auto tensor_index = _reader->outputs()[i];
159 assert(tensor_index != -1);
160 const auto tensor = _reader->tensors()[tensor_index];
161 assert(tensor != nullptr);
163 auto tensor_it = _tensor_to_data.find(tensor);
164 if (tensor_it != _tensor_to_data.end())
166 auto *data = _tensor_to_data.at(tensor);
167 _memory_manager->release_memory(data);
168 _tensor_to_data.erase(tensor_it);
173 uint8_t *RuntimeGraph::configureGraphInput(int32_t input_index)
175 resetOutputTensorsData();
177 const auto tensor_index = _reader->inputs()[input_index];
178 assert(tensor_index != -1);
179 const auto tensor = _reader->tensors()[tensor_index];
180 assert(tensor != nullptr);
182 if (_tensor_to_data.find(tensor) != _tensor_to_data.end())
184 auto *data = _tensor_to_data.at(tensor);
185 _memory_manager->release_memory(data);
188 auto *data = _memory_manager->allocate_memory(tensor);
189 _tensor_to_data[tensor] = data;
195 // TODO maybe remove it
196 void RuntimeGraph::configureGraphInput(int32_t input_index, uint8_t *data)
198 resetOutputTensorsData();
200 const auto tensor_index = _reader->inputs()[input_index];
201 assert(tensor_index != -1);
202 const auto tensor = _reader->tensors()[tensor_index];
203 assert(tensor != nullptr);
205 if (_tensor_to_data.find(tensor) != _tensor_to_data.end())
207 auto *data_prev = _tensor_to_data.at(tensor);
208 _memory_manager->release_memory(data_prev);
210 _tensor_to_data[tensor] = data;
213 int32_t RuntimeGraph::getInputDataSizeByIndex(int32_t input_index)
215 const auto tensor_index = _reader->inputs()[input_index];
216 assert(tensor_index != -1);
217 const auto tensor = _reader->tensors()[tensor_index];
218 assert(tensor != nullptr);
220 return Tensor::num_elements(tensor) * size(Tensor::element_type(tensor));
223 int32_t RuntimeGraph::getOutputDataSizeByIndex(int32_t output_index)
225 const auto tensor_index = _reader->outputs()[output_index];
226 assert(tensor_index != -1);
227 const auto tensor = _reader->tensors()[tensor_index];
228 assert(tensor != nullptr);
230 return Tensor::num_elements(tensor) * size(Tensor::element_type(tensor));
233 uint8_t *RuntimeGraph::getOutputDataByIndex(int32_t output_index)
235 const auto tensor_index = _reader->outputs()[output_index];
236 assert(tensor_index != -1);
237 const auto tensor = _reader->tensors()[tensor_index];
238 assert(tensor != nullptr);
240 assert(_tensor_to_data.find(tensor) != _tensor_to_data.end());
242 return _tensor_to_data[tensor];
245 uint8_t *RuntimeGraph::getDataByTensor(const circle::Tensor *raw_tensor)
247 if (raw_tensor == nullptr)
250 if (_tensor_to_data.find(raw_tensor) == _tensor_to_data.end())
255 return _tensor_to_data.at(raw_tensor);
258 void RuntimeGraph::makeInplaceOperation(const circle::Tensor *src_tensor,
259 const circle::Tensor *dst_tensor)
261 if (src_tensor == nullptr or dst_tensor == nullptr)
264 auto src_it = _tensor_to_data.find(src_tensor);
266 assert(src_it != _tensor_to_data.end() && "Failed makeInplaceOperation");
268 auto *data = _tensor_to_data[src_tensor];
270 _tensor_to_data.erase(src_it);
272 assert(_tensor_to_data.find(dst_tensor) == _tensor_to_data.end() &&
273 "Failed makeInplaceOperation");
274 _tensor_to_data[dst_tensor] = data;
277 uint8_t *RuntimeGraph::getConstDataByTensor(const circle::Tensor *raw_tensor)
279 if (raw_tensor == nullptr)
282 auto const &buffer = wrap(_reader->buffers()[raw_tensor->buffer()]->data());
284 return const_cast<uint8_t *>(buffer.data());
287 const circle::Tensor *RuntimeGraph::getCircleTensorByIndex(int32_t index)
292 const auto raw_tensor = _reader->tensors()[index];
297 void RuntimeGraph::configure()
299 KernelConfigureRegistry kernel_configure;
301 for (uint32_t i = 0; i < _reader->operators().size(); ++i)
303 const auto op = _reader->operators().at(i);
304 assert(op != nullptr);
306 const auto opcode = _reader->builtin_code(op);
308 kernel_configure.configure_kernel(op, opcode, this);
312 buildAllocDeallocPlan();
317 void RuntimeGraph::execute()
322 KernelExecuteRegistry kernel_executor;
324 for (uint32_t i = 0; i < _reader->operators().size(); ++i)
326 const auto op = _reader->operators().at(i);
327 assert(op != nullptr);
329 const auto opcode = _reader->builtin_code(op);
333 bool is_inplace = false;
335 if (_inplace_op_indexes.find(i) != _inplace_op_indexes.end())
338 kernel_executor.execute_kernel(op, opcode, this, is_inplace);
344 } // namespace luci_interpreter