2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
6 #include "ProfilingService.hpp"
8 #include <armnn/BackendId.hpp>
9 #include <armnn/Logging.hpp>
10 #include <common/include/SocketConnectionException.hpp>
12 #include <boost/format.hpp>
20 ProfilingGuidGenerator ProfilingService::m_GuidGenerator;
22 ProfilingDynamicGuid ProfilingService::GetNextGuid()
24 return m_GuidGenerator.NextGuid();
27 ProfilingStaticGuid ProfilingService::GetStaticId(const std::string& str)
29 return m_GuidGenerator.GenerateStaticId(str);
32 void ProfilingService::ResetExternalProfilingOptions(const ExternalProfilingOptions& options,
33 bool resetProfilingService)
35 // Update the profiling options
37 m_TimelineReporting = options.m_TimelineEnabled;
38 m_ConnectionAcknowledgedCommandHandler.setTimelineEnabled(options.m_TimelineEnabled);
40 // Check if the profiling service needs to be reset
41 if (resetProfilingService)
43 // Reset the profiling service
48 bool ProfilingService::IsProfilingEnabled() const
50 return m_Options.m_EnableProfiling;
53 ProfilingState ProfilingService::ConfigureProfilingService(
54 const ExternalProfilingOptions& options,
55 bool resetProfilingService)
57 ResetExternalProfilingOptions(options, resetProfilingService);
58 ProfilingState currentState = m_StateMachine.GetCurrentState();
59 if (options.m_EnableProfiling)
63 case ProfilingState::Uninitialised:
64 Update(); // should transition to NotConnected
65 Update(); // will either stay in NotConnected because there is no server
66 // or will enter WaitingForAck.
67 currentState = m_StateMachine.GetCurrentState();
68 if (currentState == ProfilingState::WaitingForAck)
70 Update(); // poke it again to send out the metadata packet
72 currentState = m_StateMachine.GetCurrentState();
74 case ProfilingState::NotConnected:
75 Update(); // will either stay in NotConnected because there is no server
76 // or will enter WaitingForAck
77 currentState = m_StateMachine.GetCurrentState();
78 if (currentState == ProfilingState::WaitingForAck)
80 Update(); // poke it again to send out the metadata packet
82 currentState = m_StateMachine.GetCurrentState();
90 // Make sure profiling is shutdown
93 case ProfilingState::Uninitialised:
94 case ProfilingState::NotConnected:
98 return m_StateMachine.GetCurrentState();
103 void ProfilingService::Update()
105 if (!m_Options.m_EnableProfiling)
107 // Don't run if profiling is disabled
111 ProfilingState currentState = m_StateMachine.GetCurrentState();
112 switch (currentState)
114 case ProfilingState::Uninitialised:
116 // Initialize the profiling service
119 // Move to the next state
120 m_StateMachine.TransitionToState(ProfilingState::NotConnected);
122 case ProfilingState::NotConnected:
123 // Stop the command thread (if running)
124 m_CommandHandler.Stop();
126 // Stop the send thread (if running)
127 m_SendThread.Stop(false);
129 // Stop the periodic counter capture thread (if running)
130 m_PeriodicCounterCapture.Stop();
132 // Reset any existing profiling connection
133 m_ProfilingConnection.reset();
137 // Setup the profiling connection
138 ARMNN_ASSERT(m_ProfilingConnectionFactory);
139 m_ProfilingConnection = m_ProfilingConnectionFactory->GetProfilingConnection(m_Options);
141 catch (const Exception& e)
143 ARMNN_LOG(warning) << "An error has occurred when creating the profiling connection: "
146 catch (const armnnProfiling::SocketConnectionException& e)
148 ARMNN_LOG(warning) << "An error has occurred when creating the profiling connection ["
149 << e.what() << "] on socket [" << e.GetSocketFd() << "].";
152 // Move to the next state
153 m_StateMachine.TransitionToState(m_ProfilingConnection
154 ? ProfilingState::WaitingForAck // Profiling connection obtained, wait for ack
155 : ProfilingState::NotConnected); // Profiling connection failed, stay in the
156 // "NotConnected" state
158 case ProfilingState::WaitingForAck:
159 ARMNN_ASSERT(m_ProfilingConnection);
161 // Start the command thread
162 m_CommandHandler.Start(*m_ProfilingConnection);
164 // Start the send thread, while in "WaitingForAck" state it'll send out a "Stream MetaData" packet waiting for
165 // a valid "Connection Acknowledged" packet confirming the connection
166 m_SendThread.Start(*m_ProfilingConnection);
168 // The connection acknowledged command handler will automatically transition the state to "Active" once a
169 // valid "Connection Acknowledged" packet has been received
172 case ProfilingState::Active:
174 // The period counter capture thread is started by the Periodic Counter Selection command handler upon
175 // request by an external profiling service
179 throw RuntimeException(boost::str(boost::format("Unknown profiling service state: %1")
180 % static_cast<int>(currentState)));
184 void ProfilingService::Disconnect()
186 ProfilingState currentState = m_StateMachine.GetCurrentState();
187 switch (currentState)
189 case ProfilingState::Uninitialised:
190 case ProfilingState::NotConnected:
191 case ProfilingState::WaitingForAck:
193 case ProfilingState::Active:
194 // Stop the command thread (if running)
199 throw RuntimeException(boost::str(boost::format("Unknown profiling service state: %1")
200 % static_cast<int>(currentState)));
204 // Store a profiling context returned from a backend that support profiling, and register its counters
205 void ProfilingService::AddBackendProfilingContext(const BackendId backendId,
206 std::shared_ptr<armnn::profiling::IBackendProfilingContext> profilingContext)
208 ARMNN_ASSERT(profilingContext != nullptr);
209 // Register the backend counters
210 m_MaxGlobalCounterId = profilingContext->RegisterCounters(m_MaxGlobalCounterId);
211 m_BackendProfilingContexts.emplace(backendId, std::move(profilingContext));
213 const ICounterDirectory& ProfilingService::GetCounterDirectory() const
215 return m_CounterDirectory;
218 ICounterRegistry& ProfilingService::GetCounterRegistry()
220 return m_CounterDirectory;
223 ProfilingState ProfilingService::GetCurrentState() const
225 return m_StateMachine.GetCurrentState();
228 uint16_t ProfilingService::GetCounterCount() const
230 return m_CounterDirectory.GetCounterCount();
233 bool ProfilingService::IsCounterRegistered(uint16_t counterUid) const
235 return m_CounterDirectory.IsCounterRegistered(counterUid);
238 uint32_t ProfilingService::GetAbsoluteCounterValue(uint16_t counterUid) const
240 CheckCounterUid(counterUid);
241 std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
242 ARMNN_ASSERT(counterValuePtr);
243 return counterValuePtr->load(std::memory_order::memory_order_relaxed);
246 uint32_t ProfilingService::GetDeltaCounterValue(uint16_t counterUid)
248 CheckCounterUid(counterUid);
249 std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
250 ARMNN_ASSERT(counterValuePtr);
251 const uint32_t counterValue = counterValuePtr->load(std::memory_order::memory_order_relaxed);
252 SubtractCounterValue(counterUid, counterValue);
256 const ICounterMappings& ProfilingService::GetCounterMappings() const
258 return m_CounterIdMap;
261 IRegisterCounterMapping& ProfilingService::GetCounterMappingRegistry()
263 return m_CounterIdMap;
266 CaptureData ProfilingService::GetCaptureData()
268 return m_Holder.GetCaptureData();
271 void ProfilingService::SetCaptureData(uint32_t capturePeriod,
272 const std::vector<uint16_t>& counterIds,
273 const std::set<BackendId>& activeBackends)
275 m_Holder.SetCaptureData(capturePeriod, counterIds, activeBackends);
278 void ProfilingService::SetCounterValue(uint16_t counterUid, uint32_t value)
280 CheckCounterUid(counterUid);
281 std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
282 ARMNN_ASSERT(counterValuePtr);
283 counterValuePtr->store(value, std::memory_order::memory_order_relaxed);
286 uint32_t ProfilingService::AddCounterValue(uint16_t counterUid, uint32_t value)
288 CheckCounterUid(counterUid);
289 std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
290 ARMNN_ASSERT(counterValuePtr);
291 return counterValuePtr->fetch_add(value, std::memory_order::memory_order_relaxed);
294 uint32_t ProfilingService::SubtractCounterValue(uint16_t counterUid, uint32_t value)
296 CheckCounterUid(counterUid);
297 std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
298 ARMNN_ASSERT(counterValuePtr);
299 return counterValuePtr->fetch_sub(value, std::memory_order::memory_order_relaxed);
302 uint32_t ProfilingService::IncrementCounterValue(uint16_t counterUid)
304 CheckCounterUid(counterUid);
305 std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
306 ARMNN_ASSERT(counterValuePtr);
307 return counterValuePtr->operator++(std::memory_order::memory_order_relaxed);
310 ProfilingDynamicGuid ProfilingService::NextGuid()
312 return ProfilingService::GetNextGuid();
315 ProfilingStaticGuid ProfilingService::GenerateStaticId(const std::string& str)
317 return ProfilingService::GetStaticId(str);
320 std::unique_ptr<ISendTimelinePacket> ProfilingService::GetSendTimelinePacket() const
322 return m_TimelinePacketWriterFactory.GetSendTimelinePacket();
325 void ProfilingService::Initialize()
327 // Register a category for the basic runtime counters
328 if (!m_CounterDirectory.IsCategoryRegistered("ArmNN_Runtime"))
330 m_CounterDirectory.RegisterCategory("ArmNN_Runtime");
333 // Register a counter for the number of Network loads
334 if (!m_CounterDirectory.IsCounterRegistered("Network loads"))
336 const Counter* loadedNetworksCounter =
337 m_CounterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
338 armnn::profiling::NETWORK_LOADS,
344 "The number of networks loaded at runtime",
345 std::string("networks"));
346 ARMNN_ASSERT(loadedNetworksCounter);
347 InitializeCounterValue(loadedNetworksCounter->m_Uid);
349 // Register a counter for the number of unloaded networks
350 if (!m_CounterDirectory.IsCounterRegistered("Network unloads"))
352 const Counter* unloadedNetworksCounter =
353 m_CounterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
354 armnn::profiling::NETWORK_UNLOADS,
360 "The number of networks unloaded at runtime",
361 std::string("networks"));
362 ARMNN_ASSERT(unloadedNetworksCounter);
363 InitializeCounterValue(unloadedNetworksCounter->m_Uid);
365 // Register a counter for the number of registered backends
366 if (!m_CounterDirectory.IsCounterRegistered("Backends registered"))
368 const Counter* registeredBackendsCounter =
369 m_CounterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
370 armnn::profiling::REGISTERED_BACKENDS,
375 "Backends registered",
376 "The number of registered backends",
377 std::string("backends"));
378 ARMNN_ASSERT(registeredBackendsCounter);
379 InitializeCounterValue(registeredBackendsCounter->m_Uid);
381 // Register a counter for the number of registered backends
382 if (!m_CounterDirectory.IsCounterRegistered("Backends unregistered"))
384 const Counter* unregisteredBackendsCounter =
385 m_CounterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
386 armnn::profiling::UNREGISTERED_BACKENDS,
391 "Backends unregistered",
392 "The number of unregistered backends",
393 std::string("backends"));
394 ARMNN_ASSERT(unregisteredBackendsCounter);
395 InitializeCounterValue(unregisteredBackendsCounter->m_Uid);
397 // Register a counter for the number of inferences run
398 if (!m_CounterDirectory.IsCounterRegistered("Inferences run"))
400 const Counter* inferencesRunCounter =
401 m_CounterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
402 armnn::profiling::INFERENCES_RUN,
408 "The number of inferences run",
409 std::string("inferences"));
410 ARMNN_ASSERT(inferencesRunCounter);
411 InitializeCounterValue(inferencesRunCounter->m_Uid);
415 void ProfilingService::InitializeCounterValue(uint16_t counterUid)
417 // Increase the size of the counter index if necessary
418 if (counterUid >= m_CounterIndex.size())
420 m_CounterIndex.resize(boost::numeric_cast<size_t>(counterUid) + 1);
423 // Create a new atomic counter and add it to the list
424 m_CounterValues.emplace_back(0);
426 // Register the new counter to the counter index for quick access
427 std::atomic<uint32_t>* counterValuePtr = &(m_CounterValues.back());
428 m_CounterIndex.at(counterUid) = counterValuePtr;
431 void ProfilingService::Reset()
433 // Stop the profiling service...
436 // ...then delete all the counter data and configuration...
437 m_CounterIndex.clear();
438 m_CounterValues.clear();
439 m_CounterDirectory.Clear();
440 m_CounterIdMap.Reset();
441 m_BufferManager.Reset();
443 // ...finally reset the profiling state machine
444 m_StateMachine.Reset();
445 m_BackendProfilingContexts.clear();
446 m_MaxGlobalCounterId = armnn::profiling::MAX_ARMNN_COUNTER;
449 void ProfilingService::Stop()
451 // The order in which we reset/stop the components is not trivial!
452 // First stop the producing threads
453 // Command Handler first as it is responsible for launching then Periodic Counter capture thread
454 m_CommandHandler.Stop();
455 m_PeriodicCounterCapture.Stop();
456 // The the consuming thread
457 m_SendThread.Stop(false);
459 // ...then close and destroy the profiling connection...
460 if (m_ProfilingConnection != nullptr && m_ProfilingConnection->IsOpen())
462 m_ProfilingConnection->Close();
464 m_ProfilingConnection.reset();
466 // ...then move to the "NotConnected" state
467 m_StateMachine.TransitionToState(ProfilingState::NotConnected);
470 inline void ProfilingService::CheckCounterUid(uint16_t counterUid) const
472 if (!IsCounterRegistered(counterUid))
474 throw InvalidArgumentException(boost::str(boost::format("Counter UID %1% is not registered") % counterUid));
478 void ProfilingService::NotifyBackendsForTimelineReporting()
480 BackendProfilingContext::iterator it = m_BackendProfilingContexts.begin();
481 while (it != m_BackendProfilingContexts.end())
483 auto& backendProfilingContext = it->second;
484 backendProfilingContext->EnableTimelineReporting(m_TimelineReporting);
485 // Increment the Iterator to point to next entry
490 ProfilingService::~ProfilingService()
494 } // namespace profiling