2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
7 #include <armnn/Version.hpp>
8 #include <armnn/BackendRegistry.hpp>
9 #include <armnn/Logging.hpp>
11 #include <armnn/backends/IBackendContext.hpp>
12 #include <backendsCommon/DynamicBackendUtils.hpp>
13 #include <armnn/utility/PolymorphicDowncast.hpp>
17 #include <backends/BackendProfiling.hpp>
19 using namespace armnn;
25 IRuntime* IRuntime::CreateRaw(const CreationOptions& options)
27 return new Runtime(options);
30 IRuntimePtr IRuntime::Create(const CreationOptions& options)
32 return IRuntimePtr(CreateRaw(options), &IRuntime::Destroy);
35 void IRuntime::Destroy(IRuntime* runtime)
37 delete PolymorphicDowncast<Runtime*>(runtime);
40 int Runtime::GenerateNetworkId()
42 return m_NetworkIdCounter++;
45 Status Runtime::LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr inNetwork)
47 std::string ignoredErrorMessage;
48 return LoadNetwork(networkIdOut, std::move(inNetwork), ignoredErrorMessage);
51 Status Runtime::LoadNetwork(NetworkId& networkIdOut,
52 IOptimizedNetworkPtr inNetwork,
53 std::string& errorMessage)
55 INetworkProperties networkProperties;
56 return LoadNetwork(networkIdOut, std::move(inNetwork), errorMessage, networkProperties);
59 Status Runtime::LoadNetwork(NetworkId& networkIdOut,
60 IOptimizedNetworkPtr inNetwork,
61 std::string& errorMessage,
62 const INetworkProperties& networkProperties)
64 IOptimizedNetwork* rawNetwork = inNetwork.release();
66 networkIdOut = GenerateNetworkId();
68 for (auto&& context : m_BackendContexts)
70 context.second->BeforeLoadNetwork(networkIdOut);
73 unique_ptr<LoadedNetwork> loadedNetwork = LoadedNetwork::MakeLoadedNetwork(
74 std::unique_ptr<OptimizedNetwork>(PolymorphicDowncast<OptimizedNetwork*>(rawNetwork)),
81 return Status::Failure;
85 std::lock_guard<std::mutex> lockGuard(m_Mutex);
88 m_LoadedNetworks[networkIdOut] = std::move(loadedNetwork);
91 for (auto&& context : m_BackendContexts)
93 context.second->AfterLoadNetwork(networkIdOut);
96 if (m_ProfilingService.IsProfilingEnabled())
98 m_ProfilingService.IncrementCounterValue(armnn::profiling::NETWORK_LOADS);
101 return Status::Success;
104 Status Runtime::UnloadNetwork(NetworkId networkId)
106 bool unloadOk = true;
107 for (auto&& context : m_BackendContexts)
109 unloadOk &= context.second->BeforeUnloadNetwork(networkId);
114 ARMNN_LOG(warning) << "Runtime::UnloadNetwork(): failed to unload "
115 "network with ID:" << networkId << " because BeforeUnloadNetwork failed";
116 return Status::Failure;
120 std::lock_guard<std::mutex> lockGuard(m_Mutex);
122 if (m_LoadedNetworks.erase(networkId) == 0)
124 ARMNN_LOG(warning) << "WARNING: Runtime::UnloadNetwork(): " << networkId << " not found!";
125 return Status::Failure;
128 if (m_ProfilingService.IsProfilingEnabled())
130 m_ProfilingService.IncrementCounterValue(armnn::profiling::NETWORK_UNLOADS);
134 for (auto&& context : m_BackendContexts)
136 context.second->AfterUnloadNetwork(networkId);
139 ARMNN_LOG(debug) << "Runtime::UnloadNetwork(): Unloaded network with ID: " << networkId;
140 return Status::Success;
143 const std::shared_ptr<IProfiler> Runtime::GetProfiler(NetworkId networkId) const
145 auto it = m_LoadedNetworks.find(networkId);
146 if (it != m_LoadedNetworks.end())
148 auto& loadedNetwork = it->second;
149 return loadedNetwork->GetProfiler();
155 void Runtime::ReportStructure() // armnn::profiling::IProfilingService& profilingService as param
157 // No-op for the time being, but this may be useful in future to have the profilingService available
158 // if (profilingService.IsProfilingEnabled()){}
160 LoadedNetworks::iterator it = m_LoadedNetworks.begin();
161 while (it != m_LoadedNetworks.end())
163 auto& loadedNetwork = it->second;
164 loadedNetwork->SendNetworkStructure();
165 // Increment the Iterator to point to next entry
170 Runtime::Runtime(const CreationOptions& options)
171 : m_NetworkIdCounter(0),
172 m_ProfilingService(*this)
174 ARMNN_LOG(info) << "ArmNN v" << ARMNN_VERSION << "\n";
176 if ( options.m_ProfilingOptions.m_TimelineEnabled && !options.m_ProfilingOptions.m_EnableProfiling )
178 throw RuntimeException("It is not possible to enable timeline reporting without profiling being enabled");
181 // Load any available/compatible dynamic backend before the runtime
182 // goes through the backend registry
183 LoadDynamicBackends(options.m_DynamicBackendsPath);
185 BackendIdSet supportedBackends;
186 for (const auto& id : BackendRegistryInstance().GetBackendIds())
188 // Store backend contexts for the supported ones
190 auto factoryFun = BackendRegistryInstance().GetFactory(id);
191 auto backend = factoryFun();
192 ARMNN_ASSERT(backend.get() != nullptr);
194 auto context = backend->CreateBackendContext(options);
196 // backends are allowed to return nullptrs if they
197 // don't wish to create a backend specific context
200 m_BackendContexts.emplace(std::make_pair(id, std::move(context)));
202 supportedBackends.emplace(id);
204 unique_ptr<armnn::profiling::IBackendProfiling> profilingIface =
205 std::make_unique<armnn::profiling::BackendProfiling>(armnn::profiling::BackendProfiling(
206 options, m_ProfilingService, id));
208 // Backends may also provide a profiling context. Ask for it now.
209 auto profilingContext = backend->CreateBackendProfilingContext(options, profilingIface);
210 // Backends that don't support profiling will return a null profiling context.
211 if (profilingContext)
213 // Pass the context onto the profiling service.
214 m_ProfilingService.AddBackendProfilingContext(id, profilingContext);
217 catch (const BackendUnavailableException&)
219 // Ignore backends which are unavailable
223 // pass configuration info to the profiling service
224 m_ProfilingService.ConfigureProfilingService(options.m_ProfilingOptions);
226 m_DeviceSpec.AddSupportedBackends(supportedBackends);
231 std::vector<int> networkIDs;
234 // Coverity fix: The following code may throw an exception of type std::length_error.
235 std::transform(m_LoadedNetworks.begin(), m_LoadedNetworks.end(),
236 std::back_inserter(networkIDs),
237 [](const auto &pair) { return pair.first; });
239 catch (const std::exception& e)
241 // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
242 // exception of type std::length_error.
243 // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
244 std::cerr << "WARNING: An error has occurred when getting the IDs of the networks to unload: " << e.what()
245 << "\nSome of the loaded networks may not be unloaded" << std::endl;
247 // We then proceed to unload all the networks which IDs have been appended to the list
248 // up to the point the exception was thrown (if any).
250 for (auto networkID : networkIDs)
254 // Coverity fix: UnloadNetwork() may throw an exception of type std::length_error,
255 // boost::log::v2s_mt_posix::odr_violation or boost::log::v2s_mt_posix::system_error
256 UnloadNetwork(networkID);
258 catch (const std::exception& e)
260 // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
261 // exception of type std::length_error.
262 // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
263 std::cerr << "WARNING: An error has occurred when unloading network " << networkID << ": " << e.what()
268 // Clear all dynamic backends.
269 DynamicBackendUtils::DeregisterDynamicBackends(m_DeviceSpec.GetDynamicBackends());
270 m_DeviceSpec.ClearDynamicBackends();
271 m_BackendContexts.clear();
274 LoadedNetwork* Runtime::GetLoadedNetworkPtr(NetworkId networkId) const
276 std::lock_guard<std::mutex> lockGuard(m_Mutex);
277 return m_LoadedNetworks.at(networkId).get();
280 TensorInfo Runtime::GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
282 return GetLoadedNetworkPtr(networkId)->GetInputTensorInfo(layerId);
285 TensorInfo Runtime::GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
287 return GetLoadedNetworkPtr(networkId)->GetOutputTensorInfo(layerId);
291 Status Runtime::EnqueueWorkload(NetworkId networkId,
292 const InputTensors& inputTensors,
293 const OutputTensors& outputTensors)
295 LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId);
297 static thread_local NetworkId lastId = networkId;
298 if (lastId != networkId)
300 LoadedNetworkFuncSafe(lastId, [](LoadedNetwork* network)
302 network->FreeWorkingMemory();
307 return loadedNetwork->EnqueueWorkload(inputTensors, outputTensors);
310 void Runtime::RegisterDebugCallback(NetworkId networkId, const DebugCallbackFunction& func)
312 LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId);
313 loadedNetwork->RegisterDebugCallback(func);
316 void Runtime::LoadDynamicBackends(const std::string& overrideBackendPath)
318 // Get the paths where to load the dynamic backends from
319 std::vector<std::string> backendPaths = DynamicBackendUtils::GetBackendPaths(overrideBackendPath);
321 // Get the shared objects to try to load as dynamic backends
322 std::vector<std::string> sharedObjects = DynamicBackendUtils::GetSharedObjects(backendPaths);
324 // Create a list of dynamic backends
325 m_DynamicBackends = DynamicBackendUtils::CreateDynamicBackends(sharedObjects);
327 // Register the dynamic backends in the backend registry
328 BackendIdSet registeredBackendIds = DynamicBackendUtils::RegisterDynamicBackends(m_DynamicBackends);
330 // Add the registered dynamic backend ids to the list of supported backends
331 m_DeviceSpec.AddSupportedBackends(registeredBackendIds, true);