f8b2462f966766c0cdc1dd604fa08267c0cebaac
[platform/upstream/armnn.git] / src / armnn / Runtime.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include "Runtime.hpp"
6
7 #include <armnn/Version.hpp>
8 #include <backendsCommon/BackendRegistry.hpp>
9 #include <backendsCommon/IBackendContext.hpp>
10
11 #include <iostream>
12
13 #include <boost/log/trivial.hpp>
14 #include <boost/polymorphic_cast.hpp>
15
16 using namespace armnn;
17 using namespace std;
18
19 namespace armnn
20 {
21
22 IRuntime* IRuntime::CreateRaw(const CreationOptions& options)
23 {
24     return new Runtime(options);
25 }
26
27 IRuntimePtr IRuntime::Create(const CreationOptions& options)
28 {
29     return IRuntimePtr(CreateRaw(options), &IRuntime::Destroy);
30 }
31
32 void IRuntime::Destroy(IRuntime* runtime)
33 {
34     delete boost::polymorphic_downcast<Runtime*>(runtime);
35 }
36
37 int Runtime::GenerateNetworkId()
38 {
39     return m_NetworkIdCounter++;
40 }
41
42 Status Runtime::LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr inNetwork)
43 {
44     std::string ignoredErrorMessage;
45     return LoadNetwork(networkIdOut, std::move(inNetwork), ignoredErrorMessage);
46 }
47
48 Status Runtime::LoadNetwork(NetworkId& networkIdOut,
49                             IOptimizedNetworkPtr inNetwork,
50                             std::string & errorMessage)
51 {
52     IOptimizedNetwork* rawNetwork = inNetwork.release();
53
54     networkIdOut = GenerateNetworkId();
55
56     for (auto&& context : m_BackendContexts)
57     {
58         context.second->BeforeLoadNetwork(networkIdOut);
59     }
60
61     unique_ptr<LoadedNetwork> loadedNetwork = LoadedNetwork::MakeLoadedNetwork(
62         std::unique_ptr<OptimizedNetwork>(boost::polymorphic_downcast<OptimizedNetwork*>(rawNetwork)),
63         errorMessage);
64
65     if (!loadedNetwork)
66     {
67         return Status::Failure;
68     }
69
70     {
71         std::lock_guard<std::mutex> lockGuard(m_Mutex);
72
73         // Stores the network
74         m_LoadedNetworks[networkIdOut] = std::move(loadedNetwork);
75     }
76
77     for (auto&& context : m_BackendContexts)
78     {
79         context.second->AfterLoadNetwork(networkIdOut);
80     }
81
82     return Status::Success;
83 }
84
85 Status Runtime::UnloadNetwork(NetworkId networkId)
86 {
87     bool unloadOk = true;
88     for (auto&& context : m_BackendContexts)
89     {
90         unloadOk &= context.second->BeforeUnloadNetwork(networkId);
91     }
92
93     if (!unloadOk)
94     {
95         BOOST_LOG_TRIVIAL(warning) << "Runtime::UnloadNetwork(): failed to unload "
96                                       "network with ID:" << networkId << " because BeforeUnloadNetwork failed";
97         return Status::Failure;
98     }
99
100     {
101         std::lock_guard<std::mutex> lockGuard(m_Mutex);
102
103         if (m_LoadedNetworks.erase(networkId) == 0)
104         {
105             BOOST_LOG_TRIVIAL(warning) << "WARNING: Runtime::UnloadNetwork(): " << networkId << " not found!";
106             return Status::Failure;
107         }
108     }
109
110     for (auto&& context : m_BackendContexts)
111     {
112         context.second->AfterUnloadNetwork(networkId);
113     }
114
115     BOOST_LOG_TRIVIAL(debug) << "Runtime::UnloadNetwork(): Unloaded network with ID: " << networkId;
116     return Status::Success;
117 }
118
119 const std::shared_ptr<IProfiler> Runtime::GetProfiler(NetworkId networkId) const
120 {
121     auto it = m_LoadedNetworks.find(networkId);
122     if (it != m_LoadedNetworks.end())
123     {
124         auto& loadedNetwork = it->second;
125         return loadedNetwork->GetProfiler();
126     }
127
128     return nullptr;
129 }
130
131 Runtime::Runtime(const CreationOptions& options)
132     : m_NetworkIdCounter(0)
133     , m_DeviceSpec{BackendRegistryInstance().GetBackendIds()}
134 {
135     BOOST_LOG_TRIVIAL(info) << "ArmNN v" << ARMNN_VERSION << "\n";
136
137     for (const auto& id : BackendRegistryInstance().GetBackendIds())
138     {
139         // Store backend contexts for the supported ones
140         if (m_DeviceSpec.GetSupportedBackends().count(id) > 0)
141         {
142             auto factoryFun = BackendRegistryInstance().GetFactory(id);
143             auto backend = factoryFun();
144             BOOST_ASSERT(backend.get() != nullptr);
145
146             auto context = backend->CreateBackendContext(options);
147
148             // backends are allowed to return nullptrs if they
149             // don't wish to create a backend specific context
150             if (context)
151             {
152                 m_BackendContexts.emplace(std::make_pair(id, std::move(context)));
153             }
154         }
155     }
156 }
157
158 Runtime::~Runtime()
159 {
160     std::vector<int> networkIDs;
161     try
162     {
163         // Coverity fix: The following code may throw an exception of type std::length_error.
164         std::transform(m_LoadedNetworks.begin(), m_LoadedNetworks.end(),
165                        std::back_inserter(networkIDs),
166                        [](const auto &pair) { return pair.first; });
167     }
168     catch (const std::exception& e)
169     {
170         // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
171         // exception of type std::length_error.
172         // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
173         std::cerr << "WARNING: An error has occurred when getting the IDs of the networks to unload: " << e.what()
174                   << "\nSome of the loaded networks may not be unloaded" << std::endl;
175     }
176     // We then proceed to unload all the networks which IDs have been appended to the list
177     // up to the point the exception was thrown (if any).
178
179     for (auto networkID : networkIDs)
180     {
181         try
182         {
183             // Coverity fix: UnloadNetwork() may throw an exception of type std::length_error,
184             // boost::log::v2s_mt_posix::odr_violation or boost::log::v2s_mt_posix::system_error
185             UnloadNetwork(networkID);
186         }
187         catch (const std::exception& e)
188         {
189             // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
190             // exception of type std::length_error.
191             // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
192             std::cerr << "WARNING: An error has occurred when unloading network " << networkID << ": " << e.what()
193                       << std::endl;
194         }
195     }
196 }
197
198 LoadedNetwork* Runtime::GetLoadedNetworkPtr(NetworkId networkId) const
199 {
200     std::lock_guard<std::mutex> lockGuard(m_Mutex);
201     return m_LoadedNetworks.at(networkId).get();
202 }
203
204 TensorInfo Runtime::GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
205 {
206     return GetLoadedNetworkPtr(networkId)->GetInputTensorInfo(layerId);
207 }
208
209 TensorInfo Runtime::GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
210 {
211     return GetLoadedNetworkPtr(networkId)->GetOutputTensorInfo(layerId);
212 }
213
214
215 Status Runtime::EnqueueWorkload(NetworkId networkId,
216                                 const InputTensors& inputTensors,
217                                 const OutputTensors& outputTensors)
218 {
219     LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId);
220
221     static thread_local NetworkId lastId = networkId;
222     if (lastId != networkId)
223     {
224         LoadedNetworkFuncSafe(lastId, [](LoadedNetwork* network)
225             {
226                 network->FreeWorkingMemory();
227             });
228     }
229     lastId=networkId;
230
231     return loadedNetwork->EnqueueWorkload(inputTensors, outputTensors);
232 }
233
234 void Runtime::RegisterDebugCallback(NetworkId networkId, const DebugCallbackFunction& func)
235 {
236     LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId);
237     loadedNetwork->RegisterDebugCallback(func);
238 }
239
240 }