2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "DynamicTensorManager.h"
19 #include "util/logging.h"
28 DynamicTensorManager::DynamicTensorManager(const std::shared_ptr<cpu_common::TensorRegistry> ®,
29 const std::shared_ptr<UserTensorRegistry> &user_reg)
30 : _dynamic_mem_mgr{new cpu_common::DynamicMemoryManager()}, _tensors{reg},
31 _user_tensors{user_reg}
36 void DynamicTensorManager::applyShape(const ir::OperandIndex &ind, const ir::Shape &new_shape)
38 // NOTE Handle user tensors first
39 auto user_tensor = _user_tensors->getManagedTensor(ind);
42 // User tensors cannot be reallocated.
43 auto buffer_size = user_tensor->total_size();
44 auto new_size = new_shape.num_elements() * sizeOfDataType(user_tensor->data_type());
45 if (buffer_size < new_size)
46 throw std::runtime_error{"ExecutorBase: output buffer size is less than output tensor size"};
47 user_tensor->setShape(new_shape);
50 // NOTE Then handle managed tensors
51 auto tensor = _tensors->getManagedTensor(ind);
54 bool previously_dynamic = tensor->is_dynamic();
56 auto allocTensorMem = [&](bool overwrite = false) {
57 auto capacity = tensor->total_size();
58 auto alloc = _dynamic_mem_mgr->allocate(ind, capacity);
61 tensor->overwriteBuffer(alloc);
63 tensor->setBuffer(alloc);
66 if (!previously_dynamic)
68 // TODO deallocate tensor->buffer()
69 // issue is that staticTensorManager might have allocate this memory
70 tensor->setShape(new_shape);
71 tensor->set_dynamic();
74 else if (tensor->buffer() == nullptr)
76 tensor->setShape(new_shape);
77 tensor->set_dynamic();
80 // when buffer was already allocated and new_shape requires different size
83 auto previous_size = tensor->total_size();
84 auto new_size = new_shape.num_elements() * sizeOfDataType(tensor->data_type());
85 if (previous_size != new_size)
87 _dynamic_mem_mgr->deallocate(ind);
89 tensor->setShape(new_shape);
90 tensor->set_dynamic();
94 { // when buffer with same size was already allocated, shape could differ
95 tensor->setShape(new_shape);
100 void DynamicTensorManager::buildTensor(const ir::OperandIndex &ind,
101 const ir::OperandInfo &tensor_info,
102 ir::Layout backend_layout)
104 assert(_tensors->getManagedTensor(ind) == nullptr);
105 auto tensor = std::make_shared<cpu_common::Tensor>(tensor_info, backend_layout);
106 _tensors->setManagedTensor(ind, tensor);
109 void DynamicTensorManager::planDealloc(ir::OperationIndex op_ind, ir::OperandIndex operand_ind)
111 auto find = _dealloc_tensor_map.find(op_ind);
112 if (find != _dealloc_tensor_map.end())
114 auto &input_set = find->second;
115 input_set.emplace(operand_ind);
119 _dealloc_tensor_map.emplace(
120 std::make_pair(op_ind, std::unordered_set<ir::OperandIndex>{operand_ind}));
124 void DynamicTensorManager::deallocInput(ir::OperationIndex op_ind)
126 auto find = _dealloc_tensor_map.find(op_ind);
127 if (find == _dealloc_tensor_map.end())
130 auto &input_set = find->second;
131 for (auto input_ind : input_set)
133 if (!_tensors->getManagedTensor(input_ind)->is_dynamic())
136 _dynamic_mem_mgr->deallocate(input_ind);
137 VERBOSE(DynamicTensorManager) << "Deallocating #" << input_ind.value()
138 << " (input of op_ind: " << op_ind.value() << ")" << std::endl;
142 void DynamicTensorManager::deallocSubgraphOutput(ir::OperandIndex output_ind)
144 if (!_tensors->getManagedTensor(output_ind)->is_dynamic())
147 _dynamic_mem_mgr->deallocate(output_ind);
148 VERBOSE(DynamicTensorManager) << "Deallocating #" << output_ind.value()
149 << " (output of a subgraph)" << std::endl;
152 } // namespace controlflow
153 } // namespace backend