IVGCVSW-1946: Remove armnn/src from the include paths
[platform/upstream/armnn.git] / src / armnn / Runtime.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include "Runtime.hpp"
6
7 #include <armnn/Version.hpp>
8 #include <backendsCommon/BackendRegistry.hpp>
9 #include <backendsCommon/BackendContextRegistry.hpp>
10
11 #include <iostream>
12
13 #include <boost/log/trivial.hpp>
14 #include <boost/polymorphic_cast.hpp>
15
16 using namespace armnn;
17 using namespace std;
18
19 namespace armnn
20 {
21
22 IRuntime* IRuntime::CreateRaw(const CreationOptions& options)
23 {
24     return new Runtime(options);
25 }
26
27 IRuntimePtr IRuntime::Create(const CreationOptions& options)
28 {
29     return IRuntimePtr(CreateRaw(options), &IRuntime::Destroy);
30 }
31
32 void IRuntime::Destroy(IRuntime* runtime)
33 {
34     delete boost::polymorphic_downcast<Runtime*>(runtime);
35 }
36
37 int Runtime::GenerateNetworkId()
38 {
39     return m_NetworkIdCounter++;
40 }
41
42 Status Runtime::LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr inNetwork)
43 {
44     std::string ignoredErrorMessage;
45     return LoadNetwork(networkIdOut, std::move(inNetwork), ignoredErrorMessage);
46 }
47
48 Status Runtime::LoadNetwork(NetworkId& networkIdOut,
49                             IOptimizedNetworkPtr inNetwork,
50                             std::string & errorMessage)
51 {
52     IOptimizedNetwork* rawNetwork = inNetwork.release();
53     unique_ptr<LoadedNetwork> loadedNetwork = LoadedNetwork::MakeLoadedNetwork(
54         std::unique_ptr<OptimizedNetwork>(boost::polymorphic_downcast<OptimizedNetwork*>(rawNetwork)),
55         m_Options,
56         errorMessage);
57
58     if (!loadedNetwork)
59     {
60         return Status::Failure;
61     }
62
63     networkIdOut = GenerateNetworkId();
64
65     {
66         std::lock_guard<std::mutex> lockGuard(m_Mutex);
67
68         // Stores the network
69         m_LoadedNetworks[networkIdOut] = std::move(loadedNetwork);
70     }
71
72     return Status::Success;
73 }
74
75 Status Runtime::UnloadNetwork(NetworkId networkId)
76 {
77     {
78         std::lock_guard<std::mutex> lockGuard(m_Mutex);
79
80         if (m_LoadedNetworks.erase(networkId) == 0)
81         {
82             BOOST_LOG_TRIVIAL(warning) << "WARNING: Runtime::UnloadNetwork(): " << networkId << " not found!";
83             return Status::Failure;
84         }
85     }
86
87     BOOST_LOG_TRIVIAL(debug) << "Runtime::UnloadNetwork(): Unloaded network with ID: " << networkId;
88     return Status::Success;
89 }
90
91 const std::shared_ptr<IProfiler> Runtime::GetProfiler(NetworkId networkId) const
92 {
93     auto it = m_LoadedNetworks.find(networkId);
94     if (it != m_LoadedNetworks.end())
95     {
96         auto& loadedNetwork = it->second;
97         return loadedNetwork->GetProfiler();
98     }
99
100     return nullptr;
101 }
102
103 Runtime::Runtime(const CreationOptions& options)
104     : m_Options{options}
105     , m_NetworkIdCounter(0)
106     , m_DeviceSpec{BackendRegistryInstance().GetBackendIds()}
107 {
108     BOOST_LOG_TRIVIAL(info) << "ArmNN v" << ARMNN_VERSION << "\n";
109
110     for (const auto& id : BackendContextRegistryInstance().GetBackendIds())
111     {
112         // Store backend contexts for the supported ones
113         if (m_DeviceSpec.GetSupportedBackends().count(id) > 0)
114         {
115             // Don't throw an exception, rather return a dummy factory if not
116             // found.
117             auto factoryFun = BackendContextRegistryInstance().GetFactory(
118                 id, [](const CreationOptions&) { return IBackendContextUniquePtr(); }
119             );
120
121             m_BackendContexts.emplace(std::make_pair(id, factoryFun(options)));
122         }
123     }
124 }
125
126 Runtime::~Runtime()
127 {
128     std::vector<int> networkIDs;
129     try
130     {
131         // Coverity fix: The following code may throw an exception of type std::length_error.
132         std::transform(m_LoadedNetworks.begin(), m_LoadedNetworks.end(),
133                        std::back_inserter(networkIDs),
134                        [](const auto &pair) { return pair.first; });
135     }
136     catch (const std::exception& e)
137     {
138         // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
139         // exception of type std::length_error.
140         // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
141         std::cerr << "WARNING: An error has occurred when getting the IDs of the networks to unload: " << e.what()
142                   << "\nSome of the loaded networks may not be unloaded" << std::endl;
143     }
144     // We then proceed to unload all the networks which IDs have been appended to the list
145     // up to the point the exception was thrown (if any).
146
147     for (auto networkID : networkIDs)
148     {
149         try
150         {
151             // Coverity fix: UnloadNetwork() may throw an exception of type std::length_error,
152             // boost::log::v2s_mt_posix::odr_violation or boost::log::v2s_mt_posix::system_error
153             UnloadNetwork(networkID);
154         }
155         catch (const std::exception& e)
156         {
157             // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
158             // exception of type std::length_error.
159             // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
160             std::cerr << "WARNING: An error has occurred when unloading network " << networkID << ": " << e.what()
161                       << std::endl;
162         }
163     }
164 }
165
166 LoadedNetwork* Runtime::GetLoadedNetworkPtr(NetworkId networkId) const
167 {
168     std::lock_guard<std::mutex> lockGuard(m_Mutex);
169     return m_LoadedNetworks.at(networkId).get();
170 }
171
172 TensorInfo Runtime::GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
173 {
174     return GetLoadedNetworkPtr(networkId)->GetInputTensorInfo(layerId);
175 }
176
177 TensorInfo Runtime::GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
178 {
179     return GetLoadedNetworkPtr(networkId)->GetOutputTensorInfo(layerId);
180 }
181
182
183 Status Runtime::EnqueueWorkload(NetworkId networkId,
184                                 const InputTensors& inputTensors,
185                                 const OutputTensors& outputTensors)
186 {
187     LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId);
188
189     static thread_local NetworkId lastId = networkId;
190     if (lastId != networkId)
191     {
192         LoadedNetworkFuncSafe(lastId, [](LoadedNetwork* network)
193             {
194                 network->FreeWorkingMemory();
195             });
196     }
197     lastId=networkId;
198
199     return loadedNetwork->EnqueueWorkload(inputTensors, outputTensors);
200 }
201
202 }