483eea7165711ed40e23c053a68e2405aa728a9d
[platform/upstream/armnn.git] / src / armnn / Runtime.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include "Runtime.hpp"
6
7 #include <armnn/Version.hpp>
8 #include <armnn/BackendRegistry.hpp>
9 #include <armnn/Logging.hpp>
10
11 #include <armnn/backends/IBackendContext.hpp>
12 #include <backendsCommon/DynamicBackendUtils.hpp>
13 #include <armnn/utility/PolymorphicDowncast.hpp>
14
15 #include <iostream>
16
17 #include <backends/BackendProfiling.hpp>
18
19 using namespace armnn;
20 using namespace std;
21
22 namespace armnn
23 {
24
25 IRuntime* IRuntime::CreateRaw(const CreationOptions& options)
26 {
27     return new Runtime(options);
28 }
29
30 IRuntimePtr IRuntime::Create(const CreationOptions& options)
31 {
32     return IRuntimePtr(CreateRaw(options), &IRuntime::Destroy);
33 }
34
35 void IRuntime::Destroy(IRuntime* runtime)
36 {
37     delete PolymorphicDowncast<Runtime*>(runtime);
38 }
39
40 int Runtime::GenerateNetworkId()
41 {
42     return m_NetworkIdCounter++;
43 }
44
45 Status Runtime::LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr inNetwork)
46 {
47     std::string ignoredErrorMessage;
48     return LoadNetwork(networkIdOut, std::move(inNetwork), ignoredErrorMessage);
49 }
50
51 Status Runtime::LoadNetwork(NetworkId& networkIdOut,
52                             IOptimizedNetworkPtr inNetwork,
53                             std::string& errorMessage)
54 {
55     INetworkProperties networkProperties;
56     return LoadNetwork(networkIdOut, std::move(inNetwork), errorMessage, networkProperties);
57 }
58
59 Status Runtime::LoadNetwork(NetworkId& networkIdOut,
60                             IOptimizedNetworkPtr inNetwork,
61                             std::string& errorMessage,
62                             const INetworkProperties& networkProperties)
63 {
64     IOptimizedNetwork* rawNetwork = inNetwork.release();
65
66     networkIdOut = GenerateNetworkId();
67
68     for (auto&& context : m_BackendContexts)
69     {
70         context.second->BeforeLoadNetwork(networkIdOut);
71     }
72
73     unique_ptr<LoadedNetwork> loadedNetwork = LoadedNetwork::MakeLoadedNetwork(
74         std::unique_ptr<OptimizedNetwork>(PolymorphicDowncast<OptimizedNetwork*>(rawNetwork)),
75         errorMessage,
76         networkProperties,
77         m_ProfilingService);
78
79     if (!loadedNetwork)
80     {
81         return Status::Failure;
82     }
83
84     {
85         std::lock_guard<std::mutex> lockGuard(m_Mutex);
86
87         // Stores the network
88         m_LoadedNetworks[networkIdOut] = std::move(loadedNetwork);
89     }
90
91     for (auto&& context : m_BackendContexts)
92     {
93         context.second->AfterLoadNetwork(networkIdOut);
94     }
95
96     if (m_ProfilingService.IsProfilingEnabled())
97     {
98         m_ProfilingService.IncrementCounterValue(armnn::profiling::NETWORK_LOADS);
99     }
100
101     return Status::Success;
102 }
103
104 Status Runtime::UnloadNetwork(NetworkId networkId)
105 {
106     bool unloadOk = true;
107     for (auto&& context : m_BackendContexts)
108     {
109         unloadOk &= context.second->BeforeUnloadNetwork(networkId);
110     }
111
112     if (!unloadOk)
113     {
114         ARMNN_LOG(warning) << "Runtime::UnloadNetwork(): failed to unload "
115                               "network with ID:" << networkId << " because BeforeUnloadNetwork failed";
116         return Status::Failure;
117     }
118
119     {
120         std::lock_guard<std::mutex> lockGuard(m_Mutex);
121
122         if (m_LoadedNetworks.erase(networkId) == 0)
123         {
124             ARMNN_LOG(warning) << "WARNING: Runtime::UnloadNetwork(): " << networkId << " not found!";
125             return Status::Failure;
126         }
127
128         if (m_ProfilingService.IsProfilingEnabled())
129         {
130             m_ProfilingService.IncrementCounterValue(armnn::profiling::NETWORK_UNLOADS);
131         }
132     }
133
134     for (auto&& context : m_BackendContexts)
135     {
136         context.second->AfterUnloadNetwork(networkId);
137     }
138
139     ARMNN_LOG(debug) << "Runtime::UnloadNetwork(): Unloaded network with ID: " << networkId;
140     return Status::Success;
141 }
142
143 const std::shared_ptr<IProfiler> Runtime::GetProfiler(NetworkId networkId) const
144 {
145     auto it = m_LoadedNetworks.find(networkId);
146     if (it != m_LoadedNetworks.end())
147     {
148         auto& loadedNetwork = it->second;
149         return loadedNetwork->GetProfiler();
150     }
151
152     return nullptr;
153 }
154
155 void Runtime::ReportStructure() // armnn::profiling::IProfilingService& profilingService as param
156 {
157     // No-op for the time being, but this may be useful in future to have the profilingService available
158     // if (profilingService.IsProfilingEnabled()){}
159
160     LoadedNetworks::iterator it = m_LoadedNetworks.begin();
161     while (it != m_LoadedNetworks.end())
162     {
163         auto& loadedNetwork = it->second;
164         loadedNetwork->SendNetworkStructure();
165         // Increment the Iterator to point to next entry
166         it++;
167     }
168 }
169
170 Runtime::Runtime(const CreationOptions& options)
171     : m_NetworkIdCounter(0),
172       m_ProfilingService(*this)
173 {
174     ARMNN_LOG(info) << "ArmNN v" << ARMNN_VERSION << "\n";
175
176     if ( options.m_ProfilingOptions.m_TimelineEnabled && !options.m_ProfilingOptions.m_EnableProfiling )
177     {
178         throw RuntimeException("It is not possible to enable timeline reporting without profiling being enabled");
179     }
180
181     // Load any available/compatible dynamic backend before the runtime
182     // goes through the backend registry
183     LoadDynamicBackends(options.m_DynamicBackendsPath);
184
185     BackendIdSet supportedBackends;
186     for (const auto& id : BackendRegistryInstance().GetBackendIds())
187     {
188         // Store backend contexts for the supported ones
189         try {
190             auto factoryFun = BackendRegistryInstance().GetFactory(id);
191             auto backend = factoryFun();
192             ARMNN_ASSERT(backend.get() != nullptr);
193
194             auto context = backend->CreateBackendContext(options);
195
196             // backends are allowed to return nullptrs if they
197             // don't wish to create a backend specific context
198             if (context)
199             {
200                 m_BackendContexts.emplace(std::make_pair(id, std::move(context)));
201             }
202             supportedBackends.emplace(id);
203
204             unique_ptr<armnn::profiling::IBackendProfiling> profilingIface =
205                 std::make_unique<armnn::profiling::BackendProfiling>(armnn::profiling::BackendProfiling(
206                     options, m_ProfilingService, id));
207
208             // Backends may also provide a profiling context. Ask for it now.
209             auto profilingContext = backend->CreateBackendProfilingContext(options, profilingIface);
210             // Backends that don't support profiling will return a null profiling context.
211             if (profilingContext)
212             {
213                 // Pass the context onto the profiling service.
214                 m_ProfilingService.AddBackendProfilingContext(id, profilingContext);
215             }
216         }
217         catch (const BackendUnavailableException&)
218         {
219             // Ignore backends which are unavailable
220         }
221     }
222
223     // pass configuration info to the profiling service
224     m_ProfilingService.ConfigureProfilingService(options.m_ProfilingOptions);
225
226     m_DeviceSpec.AddSupportedBackends(supportedBackends);
227 }
228
229 Runtime::~Runtime()
230 {
231     std::vector<int> networkIDs;
232     try
233     {
234         // Coverity fix: The following code may throw an exception of type std::length_error.
235         std::transform(m_LoadedNetworks.begin(), m_LoadedNetworks.end(),
236                        std::back_inserter(networkIDs),
237                        [](const auto &pair) { return pair.first; });
238     }
239     catch (const std::exception& e)
240     {
241         // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
242         // exception of type std::length_error.
243         // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
244         std::cerr << "WARNING: An error has occurred when getting the IDs of the networks to unload: " << e.what()
245                   << "\nSome of the loaded networks may not be unloaded" << std::endl;
246     }
247     // We then proceed to unload all the networks which IDs have been appended to the list
248     // up to the point the exception was thrown (if any).
249
250     for (auto networkID : networkIDs)
251     {
252         try
253         {
254             // Coverity fix: UnloadNetwork() may throw an exception of type std::length_error,
255             // boost::log::v2s_mt_posix::odr_violation or boost::log::v2s_mt_posix::system_error
256             UnloadNetwork(networkID);
257         }
258         catch (const std::exception& e)
259         {
260             // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
261             // exception of type std::length_error.
262             // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
263             std::cerr << "WARNING: An error has occurred when unloading network " << networkID << ": " << e.what()
264                       << std::endl;
265         }
266     }
267
268     // Clear all dynamic backends.
269     DynamicBackendUtils::DeregisterDynamicBackends(m_DeviceSpec.GetDynamicBackends());
270     m_DeviceSpec.ClearDynamicBackends();
271     m_BackendContexts.clear();
272 }
273
274 LoadedNetwork* Runtime::GetLoadedNetworkPtr(NetworkId networkId) const
275 {
276     std::lock_guard<std::mutex> lockGuard(m_Mutex);
277     return m_LoadedNetworks.at(networkId).get();
278 }
279
280 TensorInfo Runtime::GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
281 {
282     return GetLoadedNetworkPtr(networkId)->GetInputTensorInfo(layerId);
283 }
284
285 TensorInfo Runtime::GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
286 {
287     return GetLoadedNetworkPtr(networkId)->GetOutputTensorInfo(layerId);
288 }
289
290
291 Status Runtime::EnqueueWorkload(NetworkId networkId,
292                                 const InputTensors& inputTensors,
293                                 const OutputTensors& outputTensors)
294 {
295     LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId);
296
297     static thread_local NetworkId lastId = networkId;
298     if (lastId != networkId)
299     {
300         LoadedNetworkFuncSafe(lastId, [](LoadedNetwork* network)
301             {
302                 network->FreeWorkingMemory();
303             });
304     }
305     lastId=networkId;
306
307     return loadedNetwork->EnqueueWorkload(inputTensors, outputTensors);
308 }
309
310 void Runtime::RegisterDebugCallback(NetworkId networkId, const DebugCallbackFunction& func)
311 {
312     LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId);
313     loadedNetwork->RegisterDebugCallback(func);
314 }
315
316 void Runtime::LoadDynamicBackends(const std::string& overrideBackendPath)
317 {
318     // Get the paths where to load the dynamic backends from
319     std::vector<std::string> backendPaths = DynamicBackendUtils::GetBackendPaths(overrideBackendPath);
320
321     // Get the shared objects to try to load as dynamic backends
322     std::vector<std::string> sharedObjects = DynamicBackendUtils::GetSharedObjects(backendPaths);
323
324     // Create a list of dynamic backends
325     m_DynamicBackends = DynamicBackendUtils::CreateDynamicBackends(sharedObjects);
326
327     // Register the dynamic backends in the backend registry
328     BackendIdSet registeredBackendIds = DynamicBackendUtils::RegisterDynamicBackends(m_DynamicBackends);
329
330     // Add the registered dynamic backend ids to the list of supported backends
331     m_DeviceSpec.AddSupportedBackends(registeredBackendIds, true);
332 }
333
334 } // namespace armnn