2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
7 #include <armnn/Version.hpp>
8 #include <backendsCommon/BackendRegistry.hpp>
9 #include <backendsCommon/BackendContextRegistry.hpp>
13 #include <boost/log/trivial.hpp>
14 #include <boost/polymorphic_cast.hpp>
16 using namespace armnn;
22 IRuntime* IRuntime::CreateRaw(const CreationOptions& options)
24 return new Runtime(options);
27 IRuntimePtr IRuntime::Create(const CreationOptions& options)
29 return IRuntimePtr(CreateRaw(options), &IRuntime::Destroy);
32 void IRuntime::Destroy(IRuntime* runtime)
34 delete boost::polymorphic_downcast<Runtime*>(runtime);
37 int Runtime::GenerateNetworkId()
39 return m_NetworkIdCounter++;
42 Status Runtime::LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr inNetwork)
44 std::string ignoredErrorMessage;
45 return LoadNetwork(networkIdOut, std::move(inNetwork), ignoredErrorMessage);
48 Status Runtime::LoadNetwork(NetworkId& networkIdOut,
49 IOptimizedNetworkPtr inNetwork,
50 std::string & errorMessage)
52 IOptimizedNetwork* rawNetwork = inNetwork.release();
53 unique_ptr<LoadedNetwork> loadedNetwork = LoadedNetwork::MakeLoadedNetwork(
54 std::unique_ptr<OptimizedNetwork>(boost::polymorphic_downcast<OptimizedNetwork*>(rawNetwork)),
60 return Status::Failure;
63 networkIdOut = GenerateNetworkId();
66 std::lock_guard<std::mutex> lockGuard(m_Mutex);
69 m_LoadedNetworks[networkIdOut] = std::move(loadedNetwork);
72 return Status::Success;
75 Status Runtime::UnloadNetwork(NetworkId networkId)
78 std::lock_guard<std::mutex> lockGuard(m_Mutex);
80 if (m_LoadedNetworks.erase(networkId) == 0)
82 BOOST_LOG_TRIVIAL(warning) << "WARNING: Runtime::UnloadNetwork(): " << networkId << " not found!";
83 return Status::Failure;
87 BOOST_LOG_TRIVIAL(debug) << "Runtime::UnloadNetwork(): Unloaded network with ID: " << networkId;
88 return Status::Success;
91 const std::shared_ptr<IProfiler> Runtime::GetProfiler(NetworkId networkId) const
93 auto it = m_LoadedNetworks.find(networkId);
94 if (it != m_LoadedNetworks.end())
96 auto& loadedNetwork = it->second;
97 return loadedNetwork->GetProfiler();
103 Runtime::Runtime(const CreationOptions& options)
105 , m_NetworkIdCounter(0)
106 , m_DeviceSpec{BackendRegistryInstance().GetBackendIds()}
108 BOOST_LOG_TRIVIAL(info) << "ArmNN v" << ARMNN_VERSION << "\n";
110 for (const auto& id : BackendContextRegistryInstance().GetBackendIds())
112 // Store backend contexts for the supported ones
113 if (m_DeviceSpec.GetSupportedBackends().count(id) > 0)
115 // Don't throw an exception, rather return a dummy factory if not
117 auto factoryFun = BackendContextRegistryInstance().GetFactory(
118 id, [](const CreationOptions&) { return IBackendContextUniquePtr(); }
121 m_BackendContexts.emplace(std::make_pair(id, factoryFun(options)));
128 std::vector<int> networkIDs;
131 // Coverity fix: The following code may throw an exception of type std::length_error.
132 std::transform(m_LoadedNetworks.begin(), m_LoadedNetworks.end(),
133 std::back_inserter(networkIDs),
134 [](const auto &pair) { return pair.first; });
136 catch (const std::exception& e)
138 // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
139 // exception of type std::length_error.
140 // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
141 std::cerr << "WARNING: An error has occurred when getting the IDs of the networks to unload: " << e.what()
142 << "\nSome of the loaded networks may not be unloaded" << std::endl;
144 // We then proceed to unload all the networks which IDs have been appended to the list
145 // up to the point the exception was thrown (if any).
147 for (auto networkID : networkIDs)
151 // Coverity fix: UnloadNetwork() may throw an exception of type std::length_error,
152 // boost::log::v2s_mt_posix::odr_violation or boost::log::v2s_mt_posix::system_error
153 UnloadNetwork(networkID);
155 catch (const std::exception& e)
157 // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
158 // exception of type std::length_error.
159 // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
160 std::cerr << "WARNING: An error has occurred when unloading network " << networkID << ": " << e.what()
166 LoadedNetwork* Runtime::GetLoadedNetworkPtr(NetworkId networkId) const
168 std::lock_guard<std::mutex> lockGuard(m_Mutex);
169 return m_LoadedNetworks.at(networkId).get();
172 TensorInfo Runtime::GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
174 return GetLoadedNetworkPtr(networkId)->GetInputTensorInfo(layerId);
177 TensorInfo Runtime::GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
179 return GetLoadedNetworkPtr(networkId)->GetOutputTensorInfo(layerId);
183 Status Runtime::EnqueueWorkload(NetworkId networkId,
184 const InputTensors& inputTensors,
185 const OutputTensors& outputTensors)
187 LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId);
189 static thread_local NetworkId lastId = networkId;
190 if (lastId != networkId)
192 LoadedNetworkFuncSafe(lastId, [](LoadedNetwork* network)
194 network->FreeWorkingMemory();
199 return loadedNetwork->EnqueueWorkload(inputTensors, outputTensors);