Release 18.08
[platform/upstream/armnn.git] / src / armnn / Runtime.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5 #include "Runtime.hpp"
6
7 #include "armnn/Version.hpp"
8
9 #include <iostream>
10
11 #ifdef ARMCOMPUTECL_ENABLED
12 #include <arm_compute/core/CL/OpenCL.h>
13 #include <arm_compute/core/CL/CLKernelLibrary.h>
14 #include <arm_compute/runtime/CL/CLScheduler.h>
15 #endif
16
17 #include <boost/log/trivial.hpp>
18 #include <boost/polymorphic_cast.hpp>
19
20 using namespace armnn;
21 using namespace std;
22
23 namespace armnn
24 {
25
26 IRuntime* IRuntime::CreateRaw(const CreationOptions& options)
27 {
28     return new Runtime(options);
29 }
30
31 IRuntimePtr IRuntime::Create(const CreationOptions& options)
32 {
33     return IRuntimePtr(CreateRaw(options), &IRuntime::Destroy);
34 }
35
36 void IRuntime::Destroy(IRuntime* runtime)
37 {
38     delete boost::polymorphic_downcast<Runtime*>(runtime);
39 }
40
41 int Runtime::GenerateNetworkId()
42 {
43     return m_NetworkIdCounter++;
44 }
45
46 Status Runtime::LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr inNetwork)
47 {
48     std::string ignoredErrorMessage;
49     return LoadNetwork(networkIdOut, std::move(inNetwork), ignoredErrorMessage);
50 }
51
52 Status Runtime::LoadNetwork(NetworkId& networkIdOut,
53                             IOptimizedNetworkPtr inNetwork,
54                             std::string & errorMessage)
55 {
56     IOptimizedNetwork* rawNetwork = inNetwork.release();
57     unique_ptr<LoadedNetwork> loadedNetwork = LoadedNetwork::MakeLoadedNetwork(
58         std::unique_ptr<OptimizedNetwork>(boost::polymorphic_downcast<OptimizedNetwork*>(rawNetwork)),
59         errorMessage);
60
61     if (!loadedNetwork)
62     {
63         return Status::Failure;
64     }
65
66     networkIdOut = GenerateNetworkId();
67
68     {
69         std::lock_guard<std::mutex> lockGuard(m_Mutex);
70
71         // Stores the network
72         m_LoadedNetworks[networkIdOut] = std::move(loadedNetwork);
73     }
74
75     return Status::Success;
76 }
77
78 Status Runtime::UnloadNetwork(NetworkId networkId)
79 {
80 #ifdef ARMCOMPUTECL_ENABLED
81     if (arm_compute::CLScheduler::get().context()() != NULL)
82     {
83         // Waits for all queued CL requests to finish before unloading the network they may be using.
84         try
85         {
86             // Coverity fix: arm_compute::CLScheduler::sync() may throw an exception of type cl::Error.
87             arm_compute::CLScheduler::get().sync();
88         }
89         catch (const cl::Error&)
90         {
91             BOOST_LOG_TRIVIAL(warning) << "WARNING: Runtime::UnloadNetwork(): an error occurred while waiting for "
92                                           "the queued CL requests to finish";
93             return Status::Failure;
94         }
95     }
96 #endif
97
98     {
99         std::lock_guard<std::mutex> lockGuard(m_Mutex);
100
101         if (m_LoadedNetworks.erase(networkId) == 0)
102         {
103             BOOST_LOG_TRIVIAL(warning) << "WARNING: Runtime::UnloadNetwork(): " << networkId << " not found!";
104             return Status::Failure;
105         }
106
107 #ifdef ARMCOMPUTECL_ENABLED
108         if (arm_compute::CLScheduler::get().context()() != NULL && m_LoadedNetworks.empty())
109         {
110             // There are no loaded networks left, so clear the CL cache to free up memory
111             m_ClContextControl.ClearClCache();
112         }
113 #endif
114     }
115
116     BOOST_LOG_TRIVIAL(debug) << "Runtime::UnloadNetwork(): Unloaded network with ID: " << networkId;
117     return Status::Success;
118 }
119
120 const std::shared_ptr<IProfiler> Runtime::GetProfiler(NetworkId networkId) const
121 {
122     auto it = m_LoadedNetworks.find(networkId);
123     if (it != m_LoadedNetworks.end())
124     {
125         auto& loadedNetwork = it->second;
126         return loadedNetwork->GetProfiler();
127     }
128
129     return nullptr;
130 }
131
132 Runtime::Runtime(const CreationOptions& options)
133     : m_ClContextControl(options.m_GpuAccTunedParameters.get(),
134                          options.m_EnableGpuProfiling)
135     , m_NetworkIdCounter(0)
136 {
137     BOOST_LOG_TRIVIAL(info) << "ArmNN v" << ARMNN_VERSION << "\n";
138
139     m_DeviceSpec.m_SupportedComputeDevices.insert(armnn::Compute::CpuRef);
140     #if ARMCOMPUTECL_ENABLED
141         m_DeviceSpec.m_SupportedComputeDevices.insert(armnn::Compute::GpuAcc);
142     #endif
143     #if ARMCOMPUTENEON_ENABLED
144         m_DeviceSpec.m_SupportedComputeDevices.insert(armnn::Compute::CpuAcc);
145     #endif
146 }
147
148 Runtime::~Runtime()
149 {
150     std::vector<int> networkIDs;
151     try
152     {
153         // Coverity fix: The following code may throw an exception of type std::length_error.
154         std::transform(m_LoadedNetworks.begin(), m_LoadedNetworks.end(),
155                        std::back_inserter(networkIDs),
156                        [](const auto &pair) { return pair.first; });
157     }
158     catch (const std::exception& e)
159     {
160         // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
161         // exception of type std::length_error.
162         // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
163         std::cerr << "WARNING: An error has occurred when getting the IDs of the networks to unload: " << e.what()
164                   << "\nSome of the loaded networks may not be unloaded" << std::endl;
165     }
166     // We then proceed to unload all the networks which IDs have been appended to the list
167     // up to the point the exception was thrown (if any).
168
169     for (auto networkID : networkIDs)
170     {
171         try
172         {
173             // Coverity fix: UnloadNetwork() may throw an exception of type std::length_error,
174             // boost::log::v2s_mt_posix::odr_violation or boost::log::v2s_mt_posix::system_error
175             UnloadNetwork(networkID);
176         }
177         catch (const std::exception& e)
178         {
179             // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
180             // exception of type std::length_error.
181             // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
182             std::cerr << "WARNING: An error has occurred when unloading network " << networkID << ": " << e.what()
183                       << std::endl;
184         }
185     }
186 }
187
188 LoadedNetwork* Runtime::GetLoadedNetworkPtr(NetworkId networkId) const
189 {
190     std::lock_guard<std::mutex> lockGuard(m_Mutex);
191     return m_LoadedNetworks.at(networkId).get();
192 }
193
194 TensorInfo Runtime::GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
195 {
196     return GetLoadedNetworkPtr(networkId)->GetInputTensorInfo(layerId);
197 }
198
199 TensorInfo Runtime::GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
200 {
201     return GetLoadedNetworkPtr(networkId)->GetOutputTensorInfo(layerId);
202 }
203
204 Status Runtime::EnqueueWorkload(NetworkId networkId,
205                                 const InputTensors& inputTensors,
206                                 const OutputTensors& outputTensors)
207 {
208     LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId);
209     return loadedNetwork->EnqueueWorkload(inputTensors, outputTensors);
210 }
211
212 }