IVGCVSW-4171 Fix intermittent failure on FileOnlyProfilingDecoratorTests
[platform/upstream/armnn.git] / src / profiling / ProfilingService.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "ProfilingService.hpp"
7
8 #include <boost/log/trivial.hpp>
9 #include <boost/format.hpp>
10
11 namespace armnn
12 {
13
14 namespace profiling
15 {
16
17 void ProfilingService::ResetExternalProfilingOptions(const ExternalProfilingOptions& options,
18                                                      bool resetProfilingService)
19 {
20     // Update the profiling options
21     m_Options = options;
22
23     // Check if the profiling service needs to be reset
24     if (resetProfilingService)
25     {
26         // Reset the profiling service
27         Reset();
28     }
29 }
30
31 ProfilingState ProfilingService::ConfigureProfilingService(
32         const ExternalProfilingOptions& options,
33         bool resetProfilingService)
34 {
35     ResetExternalProfilingOptions(options, resetProfilingService);
36     ProfilingState currentState = m_StateMachine.GetCurrentState();
37     if (options.m_EnableProfiling)
38     {
39         switch (currentState)
40         {
41             case ProfilingState::Uninitialised:
42                 Update(); // should transition to NotConnected
43                 Update(); // will either stay in NotConnected because there is no server
44                           // or will enter WaitingForAck.
45                 currentState = m_StateMachine.GetCurrentState();
46                 if (currentState == ProfilingState::WaitingForAck)
47                 {
48                     Update(); // poke it again to send out the metadata packet
49                 }
50                 currentState = m_StateMachine.GetCurrentState();
51                 return currentState;
52             case ProfilingState::NotConnected:
53                 Update(); // will either stay in NotConnected because there is no server
54                           // or will enter WaitingForAck
55                 currentState = m_StateMachine.GetCurrentState();
56                 if (currentState == ProfilingState::WaitingForAck)
57                 {
58                     Update(); // poke it again to send out the metadata packet
59                 }
60                 currentState = m_StateMachine.GetCurrentState();
61                 return currentState;
62             default:
63                 return currentState;
64         }
65     }
66     else
67     {
68         // Make sure profiling is shutdown
69         switch (currentState)
70         {
71             case ProfilingState::Uninitialised:
72             case ProfilingState::NotConnected:
73                 return currentState;
74             default:
75                 Stop();
76                 return m_StateMachine.GetCurrentState();
77         }
78     }
79 }
80
81 void ProfilingService::Update()
82 {
83     if (!m_Options.m_EnableProfiling)
84     {
85         // Don't run if profiling is disabled
86         return;
87     }
88
89     ProfilingState currentState = m_StateMachine.GetCurrentState();
90     switch (currentState)
91     {
92     case ProfilingState::Uninitialised:
93
94         // Initialize the profiling service
95         Initialize();
96
97         // Move to the next state
98         m_StateMachine.TransitionToState(ProfilingState::NotConnected);
99         break;
100     case ProfilingState::NotConnected:
101         // Stop the command thread (if running)
102         m_CommandHandler.Stop();
103
104         // Stop the send thread (if running)
105         m_SendCounterPacket.Stop(false);
106
107         // Stop the periodic counter capture thread (if running)
108         m_PeriodicCounterCapture.Stop();
109
110         // Reset any existing profiling connection
111         m_ProfilingConnection.reset();
112
113         try
114         {
115             // Setup the profiling connection
116             BOOST_ASSERT(m_ProfilingConnectionFactory);
117             m_ProfilingConnection = m_ProfilingConnectionFactory->GetProfilingConnection(m_Options);
118         }
119         catch (const Exception& e)
120         {
121             BOOST_LOG_TRIVIAL(warning) << "An error has occurred when creating the profiling connection: "
122                                        << e.what() << std::endl;
123         }
124
125         // Move to the next state
126         m_StateMachine.TransitionToState(m_ProfilingConnection
127                                          ? ProfilingState::WaitingForAck  // Profiling connection obtained, wait for ack
128                                          : ProfilingState::NotConnected); // Profiling connection failed, stay in the
129                                                                           // "NotConnected" state
130         break;
131     case ProfilingState::WaitingForAck:
132         BOOST_ASSERT(m_ProfilingConnection);
133
134         // Start the command thread
135         m_CommandHandler.Start(*m_ProfilingConnection);
136
137         // Start the send thread, while in "WaitingForAck" state it'll send out a "Stream MetaData" packet waiting for
138         // a valid "Connection Acknowledged" packet confirming the connection
139         m_SendCounterPacket.Start(*m_ProfilingConnection);
140
141         // The connection acknowledged command handler will automatically transition the state to "Active" once a
142         // valid "Connection Acknowledged" packet has been received
143
144         break;
145     case ProfilingState::Active:
146
147         // The period counter capture thread is started by the Periodic Counter Selection command handler upon
148         // request by an external profiling service
149
150         break;
151     default:
152         throw RuntimeException(boost::str(boost::format("Unknown profiling service state: %1")
153                                           % static_cast<int>(currentState)));
154     }
155 }
156
157 void ProfilingService::Disconnect()
158 {
159     ProfilingState currentState = m_StateMachine.GetCurrentState();
160     switch (currentState)
161     {
162     case ProfilingState::Uninitialised:
163     case ProfilingState::NotConnected:
164     case ProfilingState::WaitingForAck:
165         return; // NOP
166     case ProfilingState::Active:
167         // Stop the command thread (if running)
168         Stop();
169
170         break;
171     default:
172         throw RuntimeException(boost::str(boost::format("Unknown profiling service state: %1")
173                                           % static_cast<int>(currentState)));
174     }
175 }
176
177 const ICounterDirectory& ProfilingService::GetCounterDirectory() const
178 {
179     return m_CounterDirectory;
180 }
181
182 ProfilingState ProfilingService::GetCurrentState() const
183 {
184     return m_StateMachine.GetCurrentState();
185 }
186
187 uint16_t ProfilingService::GetCounterCount() const
188 {
189     return m_CounterDirectory.GetCounterCount();
190 }
191
192 bool ProfilingService::IsCounterRegistered(uint16_t counterUid) const
193 {
194     return counterUid < m_CounterIndex.size();
195 }
196
197 uint32_t ProfilingService::GetCounterValue(uint16_t counterUid) const
198 {
199     CheckCounterUid(counterUid);
200     std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
201     BOOST_ASSERT(counterValuePtr);
202     return counterValuePtr->load(std::memory_order::memory_order_relaxed);
203 }
204
205 void ProfilingService::SetCounterValue(uint16_t counterUid, uint32_t value)
206 {
207     CheckCounterUid(counterUid);
208     std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
209     BOOST_ASSERT(counterValuePtr);
210     counterValuePtr->store(value, std::memory_order::memory_order_relaxed);
211 }
212
213 uint32_t ProfilingService::AddCounterValue(uint16_t counterUid, uint32_t value)
214 {
215     CheckCounterUid(counterUid);
216     std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
217     BOOST_ASSERT(counterValuePtr);
218     return counterValuePtr->fetch_add(value, std::memory_order::memory_order_relaxed);
219 }
220
221 uint32_t ProfilingService::SubtractCounterValue(uint16_t counterUid, uint32_t value)
222 {
223     CheckCounterUid(counterUid);
224     std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
225     BOOST_ASSERT(counterValuePtr);
226     return counterValuePtr->fetch_sub(value, std::memory_order::memory_order_relaxed);
227 }
228
229 uint32_t ProfilingService::IncrementCounterValue(uint16_t counterUid)
230 {
231     CheckCounterUid(counterUid);
232     std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
233     BOOST_ASSERT(counterValuePtr);
234     return counterValuePtr->operator++(std::memory_order::memory_order_relaxed);
235 }
236
237 uint32_t ProfilingService::DecrementCounterValue(uint16_t counterUid)
238 {
239     CheckCounterUid(counterUid);
240     std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
241     BOOST_ASSERT(counterValuePtr);
242     return counterValuePtr->operator--(std::memory_order::memory_order_relaxed);
243 }
244
245 ProfilingDynamicGuid ProfilingService::NextGuid()
246 {
247     return m_GuidGenerator.NextGuid();
248 }
249
250 ProfilingStaticGuid ProfilingService::GenerateStaticId(const std::string& str)
251 {
252     return m_GuidGenerator.GenerateStaticId(str);
253 }
254
255 std::unique_ptr<ISendTimelinePacket> ProfilingService::GetSendTimelinePacket() const
256 {
257     return m_TimelinePacketWriterFactory.GetSendTimelinePacket();
258 }
259
260 void ProfilingService::Initialize()
261 {
262     // Register a category for the basic runtime counters
263     if (!m_CounterDirectory.IsCategoryRegistered("ArmNN_Runtime"))
264     {
265         m_CounterDirectory.RegisterCategory("ArmNN_Runtime");
266     }
267
268     // Register a counter for the number of loaded networks
269     if (!m_CounterDirectory.IsCounterRegistered("Loaded networks"))
270     {
271         const Counter* loadedNetworksCounter =
272                 m_CounterDirectory.RegisterCounter("ArmNN_Runtime",
273                                                    0,
274                                                    0,
275                                                    1.f,
276                                                    "Loaded networks",
277                                                    "The number of networks loaded at runtime",
278                                                    std::string("networks"));
279         BOOST_ASSERT(loadedNetworksCounter);
280         InitializeCounterValue(loadedNetworksCounter->m_Uid);
281     }
282
283     // Register a counter for the number of registered backends
284     if (!m_CounterDirectory.IsCounterRegistered("Registered backends"))
285     {
286         const Counter* registeredBackendsCounter =
287                 m_CounterDirectory.RegisterCounter("ArmNN_Runtime",
288                                                    0,
289                                                    0,
290                                                    1.f,
291                                                    "Registered backends",
292                                                    "The number of registered backends",
293                                                    std::string("backends"));
294         BOOST_ASSERT(registeredBackendsCounter);
295         InitializeCounterValue(registeredBackendsCounter->m_Uid);
296     }
297
298     // Register a counter for the number of inferences run
299     if (!m_CounterDirectory.IsCounterRegistered("Inferences run"))
300     {
301         const Counter* inferencesRunCounter =
302                 m_CounterDirectory.RegisterCounter("ArmNN_Runtime",
303                                                    0,
304                                                    0,
305                                                    1.f,
306                                                    "Inferences run",
307                                                    "The number of inferences run",
308                                                    std::string("inferences"));
309         BOOST_ASSERT(inferencesRunCounter);
310         InitializeCounterValue(inferencesRunCounter->m_Uid);
311     }
312 }
313
314 void ProfilingService::InitializeCounterValue(uint16_t counterUid)
315 {
316     // Increase the size of the counter index if necessary
317     if (counterUid >= m_CounterIndex.size())
318     {
319         m_CounterIndex.resize(boost::numeric_cast<size_t>(counterUid) + 1);
320     }
321
322     // Create a new atomic counter and add it to the list
323     m_CounterValues.emplace_back(0);
324
325     // Register the new counter to the counter index for quick access
326     std::atomic<uint32_t>* counterValuePtr = &(m_CounterValues.back());
327     m_CounterIndex.at(counterUid) = counterValuePtr;
328 }
329
330 void ProfilingService::Reset()
331 {
332     // Stop the profiling service...
333     Stop();
334
335     // ...then delete all the counter data and configuration...
336     m_CounterIndex.clear();
337     m_CounterValues.clear();
338     m_CounterDirectory.Clear();
339
340     // ...finally reset the profiling state machine
341     m_StateMachine.Reset();
342 }
343
344 void ProfilingService::Stop()
345 {
346     // The order in which we reset/stop the components is not trivial!
347
348     // First stop the threads (Command Handler first)...
349     m_CommandHandler.Stop();
350     m_SendCounterPacket.Stop(false);
351     m_PeriodicCounterCapture.Stop();
352
353     // ...then close and destroy the profiling connection...
354     if (m_ProfilingConnection != nullptr && m_ProfilingConnection->IsOpen())
355     {
356         m_ProfilingConnection->Close();
357     }
358     m_ProfilingConnection.reset();
359
360     // ...then move to the "NotConnected" state
361     m_StateMachine.TransitionToState(ProfilingState::NotConnected);
362 }
363
364 inline void ProfilingService::CheckCounterUid(uint16_t counterUid) const
365 {
366     if (!IsCounterRegistered(counterUid))
367     {
368         throw InvalidArgumentException(boost::str(boost::format("Counter UID %1% is not registered") % counterUid));
369     }
370 }
371
372 ProfilingService::~ProfilingService()
373 {
374     Stop();
375 }
376
377 } // namespace profiling
378
379 } // namespace armnn