5b401ecf8f04202de040a31c4fc4c7d7dd08153c
[platform/core/ml/nnfw.git] / runtime / onert / core / src / exec / Execution.cc
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "exec/Execution.h"
18
19 #include "util/logging.h"
20
21 namespace onert
22 {
23 namespace exec
24 {
25
26 Execution::Execution(const std::shared_ptr<ExecutorMap> &executors) : _executors{executors}
27 {
28   assert(executors != nullptr);
29   assert(executors->at(ir::SubgraphIndex{0}) != nullptr);
30   const auto &primary_subg = primary_subgraph();
31   _io_desc.inputs.resize(primary_subg.getInputs().size());
32   _io_desc.outputs.resize(primary_subg.getOutputs().size());
33 }
34
35 void Execution::changeInputShape(const ir::IOIndex &index, const ir::Shape &new_shape)
36 {
37   // This should be called BEFORE setInput.
38   if (_io_desc.inputs.at(index.value()) != 0)
39     throw std::runtime_error("Error in calling order");
40
41   _io_desc.input_shape_signature[index] = new_shape;
42 }
43
44 // TODO Remove default parameter
45 void Execution::setInput(const ir::IOIndex &index, const void *buffer, size_t length,
46                          ir::Layout layout)
47 {
48   const auto input_index = primary_subgraph().getInputs().at(index);
49   const auto info = primary_subgraph().operands().at(input_index).info();
50
51   // TODO handle when (!buffer && length != 0) : setting the input as an optional tensor
52
53   // check if size enough for input is passed
54   // if input_shape_sig is set, input_shape_sig overrides shape in info
55   // note: input_shape_sig contains shape passed by nnfw_set_input_tensorinfo()
56   {
57     auto input_shape_sig = _io_desc.input_shape_signature.find(index);
58     auto size_required = (input_shape_sig != _io_desc.input_shape_signature.end())
59                              ? input_shape_sig->second.num_elements() *
60                                    onert::ir::sizeOfDataType(info.typeInfo().type())
61                              : info.total_size();
62
63     if (length < size_required)
64     {
65       throw std::runtime_error{"Too small length"};
66     }
67   }
68
69   _io_desc.inputs.at(index.value()) = std::make_unique<InputDesc>(info, buffer, length, layout);
70 }
71
72 // TODO Remove default parameter
73 void Execution::setInput(const ir::IOIndex &index, const ir::TypeInfo &type, const ir::Shape &shape,
74                          const void *buffer, size_t length, ir::Layout layout)
75 {
76   auto info = ir::OperandInfo::createStaticInfo(shape, type);
77
78   if (length < info.total_size())
79   {
80     throw std::runtime_error{"Too small length"};
81   }
82
83   _io_desc.inputs.at(index.value()) = std::make_unique<InputDesc>(info, buffer, length, layout);
84 }
85
86 // TODO Remove default parameter
87 void Execution::setOutput(const ir::IOIndex &index, void *buffer, size_t length, ir::Layout layout)
88 {
89   const auto output_index = primary_subgraph().getOutputs().at(index);
90   const auto info = primary_subgraph().operands().at(output_index).info();
91
92   if (length < info.total_size())
93   {
94     throw std::runtime_error{"Too small length"};
95   }
96
97   _io_desc.outputs.at(index.value()) = std::make_unique<OutputDesc>(info, buffer, length, layout);
98 }
99
100 // TODO Remove default parameter
101 void Execution::setOutput(const ir::IOIndex &index, const ir::TypeInfo &type,
102                           const ir::Shape &shape, void *buffer, size_t length, ir::Layout layout)
103 {
104   auto info = ir::OperandInfo::createStaticInfo(shape, type);
105
106   if (length < info.total_size())
107   {
108     throw std::runtime_error{"Too small length"};
109   }
110
111   _io_desc.outputs.at(index.value()) = std::make_unique<OutputDesc>(info, buffer, length, layout);
112 }
113
114 void Execution::setInputLayout(const ir::IOIndex &index, ir::Layout layout)
115 {
116   const auto &input_desc = _io_desc.inputs.at(index.value());
117   _io_desc.inputs.at(index.value()) =
118       std::make_unique<InputDesc>(input_desc->info, input_desc->buffer, input_desc->size, layout);
119 }
120
121 void Execution::setOutputLayout(const ir::IOIndex &index, ir::Layout layout)
122 {
123   const auto &output_desc = _io_desc.outputs.at(index.value());
124   _io_desc.outputs.at(index.value()) = std::make_unique<OutputDesc>(
125       output_desc->info, output_desc->buffer, output_desc->size, layout);
126 }
127
128 void Execution::execute()
129 {
130   VERBOSE(Execution) << "Start execution" << std::endl;
131
132   primary_executor()->execute(_io_desc);
133   finished = true;
134
135   VERBOSE(Execution) << "Execution finished" << std::endl;
136 }
137
138 void Execution::startExecute()
139 {
140   VERBOSE(Execution) << "Create asynchronous execution thread" << std::endl;
141
142   _exec_thread = std::make_unique<std::thread>(&Execution::execute, this);
143 }
144
145 void Execution::waitFinish()
146 {
147   VERBOSE(Execution) << "Wait to finish execution" << std::endl;
148
149   _exec_thread->join();
150   finished = true;
151 }
152
153 bool Execution::isFinished(void) const { return finished; }
154
155 ir::Shape Execution::getInputShape(ir::IOIndex ind) const
156 {
157   auto itr = _io_desc.input_shape_signature.find(ind);
158   if (itr == _io_desc.input_shape_signature.end())
159   {
160     auto operand_idx = primary_subgraph().getInputs().at(ind.value());
161     return primary_subgraph().operands().at(operand_idx).shape();
162   }
163   else
164   {
165     return itr->second;
166   }
167 }
168
169 ir::Shape Execution::getOutputShape(ir::IOIndex ind) const
170 {
171   if (!isFinished())
172     throw std::runtime_error("Cannot get output shape before execution is finished");
173
174   const auto &output_desc = _io_desc.outputs.at(ind.value());
175
176   return output_desc->info.shape();
177 }
178
179 } // namespace exec
180 } // namespace onert