IVGCVSW-4828 Call m_CounterDirectory.IsCounterRegistered in ProfilingService::IsCount...
[platform/upstream/armnn.git] / src / profiling / ProfilingService.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "ProfilingService.hpp"
7
8 #include <armnn/BackendId.hpp>
9 #include <armnn/Logging.hpp>
10 #include <common/include/SocketConnectionException.hpp>
11
12 #include <boost/format.hpp>
13
14 namespace armnn
15 {
16
17 namespace profiling
18 {
19
20 ProfilingGuidGenerator ProfilingService::m_GuidGenerator;
21
22 ProfilingDynamicGuid ProfilingService::GetNextGuid()
23 {
24     return m_GuidGenerator.NextGuid();
25 }
26
27 ProfilingStaticGuid ProfilingService::GetStaticId(const std::string& str)
28 {
29     return m_GuidGenerator.GenerateStaticId(str);
30 }
31
32 void ProfilingService::ResetExternalProfilingOptions(const ExternalProfilingOptions& options,
33                                                      bool resetProfilingService)
34 {
35     // Update the profiling options
36     m_Options = options;
37     m_TimelineReporting = options.m_TimelineEnabled;
38     m_ConnectionAcknowledgedCommandHandler.setTimelineEnabled(options.m_TimelineEnabled);
39
40     // Check if the profiling service needs to be reset
41     if (resetProfilingService)
42     {
43         // Reset the profiling service
44         Reset();
45     }
46 }
47
48 bool ProfilingService::IsProfilingEnabled() const
49 {
50     return m_Options.m_EnableProfiling;
51 }
52
53 ProfilingState ProfilingService::ConfigureProfilingService(
54         const ExternalProfilingOptions& options,
55         bool resetProfilingService)
56 {
57     ResetExternalProfilingOptions(options, resetProfilingService);
58     ProfilingState currentState = m_StateMachine.GetCurrentState();
59     if (options.m_EnableProfiling)
60     {
61         switch (currentState)
62         {
63             case ProfilingState::Uninitialised:
64                 Update(); // should transition to NotConnected
65                 Update(); // will either stay in NotConnected because there is no server
66                           // or will enter WaitingForAck.
67                 currentState = m_StateMachine.GetCurrentState();
68                 if (currentState == ProfilingState::WaitingForAck)
69                 {
70                     Update(); // poke it again to send out the metadata packet
71                 }
72                 currentState = m_StateMachine.GetCurrentState();
73                 return currentState;
74             case ProfilingState::NotConnected:
75                 Update(); // will either stay in NotConnected because there is no server
76                           // or will enter WaitingForAck
77                 currentState = m_StateMachine.GetCurrentState();
78                 if (currentState == ProfilingState::WaitingForAck)
79                 {
80                     Update(); // poke it again to send out the metadata packet
81                 }
82                 currentState = m_StateMachine.GetCurrentState();
83                 return currentState;
84             default:
85                 return currentState;
86         }
87     }
88     else
89     {
90         // Make sure profiling is shutdown
91         switch (currentState)
92         {
93             case ProfilingState::Uninitialised:
94             case ProfilingState::NotConnected:
95                 return currentState;
96             default:
97                 Stop();
98                 return m_StateMachine.GetCurrentState();
99         }
100     }
101 }
102
103 void ProfilingService::Update()
104 {
105     if (!m_Options.m_EnableProfiling)
106     {
107         // Don't run if profiling is disabled
108         return;
109     }
110
111     ProfilingState currentState = m_StateMachine.GetCurrentState();
112     switch (currentState)
113     {
114     case ProfilingState::Uninitialised:
115
116         // Initialize the profiling service
117         Initialize();
118
119         // Move to the next state
120         m_StateMachine.TransitionToState(ProfilingState::NotConnected);
121         break;
122     case ProfilingState::NotConnected:
123         // Stop the command thread (if running)
124         m_CommandHandler.Stop();
125
126         // Stop the send thread (if running)
127         m_SendThread.Stop(false);
128
129         // Stop the periodic counter capture thread (if running)
130         m_PeriodicCounterCapture.Stop();
131
132         // Reset any existing profiling connection
133         m_ProfilingConnection.reset();
134
135         try
136         {
137             // Setup the profiling connection
138             ARMNN_ASSERT(m_ProfilingConnectionFactory);
139             m_ProfilingConnection = m_ProfilingConnectionFactory->GetProfilingConnection(m_Options);
140         }
141         catch (const Exception& e)
142         {
143             ARMNN_LOG(warning) << "An error has occurred when creating the profiling connection: "
144                                        << e.what();
145         }
146         catch (const armnnProfiling::SocketConnectionException& e)
147         {
148             ARMNN_LOG(warning) << "An error has occurred when creating the profiling connection ["
149                                        << e.what() << "] on socket [" << e.GetSocketFd() << "].";
150         }
151
152         // Move to the next state
153         m_StateMachine.TransitionToState(m_ProfilingConnection
154                                          ? ProfilingState::WaitingForAck  // Profiling connection obtained, wait for ack
155                                          : ProfilingState::NotConnected); // Profiling connection failed, stay in the
156                                                                           // "NotConnected" state
157         break;
158     case ProfilingState::WaitingForAck:
159         ARMNN_ASSERT(m_ProfilingConnection);
160
161         // Start the command thread
162         m_CommandHandler.Start(*m_ProfilingConnection);
163
164         // Start the send thread, while in "WaitingForAck" state it'll send out a "Stream MetaData" packet waiting for
165         // a valid "Connection Acknowledged" packet confirming the connection
166         m_SendThread.Start(*m_ProfilingConnection);
167
168         // The connection acknowledged command handler will automatically transition the state to "Active" once a
169         // valid "Connection Acknowledged" packet has been received
170
171         break;
172     case ProfilingState::Active:
173
174         // The period counter capture thread is started by the Periodic Counter Selection command handler upon
175         // request by an external profiling service
176
177         break;
178     default:
179         throw RuntimeException(boost::str(boost::format("Unknown profiling service state: %1")
180                                           % static_cast<int>(currentState)));
181     }
182 }
183
184 void ProfilingService::Disconnect()
185 {
186     ProfilingState currentState = m_StateMachine.GetCurrentState();
187     switch (currentState)
188     {
189     case ProfilingState::Uninitialised:
190     case ProfilingState::NotConnected:
191     case ProfilingState::WaitingForAck:
192         return; // NOP
193     case ProfilingState::Active:
194         // Stop the command thread (if running)
195         Stop();
196
197         break;
198     default:
199         throw RuntimeException(boost::str(boost::format("Unknown profiling service state: %1")
200                                           % static_cast<int>(currentState)));
201     }
202 }
203
204 // Store a profiling context returned from a backend that support profiling, and register its counters
205 void ProfilingService::AddBackendProfilingContext(const BackendId backendId,
206     std::shared_ptr<armnn::profiling::IBackendProfilingContext> profilingContext)
207 {
208     ARMNN_ASSERT(profilingContext != nullptr);
209     // Register the backend counters
210     m_MaxGlobalCounterId = profilingContext->RegisterCounters(m_MaxGlobalCounterId);
211     m_BackendProfilingContexts.emplace(backendId, std::move(profilingContext));
212 }
213 const ICounterDirectory& ProfilingService::GetCounterDirectory() const
214 {
215     return m_CounterDirectory;
216 }
217
218 ICounterRegistry& ProfilingService::GetCounterRegistry()
219 {
220     return m_CounterDirectory;
221 }
222
223 ProfilingState ProfilingService::GetCurrentState() const
224 {
225     return m_StateMachine.GetCurrentState();
226 }
227
228 uint16_t ProfilingService::GetCounterCount() const
229 {
230     return m_CounterDirectory.GetCounterCount();
231 }
232
233 bool ProfilingService::IsCounterRegistered(uint16_t counterUid) const
234 {
235     return m_CounterDirectory.IsCounterRegistered(counterUid);
236 }
237
238 uint32_t ProfilingService::GetAbsoluteCounterValue(uint16_t counterUid) const
239 {
240     CheckCounterUid(counterUid);
241     std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
242     ARMNN_ASSERT(counterValuePtr);
243     return counterValuePtr->load(std::memory_order::memory_order_relaxed);
244 }
245
246 uint32_t ProfilingService::GetDeltaCounterValue(uint16_t counterUid)
247 {
248     CheckCounterUid(counterUid);
249     std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
250     ARMNN_ASSERT(counterValuePtr);
251     const uint32_t counterValue = counterValuePtr->load(std::memory_order::memory_order_relaxed);
252     SubtractCounterValue(counterUid, counterValue);
253     return counterValue;
254 }
255
256 const ICounterMappings& ProfilingService::GetCounterMappings() const
257 {
258     return m_CounterIdMap;
259 }
260
261 IRegisterCounterMapping& ProfilingService::GetCounterMappingRegistry()
262 {
263     return m_CounterIdMap;
264 }
265
266 CaptureData ProfilingService::GetCaptureData()
267 {
268     return m_Holder.GetCaptureData();
269 }
270
271 void ProfilingService::SetCaptureData(uint32_t capturePeriod,
272                                       const std::vector<uint16_t>& counterIds,
273                                       const std::set<BackendId>& activeBackends)
274 {
275     m_Holder.SetCaptureData(capturePeriod, counterIds, activeBackends);
276 }
277
278 void ProfilingService::SetCounterValue(uint16_t counterUid, uint32_t value)
279 {
280     CheckCounterUid(counterUid);
281     std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
282     ARMNN_ASSERT(counterValuePtr);
283     counterValuePtr->store(value, std::memory_order::memory_order_relaxed);
284 }
285
286 uint32_t ProfilingService::AddCounterValue(uint16_t counterUid, uint32_t value)
287 {
288     CheckCounterUid(counterUid);
289     std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
290     ARMNN_ASSERT(counterValuePtr);
291     return counterValuePtr->fetch_add(value, std::memory_order::memory_order_relaxed);
292 }
293
294 uint32_t ProfilingService::SubtractCounterValue(uint16_t counterUid, uint32_t value)
295 {
296     CheckCounterUid(counterUid);
297     std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
298     ARMNN_ASSERT(counterValuePtr);
299     return counterValuePtr->fetch_sub(value, std::memory_order::memory_order_relaxed);
300 }
301
302 uint32_t ProfilingService::IncrementCounterValue(uint16_t counterUid)
303 {
304     CheckCounterUid(counterUid);
305     std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
306     ARMNN_ASSERT(counterValuePtr);
307     return counterValuePtr->operator++(std::memory_order::memory_order_relaxed);
308 }
309
310 ProfilingDynamicGuid ProfilingService::NextGuid()
311 {
312     return ProfilingService::GetNextGuid();
313 }
314
315 ProfilingStaticGuid ProfilingService::GenerateStaticId(const std::string& str)
316 {
317     return ProfilingService::GetStaticId(str);
318 }
319
320 std::unique_ptr<ISendTimelinePacket> ProfilingService::GetSendTimelinePacket() const
321 {
322     return m_TimelinePacketWriterFactory.GetSendTimelinePacket();
323 }
324
325 void ProfilingService::Initialize()
326 {
327     // Register a category for the basic runtime counters
328     if (!m_CounterDirectory.IsCategoryRegistered("ArmNN_Runtime"))
329     {
330         m_CounterDirectory.RegisterCategory("ArmNN_Runtime");
331     }
332
333     // Register a counter for the number of Network loads
334     if (!m_CounterDirectory.IsCounterRegistered("Network loads"))
335     {
336         const Counter* loadedNetworksCounter =
337                 m_CounterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
338                                                    armnn::profiling::NETWORK_LOADS,
339                                                    "ArmNN_Runtime",
340                                                    0,
341                                                    0,
342                                                    1.f,
343                                                    "Network loads",
344                                                    "The number of networks loaded at runtime",
345                                                    std::string("networks"));
346         ARMNN_ASSERT(loadedNetworksCounter);
347         InitializeCounterValue(loadedNetworksCounter->m_Uid);
348     }
349     // Register a counter for the number of unloaded networks
350     if (!m_CounterDirectory.IsCounterRegistered("Network unloads"))
351     {
352         const Counter* unloadedNetworksCounter =
353                 m_CounterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
354                                                    armnn::profiling::NETWORK_UNLOADS,
355                                                    "ArmNN_Runtime",
356                                                    0,
357                                                    0,
358                                                    1.f,
359                                                    "Network unloads",
360                                                    "The number of networks unloaded at runtime",
361                                                    std::string("networks"));
362         ARMNN_ASSERT(unloadedNetworksCounter);
363         InitializeCounterValue(unloadedNetworksCounter->m_Uid);
364     }
365     // Register a counter for the number of registered backends
366     if (!m_CounterDirectory.IsCounterRegistered("Backends registered"))
367     {
368         const Counter* registeredBackendsCounter =
369                 m_CounterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
370                                                    armnn::profiling::REGISTERED_BACKENDS,
371                                                    "ArmNN_Runtime",
372                                                    0,
373                                                    0,
374                                                    1.f,
375                                                    "Backends registered",
376                                                    "The number of registered backends",
377                                                    std::string("backends"));
378         ARMNN_ASSERT(registeredBackendsCounter);
379         InitializeCounterValue(registeredBackendsCounter->m_Uid);
380     }
381     // Register a counter for the number of registered backends
382     if (!m_CounterDirectory.IsCounterRegistered("Backends unregistered"))
383     {
384         const Counter* unregisteredBackendsCounter =
385                 m_CounterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
386                                                    armnn::profiling::UNREGISTERED_BACKENDS,
387                                                    "ArmNN_Runtime",
388                                                    0,
389                                                    0,
390                                                    1.f,
391                                                    "Backends unregistered",
392                                                    "The number of unregistered backends",
393                                                    std::string("backends"));
394         ARMNN_ASSERT(unregisteredBackendsCounter);
395         InitializeCounterValue(unregisteredBackendsCounter->m_Uid);
396     }
397     // Register a counter for the number of inferences run
398     if (!m_CounterDirectory.IsCounterRegistered("Inferences run"))
399     {
400         const Counter* inferencesRunCounter =
401                 m_CounterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
402                                                    armnn::profiling::INFERENCES_RUN,
403                                                    "ArmNN_Runtime",
404                                                    0,
405                                                    0,
406                                                    1.f,
407                                                    "Inferences run",
408                                                    "The number of inferences run",
409                                                    std::string("inferences"));
410         ARMNN_ASSERT(inferencesRunCounter);
411         InitializeCounterValue(inferencesRunCounter->m_Uid);
412     }
413 }
414
415 void ProfilingService::InitializeCounterValue(uint16_t counterUid)
416 {
417     // Increase the size of the counter index if necessary
418     if (counterUid >= m_CounterIndex.size())
419     {
420         m_CounterIndex.resize(boost::numeric_cast<size_t>(counterUid) + 1);
421     }
422
423     // Create a new atomic counter and add it to the list
424     m_CounterValues.emplace_back(0);
425
426     // Register the new counter to the counter index for quick access
427     std::atomic<uint32_t>* counterValuePtr = &(m_CounterValues.back());
428     m_CounterIndex.at(counterUid) = counterValuePtr;
429 }
430
431 void ProfilingService::Reset()
432 {
433     // Stop the profiling service...
434     Stop();
435
436     // ...then delete all the counter data and configuration...
437     m_CounterIndex.clear();
438     m_CounterValues.clear();
439     m_CounterDirectory.Clear();
440     m_CounterIdMap.Reset();
441     m_BufferManager.Reset();
442
443     // ...finally reset the profiling state machine
444     m_StateMachine.Reset();
445     m_BackendProfilingContexts.clear();
446     m_MaxGlobalCounterId = armnn::profiling::MAX_ARMNN_COUNTER;
447 }
448
449 void ProfilingService::Stop()
450 {
451     // The order in which we reset/stop the components is not trivial!
452     // First stop the producing threads
453     // Command Handler first as it is responsible for launching then Periodic Counter capture thread
454     m_CommandHandler.Stop();
455     m_PeriodicCounterCapture.Stop();
456     // The the consuming thread
457     m_SendThread.Stop(false);
458
459     // ...then close and destroy the profiling connection...
460     if (m_ProfilingConnection != nullptr && m_ProfilingConnection->IsOpen())
461     {
462         m_ProfilingConnection->Close();
463     }
464     m_ProfilingConnection.reset();
465
466     // ...then move to the "NotConnected" state
467     m_StateMachine.TransitionToState(ProfilingState::NotConnected);
468 }
469
470 inline void ProfilingService::CheckCounterUid(uint16_t counterUid) const
471 {
472     if (!IsCounterRegistered(counterUid))
473     {
474         throw InvalidArgumentException(boost::str(boost::format("Counter UID %1% is not registered") % counterUid));
475     }
476 }
477
478 void ProfilingService::NotifyBackendsForTimelineReporting()
479 {
480     BackendProfilingContext::iterator it = m_BackendProfilingContexts.begin();
481     while (it != m_BackendProfilingContexts.end())
482     {
483         auto& backendProfilingContext = it->second;
484         backendProfilingContext->EnableTimelineReporting(m_TimelineReporting);
485         // Increment the Iterator to point to next entry
486         it++;
487     }
488 }
489
490 ProfilingService::~ProfilingService()
491 {
492     Stop();
493 }
494 } // namespace profiling
495
496 } // namespace armnn