IVGCVSW-5301 Remove all boost::numeric_cast from armnn/src/profiling
[platform/upstream/armnn.git] / src / profiling / ProfilingService.cpp
1 //
2 // Copyright © 2019 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "ProfilingService.hpp"
7
8 #include <armnn/BackendId.hpp>
9 #include <armnn/Logging.hpp>
10 #include <armnn/utility/NumericCast.hpp>
11
12 #include <common/include/SocketConnectionException.hpp>
13
14 #include <boost/format.hpp>
15
16 namespace armnn
17 {
18
19 namespace profiling
20 {
21
22 ProfilingGuidGenerator ProfilingService::m_GuidGenerator;
23
24 ProfilingDynamicGuid ProfilingService::GetNextGuid()
25 {
26     return m_GuidGenerator.NextGuid();
27 }
28
29 ProfilingStaticGuid ProfilingService::GetStaticId(const std::string& str)
30 {
31     return m_GuidGenerator.GenerateStaticId(str);
32 }
33
34 void ProfilingService::ResetGuidGenerator()
35 {
36     m_GuidGenerator.Reset();
37 }
38
39 void ProfilingService::ResetExternalProfilingOptions(const ExternalProfilingOptions& options,
40                                                      bool resetProfilingService)
41 {
42     // Update the profiling options
43     m_Options = options;
44     m_TimelineReporting = options.m_TimelineEnabled;
45     m_ConnectionAcknowledgedCommandHandler.setTimelineEnabled(options.m_TimelineEnabled);
46
47     // Check if the profiling service needs to be reset
48     if (resetProfilingService)
49     {
50         // Reset the profiling service
51         Reset();
52     }
53 }
54
55 bool ProfilingService::IsProfilingEnabled() const
56 {
57     return m_Options.m_EnableProfiling;
58 }
59
60 ProfilingState ProfilingService::ConfigureProfilingService(
61         const ExternalProfilingOptions& options,
62         bool resetProfilingService)
63 {
64     ResetExternalProfilingOptions(options, resetProfilingService);
65     ProfilingState currentState = m_StateMachine.GetCurrentState();
66     if (options.m_EnableProfiling)
67     {
68         switch (currentState)
69         {
70             case ProfilingState::Uninitialised:
71                 Update(); // should transition to NotConnected
72                 Update(); // will either stay in NotConnected because there is no server
73                           // or will enter WaitingForAck.
74                 currentState = m_StateMachine.GetCurrentState();
75                 if (currentState == ProfilingState::WaitingForAck)
76                 {
77                     Update(); // poke it again to send out the metadata packet
78                 }
79                 currentState = m_StateMachine.GetCurrentState();
80                 return currentState;
81             case ProfilingState::NotConnected:
82                 Update(); // will either stay in NotConnected because there is no server
83                           // or will enter WaitingForAck
84                 currentState = m_StateMachine.GetCurrentState();
85                 if (currentState == ProfilingState::WaitingForAck)
86                 {
87                     Update(); // poke it again to send out the metadata packet
88                 }
89                 currentState = m_StateMachine.GetCurrentState();
90                 return currentState;
91             default:
92                 return currentState;
93         }
94     }
95     else
96     {
97         // Make sure profiling is shutdown
98         switch (currentState)
99         {
100             case ProfilingState::Uninitialised:
101             case ProfilingState::NotConnected:
102                 return currentState;
103             default:
104                 Stop();
105                 return m_StateMachine.GetCurrentState();
106         }
107     }
108 }
109
110 void ProfilingService::Update()
111 {
112     if (!m_Options.m_EnableProfiling)
113     {
114         // Don't run if profiling is disabled
115         return;
116     }
117
118     ProfilingState currentState = m_StateMachine.GetCurrentState();
119     switch (currentState)
120     {
121     case ProfilingState::Uninitialised:
122
123         // Initialize the profiling service
124         Initialize();
125
126         // Move to the next state
127         m_StateMachine.TransitionToState(ProfilingState::NotConnected);
128         break;
129     case ProfilingState::NotConnected:
130         // Stop the command thread (if running)
131         m_CommandHandler.Stop();
132
133         // Stop the send thread (if running)
134         m_SendThread.Stop(false);
135
136         // Stop the periodic counter capture thread (if running)
137         m_PeriodicCounterCapture.Stop();
138
139         // Reset any existing profiling connection
140         m_ProfilingConnection.reset();
141
142         try
143         {
144             // Setup the profiling connection
145             ARMNN_ASSERT(m_ProfilingConnectionFactory);
146             m_ProfilingConnection = m_ProfilingConnectionFactory->GetProfilingConnection(m_Options);
147         }
148         catch (const Exception& e)
149         {
150             ARMNN_LOG(warning) << "An error has occurred when creating the profiling connection: "
151                                        << e.what();
152         }
153         catch (const arm::pipe::SocketConnectionException& e)
154         {
155             ARMNN_LOG(warning) << "An error has occurred when creating the profiling connection ["
156                                        << e.what() << "] on socket [" << e.GetSocketFd() << "].";
157         }
158
159         // Move to the next state
160         m_StateMachine.TransitionToState(m_ProfilingConnection
161                                          ? ProfilingState::WaitingForAck  // Profiling connection obtained, wait for ack
162                                          : ProfilingState::NotConnected); // Profiling connection failed, stay in the
163                                                                           // "NotConnected" state
164         break;
165     case ProfilingState::WaitingForAck:
166         ARMNN_ASSERT(m_ProfilingConnection);
167
168         // Start the command thread
169         m_CommandHandler.Start(*m_ProfilingConnection);
170
171         // Start the send thread, while in "WaitingForAck" state it'll send out a "Stream MetaData" packet waiting for
172         // a valid "Connection Acknowledged" packet confirming the connection
173         m_SendThread.Start(*m_ProfilingConnection);
174
175         // The connection acknowledged command handler will automatically transition the state to "Active" once a
176         // valid "Connection Acknowledged" packet has been received
177
178         break;
179     case ProfilingState::Active:
180
181         // The period counter capture thread is started by the Periodic Counter Selection command handler upon
182         // request by an external profiling service
183
184         break;
185     default:
186         throw RuntimeException(boost::str(boost::format("Unknown profiling service state: %1")
187                                           % static_cast<int>(currentState)));
188     }
189 }
190
191 void ProfilingService::Disconnect()
192 {
193     ProfilingState currentState = m_StateMachine.GetCurrentState();
194     switch (currentState)
195     {
196     case ProfilingState::Uninitialised:
197     case ProfilingState::NotConnected:
198     case ProfilingState::WaitingForAck:
199         return; // NOP
200     case ProfilingState::Active:
201         // Stop the command thread (if running)
202         Stop();
203
204         break;
205     default:
206         throw RuntimeException(boost::str(boost::format("Unknown profiling service state: %1")
207                                           % static_cast<int>(currentState)));
208     }
209 }
210
211 // Store a profiling context returned from a backend that support profiling, and register its counters
212 void ProfilingService::AddBackendProfilingContext(const BackendId backendId,
213     std::shared_ptr<armnn::profiling::IBackendProfilingContext> profilingContext)
214 {
215     ARMNN_ASSERT(profilingContext != nullptr);
216     // Register the backend counters
217     m_MaxGlobalCounterId = profilingContext->RegisterCounters(m_MaxGlobalCounterId);
218     m_BackendProfilingContexts.emplace(backendId, std::move(profilingContext));
219 }
220 const ICounterDirectory& ProfilingService::GetCounterDirectory() const
221 {
222     return m_CounterDirectory;
223 }
224
225 ICounterRegistry& ProfilingService::GetCounterRegistry()
226 {
227     return m_CounterDirectory;
228 }
229
230 ProfilingState ProfilingService::GetCurrentState() const
231 {
232     return m_StateMachine.GetCurrentState();
233 }
234
235 uint16_t ProfilingService::GetCounterCount() const
236 {
237     return m_CounterDirectory.GetCounterCount();
238 }
239
240 bool ProfilingService::IsCounterRegistered(uint16_t counterUid) const
241 {
242     return m_CounterDirectory.IsCounterRegistered(counterUid);
243 }
244
245 uint32_t ProfilingService::GetAbsoluteCounterValue(uint16_t counterUid) const
246 {
247     CheckCounterUid(counterUid);
248     std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
249     ARMNN_ASSERT(counterValuePtr);
250     return counterValuePtr->load(std::memory_order::memory_order_relaxed);
251 }
252
253 uint32_t ProfilingService::GetDeltaCounterValue(uint16_t counterUid)
254 {
255     CheckCounterUid(counterUid);
256     std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
257     ARMNN_ASSERT(counterValuePtr);
258     const uint32_t counterValue = counterValuePtr->load(std::memory_order::memory_order_relaxed);
259     SubtractCounterValue(counterUid, counterValue);
260     return counterValue;
261 }
262
263 const ICounterMappings& ProfilingService::GetCounterMappings() const
264 {
265     return m_CounterIdMap;
266 }
267
268 IRegisterCounterMapping& ProfilingService::GetCounterMappingRegistry()
269 {
270     return m_CounterIdMap;
271 }
272
273 CaptureData ProfilingService::GetCaptureData()
274 {
275     return m_Holder.GetCaptureData();
276 }
277
278 void ProfilingService::SetCaptureData(uint32_t capturePeriod,
279                                       const std::vector<uint16_t>& counterIds,
280                                       const std::set<BackendId>& activeBackends)
281 {
282     m_Holder.SetCaptureData(capturePeriod, counterIds, activeBackends);
283 }
284
285 void ProfilingService::SetCounterValue(uint16_t counterUid, uint32_t value)
286 {
287     CheckCounterUid(counterUid);
288     std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
289     ARMNN_ASSERT(counterValuePtr);
290     counterValuePtr->store(value, std::memory_order::memory_order_relaxed);
291 }
292
293 uint32_t ProfilingService::AddCounterValue(uint16_t counterUid, uint32_t value)
294 {
295     CheckCounterUid(counterUid);
296     std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
297     ARMNN_ASSERT(counterValuePtr);
298     return counterValuePtr->fetch_add(value, std::memory_order::memory_order_relaxed);
299 }
300
301 uint32_t ProfilingService::SubtractCounterValue(uint16_t counterUid, uint32_t value)
302 {
303     CheckCounterUid(counterUid);
304     std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
305     ARMNN_ASSERT(counterValuePtr);
306     return counterValuePtr->fetch_sub(value, std::memory_order::memory_order_relaxed);
307 }
308
309 uint32_t ProfilingService::IncrementCounterValue(uint16_t counterUid)
310 {
311     CheckCounterUid(counterUid);
312     std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
313     ARMNN_ASSERT(counterValuePtr);
314     return counterValuePtr->operator++(std::memory_order::memory_order_relaxed);
315 }
316
317 ProfilingDynamicGuid ProfilingService::NextGuid()
318 {
319     return ProfilingService::GetNextGuid();
320 }
321
322 ProfilingStaticGuid ProfilingService::GenerateStaticId(const std::string& str)
323 {
324     return ProfilingService::GetStaticId(str);
325 }
326
327 std::unique_ptr<ISendTimelinePacket> ProfilingService::GetSendTimelinePacket() const
328 {
329     return m_TimelinePacketWriterFactory.GetSendTimelinePacket();
330 }
331
332 void ProfilingService::Initialize()
333 {
334     // Register a category for the basic runtime counters
335     if (!m_CounterDirectory.IsCategoryRegistered("ArmNN_Runtime"))
336     {
337         m_CounterDirectory.RegisterCategory("ArmNN_Runtime");
338     }
339
340     // Register a counter for the number of Network loads
341     if (!m_CounterDirectory.IsCounterRegistered("Network loads"))
342     {
343         const Counter* loadedNetworksCounter =
344                 m_CounterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
345                                                    armnn::profiling::NETWORK_LOADS,
346                                                    "ArmNN_Runtime",
347                                                    0,
348                                                    0,
349                                                    1.f,
350                                                    "Network loads",
351                                                    "The number of networks loaded at runtime",
352                                                    std::string("networks"));
353         ARMNN_ASSERT(loadedNetworksCounter);
354         InitializeCounterValue(loadedNetworksCounter->m_Uid);
355     }
356     // Register a counter for the number of unloaded networks
357     if (!m_CounterDirectory.IsCounterRegistered("Network unloads"))
358     {
359         const Counter* unloadedNetworksCounter =
360                 m_CounterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
361                                                    armnn::profiling::NETWORK_UNLOADS,
362                                                    "ArmNN_Runtime",
363                                                    0,
364                                                    0,
365                                                    1.f,
366                                                    "Network unloads",
367                                                    "The number of networks unloaded at runtime",
368                                                    std::string("networks"));
369         ARMNN_ASSERT(unloadedNetworksCounter);
370         InitializeCounterValue(unloadedNetworksCounter->m_Uid);
371     }
372     // Register a counter for the number of registered backends
373     if (!m_CounterDirectory.IsCounterRegistered("Backends registered"))
374     {
375         const Counter* registeredBackendsCounter =
376                 m_CounterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
377                                                    armnn::profiling::REGISTERED_BACKENDS,
378                                                    "ArmNN_Runtime",
379                                                    0,
380                                                    0,
381                                                    1.f,
382                                                    "Backends registered",
383                                                    "The number of registered backends",
384                                                    std::string("backends"));
385         ARMNN_ASSERT(registeredBackendsCounter);
386         InitializeCounterValue(registeredBackendsCounter->m_Uid);
387
388         // Due to backends being registered before the profiling service becomes active,
389         // we need to set the counter to the correct value here
390         SetCounterValue(armnn::profiling::REGISTERED_BACKENDS, static_cast<uint32_t>(BackendRegistryInstance().Size()));
391     }
392     // Register a counter for the number of registered backends
393     if (!m_CounterDirectory.IsCounterRegistered("Backends unregistered"))
394     {
395         const Counter* unregisteredBackendsCounter =
396                 m_CounterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
397                                                    armnn::profiling::UNREGISTERED_BACKENDS,
398                                                    "ArmNN_Runtime",
399                                                    0,
400                                                    0,
401                                                    1.f,
402                                                    "Backends unregistered",
403                                                    "The number of unregistered backends",
404                                                    std::string("backends"));
405         ARMNN_ASSERT(unregisteredBackendsCounter);
406         InitializeCounterValue(unregisteredBackendsCounter->m_Uid);
407     }
408     // Register a counter for the number of inferences run
409     if (!m_CounterDirectory.IsCounterRegistered("Inferences run"))
410     {
411         const Counter* inferencesRunCounter =
412                 m_CounterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
413                                                    armnn::profiling::INFERENCES_RUN,
414                                                    "ArmNN_Runtime",
415                                                    0,
416                                                    0,
417                                                    1.f,
418                                                    "Inferences run",
419                                                    "The number of inferences run",
420                                                    std::string("inferences"));
421         ARMNN_ASSERT(inferencesRunCounter);
422         InitializeCounterValue(inferencesRunCounter->m_Uid);
423     }
424 }
425
426 void ProfilingService::InitializeCounterValue(uint16_t counterUid)
427 {
428     // Increase the size of the counter index if necessary
429     if (counterUid >= m_CounterIndex.size())
430     {
431         m_CounterIndex.resize(armnn::numeric_cast<size_t>(counterUid) + 1);
432     }
433
434     // Create a new atomic counter and add it to the list
435     m_CounterValues.emplace_back(0);
436
437     // Register the new counter to the counter index for quick access
438     std::atomic<uint32_t>* counterValuePtr = &(m_CounterValues.back());
439     m_CounterIndex.at(counterUid) = counterValuePtr;
440 }
441
442 void ProfilingService::Reset()
443 {
444     // Stop the profiling service...
445     Stop();
446
447     // ...then delete all the counter data and configuration...
448     m_CounterIndex.clear();
449     m_CounterValues.clear();
450     m_CounterDirectory.Clear();
451     m_CounterIdMap.Reset();
452     m_BufferManager.Reset();
453
454     // ...finally reset the profiling state machine
455     m_StateMachine.Reset();
456     m_BackendProfilingContexts.clear();
457     m_MaxGlobalCounterId = armnn::profiling::MAX_ARMNN_COUNTER;
458 }
459
460 void ProfilingService::Stop()
461 {
462     {   // only lock when we are updating the inference completed variable
463         std::unique_lock<std::mutex> lck(m_ServiceActiveMutex);
464         m_ServiceActive = false;
465     }
466     // The order in which we reset/stop the components is not trivial!
467     // First stop the producing threads
468     // Command Handler first as it is responsible for launching then Periodic Counter capture thread
469     m_CommandHandler.Stop();
470     m_PeriodicCounterCapture.Stop();
471     // The the consuming thread
472     m_SendThread.Stop(false);
473
474     // ...then close and destroy the profiling connection...
475     if (m_ProfilingConnection != nullptr && m_ProfilingConnection->IsOpen())
476     {
477         m_ProfilingConnection->Close();
478     }
479     m_ProfilingConnection.reset();
480
481     // ...then move to the "NotConnected" state
482     m_StateMachine.TransitionToState(ProfilingState::NotConnected);
483 }
484
485 inline void ProfilingService::CheckCounterUid(uint16_t counterUid) const
486 {
487     if (!IsCounterRegistered(counterUid))
488     {
489         throw InvalidArgumentException(boost::str(boost::format("Counter UID %1% is not registered") % counterUid));
490     }
491 }
492
493 void ProfilingService::NotifyBackendsForTimelineReporting()
494 {
495     BackendProfilingContext::iterator it = m_BackendProfilingContexts.begin();
496     while (it != m_BackendProfilingContexts.end())
497     {
498         auto& backendProfilingContext = it->second;
499         backendProfilingContext->EnableTimelineReporting(m_TimelineReporting);
500         // Increment the Iterator to point to next entry
501         it++;
502     }
503 }
504
505 void ProfilingService::NotifyProfilingServiceActive()
506 {
507     {   // only lock when we are updating the inference completed variable
508         std::unique_lock<std::mutex> lck(m_ServiceActiveMutex);
509         m_ServiceActive = true;
510     }
511     m_ServiceActiveConditionVariable.notify_one();
512 }
513
514 void ProfilingService::WaitForProfilingServiceActivation(unsigned int timeout)
515 {
516     std::unique_lock<std::mutex> lck(m_ServiceActiveMutex);
517
518     auto start = std::chrono::high_resolution_clock::now();
519     // Here we we will go back to sleep after a spurious wake up if
520     // m_InferenceCompleted is not yet true.
521     if (!m_ServiceActiveConditionVariable.wait_for(lck,
522                                                    std::chrono::milliseconds(timeout),
523                                                    [&]{return m_ServiceActive == true;}))
524     {
525         if (m_ServiceActive == true)
526         {
527             return;
528         }
529         auto finish = std::chrono::high_resolution_clock::now();
530         std::chrono::duration<double, std::milli> elapsed = finish - start;
531         std::stringstream ss;
532         ss << "Timed out waiting on profiling service activation for " << elapsed.count() << " ms";
533         ARMNN_LOG(warning) << ss.str();
534     }
535     return;
536 }
537
538 ProfilingService::~ProfilingService()
539 {
540     Stop();
541 }
542 } // namespace profiling
543
544 } // namespace armnn