IVGCVSW-3595 Implement the LoadDynamicBackends function in the Runtime class
[platform/upstream/armnn.git] / src / armnn / Runtime.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include "Runtime.hpp"
6
7 #include <armnn/Version.hpp>
8
9 #include <backendsCommon/BackendRegistry.hpp>
10 #include <backendsCommon/IBackendContext.hpp>
11 #include <backendsCommon/DynamicBackendUtils.hpp>
12 #include <backendsCommon/DynamicBackend.hpp>
13
14 #include <iostream>
15
16 #include <boost/log/trivial.hpp>
17 #include <boost/polymorphic_cast.hpp>
18
19 using namespace armnn;
20 using namespace std;
21
22 namespace armnn
23 {
24
25 IRuntime* IRuntime::CreateRaw(const CreationOptions& options)
26 {
27     return new Runtime(options);
28 }
29
30 IRuntimePtr IRuntime::Create(const CreationOptions& options)
31 {
32     return IRuntimePtr(CreateRaw(options), &IRuntime::Destroy);
33 }
34
35 void IRuntime::Destroy(IRuntime* runtime)
36 {
37     delete boost::polymorphic_downcast<Runtime*>(runtime);
38 }
39
40 int Runtime::GenerateNetworkId()
41 {
42     return m_NetworkIdCounter++;
43 }
44
45 Status Runtime::LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr inNetwork)
46 {
47     std::string ignoredErrorMessage;
48     return LoadNetwork(networkIdOut, std::move(inNetwork), ignoredErrorMessage);
49 }
50
51 Status Runtime::LoadNetwork(NetworkId& networkIdOut,
52                             IOptimizedNetworkPtr inNetwork,
53                             std::string & errorMessage)
54 {
55     IOptimizedNetwork* rawNetwork = inNetwork.release();
56
57     networkIdOut = GenerateNetworkId();
58
59     for (auto&& context : m_BackendContexts)
60     {
61         context.second->BeforeLoadNetwork(networkIdOut);
62     }
63
64     unique_ptr<LoadedNetwork> loadedNetwork = LoadedNetwork::MakeLoadedNetwork(
65         std::unique_ptr<OptimizedNetwork>(boost::polymorphic_downcast<OptimizedNetwork*>(rawNetwork)),
66         errorMessage);
67
68     if (!loadedNetwork)
69     {
70         return Status::Failure;
71     }
72
73     {
74         std::lock_guard<std::mutex> lockGuard(m_Mutex);
75
76         // Stores the network
77         m_LoadedNetworks[networkIdOut] = std::move(loadedNetwork);
78     }
79
80     for (auto&& context : m_BackendContexts)
81     {
82         context.second->AfterLoadNetwork(networkIdOut);
83     }
84
85     return Status::Success;
86 }
87
88 Status Runtime::UnloadNetwork(NetworkId networkId)
89 {
90     bool unloadOk = true;
91     for (auto&& context : m_BackendContexts)
92     {
93         unloadOk &= context.second->BeforeUnloadNetwork(networkId);
94     }
95
96     if (!unloadOk)
97     {
98         BOOST_LOG_TRIVIAL(warning) << "Runtime::UnloadNetwork(): failed to unload "
99                                       "network with ID:" << networkId << " because BeforeUnloadNetwork failed";
100         return Status::Failure;
101     }
102
103     {
104         std::lock_guard<std::mutex> lockGuard(m_Mutex);
105
106         if (m_LoadedNetworks.erase(networkId) == 0)
107         {
108             BOOST_LOG_TRIVIAL(warning) << "WARNING: Runtime::UnloadNetwork(): " << networkId << " not found!";
109             return Status::Failure;
110         }
111     }
112
113     for (auto&& context : m_BackendContexts)
114     {
115         context.second->AfterUnloadNetwork(networkId);
116     }
117
118     BOOST_LOG_TRIVIAL(debug) << "Runtime::UnloadNetwork(): Unloaded network with ID: " << networkId;
119     return Status::Success;
120 }
121
122 const std::shared_ptr<IProfiler> Runtime::GetProfiler(NetworkId networkId) const
123 {
124     auto it = m_LoadedNetworks.find(networkId);
125     if (it != m_LoadedNetworks.end())
126     {
127         auto& loadedNetwork = it->second;
128         return loadedNetwork->GetProfiler();
129     }
130
131     return nullptr;
132 }
133
134 Runtime::Runtime(const CreationOptions& options)
135     : m_NetworkIdCounter(0)
136     , m_DeviceSpec{BackendRegistryInstance().GetBackendIds()}
137 {
138     BOOST_LOG_TRIVIAL(info) << "ArmNN v" << ARMNN_VERSION << "\n";
139
140     // Load any available/compatible dynamic backend before the runtime
141     // goes through the backend registry
142     LoadDynamicBackends(options.m_DynamicBackendsPath);
143
144     for (const auto& id : BackendRegistryInstance().GetBackendIds())
145     {
146         // Store backend contexts for the supported ones
147         if (m_DeviceSpec.GetSupportedBackends().count(id) > 0)
148         {
149             auto factoryFun = BackendRegistryInstance().GetFactory(id);
150             auto backend = factoryFun();
151             BOOST_ASSERT(backend.get() != nullptr);
152
153             auto context = backend->CreateBackendContext(options);
154
155             // backends are allowed to return nullptrs if they
156             // don't wish to create a backend specific context
157             if (context)
158             {
159                 m_BackendContexts.emplace(std::make_pair(id, std::move(context)));
160             }
161         }
162     }
163 }
164
165 Runtime::~Runtime()
166 {
167     std::vector<int> networkIDs;
168     try
169     {
170         // Coverity fix: The following code may throw an exception of type std::length_error.
171         std::transform(m_LoadedNetworks.begin(), m_LoadedNetworks.end(),
172                        std::back_inserter(networkIDs),
173                        [](const auto &pair) { return pair.first; });
174     }
175     catch (const std::exception& e)
176     {
177         // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
178         // exception of type std::length_error.
179         // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
180         std::cerr << "WARNING: An error has occurred when getting the IDs of the networks to unload: " << e.what()
181                   << "\nSome of the loaded networks may not be unloaded" << std::endl;
182     }
183     // We then proceed to unload all the networks which IDs have been appended to the list
184     // up to the point the exception was thrown (if any).
185
186     for (auto networkID : networkIDs)
187     {
188         try
189         {
190             // Coverity fix: UnloadNetwork() may throw an exception of type std::length_error,
191             // boost::log::v2s_mt_posix::odr_violation or boost::log::v2s_mt_posix::system_error
192             UnloadNetwork(networkID);
193         }
194         catch (const std::exception& e)
195         {
196             // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
197             // exception of type std::length_error.
198             // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
199             std::cerr << "WARNING: An error has occurred when unloading network " << networkID << ": " << e.what()
200                       << std::endl;
201         }
202     }
203 }
204
205 LoadedNetwork* Runtime::GetLoadedNetworkPtr(NetworkId networkId) const
206 {
207     std::lock_guard<std::mutex> lockGuard(m_Mutex);
208     return m_LoadedNetworks.at(networkId).get();
209 }
210
211 TensorInfo Runtime::GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
212 {
213     return GetLoadedNetworkPtr(networkId)->GetInputTensorInfo(layerId);
214 }
215
216 TensorInfo Runtime::GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
217 {
218     return GetLoadedNetworkPtr(networkId)->GetOutputTensorInfo(layerId);
219 }
220
221
222 Status Runtime::EnqueueWorkload(NetworkId networkId,
223                                 const InputTensors& inputTensors,
224                                 const OutputTensors& outputTensors)
225 {
226     LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId);
227
228     static thread_local NetworkId lastId = networkId;
229     if (lastId != networkId)
230     {
231         LoadedNetworkFuncSafe(lastId, [](LoadedNetwork* network)
232             {
233                 network->FreeWorkingMemory();
234             });
235     }
236     lastId=networkId;
237
238     return loadedNetwork->EnqueueWorkload(inputTensors, outputTensors);
239 }
240
241 void Runtime::RegisterDebugCallback(NetworkId networkId, const DebugCallbackFunction& func)
242 {
243     LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId);
244     loadedNetwork->RegisterDebugCallback(func);
245 }
246
247 void Runtime::LoadDynamicBackends(const std::string& overrideBackendPath)
248 {
249     // Get the paths where to load the dynamic backends from
250     std::vector<std::string> backendPaths = DynamicBackendUtils::GetBackendPaths(overrideBackendPath);
251
252     // Get the shared objects to try to load as dynamic backends
253     std::vector<std::string> sharedObjects = DynamicBackendUtils::GetSharedObjects(backendPaths);
254
255     // Create a list of dynamic backends
256     DynamicBackendUtils::CreateDynamicBackends(sharedObjects);
257 }
258
259 } // namespace armnn