55524a4dfe4d9c982dac9bb2d80cfc8ac0fe3b36
[platform/upstream/armnn.git] / src / profiling / test / ProfilingTests.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "../CommandHandlerKey.hpp"
7 #include "../CommandHandlerFunctor.hpp"
8 #include "../CommandHandlerRegistry.hpp"
9 #include "../EncodeVersion.hpp"
10 #include "../Holder.hpp"
11 #include "../Packet.hpp"
12 #include "../PacketVersionResolver.hpp"
13 #include "../ProfilingService.hpp"
14 #include "../ProfilingStateMachine.hpp"
15 #include "../PeriodicCounterSelectionCommandHandler.hpp"
16 #include "../ProfilingUtils.hpp"
17 #include "../SocketProfilingConnection.hpp"
18 #include "../IPeriodicCounterCapture.hpp"
19 #include "SendCounterPacketTests.hpp"
20
21 #include <Runtime.hpp>
22
23
24 #include <boost/test/unit_test.hpp>
25 #include <boost/numeric/conversion/cast.hpp>
26
27 #include <cstdint>
28 #include <cstring>
29 #include <limits>
30 #include <map>
31 #include <random>
32 #include <thread>
33
34 BOOST_AUTO_TEST_SUITE(ExternalProfiling)
35
36 using namespace armnn::profiling;
37
38 BOOST_AUTO_TEST_CASE(CheckCommandHandlerKeyComparisons)
39 {
40     CommandHandlerKey testKey0(1, 1);
41     CommandHandlerKey testKey1(1, 1);
42     CommandHandlerKey testKey2(1, 1);
43     CommandHandlerKey testKey3(0, 0);
44     CommandHandlerKey testKey4(2, 2);
45     CommandHandlerKey testKey5(0, 2);
46
47     BOOST_CHECK(testKey1<testKey4);
48     BOOST_CHECK(testKey1>testKey3);
49     BOOST_CHECK(testKey1<=testKey4);
50     BOOST_CHECK(testKey1>=testKey3);
51     BOOST_CHECK(testKey1<=testKey2);
52     BOOST_CHECK(testKey1>=testKey2);
53     BOOST_CHECK(testKey1==testKey2);
54     BOOST_CHECK(testKey1==testKey1);
55
56     BOOST_CHECK(!(testKey1==testKey5));
57     BOOST_CHECK(!(testKey1!=testKey1));
58     BOOST_CHECK(testKey1!=testKey5);
59
60     BOOST_CHECK(testKey1==testKey2 && testKey2==testKey1);
61     BOOST_CHECK(testKey0==testKey1 && testKey1==testKey2 && testKey0==testKey2);
62
63     BOOST_CHECK(testKey1.GetPacketId()==1);
64     BOOST_CHECK(testKey1.GetVersion()==1);
65
66     std::vector<CommandHandlerKey> vect =
67     {
68         CommandHandlerKey(0,1), CommandHandlerKey(2,0), CommandHandlerKey(1,0),
69         CommandHandlerKey(2,1), CommandHandlerKey(1,1), CommandHandlerKey(0,1),
70         CommandHandlerKey(2,0), CommandHandlerKey(0,0)
71     };
72
73     std::sort(vect.begin(), vect.end());
74
75     std::vector<CommandHandlerKey> expectedVect =
76     {
77         CommandHandlerKey(0,0), CommandHandlerKey(0,1), CommandHandlerKey(0,1),
78         CommandHandlerKey(1,0), CommandHandlerKey(1,1), CommandHandlerKey(2,0),
79         CommandHandlerKey(2,0), CommandHandlerKey(2,1)
80     };
81
82     BOOST_CHECK(vect == expectedVect);
83 }
84
85 BOOST_AUTO_TEST_CASE(CheckEncodeVersion)
86 {
87     Version version1(12);
88
89     BOOST_CHECK(version1.GetMajor() == 0);
90     BOOST_CHECK(version1.GetMinor() == 0);
91     BOOST_CHECK(version1.GetPatch() == 12);
92
93     Version version2(4108);
94
95     BOOST_CHECK(version2.GetMajor() == 0);
96     BOOST_CHECK(version2.GetMinor() == 1);
97     BOOST_CHECK(version2.GetPatch() == 12);
98
99     Version version3(4198412);
100
101     BOOST_CHECK(version3.GetMajor() == 1);
102     BOOST_CHECK(version3.GetMinor() == 1);
103     BOOST_CHECK(version3.GetPatch() == 12);
104
105     Version version4(0);
106
107     BOOST_CHECK(version4.GetMajor() == 0);
108     BOOST_CHECK(version4.GetMinor() == 0);
109     BOOST_CHECK(version4.GetPatch() == 0);
110
111     Version version5(1, 0, 0);
112     BOOST_CHECK(version5.GetEncodedValue() == 4194304);
113 }
114
115 BOOST_AUTO_TEST_CASE(CheckPacketClass)
116 {
117     uint32_t length = 4;
118     std::unique_ptr<char[]> packetData0 = std::make_unique<char[]>(length);
119     std::unique_ptr<char[]> packetData1 = std::make_unique<char[]>(0);
120     std::unique_ptr<char[]> nullPacketData;
121
122     Packet packetTest0(472580096, length, packetData0);
123
124     BOOST_CHECK(packetTest0.GetHeader() == 472580096);
125     BOOST_CHECK(packetTest0.GetPacketFamily() == 7);
126     BOOST_CHECK(packetTest0.GetPacketId() == 43);
127     BOOST_CHECK(packetTest0.GetLength() == length);
128     BOOST_CHECK(packetTest0.GetPacketType() == 3);
129     BOOST_CHECK(packetTest0.GetPacketClass() == 5);
130
131     BOOST_CHECK_THROW(Packet packetTest1(472580096, 0, packetData1), armnn::Exception);
132     BOOST_CHECK_NO_THROW(Packet packetTest2(472580096, 0, nullPacketData));
133
134     Packet packetTest3(472580096, 0, nullPacketData);
135     BOOST_CHECK(packetTest3.GetLength() == 0);
136     BOOST_CHECK(packetTest3.GetData() == nullptr);
137
138     const char* packetTest0Data = packetTest0.GetData();
139     Packet packetTest4(std::move(packetTest0));
140
141     BOOST_CHECK(packetTest0.GetData() == nullptr);
142     BOOST_CHECK(packetTest4.GetData() == packetTest0Data);
143
144     BOOST_CHECK(packetTest4.GetHeader() == 472580096);
145     BOOST_CHECK(packetTest4.GetPacketFamily() == 7);
146     BOOST_CHECK(packetTest4.GetPacketId() == 43);
147     BOOST_CHECK(packetTest4.GetLength() == length);
148     BOOST_CHECK(packetTest4.GetPacketType() == 3);
149     BOOST_CHECK(packetTest4.GetPacketClass() == 5);
150 }
151
152 // Create Derived Classes
153 class TestFunctorA : public CommandHandlerFunctor
154 {
155 public:
156     using CommandHandlerFunctor::CommandHandlerFunctor;
157
158     int GetCount() { return m_Count; }
159
160     void operator()(const Packet& packet) override
161     {
162         m_Count++;
163     }
164
165 private:
166     int m_Count = 0;
167 };
168
169 class TestFunctorB : public TestFunctorA
170 {
171     using TestFunctorA::TestFunctorA;
172 };
173
174 class TestFunctorC : public TestFunctorA
175 {
176     using TestFunctorA::TestFunctorA;
177 };
178
179 BOOST_AUTO_TEST_CASE(CheckCommandHandlerFunctor)
180 {
181     // Hard code the version as it will be the same during a single profiling session
182     uint32_t version = 1;
183
184     TestFunctorA testFunctorA(461, version);
185     TestFunctorB testFunctorB(963, version);
186     TestFunctorC testFunctorC(983, version);
187
188     CommandHandlerKey keyA(testFunctorA.GetPacketId(), testFunctorA.GetVersion());
189     CommandHandlerKey keyB(testFunctorB.GetPacketId(), testFunctorB.GetVersion());
190     CommandHandlerKey keyC(testFunctorC.GetPacketId(), testFunctorC.GetVersion());
191
192     // Create the unwrapped map to simulate the Command Handler Registry
193     std::map<CommandHandlerKey, CommandHandlerFunctor*> registry;
194
195     registry.insert(std::make_pair(keyB, &testFunctorB));
196     registry.insert(std::make_pair(keyA, &testFunctorA));
197     registry.insert(std::make_pair(keyC, &testFunctorC));
198
199     // Check the order of the map is correct
200     auto it = registry.begin();
201     BOOST_CHECK(it->first==keyA);
202     it++;
203     BOOST_CHECK(it->first==keyB);
204     it++;
205     BOOST_CHECK(it->first==keyC);
206
207     std::unique_ptr<char[]> packetDataA;
208     std::unique_ptr<char[]> packetDataB;
209     std::unique_ptr<char[]> packetDataC;
210
211     Packet packetA(500000000, 0, packetDataA);
212     Packet packetB(600000000, 0, packetDataB);
213     Packet packetC(400000000, 0, packetDataC);
214
215     // Check the correct operator of derived class is called
216     registry.at(CommandHandlerKey(packetA.GetPacketId(), version))->operator()(packetA);
217     BOOST_CHECK(testFunctorA.GetCount() == 1);
218     BOOST_CHECK(testFunctorB.GetCount() == 0);
219     BOOST_CHECK(testFunctorC.GetCount() == 0);
220
221     registry.at(CommandHandlerKey(packetB.GetPacketId(), version))->operator()(packetB);
222     BOOST_CHECK(testFunctorA.GetCount() == 1);
223     BOOST_CHECK(testFunctorB.GetCount() == 1);
224     BOOST_CHECK(testFunctorC.GetCount() == 0);
225
226     registry.at(CommandHandlerKey(packetC.GetPacketId(), version))->operator()(packetC);
227     BOOST_CHECK(testFunctorA.GetCount() == 1);
228     BOOST_CHECK(testFunctorB.GetCount() == 1);
229     BOOST_CHECK(testFunctorC.GetCount() == 1);
230 }
231
232 BOOST_AUTO_TEST_CASE(CheckCommandHandlerRegistry)
233 {
234     // Hard code the version as it will be the same during a single profiling session
235     uint32_t version = 1;
236
237     TestFunctorA testFunctorA(461, version);
238     TestFunctorB testFunctorB(963, version);
239     TestFunctorC testFunctorC(983, version);
240
241     // Create the Command Handler Registry
242     CommandHandlerRegistry registry;
243
244     // Register multiple different derived classes
245     registry.RegisterFunctor(&testFunctorA, testFunctorA.GetPacketId(), testFunctorA.GetVersion());
246     registry.RegisterFunctor(&testFunctorB, testFunctorB.GetPacketId(), testFunctorB.GetVersion());
247     registry.RegisterFunctor(&testFunctorC, testFunctorC.GetPacketId(), testFunctorC.GetVersion());
248
249     std::unique_ptr<char[]> packetDataA;
250     std::unique_ptr<char[]> packetDataB;
251     std::unique_ptr<char[]> packetDataC;
252
253     Packet packetA(500000000, 0, packetDataA);
254     Packet packetB(600000000, 0, packetDataB);
255     Packet packetC(400000000, 0, packetDataC);
256
257     // Check the correct operator of derived class is called
258     registry.GetFunctor(packetA.GetPacketId(), version)->operator()(packetA);
259     BOOST_CHECK(testFunctorA.GetCount() == 1);
260     BOOST_CHECK(testFunctorB.GetCount() == 0);
261     BOOST_CHECK(testFunctorC.GetCount() == 0);
262
263     registry.GetFunctor(packetB.GetPacketId(), version)->operator()(packetB);
264     BOOST_CHECK(testFunctorA.GetCount() == 1);
265     BOOST_CHECK(testFunctorB.GetCount() == 1);
266     BOOST_CHECK(testFunctorC.GetCount() == 0);
267
268     registry.GetFunctor(packetC.GetPacketId(), version)->operator()(packetC);
269     BOOST_CHECK(testFunctorA.GetCount() == 1);
270     BOOST_CHECK(testFunctorB.GetCount() == 1);
271     BOOST_CHECK(testFunctorC.GetCount() == 1);
272
273     // Re-register an existing key with a new function
274     registry.RegisterFunctor(&testFunctorC, testFunctorA.GetPacketId(), version);
275     registry.GetFunctor(packetA.GetPacketId(), version)->operator()(packetC);
276     BOOST_CHECK(testFunctorA.GetCount() == 1);
277     BOOST_CHECK(testFunctorB.GetCount() == 1);
278     BOOST_CHECK(testFunctorC.GetCount() == 2);
279
280     // Check that non-existent key returns nullptr for its functor
281     BOOST_CHECK_THROW(registry.GetFunctor(0, 0), armnn::Exception);
282 }
283
284 BOOST_AUTO_TEST_CASE(CheckPacketVersionResolver)
285 {
286     // Set up random number generator for generating packetId values
287     std::random_device device;
288     std::mt19937 generator(device());
289     std::uniform_int_distribution<uint32_t> distribution(std::numeric_limits<uint32_t>::min(),
290                                                          std::numeric_limits<uint32_t>::max());
291
292     // NOTE: Expected version is always 1.0.0, regardless of packetId
293     const Version expectedVersion(1, 0, 0);
294
295     PacketVersionResolver packetVersionResolver;
296
297     constexpr unsigned int numTests = 10u;
298
299     for (unsigned int i = 0u; i < numTests; ++i)
300     {
301         const uint32_t packetId = distribution(generator);
302         Version resolvedVersion = packetVersionResolver.ResolvePacketVersion(packetId);
303
304         BOOST_TEST(resolvedVersion == expectedVersion);
305     }
306 }
307 void ProfilingCurrentStateThreadImpl(ProfilingStateMachine& states)
308 {
309     ProfilingState newState = ProfilingState::NotConnected;
310     states.GetCurrentState();
311     states.TransitionToState(newState);
312 }
313
314 BOOST_AUTO_TEST_CASE(CheckProfilingStateMachine)
315 {
316     ProfilingStateMachine profilingState1(ProfilingState::Uninitialised);
317     profilingState1.TransitionToState(ProfilingState::Uninitialised);
318     BOOST_CHECK(profilingState1.GetCurrentState() ==  ProfilingState::Uninitialised);
319
320     ProfilingStateMachine profilingState2(ProfilingState::Uninitialised);
321     profilingState2.TransitionToState(ProfilingState::NotConnected);
322     BOOST_CHECK(profilingState2.GetCurrentState() == ProfilingState::NotConnected);
323
324     ProfilingStateMachine profilingState3(ProfilingState::NotConnected);
325     profilingState3.TransitionToState(ProfilingState::NotConnected);
326     BOOST_CHECK(profilingState3.GetCurrentState() == ProfilingState::NotConnected);
327
328     ProfilingStateMachine profilingState4(ProfilingState::NotConnected);
329     profilingState4.TransitionToState(ProfilingState::WaitingForAck);
330     BOOST_CHECK(profilingState4.GetCurrentState() == ProfilingState::WaitingForAck);
331
332     ProfilingStateMachine profilingState5(ProfilingState::WaitingForAck);
333     profilingState5.TransitionToState(ProfilingState::WaitingForAck);
334     BOOST_CHECK(profilingState5.GetCurrentState() == ProfilingState::WaitingForAck);
335
336     ProfilingStateMachine profilingState6(ProfilingState::WaitingForAck);
337     profilingState6.TransitionToState(ProfilingState::Active);
338     BOOST_CHECK(profilingState6.GetCurrentState() == ProfilingState::Active);
339
340     ProfilingStateMachine profilingState7(ProfilingState::Active);
341     profilingState7.TransitionToState(ProfilingState::NotConnected);
342     BOOST_CHECK(profilingState7.GetCurrentState() == ProfilingState::NotConnected);
343
344     ProfilingStateMachine profilingState8(ProfilingState::Active);
345     profilingState8.TransitionToState(ProfilingState::Active);
346     BOOST_CHECK(profilingState8.GetCurrentState() == ProfilingState::Active);
347
348     ProfilingStateMachine profilingState9(ProfilingState::Uninitialised);
349     BOOST_CHECK_THROW(profilingState9.TransitionToState(ProfilingState::WaitingForAck),
350                       armnn::Exception);
351
352     ProfilingStateMachine profilingState10(ProfilingState::Uninitialised);
353     BOOST_CHECK_THROW(profilingState10.TransitionToState(ProfilingState::Active),
354                       armnn::Exception);
355
356     ProfilingStateMachine profilingState11(ProfilingState::NotConnected);
357     BOOST_CHECK_THROW(profilingState11.TransitionToState(ProfilingState::Uninitialised),
358                       armnn::Exception);
359
360     ProfilingStateMachine profilingState12(ProfilingState::NotConnected);
361     BOOST_CHECK_THROW(profilingState12.TransitionToState(ProfilingState::Active),
362                       armnn::Exception);
363
364     ProfilingStateMachine profilingState13(ProfilingState::WaitingForAck);
365     BOOST_CHECK_THROW(profilingState13.TransitionToState(ProfilingState::Uninitialised),
366                       armnn::Exception);
367
368     ProfilingStateMachine profilingState14(ProfilingState::WaitingForAck);
369     BOOST_CHECK_THROW(profilingState14.TransitionToState(ProfilingState::NotConnected),
370                       armnn::Exception);
371
372     ProfilingStateMachine profilingState15(ProfilingState::Active);
373     BOOST_CHECK_THROW(profilingState15.TransitionToState(ProfilingState::Uninitialised),
374                       armnn::Exception);
375
376     ProfilingStateMachine profilingState16(armnn::profiling::ProfilingState::Active);
377     BOOST_CHECK_THROW(profilingState16.TransitionToState(ProfilingState::WaitingForAck),
378                       armnn::Exception);
379
380     ProfilingStateMachine profilingState17(ProfilingState::Uninitialised);
381
382     std::thread thread1 (ProfilingCurrentStateThreadImpl,std::ref(profilingState17));
383     std::thread thread2 (ProfilingCurrentStateThreadImpl,std::ref(profilingState17));
384     std::thread thread3 (ProfilingCurrentStateThreadImpl,std::ref(profilingState17));
385     std::thread thread4 (ProfilingCurrentStateThreadImpl,std::ref(profilingState17));
386     std::thread thread5 (ProfilingCurrentStateThreadImpl,std::ref(profilingState17));
387
388     thread1.join();
389     thread2.join();
390     thread3.join();
391     thread4.join();
392     thread5.join();
393
394     BOOST_TEST((profilingState17.GetCurrentState() == ProfilingState::NotConnected));
395 }
396
397 void CaptureDataWriteThreadImpl(Holder &holder, uint32_t capturePeriod, std::vector<uint16_t>& counterIds)
398 {
399     holder.SetCaptureData(capturePeriod, counterIds);
400 }
401
402 void CaptureDataReadThreadImpl(const Holder& holder, CaptureData& captureData)
403 {
404     captureData = holder.GetCaptureData();
405 }
406
407 BOOST_AUTO_TEST_CASE(CheckCaptureDataHolder)
408 {
409     std::map<uint32_t, std::vector<uint16_t>> periodIdMap;
410     std::vector<uint16_t> counterIds;
411     uint16_t numThreads = 50;
412     for (uint16_t i = 0; i < numThreads; ++i)
413     {
414         counterIds.emplace_back(i);
415         periodIdMap.insert(std::make_pair(i, counterIds));
416     }
417
418     // Check CaptureData functions
419     CaptureData capture;
420     BOOST_CHECK(capture.GetCapturePeriod() == 0);
421     BOOST_CHECK((capture.GetCounterIds()).empty());
422     capture.SetCapturePeriod(0);
423     capture.SetCounterIds(periodIdMap[0]);
424     BOOST_CHECK(capture.GetCapturePeriod() == 0);
425     BOOST_CHECK(capture.GetCounterIds() == periodIdMap[0]);
426
427     Holder holder;
428     BOOST_CHECK((holder.GetCaptureData()).GetCapturePeriod() == 0);
429     BOOST_CHECK(((holder.GetCaptureData()).GetCounterIds()).empty());
430
431     // Check Holder functions
432     std::thread thread1(CaptureDataWriteThreadImpl, std::ref(holder), 2, std::ref(periodIdMap[2]));
433     thread1.join();
434
435     BOOST_CHECK((holder.GetCaptureData()).GetCapturePeriod() == 2);
436     BOOST_CHECK((holder.GetCaptureData()).GetCounterIds() == periodIdMap[2]);
437
438     CaptureData captureData;
439     std::thread thread2(CaptureDataReadThreadImpl, std::ref(holder), std::ref(captureData));
440     thread2.join();
441     BOOST_CHECK(captureData.GetCounterIds() == periodIdMap[2]);
442
443     std::vector<std::thread> threadsVect;
444     for (int i = 0; i < numThreads; i+=2)
445     {
446         threadsVect.emplace_back(std::thread(CaptureDataWriteThreadImpl,
447                                  std::ref(holder),
448                                  i,
449                                  std::ref(periodIdMap[static_cast<uint16_t >(i)])));
450
451         threadsVect.emplace_back(std::thread(CaptureDataReadThreadImpl,
452                                  std::ref(holder),
453                                  std::ref(captureData)));
454     }
455
456     for (uint16_t i = 0; i < numThreads; ++i)
457     {
458         threadsVect[i].join();
459     }
460
461     std::vector<std::thread> readThreadsVect;
462     for (uint16_t i = 0; i < numThreads; ++i)
463     {
464         readThreadsVect.emplace_back(
465                 std::thread(CaptureDataReadThreadImpl, std::ref(holder), std::ref(captureData)));
466     }
467
468     for (uint16_t i = 0; i < numThreads; ++i)
469     {
470         readThreadsVect[i].join();
471     }
472
473     // Check CaptureData was written/read correctly from multiple threads
474     std::vector<uint16_t> captureIds = captureData.GetCounterIds();
475     uint32_t capturePeriod = captureData.GetCapturePeriod();
476
477     BOOST_CHECK(captureIds == periodIdMap[capturePeriod]);
478
479     std::vector<uint16_t> readIds = holder.GetCaptureData().GetCounterIds();
480     BOOST_CHECK(captureIds == readIds);
481 }
482
483 BOOST_AUTO_TEST_CASE(CaptureDataMethods)
484 {
485     // Check assignment operator
486     CaptureData assignableCaptureData;
487     std::vector<uint16_t> counterIds = {42, 29, 13};
488     assignableCaptureData.SetCapturePeriod(3);
489     assignableCaptureData.SetCounterIds(counterIds);
490
491     CaptureData secondCaptureData;
492
493     BOOST_CHECK(assignableCaptureData.GetCapturePeriod() == 3);
494     BOOST_CHECK(assignableCaptureData.GetCounterIds() == counterIds);
495
496     secondCaptureData = assignableCaptureData;
497     BOOST_CHECK(secondCaptureData.GetCapturePeriod() == 3);
498     BOOST_CHECK(secondCaptureData.GetCounterIds() == counterIds);
499
500     // Check copy constructor
501     CaptureData copyConstructedCaptureData(assignableCaptureData);
502
503     BOOST_CHECK(copyConstructedCaptureData.GetCapturePeriod() == 3);
504     BOOST_CHECK(copyConstructedCaptureData.GetCounterIds() == counterIds);
505 }
506
507 BOOST_AUTO_TEST_CASE(CheckProfilingServiceDisabled)
508 {
509     armnn::Runtime::CreationOptions::ExternalProfilingOptions options;
510     ProfilingService service(options);
511     BOOST_CHECK(service.GetCurrentState() ==  ProfilingState::Uninitialised);
512     service.Run();
513     BOOST_CHECK(service.GetCurrentState() ==  ProfilingState::Uninitialised);
514 }
515
516 BOOST_AUTO_TEST_CASE(CheckProfilingServiceEnabled)
517 {
518     armnn::Runtime::CreationOptions::ExternalProfilingOptions options;
519     options.m_EnableProfiling = true;
520     ProfilingService service(options);
521     BOOST_CHECK(service.GetCurrentState() ==  ProfilingState::NotConnected);
522     service.Run();
523     BOOST_CHECK(service.GetCurrentState() ==  ProfilingState::WaitingForAck);
524 }
525
526
527 BOOST_AUTO_TEST_CASE(CheckProfilingServiceEnabledRuntime)
528 {
529     armnn::Runtime::CreationOptions::ExternalProfilingOptions options;
530     ProfilingService service(options);
531     BOOST_CHECK(service.GetCurrentState() ==  ProfilingState::Uninitialised);
532     service.Run();
533     BOOST_CHECK(service.GetCurrentState() ==  ProfilingState::Uninitialised);
534     service.m_Options.m_EnableProfiling = true;
535     service.Run();
536     BOOST_CHECK(service.GetCurrentState() ==  ProfilingState::NotConnected);
537     service.Run();
538     BOOST_CHECK(service.GetCurrentState() ==  ProfilingState::WaitingForAck);
539 }
540
541 void GetNextUidTestImpl(uint16_t& outUid)
542 {
543     outUid = GetNextUid();
544 }
545
546 BOOST_AUTO_TEST_CASE(GetNextUidTest)
547 {
548     uint16_t uid0 = 0;
549     uint16_t uid1 = 0;
550     uint16_t uid2 = 0;
551
552     std::thread thread1(GetNextUidTestImpl, std::ref(uid0));
553     std::thread thread2(GetNextUidTestImpl, std::ref(uid1));
554     std::thread thread3(GetNextUidTestImpl, std::ref(uid2));
555     thread1.join();
556     thread2.join();
557     thread3.join();
558
559     BOOST_TEST(uid0 > 0);
560     BOOST_TEST(uid1 > 0);
561     BOOST_TEST(uid2 > 0);
562     BOOST_TEST(uid0 != uid1);
563     BOOST_TEST(uid0 != uid2);
564     BOOST_TEST(uid1 != uid2);
565 }
566
567 BOOST_AUTO_TEST_CASE(CounterSelectionCommandHandlerParseData)
568 {
569     using boost::numeric_cast;
570
571     class TestCaptureThread : public IPeriodicCounterCapture
572     {
573         void Start() override {};
574     };
575
576     const uint32_t packetId = 0x40000;
577
578     uint32_t version = 1;
579     Holder holder;
580     TestCaptureThread captureThread;
581     MockBuffer mockBuffer(512);
582     SendCounterPacket sendCounterPacket(mockBuffer);
583
584     uint32_t sizeOfUint32 = numeric_cast<uint32_t>(sizeof(uint32_t));
585     uint32_t sizeOfUint16 = numeric_cast<uint32_t>(sizeof(uint16_t));
586
587     // Data with period and counters
588     uint32_t period1 = 10;
589     uint32_t dataLength1 = 8;
590     uint32_t offset = 0;
591
592     std::unique_ptr<char[]> uniqueData1 = std::make_unique<char[]>(dataLength1);
593     unsigned char* data1 = reinterpret_cast<unsigned char*>(uniqueData1.get());
594
595     WriteUint32(data1, offset, period1);
596     offset += sizeOfUint32;
597     WriteUint16(data1, offset, 4000);
598     offset += sizeOfUint16;
599     WriteUint16(data1, offset, 5000);
600
601     Packet packetA(packetId, dataLength1, uniqueData1);
602
603     PeriodicCounterSelectionCommandHandler commandHandler(packetId, version, holder, captureThread,
604                                                           sendCounterPacket);
605     commandHandler(packetA);
606
607     std::vector<uint16_t> counterIds = holder.GetCaptureData().GetCounterIds();
608
609     BOOST_TEST(holder.GetCaptureData().GetCapturePeriod() == period1);
610     BOOST_TEST(counterIds.size() == 2);
611     BOOST_TEST(counterIds[0] == 4000);
612     BOOST_TEST(counterIds[1] == 5000);
613
614     unsigned int size = 0;
615
616     const unsigned char* readBuffer = mockBuffer.GetReadBuffer(size);
617
618     offset = 0;
619
620     uint32_t headerWord0 = ReadUint32(readBuffer, offset);
621     offset += sizeOfUint32;
622     uint32_t headerWord1 = ReadUint32(readBuffer, offset);
623     offset += sizeOfUint32;
624     uint32_t period = ReadUint32(readBuffer, offset);
625
626     BOOST_TEST(((headerWord0 >> 26) & 0x3F) == 0);  // packet family
627     BOOST_TEST(((headerWord0 >> 16) & 0x3FF) == 4); // packet id
628     BOOST_TEST(headerWord1 == 8);                   // data lenght
629     BOOST_TEST(period == 10);                       // capture period
630
631     uint16_t counterId = 0;
632     offset += sizeOfUint32;
633     counterId = ReadUint16(readBuffer, offset);
634     BOOST_TEST(counterId == 4000);
635     offset += sizeOfUint16;
636     counterId = ReadUint16(readBuffer, offset);
637     BOOST_TEST(counterId == 5000);
638
639     // Data with period only
640     uint32_t period2 = 11;
641     uint32_t dataLength2 = 4;
642
643     std::unique_ptr<char[]> uniqueData2 = std::make_unique<char[]>(dataLength2);
644
645     WriteUint32(reinterpret_cast<unsigned char*>(uniqueData2.get()), 0, period2);
646
647     Packet packetB(packetId, dataLength2, uniqueData2);
648
649     commandHandler(packetB);
650
651     counterIds = holder.GetCaptureData().GetCounterIds();
652
653     BOOST_TEST(holder.GetCaptureData().GetCapturePeriod() == period2);
654     BOOST_TEST(counterIds.size() == 0);
655
656     readBuffer = mockBuffer.GetReadBuffer(size);
657
658     offset = 0;
659
660     headerWord0 = ReadUint32(readBuffer, offset);
661     offset += sizeOfUint32;
662     headerWord1 = ReadUint32(readBuffer, offset);
663     offset += sizeOfUint32;
664     period = ReadUint32(readBuffer, offset);
665
666     BOOST_TEST(((headerWord0 >> 26) & 0x3F) == 0);  // packet family
667     BOOST_TEST(((headerWord0 >> 16) & 0x3FF) == 4); // packet id
668     BOOST_TEST(headerWord1 == 4);                   // data lenght
669     BOOST_TEST(period == 11);                       // capture period
670
671 }
672
673 BOOST_AUTO_TEST_CASE(CheckSocketProfilingConnection)
674 {
675     // Check that creating a SocketProfilingConnection results in an exception as the Gator UDS doesn't exist.
676     BOOST_CHECK_THROW(new SocketProfilingConnection(), armnn::Exception);
677 }
678
679 BOOST_AUTO_TEST_SUITE_END()