2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "exec/Execution.h"
19 #include "util/logging.h"
26 Execution::Execution(const std::shared_ptr<ExecutorMap> &executors) : _executors{executors}
28 assert(executors != nullptr);
29 assert(executors->at(ir::SubgraphIndex{0}) != nullptr);
30 const auto &primary_subg = primary_subgraph();
31 _io_desc.inputs.resize(primary_subg.getInputs().size());
32 _io_desc.outputs.resize(primary_subg.getOutputs().size());
35 void Execution::changeInputShape(const ir::IOIndex &index, const ir::Shape &new_shape)
37 // This should be called BEFORE setInput.
38 if (_io_desc.inputs.at(index.value()) != 0)
39 throw std::runtime_error("Error in calling order");
41 _io_desc.input_shape_signature[index] = new_shape;
44 // TODO Remove default parameter
45 void Execution::setInput(const ir::IOIndex &index, const void *buffer, size_t length,
48 const auto input_index = primary_subgraph().getInputs().at(index);
49 const auto info = primary_subgraph().operands().at(input_index).info();
51 // TODO handle when (!buffer && length != 0) : setting the input as an optional tensor
53 // check if size enough for input is passed
54 // if input_shape_sig is set, input_shape_sig overrides shape in info
55 // note: input_shape_sig contains shape passed by nnfw_set_input_tensorinfo()
57 auto input_shape_sig = _io_desc.input_shape_signature.find(index);
58 auto size_required = (input_shape_sig != _io_desc.input_shape_signature.end())
59 ? input_shape_sig->second.num_elements() *
60 onert::ir::sizeOfDataType(info.typeInfo().type())
63 if (length < size_required)
65 throw std::runtime_error{"Too small length"};
69 _io_desc.inputs.at(index.value()) = std::make_unique<InputDesc>(info, buffer, length, layout);
72 // TODO Remove default parameter
73 void Execution::setInput(const ir::IOIndex &index, const ir::TypeInfo &type, const ir::Shape &shape,
74 const void *buffer, size_t length, ir::Layout layout)
76 auto info = ir::OperandInfo::createStaticInfo(shape, type);
78 if (length < info.total_size())
80 throw std::runtime_error{"Too small length"};
83 _io_desc.inputs.at(index.value()) = std::make_unique<InputDesc>(info, buffer, length, layout);
86 // TODO Remove default parameter
87 void Execution::setOutput(const ir::IOIndex &index, void *buffer, size_t length, ir::Layout layout)
89 const auto output_index = primary_subgraph().getOutputs().at(index);
90 const auto info = primary_subgraph().operands().at(output_index).info();
92 if (length < info.total_size())
94 throw std::runtime_error{"Too small length"};
97 _io_desc.outputs.at(index.value()) = std::make_unique<OutputDesc>(info, buffer, length, layout);
100 // TODO Remove default parameter
101 void Execution::setOutput(const ir::IOIndex &index, const ir::TypeInfo &type,
102 const ir::Shape &shape, void *buffer, size_t length, ir::Layout layout)
104 auto info = ir::OperandInfo::createStaticInfo(shape, type);
106 if (length < info.total_size())
108 throw std::runtime_error{"Too small length"};
111 _io_desc.outputs.at(index.value()) = std::make_unique<OutputDesc>(info, buffer, length, layout);
114 void Execution::setInputLayout(const ir::IOIndex &index, ir::Layout layout)
116 const auto &input_desc = _io_desc.inputs.at(index.value());
117 _io_desc.inputs.at(index.value()) =
118 std::make_unique<InputDesc>(input_desc->info, input_desc->buffer, input_desc->size, layout);
121 void Execution::setOutputLayout(const ir::IOIndex &index, ir::Layout layout)
123 const auto &output_desc = _io_desc.outputs.at(index.value());
124 _io_desc.outputs.at(index.value()) = std::make_unique<OutputDesc>(
125 output_desc->info, output_desc->buffer, output_desc->size, layout);
128 void Execution::execute()
130 VERBOSE(Execution) << "Start execution" << std::endl;
132 primary_executor()->execute(_io_desc);
135 VERBOSE(Execution) << "Execution finished" << std::endl;
138 void Execution::startExecute()
140 VERBOSE(Execution) << "Create asynchronous execution thread" << std::endl;
142 _exec_thread = std::make_unique<std::thread>(&Execution::execute, this);
145 void Execution::waitFinish()
147 VERBOSE(Execution) << "Wait to finish execution" << std::endl;
149 _exec_thread->join();
153 bool Execution::isFinished(void) const { return finished; }
155 ir::Shape Execution::getInputShape(ir::IOIndex ind) const
157 auto itr = _io_desc.input_shape_signature.find(ind);
158 if (itr == _io_desc.input_shape_signature.end())
160 auto operand_idx = primary_subgraph().getInputs().at(ind.value());
161 return primary_subgraph().operands().at(operand_idx).shape();
169 ir::Shape Execution::getOutputShape(ir::IOIndex ind) const
172 throw std::runtime_error("Cannot get output shape before execution is finished");
174 const auto &output_desc = _io_desc.outputs.at(ind.value());
176 return output_desc->info.shape();