2 // Copyright © 2019 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
6 #include "ProfilingService.hpp"
8 #include <armnn/BackendId.hpp>
9 #include <armnn/Logging.hpp>
10 #include <armnn/utility/NumericCast.hpp>
12 #include <common/include/SocketConnectionException.hpp>
14 #include <boost/format.hpp>
22 ProfilingGuidGenerator ProfilingService::m_GuidGenerator;
24 ProfilingDynamicGuid ProfilingService::GetNextGuid()
26 return m_GuidGenerator.NextGuid();
29 ProfilingStaticGuid ProfilingService::GetStaticId(const std::string& str)
31 return m_GuidGenerator.GenerateStaticId(str);
34 void ProfilingService::ResetGuidGenerator()
36 m_GuidGenerator.Reset();
39 void ProfilingService::ResetExternalProfilingOptions(const ExternalProfilingOptions& options,
40 bool resetProfilingService)
42 // Update the profiling options
44 m_TimelineReporting = options.m_TimelineEnabled;
45 m_ConnectionAcknowledgedCommandHandler.setTimelineEnabled(options.m_TimelineEnabled);
47 // Check if the profiling service needs to be reset
48 if (resetProfilingService)
50 // Reset the profiling service
55 bool ProfilingService::IsProfilingEnabled() const
57 return m_Options.m_EnableProfiling;
60 ProfilingState ProfilingService::ConfigureProfilingService(
61 const ExternalProfilingOptions& options,
62 bool resetProfilingService)
64 ResetExternalProfilingOptions(options, resetProfilingService);
65 ProfilingState currentState = m_StateMachine.GetCurrentState();
66 if (options.m_EnableProfiling)
70 case ProfilingState::Uninitialised:
71 Update(); // should transition to NotConnected
72 Update(); // will either stay in NotConnected because there is no server
73 // or will enter WaitingForAck.
74 currentState = m_StateMachine.GetCurrentState();
75 if (currentState == ProfilingState::WaitingForAck)
77 Update(); // poke it again to send out the metadata packet
79 currentState = m_StateMachine.GetCurrentState();
81 case ProfilingState::NotConnected:
82 Update(); // will either stay in NotConnected because there is no server
83 // or will enter WaitingForAck
84 currentState = m_StateMachine.GetCurrentState();
85 if (currentState == ProfilingState::WaitingForAck)
87 Update(); // poke it again to send out the metadata packet
89 currentState = m_StateMachine.GetCurrentState();
97 // Make sure profiling is shutdown
100 case ProfilingState::Uninitialised:
101 case ProfilingState::NotConnected:
105 return m_StateMachine.GetCurrentState();
110 void ProfilingService::Update()
112 if (!m_Options.m_EnableProfiling)
114 // Don't run if profiling is disabled
118 ProfilingState currentState = m_StateMachine.GetCurrentState();
119 switch (currentState)
121 case ProfilingState::Uninitialised:
123 // Initialize the profiling service
126 // Move to the next state
127 m_StateMachine.TransitionToState(ProfilingState::NotConnected);
129 case ProfilingState::NotConnected:
130 // Stop the command thread (if running)
131 m_CommandHandler.Stop();
133 // Stop the send thread (if running)
134 m_SendThread.Stop(false);
136 // Stop the periodic counter capture thread (if running)
137 m_PeriodicCounterCapture.Stop();
139 // Reset any existing profiling connection
140 m_ProfilingConnection.reset();
144 // Setup the profiling connection
145 ARMNN_ASSERT(m_ProfilingConnectionFactory);
146 m_ProfilingConnection = m_ProfilingConnectionFactory->GetProfilingConnection(m_Options);
148 catch (const Exception& e)
150 ARMNN_LOG(warning) << "An error has occurred when creating the profiling connection: "
153 catch (const arm::pipe::SocketConnectionException& e)
155 ARMNN_LOG(warning) << "An error has occurred when creating the profiling connection ["
156 << e.what() << "] on socket [" << e.GetSocketFd() << "].";
159 // Move to the next state
160 m_StateMachine.TransitionToState(m_ProfilingConnection
161 ? ProfilingState::WaitingForAck // Profiling connection obtained, wait for ack
162 : ProfilingState::NotConnected); // Profiling connection failed, stay in the
163 // "NotConnected" state
165 case ProfilingState::WaitingForAck:
166 ARMNN_ASSERT(m_ProfilingConnection);
168 // Start the command thread
169 m_CommandHandler.Start(*m_ProfilingConnection);
171 // Start the send thread, while in "WaitingForAck" state it'll send out a "Stream MetaData" packet waiting for
172 // a valid "Connection Acknowledged" packet confirming the connection
173 m_SendThread.Start(*m_ProfilingConnection);
175 // The connection acknowledged command handler will automatically transition the state to "Active" once a
176 // valid "Connection Acknowledged" packet has been received
179 case ProfilingState::Active:
181 // The period counter capture thread is started by the Periodic Counter Selection command handler upon
182 // request by an external profiling service
186 throw RuntimeException(boost::str(boost::format("Unknown profiling service state: %1")
187 % static_cast<int>(currentState)));
191 void ProfilingService::Disconnect()
193 ProfilingState currentState = m_StateMachine.GetCurrentState();
194 switch (currentState)
196 case ProfilingState::Uninitialised:
197 case ProfilingState::NotConnected:
198 case ProfilingState::WaitingForAck:
200 case ProfilingState::Active:
201 // Stop the command thread (if running)
206 throw RuntimeException(boost::str(boost::format("Unknown profiling service state: %1")
207 % static_cast<int>(currentState)));
211 // Store a profiling context returned from a backend that support profiling, and register its counters
212 void ProfilingService::AddBackendProfilingContext(const BackendId backendId,
213 std::shared_ptr<armnn::profiling::IBackendProfilingContext> profilingContext)
215 ARMNN_ASSERT(profilingContext != nullptr);
216 // Register the backend counters
217 m_MaxGlobalCounterId = profilingContext->RegisterCounters(m_MaxGlobalCounterId);
218 m_BackendProfilingContexts.emplace(backendId, std::move(profilingContext));
220 const ICounterDirectory& ProfilingService::GetCounterDirectory() const
222 return m_CounterDirectory;
225 ICounterRegistry& ProfilingService::GetCounterRegistry()
227 return m_CounterDirectory;
230 ProfilingState ProfilingService::GetCurrentState() const
232 return m_StateMachine.GetCurrentState();
235 uint16_t ProfilingService::GetCounterCount() const
237 return m_CounterDirectory.GetCounterCount();
240 bool ProfilingService::IsCounterRegistered(uint16_t counterUid) const
242 return m_CounterDirectory.IsCounterRegistered(counterUid);
245 uint32_t ProfilingService::GetAbsoluteCounterValue(uint16_t counterUid) const
247 CheckCounterUid(counterUid);
248 std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
249 ARMNN_ASSERT(counterValuePtr);
250 return counterValuePtr->load(std::memory_order::memory_order_relaxed);
253 uint32_t ProfilingService::GetDeltaCounterValue(uint16_t counterUid)
255 CheckCounterUid(counterUid);
256 std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
257 ARMNN_ASSERT(counterValuePtr);
258 const uint32_t counterValue = counterValuePtr->load(std::memory_order::memory_order_relaxed);
259 SubtractCounterValue(counterUid, counterValue);
263 const ICounterMappings& ProfilingService::GetCounterMappings() const
265 return m_CounterIdMap;
268 IRegisterCounterMapping& ProfilingService::GetCounterMappingRegistry()
270 return m_CounterIdMap;
273 CaptureData ProfilingService::GetCaptureData()
275 return m_Holder.GetCaptureData();
278 void ProfilingService::SetCaptureData(uint32_t capturePeriod,
279 const std::vector<uint16_t>& counterIds,
280 const std::set<BackendId>& activeBackends)
282 m_Holder.SetCaptureData(capturePeriod, counterIds, activeBackends);
285 void ProfilingService::SetCounterValue(uint16_t counterUid, uint32_t value)
287 CheckCounterUid(counterUid);
288 std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
289 ARMNN_ASSERT(counterValuePtr);
290 counterValuePtr->store(value, std::memory_order::memory_order_relaxed);
293 uint32_t ProfilingService::AddCounterValue(uint16_t counterUid, uint32_t value)
295 CheckCounterUid(counterUid);
296 std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
297 ARMNN_ASSERT(counterValuePtr);
298 return counterValuePtr->fetch_add(value, std::memory_order::memory_order_relaxed);
301 uint32_t ProfilingService::SubtractCounterValue(uint16_t counterUid, uint32_t value)
303 CheckCounterUid(counterUid);
304 std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
305 ARMNN_ASSERT(counterValuePtr);
306 return counterValuePtr->fetch_sub(value, std::memory_order::memory_order_relaxed);
309 uint32_t ProfilingService::IncrementCounterValue(uint16_t counterUid)
311 CheckCounterUid(counterUid);
312 std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
313 ARMNN_ASSERT(counterValuePtr);
314 return counterValuePtr->operator++(std::memory_order::memory_order_relaxed);
317 ProfilingDynamicGuid ProfilingService::NextGuid()
319 return ProfilingService::GetNextGuid();
322 ProfilingStaticGuid ProfilingService::GenerateStaticId(const std::string& str)
324 return ProfilingService::GetStaticId(str);
327 std::unique_ptr<ISendTimelinePacket> ProfilingService::GetSendTimelinePacket() const
329 return m_TimelinePacketWriterFactory.GetSendTimelinePacket();
332 void ProfilingService::Initialize()
334 // Register a category for the basic runtime counters
335 if (!m_CounterDirectory.IsCategoryRegistered("ArmNN_Runtime"))
337 m_CounterDirectory.RegisterCategory("ArmNN_Runtime");
340 // Register a counter for the number of Network loads
341 if (!m_CounterDirectory.IsCounterRegistered("Network loads"))
343 const Counter* loadedNetworksCounter =
344 m_CounterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
345 armnn::profiling::NETWORK_LOADS,
351 "The number of networks loaded at runtime",
352 std::string("networks"));
353 ARMNN_ASSERT(loadedNetworksCounter);
354 InitializeCounterValue(loadedNetworksCounter->m_Uid);
356 // Register a counter for the number of unloaded networks
357 if (!m_CounterDirectory.IsCounterRegistered("Network unloads"))
359 const Counter* unloadedNetworksCounter =
360 m_CounterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
361 armnn::profiling::NETWORK_UNLOADS,
367 "The number of networks unloaded at runtime",
368 std::string("networks"));
369 ARMNN_ASSERT(unloadedNetworksCounter);
370 InitializeCounterValue(unloadedNetworksCounter->m_Uid);
372 // Register a counter for the number of registered backends
373 if (!m_CounterDirectory.IsCounterRegistered("Backends registered"))
375 const Counter* registeredBackendsCounter =
376 m_CounterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
377 armnn::profiling::REGISTERED_BACKENDS,
382 "Backends registered",
383 "The number of registered backends",
384 std::string("backends"));
385 ARMNN_ASSERT(registeredBackendsCounter);
386 InitializeCounterValue(registeredBackendsCounter->m_Uid);
388 // Due to backends being registered before the profiling service becomes active,
389 // we need to set the counter to the correct value here
390 SetCounterValue(armnn::profiling::REGISTERED_BACKENDS, static_cast<uint32_t>(BackendRegistryInstance().Size()));
392 // Register a counter for the number of registered backends
393 if (!m_CounterDirectory.IsCounterRegistered("Backends unregistered"))
395 const Counter* unregisteredBackendsCounter =
396 m_CounterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
397 armnn::profiling::UNREGISTERED_BACKENDS,
402 "Backends unregistered",
403 "The number of unregistered backends",
404 std::string("backends"));
405 ARMNN_ASSERT(unregisteredBackendsCounter);
406 InitializeCounterValue(unregisteredBackendsCounter->m_Uid);
408 // Register a counter for the number of inferences run
409 if (!m_CounterDirectory.IsCounterRegistered("Inferences run"))
411 const Counter* inferencesRunCounter =
412 m_CounterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
413 armnn::profiling::INFERENCES_RUN,
419 "The number of inferences run",
420 std::string("inferences"));
421 ARMNN_ASSERT(inferencesRunCounter);
422 InitializeCounterValue(inferencesRunCounter->m_Uid);
426 void ProfilingService::InitializeCounterValue(uint16_t counterUid)
428 // Increase the size of the counter index if necessary
429 if (counterUid >= m_CounterIndex.size())
431 m_CounterIndex.resize(armnn::numeric_cast<size_t>(counterUid) + 1);
434 // Create a new atomic counter and add it to the list
435 m_CounterValues.emplace_back(0);
437 // Register the new counter to the counter index for quick access
438 std::atomic<uint32_t>* counterValuePtr = &(m_CounterValues.back());
439 m_CounterIndex.at(counterUid) = counterValuePtr;
442 void ProfilingService::Reset()
444 // Stop the profiling service...
447 // ...then delete all the counter data and configuration...
448 m_CounterIndex.clear();
449 m_CounterValues.clear();
450 m_CounterDirectory.Clear();
451 m_CounterIdMap.Reset();
452 m_BufferManager.Reset();
454 // ...finally reset the profiling state machine
455 m_StateMachine.Reset();
456 m_BackendProfilingContexts.clear();
457 m_MaxGlobalCounterId = armnn::profiling::MAX_ARMNN_COUNTER;
460 void ProfilingService::Stop()
462 { // only lock when we are updating the inference completed variable
463 std::unique_lock<std::mutex> lck(m_ServiceActiveMutex);
464 m_ServiceActive = false;
466 // The order in which we reset/stop the components is not trivial!
467 // First stop the producing threads
468 // Command Handler first as it is responsible for launching then Periodic Counter capture thread
469 m_CommandHandler.Stop();
470 m_PeriodicCounterCapture.Stop();
471 // The the consuming thread
472 m_SendThread.Stop(false);
474 // ...then close and destroy the profiling connection...
475 if (m_ProfilingConnection != nullptr && m_ProfilingConnection->IsOpen())
477 m_ProfilingConnection->Close();
479 m_ProfilingConnection.reset();
481 // ...then move to the "NotConnected" state
482 m_StateMachine.TransitionToState(ProfilingState::NotConnected);
485 inline void ProfilingService::CheckCounterUid(uint16_t counterUid) const
487 if (!IsCounterRegistered(counterUid))
489 throw InvalidArgumentException(boost::str(boost::format("Counter UID %1% is not registered") % counterUid));
493 void ProfilingService::NotifyBackendsForTimelineReporting()
495 BackendProfilingContext::iterator it = m_BackendProfilingContexts.begin();
496 while (it != m_BackendProfilingContexts.end())
498 auto& backendProfilingContext = it->second;
499 backendProfilingContext->EnableTimelineReporting(m_TimelineReporting);
500 // Increment the Iterator to point to next entry
505 void ProfilingService::NotifyProfilingServiceActive()
507 { // only lock when we are updating the inference completed variable
508 std::unique_lock<std::mutex> lck(m_ServiceActiveMutex);
509 m_ServiceActive = true;
511 m_ServiceActiveConditionVariable.notify_one();
514 void ProfilingService::WaitForProfilingServiceActivation(unsigned int timeout)
516 std::unique_lock<std::mutex> lck(m_ServiceActiveMutex);
518 auto start = std::chrono::high_resolution_clock::now();
519 // Here we we will go back to sleep after a spurious wake up if
520 // m_InferenceCompleted is not yet true.
521 if (!m_ServiceActiveConditionVariable.wait_for(lck,
522 std::chrono::milliseconds(timeout),
523 [&]{return m_ServiceActive == true;}))
525 if (m_ServiceActive == true)
529 auto finish = std::chrono::high_resolution_clock::now();
530 std::chrono::duration<double, std::milli> elapsed = finish - start;
531 std::stringstream ss;
532 ss << "Timed out waiting on profiling service activation for " << elapsed.count() << " ms";
533 ARMNN_LOG(warning) << ss.str();
538 ProfilingService::~ProfilingService()
542 } // namespace profiling