ArmNN  NotReleased
ProfilingService.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "BufferManager.hpp"
9 #include "CommandHandler.hpp"
11 #include "CounterDirectory.hpp"
12 #include "CounterIdMap.hpp"
13 #include "ICounterRegistry.hpp"
14 #include "ICounterValues.hpp"
15 #include "IProfilingService.hpp"
23 #include "SendCounterPacket.hpp"
24 #include "SendThread.hpp"
25 #include "SendTimelinePacket.hpp"
28 
29 namespace armnn
30 {
31 
32 namespace profiling
33 {
34 // Static constants describing ArmNN's counter UID's
35 static const uint16_t NETWORK_LOADS = 0;
36 static const uint16_t NETWORK_UNLOADS = 1;
37 static const uint16_t REGISTERED_BACKENDS = 2;
38 static const uint16_t UNREGISTERED_BACKENDS = 3;
39 static const uint16_t INFERENCES_RUN = 4;
40 static const uint16_t MAX_ARMNN_COUNTER = INFERENCES_RUN;
41 
43 {
44 public:
46  using IProfilingConnectionFactoryPtr = std::unique_ptr<IProfilingConnectionFactory>;
47  using IProfilingConnectionPtr = std::unique_ptr<IProfilingConnection>;
48  using CounterIndices = std::vector<std::atomic<uint32_t>*>;
49  using CounterValues = std::list<std::atomic<uint32_t>>;
50 
51  // Getter for the singleton instance
53  {
54  static ProfilingService instance;
55  return instance;
56  }
57 
58  // Resets the profiling options, optionally clears the profiling service entirely
59  void ResetExternalProfilingOptions(const ExternalProfilingOptions& options, bool resetProfilingService = false);
61  bool resetProfilingService = false);
62 
63 
64  // Updates the profiling service, making it transition to a new state if necessary
65  void Update();
66 
67  // Disconnects the profiling service from the external server
68  void Disconnect();
69 
70  // Store a profiling context returned from a backend that support profiling.
71  void AddBackendProfilingContext(const BackendId backendId,
72  std::shared_ptr<armnn::profiling::IBackendProfilingContext> profilingContext);
73 
77  bool IsCounterRegistered(uint16_t counterUid) const override;
78  uint32_t GetCounterValue(uint16_t counterUid) const override;
79  uint16_t GetCounterCount() const override;
80  // counter global/backend mapping functions
81  const ICounterMappings& GetCounterMappings() const override;
83 
84  // Getters for the profiling service state
85  bool IsProfilingEnabled() const override;
86 
87  CaptureData GetCaptureData() override;
88  void SetCaptureData(uint32_t capturePeriod,
89  const std::vector<uint16_t>& counterIds,
90  const std::set<BackendId>& activeBackends);
91 
92  // Setters for the profiling service state
93  void SetCounterValue(uint16_t counterUid, uint32_t value) override;
94  uint32_t AddCounterValue(uint16_t counterUid, uint32_t value) override;
95  uint32_t SubtractCounterValue(uint16_t counterUid, uint32_t value) override;
96  uint32_t IncrementCounterValue(uint16_t counterUid) override;
97 
98  // IProfilingGuidGenerator functions
100  ProfilingDynamicGuid NextGuid() override;
102  ProfilingStaticGuid GenerateStaticId(const std::string& str) override;
103 
104  std::unique_ptr<ISendTimelinePacket> GetSendTimelinePacket() const override;
105 
107  {
108  return m_SendCounterPacket;
109  }
110 
112  bool IsEnabled() { return m_Options.m_EnableProfiling; }
113 
114 private:
115  // Copy/move constructors/destructors and copy/move assignment operators are deleted
116  ProfilingService(const ProfilingService&) = delete;
118  ProfilingService& operator=(const ProfilingService&) = delete;
119  ProfilingService& operator=(ProfilingService&&) = delete;
120 
121  // Initialization/reset functions
122  void Initialize();
123  void InitializeCounterValue(uint16_t counterUid);
124  void Reset();
125  void Stop();
126 
127  // Helper function
128  void CheckCounterUid(uint16_t counterUid) const;
129 
130  // Profiling service components
131  ExternalProfilingOptions m_Options;
132  CounterDirectory m_CounterDirectory;
133  CounterIdMap m_CounterIdMap;
134  IProfilingConnectionFactoryPtr m_ProfilingConnectionFactory;
135  IProfilingConnectionPtr m_ProfilingConnection;
136  ProfilingStateMachine m_StateMachine;
137  CounterIndices m_CounterIndex;
138  CounterValues m_CounterValues;
139  CommandHandlerRegistry m_CommandHandlerRegistry;
140  PacketVersionResolver m_PacketVersionResolver;
141  CommandHandler m_CommandHandler;
142  BufferManager m_BufferManager;
143  SendCounterPacket m_SendCounterPacket;
144  SendThread m_SendThread;
145  SendTimelinePacket m_SendTimelinePacket;
146  Holder m_Holder;
147  PeriodicCounterCapture m_PeriodicCounterCapture;
148  ConnectionAcknowledgedCommandHandler m_ConnectionAcknowledgedCommandHandler;
149  RequestCounterDirectoryCommandHandler m_RequestCounterDirectoryCommandHandler;
150  PeriodicCounterSelectionCommandHandler m_PeriodicCounterSelectionCommandHandler;
151  PerJobCounterSelectionCommandHandler m_PerJobCounterSelectionCommandHandler;
152  ProfilingGuidGenerator m_GuidGenerator;
153  TimelinePacketWriterFactory m_TimelinePacketWriterFactory;
154  std::unordered_map<BackendId,
155  std::shared_ptr<armnn::profiling::IBackendProfilingContext>> m_BackendProfilingContexts;
156  uint16_t m_MaxGlobalCounterId;
157 
158 protected:
159  // Default constructor/destructor kept protected for testing
161  : m_Options()
162  , m_CounterDirectory()
163  , m_ProfilingConnectionFactory(new ProfilingConnectionFactory())
164  , m_ProfilingConnection()
165  , m_StateMachine()
166  , m_CounterIndex()
167  , m_CounterValues()
168  , m_CommandHandlerRegistry()
169  , m_PacketVersionResolver()
170  , m_CommandHandler(1000,
171  false,
172  m_CommandHandlerRegistry,
173  m_PacketVersionResolver)
174  , m_BufferManager()
175  , m_SendCounterPacket(m_BufferManager)
176  , m_SendThread(m_StateMachine, m_BufferManager, m_SendCounterPacket)
177  , m_SendTimelinePacket(m_BufferManager)
178  , m_PeriodicCounterCapture(m_Holder, m_SendCounterPacket, *this, m_CounterIdMap, m_BackendProfilingContexts)
179  , m_ConnectionAcknowledgedCommandHandler(0,
180  1,
181  m_PacketVersionResolver.ResolvePacketVersion(0, 1).GetEncodedValue(),
182  m_CounterDirectory,
183  m_SendCounterPacket,
184  m_SendTimelinePacket,
185  m_StateMachine)
186  , m_RequestCounterDirectoryCommandHandler(0,
187  3,
188  m_PacketVersionResolver.ResolvePacketVersion(0, 3).GetEncodedValue(),
189  m_CounterDirectory,
190  m_SendCounterPacket,
191  m_SendTimelinePacket,
192  m_StateMachine)
193  , m_PeriodicCounterSelectionCommandHandler(0,
194  4,
195  m_PacketVersionResolver.ResolvePacketVersion(0, 4).GetEncodedValue(),
196  m_BackendProfilingContexts,
197  m_CounterIdMap,
198  m_Holder,
199  MAX_ARMNN_COUNTER,
200  m_PeriodicCounterCapture,
201  *this,
202  m_SendCounterPacket,
203  m_StateMachine)
204  , m_PerJobCounterSelectionCommandHandler(0,
205  5,
206  m_PacketVersionResolver.ResolvePacketVersion(0, 5).GetEncodedValue(),
207  m_StateMachine)
208  , m_TimelinePacketWriterFactory(m_BufferManager)
209  , m_MaxGlobalCounterId(armnn::profiling::INFERENCES_RUN)
210  {
211  // Register the "Connection Acknowledged" command handler
212  m_CommandHandlerRegistry.RegisterFunctor(&m_ConnectionAcknowledgedCommandHandler);
213 
214  // Register the "Request Counter Directory" command handler
215  m_CommandHandlerRegistry.RegisterFunctor(&m_RequestCounterDirectoryCommandHandler);
216 
217  // Register the "Periodic Counter Selection" command handler
218  m_CommandHandlerRegistry.RegisterFunctor(&m_PeriodicCounterSelectionCommandHandler);
219 
220  // Register the "Per-Job Counter Selection" command handler
221  m_CommandHandlerRegistry.RegisterFunctor(&m_PerJobCounterSelectionCommandHandler);
222  }
224 
225  // Protected methods for testing
229  {
230  BOOST_ASSERT(instance.m_ProfilingConnectionFactory);
231  BOOST_ASSERT(other);
232 
233  backup = instance.m_ProfilingConnectionFactory.release();
234  instance.m_ProfilingConnectionFactory.reset(other);
235  }
237  {
238  return instance.m_ProfilingConnection.get();
239  }
241  {
242  instance.m_StateMachine.TransitionToState(newState);
243  }
244  bool WaitForPacketSent(ProfilingService& instance, uint32_t timeout = 1000)
245  {
246  return instance.m_SendThread.WaitForPacketSent(timeout);
247  }
248 
250  {
251  return instance.m_BufferManager;
252  }
253 };
254 
255 } // namespace profiling
256 
257 } // namespace armnn
uint32_t GetCounterValue(uint16_t counterUid) const override
uint32_t SubtractCounterValue(uint16_t counterUid, uint32_t value) override
uint32_t AddCounterValue(uint16_t counterUid, uint32_t value) override
BufferManager & GetBufferManager(ProfilingService &instance)
bool WaitForPacketSent(uint32_t timeout)
Definition: SendThread.cpp:263
const ICounterMappings & GetCounterMappings() const override
DataLayout::NHWC false
IRegisterCounterMapping & GetCounterMappingRegistry()
ProfilingStaticGuid GenerateStaticId(const std::string &str) override
Create a ProfilingStaticGuid based on a hash of the string.
IProfilingConnection * GetProfilingConnection(ProfilingService &instance)
void TransitionToState(ProfilingState newState)
void TransitionToState(ProfilingService &instance, ProfilingState newState)
ProfilingDynamicGuid NextGuid() override
Return the next random Guid in the sequence.
std::list< std::atomic< uint32_t > > CounterValues
void SetCounterValue(uint16_t counterUid, uint32_t value) override
void AddBackendProfilingContext(const BackendId backendId, std::shared_ptr< armnn::profiling::IBackendProfilingContext > profilingContext)
const ICounterDirectory & GetCounterDirectory() const
bool IsCounterRegistered(uint16_t counterUid) const override
bool IsProfilingEnabled() const override
std::unique_ptr< IProfilingConnection > IProfilingConnectionPtr
bool IsEnabled()
Check if the profiling is enabled.
std::unique_ptr< IProfilingConnectionFactory > IProfilingConnectionFactoryPtr
void ResetExternalProfilingOptions(const ExternalProfilingOptions &options, bool resetProfilingService=false)
void RegisterFunctor(CommandHandlerFunctor *functor, uint32_t familyId, uint32_t packetId, uint32_t version)
std::vector< std::atomic< uint32_t > * > CounterIndices
uint32_t IncrementCounterValue(uint16_t counterUid) override
static ProfilingService & Instance()
bool WaitForPacketSent(ProfilingService &instance, uint32_t timeout=1000)
void SwapProfilingConnectionFactory(ProfilingService &instance, IProfilingConnectionFactory *other, IProfilingConnectionFactory *&backup)
Strongly typed guids to distinguish between those generated at runtime, and those that are statically...
Definition: Types.hpp:291
uint16_t GetCounterCount() const override
ProfilingState GetCurrentState() const
std::unique_ptr< ISendTimelinePacket > GetSendTimelinePacket() const override
ProfilingState ConfigureProfilingService(const ExternalProfilingOptions &options, bool resetProfilingService=false)
armnn::Runtime::CreationOptions::ExternalProfilingOptions options
void SetCaptureData(uint32_t capturePeriod, const std::vector< uint16_t > &counterIds, const std::set< BackendId > &activeBackends)
ISendCounterPacket & GetSendCounterPacket() override