IVGCVSW-4073 Send stream info in the ConnectionAcknowledgedCommandHandler
[platform/upstream/armnn.git] / src / profiling / test / ProfilingTests.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "ProfilingTests.hpp"
7
8 #include <CommandHandler.hpp>
9 #include <CommandHandlerKey.hpp>
10 #include <CommandHandlerRegistry.hpp>
11 #include <ConnectionAcknowledgedCommandHandler.hpp>
12 #include <CounterDirectory.hpp>
13 #include <EncodeVersion.hpp>
14 #include <Holder.hpp>
15 #include <ICounterValues.hpp>
16 #include <Packet.hpp>
17 #include <PacketVersionResolver.hpp>
18 #include <PeriodicCounterCapture.hpp>
19 #include <PeriodicCounterSelectionCommandHandler.hpp>
20 #include <ProfilingStateMachine.hpp>
21 #include <ProfilingUtils.hpp>
22 #include <RequestCounterDirectoryCommandHandler.hpp>
23 #include <Runtime.hpp>
24 #include <SocketProfilingConnection.hpp>
25 #include <SendCounterPacket.hpp>
26 #include <SendTimelinePacket.hpp>
27
28 #include <armnn/Conversion.hpp>
29
30 #include <armnn/Utils.hpp>
31
32 #include <boost/algorithm/string.hpp>
33 #include <boost/numeric/conversion/cast.hpp>
34
35 #include <cstdint>
36 #include <cstring>
37 #include <iostream>
38 #include <limits>
39 #include <map>
40 #include <random>
41
42 using namespace armnn::profiling;
43
44 BOOST_AUTO_TEST_SUITE(ExternalProfiling)
45
46 BOOST_AUTO_TEST_CASE(CheckCommandHandlerKeyComparisons)
47 {
48     CommandHandlerKey testKey1_0(1, 1, 1);
49     CommandHandlerKey testKey1_1(1, 1, 1);
50     CommandHandlerKey testKey1_2(1, 2, 1);
51
52     CommandHandlerKey testKey0(0, 1, 1);
53     CommandHandlerKey testKey1(0, 1, 1);
54     CommandHandlerKey testKey2(0, 1, 1);
55     CommandHandlerKey testKey3(0, 0, 0);
56     CommandHandlerKey testKey4(0, 2, 2);
57     CommandHandlerKey testKey5(0, 0, 2);
58
59     BOOST_CHECK(testKey1_0 > testKey0);
60     BOOST_CHECK(testKey1_0 == testKey1_1);
61     BOOST_CHECK(testKey1_0 < testKey1_2);
62
63     BOOST_CHECK(testKey1 < testKey4);
64     BOOST_CHECK(testKey1 > testKey3);
65     BOOST_CHECK(testKey1 <= testKey4);
66     BOOST_CHECK(testKey1 >= testKey3);
67     BOOST_CHECK(testKey1 <= testKey2);
68     BOOST_CHECK(testKey1 >= testKey2);
69     BOOST_CHECK(testKey1 == testKey2);
70     BOOST_CHECK(testKey1 == testKey1);
71
72     BOOST_CHECK(!(testKey1 == testKey5));
73     BOOST_CHECK(!(testKey1 != testKey1));
74     BOOST_CHECK(testKey1 != testKey5);
75
76     BOOST_CHECK(testKey1 == testKey2 && testKey2 == testKey1);
77     BOOST_CHECK(testKey0 == testKey1 && testKey1 == testKey2 && testKey0 == testKey2);
78
79     BOOST_CHECK(testKey1.GetPacketId() == 1);
80     BOOST_CHECK(testKey1.GetVersion() == 1);
81
82     std::vector<CommandHandlerKey> vect = { CommandHandlerKey(0, 0, 1), CommandHandlerKey(0, 2, 0),
83                                             CommandHandlerKey(0, 1, 0), CommandHandlerKey(0, 2, 1),
84                                             CommandHandlerKey(0, 1, 1), CommandHandlerKey(0, 0, 1),
85                                             CommandHandlerKey(0, 2, 0), CommandHandlerKey(0, 0, 0) };
86
87     std::sort(vect.begin(), vect.end());
88
89     std::vector<CommandHandlerKey> expectedVect = { CommandHandlerKey(0, 0, 0), CommandHandlerKey(0, 0, 1),
90                                                     CommandHandlerKey(0, 0, 1), CommandHandlerKey(0, 1, 0),
91                                                     CommandHandlerKey(0, 1, 1), CommandHandlerKey(0, 2, 0),
92                                                     CommandHandlerKey(0, 2, 0), CommandHandlerKey(0, 2, 1) };
93
94     BOOST_CHECK(vect == expectedVect);
95 }
96
97 BOOST_AUTO_TEST_CASE(CheckPacketKeyComparisons)
98 {
99     PacketKey key0(0, 0);
100     PacketKey key1(0, 0);
101     PacketKey key2(0, 1);
102     PacketKey key3(0, 2);
103     PacketKey key4(1, 0);
104     PacketKey key5(1, 0);
105     PacketKey key6(1, 1);
106
107     BOOST_CHECK(!(key0 < key1));
108     BOOST_CHECK(!(key0 > key1));
109     BOOST_CHECK(key0 <= key1);
110     BOOST_CHECK(key0 >= key1);
111     BOOST_CHECK(key0 == key1);
112     BOOST_CHECK(key0 < key2);
113     BOOST_CHECK(key2 < key3);
114     BOOST_CHECK(key3 > key0);
115     BOOST_CHECK(key4 == key5);
116     BOOST_CHECK(key4 > key0);
117     BOOST_CHECK(key5 < key6);
118     BOOST_CHECK(key5 <= key6);
119     BOOST_CHECK(key5 != key6);
120 }
121
122 BOOST_AUTO_TEST_CASE(CheckCommandHandler)
123 {
124     PacketVersionResolver packetVersionResolver;
125     ProfilingStateMachine profilingStateMachine;
126
127     TestProfilingConnectionBase testProfilingConnectionBase;
128     TestProfilingConnectionTimeoutError testProfilingConnectionTimeOutError;
129     TestProfilingConnectionArmnnError testProfilingConnectionArmnnError;
130     CounterDirectory counterDirectory;
131     MockBufferManager mockBuffer(1024);
132     SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer);
133     SendTimelinePacket sendTimelinePacket(mockBuffer);
134
135     ConnectionAcknowledgedCommandHandler connectionAcknowledgedCommandHandler(0, 1, 4194304, counterDirectory,
136                                                                               sendCounterPacket, sendTimelinePacket,
137                                                                               profilingStateMachine);
138     CommandHandlerRegistry commandHandlerRegistry;
139
140     commandHandlerRegistry.RegisterFunctor(&connectionAcknowledgedCommandHandler);
141
142     profilingStateMachine.TransitionToState(ProfilingState::NotConnected);
143     profilingStateMachine.TransitionToState(ProfilingState::WaitingForAck);
144
145     CommandHandler commandHandler0(1, true, commandHandlerRegistry, packetVersionResolver);
146
147     commandHandler0.Start(testProfilingConnectionBase);
148     commandHandler0.Start(testProfilingConnectionBase);
149     commandHandler0.Start(testProfilingConnectionBase);
150
151     commandHandler0.Stop();
152
153     BOOST_CHECK(profilingStateMachine.GetCurrentState() == ProfilingState::Active);
154
155     profilingStateMachine.TransitionToState(ProfilingState::NotConnected);
156     profilingStateMachine.TransitionToState(ProfilingState::WaitingForAck);
157     // commandHandler1 should give up after one timeout
158     CommandHandler commandHandler1(10, true, commandHandlerRegistry, packetVersionResolver);
159
160     commandHandler1.Start(testProfilingConnectionTimeOutError);
161
162     std::this_thread::sleep_for(std::chrono::milliseconds(100));
163
164     BOOST_CHECK(!commandHandler1.IsRunning());
165     commandHandler1.Stop();
166
167     BOOST_CHECK(profilingStateMachine.GetCurrentState() == ProfilingState::WaitingForAck);
168     // Now commandHandler1 should persist after a timeout
169     commandHandler1.SetStopAfterTimeout(false);
170     commandHandler1.Start(testProfilingConnectionTimeOutError);
171
172     for (int i = 0; i < 100; i++)
173     {
174         if (profilingStateMachine.GetCurrentState() == ProfilingState::Active)
175         {
176             break;
177         }
178
179         std::this_thread::sleep_for(std::chrono::milliseconds(100));
180     }
181
182     commandHandler1.Stop();
183
184     BOOST_CHECK(profilingStateMachine.GetCurrentState() == ProfilingState::Active);
185
186     CommandHandler commandHandler2(100, false, commandHandlerRegistry, packetVersionResolver);
187
188     commandHandler2.Start(testProfilingConnectionArmnnError);
189
190     // commandHandler2 should not stop once it encounters a non timing error
191     std::this_thread::sleep_for(std::chrono::milliseconds(500));
192
193     BOOST_CHECK(commandHandler2.IsRunning());
194     commandHandler2.Stop();
195 }
196
197 BOOST_AUTO_TEST_CASE(CheckEncodeVersion)
198 {
199     Version version1(12);
200
201     BOOST_CHECK(version1.GetMajor() == 0);
202     BOOST_CHECK(version1.GetMinor() == 0);
203     BOOST_CHECK(version1.GetPatch() == 12);
204
205     Version version2(4108);
206
207     BOOST_CHECK(version2.GetMajor() == 0);
208     BOOST_CHECK(version2.GetMinor() == 1);
209     BOOST_CHECK(version2.GetPatch() == 12);
210
211     Version version3(4198412);
212
213     BOOST_CHECK(version3.GetMajor() == 1);
214     BOOST_CHECK(version3.GetMinor() == 1);
215     BOOST_CHECK(version3.GetPatch() == 12);
216
217     Version version4(0);
218
219     BOOST_CHECK(version4.GetMajor() == 0);
220     BOOST_CHECK(version4.GetMinor() == 0);
221     BOOST_CHECK(version4.GetPatch() == 0);
222
223     Version version5(1, 0, 0);
224     BOOST_CHECK(version5.GetEncodedValue() == 4194304);
225 }
226
227 BOOST_AUTO_TEST_CASE(CheckPacketClass)
228 {
229     uint32_t length                              = 4;
230     std::unique_ptr<unsigned char[]> packetData0 = std::make_unique<unsigned char[]>(length);
231     std::unique_ptr<unsigned char[]> packetData1 = std::make_unique<unsigned char[]>(0);
232     std::unique_ptr<unsigned char[]> nullPacketData;
233
234     Packet packetTest0(472580096, length, packetData0);
235
236     BOOST_CHECK(packetTest0.GetHeader() == 472580096);
237     BOOST_CHECK(packetTest0.GetPacketFamily() == 7);
238     BOOST_CHECK(packetTest0.GetPacketId() == 43);
239     BOOST_CHECK(packetTest0.GetLength() == length);
240     BOOST_CHECK(packetTest0.GetPacketType() == 3);
241     BOOST_CHECK(packetTest0.GetPacketClass() == 5);
242
243     BOOST_CHECK_THROW(Packet packetTest1(472580096, 0, packetData1), armnn::Exception);
244     BOOST_CHECK_NO_THROW(Packet packetTest2(472580096, 0, nullPacketData));
245
246     Packet packetTest3(472580096, 0, nullPacketData);
247     BOOST_CHECK(packetTest3.GetLength() == 0);
248     BOOST_CHECK(packetTest3.GetData() == nullptr);
249
250     const unsigned char* packetTest0Data = packetTest0.GetData();
251     Packet packetTest4(std::move(packetTest0));
252
253     BOOST_CHECK(packetTest0.GetData() == nullptr);
254     BOOST_CHECK(packetTest4.GetData() == packetTest0Data);
255
256     BOOST_CHECK(packetTest4.GetHeader() == 472580096);
257     BOOST_CHECK(packetTest4.GetPacketFamily() == 7);
258     BOOST_CHECK(packetTest4.GetPacketId() == 43);
259     BOOST_CHECK(packetTest4.GetLength() == length);
260     BOOST_CHECK(packetTest4.GetPacketType() == 3);
261     BOOST_CHECK(packetTest4.GetPacketClass() == 5);
262 }
263
264 BOOST_AUTO_TEST_CASE(CheckCommandHandlerFunctor)
265 {
266     // Hard code the version as it will be the same during a single profiling session
267     uint32_t version = 1;
268
269     TestFunctorA testFunctorA(7, 461, version);
270     TestFunctorB testFunctorB(8, 963, version);
271     TestFunctorC testFunctorC(5, 983, version);
272
273     CommandHandlerKey keyA(testFunctorA.GetFamilyId(), testFunctorA.GetPacketId(), testFunctorA.GetVersion());
274     CommandHandlerKey keyB(testFunctorB.GetFamilyId(), testFunctorB.GetPacketId(), testFunctorB.GetVersion());
275     CommandHandlerKey keyC(testFunctorC.GetFamilyId(), testFunctorC.GetPacketId(), testFunctorC.GetVersion());
276
277     // Create the unwrapped map to simulate the Command Handler Registry
278     std::map<CommandHandlerKey, CommandHandlerFunctor*> registry;
279
280     registry.insert(std::make_pair(keyB, &testFunctorB));
281     registry.insert(std::make_pair(keyA, &testFunctorA));
282     registry.insert(std::make_pair(keyC, &testFunctorC));
283
284     // Check the order of the map is correct
285     auto it = registry.begin();
286     BOOST_CHECK(it->first == keyC);    // familyId == 5
287     it++;
288     BOOST_CHECK(it->first == keyA);    // familyId == 7
289     it++;
290     BOOST_CHECK(it->first == keyB);    // familyId == 8
291
292     std::unique_ptr<unsigned char[]> packetDataA;
293     std::unique_ptr<unsigned char[]> packetDataB;
294     std::unique_ptr<unsigned char[]> packetDataC;
295
296     Packet packetA(500000000, 0, packetDataA);
297     Packet packetB(600000000, 0, packetDataB);
298     Packet packetC(400000000, 0, packetDataC);
299
300     // Check the correct operator of derived class is called
301     registry.at(CommandHandlerKey(packetA.GetPacketFamily(), packetA.GetPacketId(), version))->operator()(packetA);
302     BOOST_CHECK(testFunctorA.GetCount() == 1);
303     BOOST_CHECK(testFunctorB.GetCount() == 0);
304     BOOST_CHECK(testFunctorC.GetCount() == 0);
305
306     registry.at(CommandHandlerKey(packetB.GetPacketFamily(), packetB.GetPacketId(), version))->operator()(packetB);
307     BOOST_CHECK(testFunctorA.GetCount() == 1);
308     BOOST_CHECK(testFunctorB.GetCount() == 1);
309     BOOST_CHECK(testFunctorC.GetCount() == 0);
310
311     registry.at(CommandHandlerKey(packetC.GetPacketFamily(), packetC.GetPacketId(), version))->operator()(packetC);
312     BOOST_CHECK(testFunctorA.GetCount() == 1);
313     BOOST_CHECK(testFunctorB.GetCount() == 1);
314     BOOST_CHECK(testFunctorC.GetCount() == 1);
315 }
316
317 BOOST_AUTO_TEST_CASE(CheckCommandHandlerRegistry)
318 {
319     // Hard code the version as it will be the same during a single profiling session
320     uint32_t version = 1;
321
322     TestFunctorA testFunctorA(7, 461, version);
323     TestFunctorB testFunctorB(8, 963, version);
324     TestFunctorC testFunctorC(5, 983, version);
325
326     // Create the Command Handler Registry
327     CommandHandlerRegistry registry;
328
329     // Register multiple different derived classes
330     registry.RegisterFunctor(&testFunctorA);
331     registry.RegisterFunctor(&testFunctorB);
332     registry.RegisterFunctor(&testFunctorC);
333
334     std::unique_ptr<unsigned char[]> packetDataA;
335     std::unique_ptr<unsigned char[]> packetDataB;
336     std::unique_ptr<unsigned char[]> packetDataC;
337
338     Packet packetA(500000000, 0, packetDataA);
339     Packet packetB(600000000, 0, packetDataB);
340     Packet packetC(400000000, 0, packetDataC);
341
342     // Check the correct operator of derived class is called
343     registry.GetFunctor(packetA.GetPacketFamily(), packetA.GetPacketId(), version)->operator()(packetA);
344     BOOST_CHECK(testFunctorA.GetCount() == 1);
345     BOOST_CHECK(testFunctorB.GetCount() == 0);
346     BOOST_CHECK(testFunctorC.GetCount() == 0);
347
348     registry.GetFunctor(packetB.GetPacketFamily(), packetB.GetPacketId(), version)->operator()(packetB);
349     BOOST_CHECK(testFunctorA.GetCount() == 1);
350     BOOST_CHECK(testFunctorB.GetCount() == 1);
351     BOOST_CHECK(testFunctorC.GetCount() == 0);
352
353     registry.GetFunctor(packetC.GetPacketFamily(), packetC.GetPacketId(), version)->operator()(packetC);
354     BOOST_CHECK(testFunctorA.GetCount() == 1);
355     BOOST_CHECK(testFunctorB.GetCount() == 1);
356     BOOST_CHECK(testFunctorC.GetCount() == 1);
357
358     // Re-register an existing key with a new function
359     registry.RegisterFunctor(&testFunctorC, testFunctorA.GetFamilyId(), testFunctorA.GetPacketId(), version);
360     registry.GetFunctor(packetA.GetPacketFamily(), packetA.GetPacketId(), version)->operator()(packetC);
361     BOOST_CHECK(testFunctorA.GetCount() == 1);
362     BOOST_CHECK(testFunctorB.GetCount() == 1);
363     BOOST_CHECK(testFunctorC.GetCount() == 2);
364
365     // Check that non-existent key returns nullptr for its functor
366     BOOST_CHECK_THROW(registry.GetFunctor(0, 0, 0), armnn::Exception);
367 }
368
369 BOOST_AUTO_TEST_CASE(CheckPacketVersionResolver)
370 {
371     // Set up random number generator for generating packetId values
372     std::random_device device;
373     std::mt19937 generator(device());
374     std::uniform_int_distribution<uint32_t> distribution(std::numeric_limits<uint32_t>::min(),
375                                                          std::numeric_limits<uint32_t>::max());
376
377     // NOTE: Expected version is always 1.0.0, regardless of packetId
378     const Version expectedVersion(1, 0, 0);
379
380     PacketVersionResolver packetVersionResolver;
381
382     constexpr unsigned int numTests = 10u;
383
384     for (unsigned int i = 0u; i < numTests; ++i)
385     {
386         const uint32_t familyId = distribution(generator);
387         const uint32_t packetId = distribution(generator);
388         Version resolvedVersion = packetVersionResolver.ResolvePacketVersion(familyId, packetId);
389
390         BOOST_TEST(resolvedVersion == expectedVersion);
391     }
392 }
393
394 void ProfilingCurrentStateThreadImpl(ProfilingStateMachine& states)
395 {
396     ProfilingState newState = ProfilingState::NotConnected;
397     states.GetCurrentState();
398     states.TransitionToState(newState);
399 }
400
401 BOOST_AUTO_TEST_CASE(CheckProfilingStateMachine)
402 {
403     ProfilingStateMachine profilingState1(ProfilingState::Uninitialised);
404     profilingState1.TransitionToState(ProfilingState::Uninitialised);
405     BOOST_CHECK(profilingState1.GetCurrentState() == ProfilingState::Uninitialised);
406
407     ProfilingStateMachine profilingState2(ProfilingState::Uninitialised);
408     profilingState2.TransitionToState(ProfilingState::NotConnected);
409     BOOST_CHECK(profilingState2.GetCurrentState() == ProfilingState::NotConnected);
410
411     ProfilingStateMachine profilingState3(ProfilingState::NotConnected);
412     profilingState3.TransitionToState(ProfilingState::NotConnected);
413     BOOST_CHECK(profilingState3.GetCurrentState() == ProfilingState::NotConnected);
414
415     ProfilingStateMachine profilingState4(ProfilingState::NotConnected);
416     profilingState4.TransitionToState(ProfilingState::WaitingForAck);
417     BOOST_CHECK(profilingState4.GetCurrentState() == ProfilingState::WaitingForAck);
418
419     ProfilingStateMachine profilingState5(ProfilingState::WaitingForAck);
420     profilingState5.TransitionToState(ProfilingState::WaitingForAck);
421     BOOST_CHECK(profilingState5.GetCurrentState() == ProfilingState::WaitingForAck);
422
423     ProfilingStateMachine profilingState6(ProfilingState::WaitingForAck);
424     profilingState6.TransitionToState(ProfilingState::Active);
425     BOOST_CHECK(profilingState6.GetCurrentState() == ProfilingState::Active);
426
427     ProfilingStateMachine profilingState7(ProfilingState::Active);
428     profilingState7.TransitionToState(ProfilingState::NotConnected);
429     BOOST_CHECK(profilingState7.GetCurrentState() == ProfilingState::NotConnected);
430
431     ProfilingStateMachine profilingState8(ProfilingState::Active);
432     profilingState8.TransitionToState(ProfilingState::Active);
433     BOOST_CHECK(profilingState8.GetCurrentState() == ProfilingState::Active);
434
435     ProfilingStateMachine profilingState9(ProfilingState::Uninitialised);
436     BOOST_CHECK_THROW(profilingState9.TransitionToState(ProfilingState::WaitingForAck), armnn::Exception);
437
438     ProfilingStateMachine profilingState10(ProfilingState::Uninitialised);
439     BOOST_CHECK_THROW(profilingState10.TransitionToState(ProfilingState::Active), armnn::Exception);
440
441     ProfilingStateMachine profilingState11(ProfilingState::NotConnected);
442     BOOST_CHECK_THROW(profilingState11.TransitionToState(ProfilingState::Uninitialised), armnn::Exception);
443
444     ProfilingStateMachine profilingState12(ProfilingState::NotConnected);
445     BOOST_CHECK_THROW(profilingState12.TransitionToState(ProfilingState::Active), armnn::Exception);
446
447     ProfilingStateMachine profilingState13(ProfilingState::WaitingForAck);
448     BOOST_CHECK_THROW(profilingState13.TransitionToState(ProfilingState::Uninitialised), armnn::Exception);
449
450     ProfilingStateMachine profilingState14(ProfilingState::WaitingForAck);
451     profilingState14.TransitionToState(ProfilingState::NotConnected);
452     BOOST_CHECK(profilingState14.GetCurrentState() == ProfilingState::NotConnected);
453
454     ProfilingStateMachine profilingState15(ProfilingState::Active);
455     BOOST_CHECK_THROW(profilingState15.TransitionToState(ProfilingState::Uninitialised), armnn::Exception);
456
457     ProfilingStateMachine profilingState16(armnn::profiling::ProfilingState::Active);
458     BOOST_CHECK_THROW(profilingState16.TransitionToState(ProfilingState::WaitingForAck), armnn::Exception);
459
460     ProfilingStateMachine profilingState17(ProfilingState::Uninitialised);
461
462     std::thread thread1(ProfilingCurrentStateThreadImpl, std::ref(profilingState17));
463     std::thread thread2(ProfilingCurrentStateThreadImpl, std::ref(profilingState17));
464     std::thread thread3(ProfilingCurrentStateThreadImpl, std::ref(profilingState17));
465     std::thread thread4(ProfilingCurrentStateThreadImpl, std::ref(profilingState17));
466     std::thread thread5(ProfilingCurrentStateThreadImpl, std::ref(profilingState17));
467
468     thread1.join();
469     thread2.join();
470     thread3.join();
471     thread4.join();
472     thread5.join();
473
474     BOOST_TEST((profilingState17.GetCurrentState() == ProfilingState::NotConnected));
475 }
476
477 void CaptureDataWriteThreadImpl(Holder& holder, uint32_t capturePeriod, const std::vector<uint16_t>& counterIds)
478 {
479     holder.SetCaptureData(capturePeriod, counterIds);
480 }
481
482 void CaptureDataReadThreadImpl(const Holder& holder, CaptureData& captureData)
483 {
484     captureData = holder.GetCaptureData();
485 }
486
487 BOOST_AUTO_TEST_CASE(CheckCaptureDataHolder)
488 {
489     std::map<uint32_t, std::vector<uint16_t>> periodIdMap;
490     std::vector<uint16_t> counterIds;
491     uint32_t numThreads = 10;
492     for (uint32_t i = 0; i < numThreads; ++i)
493     {
494         counterIds.emplace_back(i);
495         periodIdMap.insert(std::make_pair(i, counterIds));
496     }
497
498     // Verify the read and write threads set the holder correctly
499     // and retrieve the expected values
500     Holder holder;
501     BOOST_CHECK((holder.GetCaptureData()).GetCapturePeriod() == 0);
502     BOOST_CHECK(((holder.GetCaptureData()).GetCounterIds()).empty());
503
504     // Check Holder functions
505     std::thread thread1(CaptureDataWriteThreadImpl, std::ref(holder), 2, std::ref(periodIdMap[2]));
506     thread1.join();
507     BOOST_CHECK((holder.GetCaptureData()).GetCapturePeriod() == 2);
508     BOOST_CHECK((holder.GetCaptureData()).GetCounterIds() == periodIdMap[2]);
509     // NOTE: now that we have some initial values in the holder we don't have to worry
510     //       in the multi-threaded section below about a read thread accessing the holder
511     //       before any write thread has gotten to it so we read period = 0, counterIds empty
512     //       instead of period = 0, counterIds = {0} as will the case when write thread 0
513     //       has executed.
514
515     CaptureData captureData;
516     std::thread thread2(CaptureDataReadThreadImpl, std::ref(holder), std::ref(captureData));
517     thread2.join();
518     BOOST_CHECK(captureData.GetCapturePeriod() == 2);
519     BOOST_CHECK(captureData.GetCounterIds() == periodIdMap[2]);
520
521     std::map<uint32_t, CaptureData> captureDataIdMap;
522     for (uint32_t i = 0; i < numThreads; ++i)
523     {
524         CaptureData perThreadCaptureData;
525         captureDataIdMap.insert(std::make_pair(i, perThreadCaptureData));
526     }
527
528     std::vector<std::thread> threadsVect;
529     std::vector<std::thread> readThreadsVect;
530     for (uint32_t i = 0; i < numThreads; ++i)
531     {
532         threadsVect.emplace_back(
533             std::thread(CaptureDataWriteThreadImpl, std::ref(holder), i, std::ref(periodIdMap[i])));
534
535         // Verify that the CaptureData goes into the thread in a virgin state
536         BOOST_CHECK(captureDataIdMap.at(i).GetCapturePeriod() == 0);
537         BOOST_CHECK(captureDataIdMap.at(i).GetCounterIds().empty());
538         readThreadsVect.emplace_back(
539             std::thread(CaptureDataReadThreadImpl, std::ref(holder), std::ref(captureDataIdMap.at(i))));
540     }
541
542     for (uint32_t i = 0; i < numThreads; ++i)
543     {
544         threadsVect[i].join();
545         readThreadsVect[i].join();
546     }
547
548     // Look at the CaptureData that each read thread has filled
549     // the capture period it read should match the counter ids entry
550     for (uint32_t i = 0; i < numThreads; ++i)
551     {
552         CaptureData perThreadCaptureData = captureDataIdMap.at(i);
553         BOOST_CHECK(perThreadCaptureData.GetCounterIds() == periodIdMap.at(perThreadCaptureData.GetCapturePeriod()));
554     }
555 }
556
557 BOOST_AUTO_TEST_CASE(CaptureDataMethods)
558 {
559     // Check CaptureData setter and getter functions
560     std::vector<uint16_t> counterIds = { 42, 29, 13 };
561     CaptureData captureData;
562     BOOST_CHECK(captureData.GetCapturePeriod() == 0);
563     BOOST_CHECK((captureData.GetCounterIds()).empty());
564     captureData.SetCapturePeriod(150);
565     captureData.SetCounterIds(counterIds);
566     BOOST_CHECK(captureData.GetCapturePeriod() == 150);
567     BOOST_CHECK(captureData.GetCounterIds() == counterIds);
568
569     // Check assignment operator
570     CaptureData secondCaptureData;
571
572     secondCaptureData = captureData;
573     BOOST_CHECK(secondCaptureData.GetCapturePeriod() == 150);
574     BOOST_CHECK(secondCaptureData.GetCounterIds() == counterIds);
575
576     // Check copy constructor
577     CaptureData copyConstructedCaptureData(captureData);
578
579     BOOST_CHECK(copyConstructedCaptureData.GetCapturePeriod() == 150);
580     BOOST_CHECK(copyConstructedCaptureData.GetCounterIds() == counterIds);
581 }
582
583 BOOST_AUTO_TEST_CASE(CheckProfilingServiceDisabled)
584 {
585     armnn::Runtime::CreationOptions::ExternalProfilingOptions options;
586     ProfilingService& profilingService = ProfilingService::Instance();
587     profilingService.ResetExternalProfilingOptions(options, true);
588     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
589     profilingService.Update();
590     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
591 }
592
593 BOOST_AUTO_TEST_CASE(CheckProfilingServiceEnabled)
594 {
595     // Locally reduce log level to "Warning", as this test needs to parse a warning message from the standard output
596     LogLevelSwapper logLevelSwapper(armnn::LogSeverity::Warning);
597
598     armnn::Runtime::CreationOptions::ExternalProfilingOptions options;
599     options.m_EnableProfiling          = true;
600     ProfilingService& profilingService = ProfilingService::Instance();
601     profilingService.ResetExternalProfilingOptions(options, true);
602     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
603     profilingService.Update();
604     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
605
606     // Redirect the output to a local stream so that we can parse the warning message
607     std::stringstream ss;
608     StreamRedirector streamRedirector(std::cout, ss.rdbuf());
609     profilingService.Update();
610     BOOST_CHECK(boost::contains(ss.str(), "Cannot connect to stream socket: Connection refused"));
611 }
612
613 BOOST_AUTO_TEST_CASE(CheckProfilingServiceEnabledRuntime)
614 {
615     // Locally reduce log level to "Warning", as this test needs to parse a warning message from the standard output
616     LogLevelSwapper logLevelSwapper(armnn::LogSeverity::Warning);
617
618     armnn::Runtime::CreationOptions::ExternalProfilingOptions options;
619     ProfilingService& profilingService = ProfilingService::Instance();
620     profilingService.ResetExternalProfilingOptions(options, true);
621     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
622     profilingService.Update();
623     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
624     options.m_EnableProfiling = true;
625     profilingService.ResetExternalProfilingOptions(options);
626     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
627     profilingService.Update();
628     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
629
630     // Redirect the output to a local stream so that we can parse the warning message
631     std::stringstream ss;
632     StreamRedirector streamRedirector(std::cout, ss.rdbuf());
633     profilingService.Update();
634     BOOST_CHECK(boost::contains(ss.str(), "Cannot connect to stream socket: Connection refused"));
635 }
636
637 BOOST_AUTO_TEST_CASE(CheckProfilingServiceCounterDirectory)
638 {
639     armnn::Runtime::CreationOptions::ExternalProfilingOptions options;
640     ProfilingService& profilingService = ProfilingService::Instance();
641     profilingService.ResetExternalProfilingOptions(options, true);
642
643     const ICounterDirectory& counterDirectory0 = profilingService.GetCounterDirectory();
644     BOOST_CHECK(counterDirectory0.GetCounterCount() == 0);
645     profilingService.Update();
646     BOOST_CHECK(counterDirectory0.GetCounterCount() == 0);
647
648     options.m_EnableProfiling = true;
649     profilingService.ResetExternalProfilingOptions(options);
650
651     const ICounterDirectory& counterDirectory1 = profilingService.GetCounterDirectory();
652     BOOST_CHECK(counterDirectory1.GetCounterCount() == 0);
653     profilingService.Update();
654     BOOST_CHECK(counterDirectory1.GetCounterCount() != 0);
655 }
656
657 BOOST_AUTO_TEST_CASE(CheckProfilingServiceCounterValues)
658 {
659     armnn::Runtime::CreationOptions::ExternalProfilingOptions options;
660     options.m_EnableProfiling          = true;
661     ProfilingService& profilingService = ProfilingService::Instance();
662     profilingService.ResetExternalProfilingOptions(options, true);
663
664     profilingService.Update();
665     const ICounterDirectory& counterDirectory = profilingService.GetCounterDirectory();
666     const Counters& counters                  = counterDirectory.GetCounters();
667     BOOST_CHECK(!counters.empty());
668
669     // Get the UID of the first counter for testing
670     uint16_t counterUid = counters.begin()->first;
671
672     ProfilingService* profilingServicePtr = &profilingService;
673     std::vector<std::thread> writers;
674
675     for (int i = 0; i < 100; ++i)
676     {
677         // Increment and decrement the first counter
678         writers.push_back(std::thread(&ProfilingService::IncrementCounterValue, profilingServicePtr, counterUid));
679         writers.push_back(std::thread(&ProfilingService::DecrementCounterValue, profilingServicePtr, counterUid));
680         // Add 10 and subtract 5 from the first counter
681         writers.push_back(std::thread(&ProfilingService::AddCounterValue, profilingServicePtr, counterUid, 10));
682         writers.push_back(std::thread(&ProfilingService::SubtractCounterValue, profilingServicePtr, counterUid, 5));
683     }
684
685     std::for_each(writers.begin(), writers.end(), mem_fn(&std::thread::join));
686
687     uint32_t counterValue = 0;
688     BOOST_CHECK_NO_THROW(counterValue = profilingService.GetCounterValue(counterUid));
689     BOOST_CHECK(counterValue == 500);
690
691     BOOST_CHECK_NO_THROW(profilingService.SetCounterValue(counterUid, 0));
692     BOOST_CHECK_NO_THROW(counterValue = profilingService.GetCounterValue(counterUid));
693     BOOST_CHECK(counterValue == 0);
694 }
695
696 BOOST_AUTO_TEST_CASE(CheckProfilingObjectUids)
697 {
698     uint16_t uid = 0;
699     BOOST_CHECK_NO_THROW(uid = GetNextUid());
700     BOOST_CHECK(uid >= 1);
701
702     uint16_t nextUid = 0;
703     BOOST_CHECK_NO_THROW(nextUid = GetNextUid());
704     BOOST_CHECK(nextUid > uid);
705
706     std::vector<uint16_t> counterUids;
707     BOOST_CHECK_NO_THROW(counterUids = GetNextCounterUids(0));
708     BOOST_CHECK(counterUids.size() == 1);
709     BOOST_CHECK(counterUids[0] >= 0);
710
711     std::vector<uint16_t> nextCounterUids;
712     BOOST_CHECK_NO_THROW(nextCounterUids = GetNextCounterUids(1));
713     BOOST_CHECK(nextCounterUids.size() == 1);
714     BOOST_CHECK(nextCounterUids[0] > counterUids[0]);
715
716     std::vector<uint16_t> counterUidsMultiCore;
717     uint16_t numberOfCores = 13;
718     BOOST_CHECK_NO_THROW(counterUidsMultiCore = GetNextCounterUids(numberOfCores));
719     BOOST_CHECK(counterUidsMultiCore.size() == numberOfCores);
720     BOOST_CHECK(counterUidsMultiCore.front() >= nextCounterUids[0]);
721     for (size_t i = 1; i < numberOfCores; i++)
722     {
723         BOOST_CHECK(counterUidsMultiCore[i] == counterUidsMultiCore[i - 1] + 1);
724     }
725     BOOST_CHECK(counterUidsMultiCore.back() == counterUidsMultiCore.front() + numberOfCores - 1);
726 }
727
728 BOOST_AUTO_TEST_CASE(CheckCounterDirectoryRegisterCategory)
729 {
730     CounterDirectory counterDirectory;
731     BOOST_CHECK(counterDirectory.GetCategoryCount() == 0);
732     BOOST_CHECK(counterDirectory.GetDeviceCount() == 0);
733     BOOST_CHECK(counterDirectory.GetCounterSetCount() == 0);
734     BOOST_CHECK(counterDirectory.GetCounterCount() == 0);
735
736     // Register a category with an invalid name
737     const Category* noCategory = nullptr;
738     BOOST_CHECK_THROW(noCategory = counterDirectory.RegisterCategory(""), armnn::InvalidArgumentException);
739     BOOST_CHECK(counterDirectory.GetCategoryCount() == 0);
740     BOOST_CHECK(!noCategory);
741
742     // Register a category with an invalid name
743     BOOST_CHECK_THROW(noCategory = counterDirectory.RegisterCategory("invalid category"),
744                       armnn::InvalidArgumentException);
745     BOOST_CHECK(counterDirectory.GetCategoryCount() == 0);
746     BOOST_CHECK(!noCategory);
747
748     // Register a new category
749     const std::string categoryName = "some_category";
750     const Category* category       = nullptr;
751     BOOST_CHECK_NO_THROW(category = counterDirectory.RegisterCategory(categoryName));
752     BOOST_CHECK(counterDirectory.GetCategoryCount() == 1);
753     BOOST_CHECK(category);
754     BOOST_CHECK(category->m_Name == categoryName);
755     BOOST_CHECK(category->m_Counters.empty());
756     BOOST_CHECK(category->m_DeviceUid == 0);
757     BOOST_CHECK(category->m_CounterSetUid == 0);
758
759     // Get the registered category
760     const Category* registeredCategory = counterDirectory.GetCategory(categoryName);
761     BOOST_CHECK(counterDirectory.GetCategoryCount() == 1);
762     BOOST_CHECK(registeredCategory);
763     BOOST_CHECK(registeredCategory == category);
764
765     // Try to get a category not registered
766     const Category* notRegisteredCategory = counterDirectory.GetCategory("not_registered_category");
767     BOOST_CHECK(counterDirectory.GetCategoryCount() == 1);
768     BOOST_CHECK(!notRegisteredCategory);
769
770     // Register a category already registered
771     const Category* anotherCategory = nullptr;
772     BOOST_CHECK_THROW(anotherCategory = counterDirectory.RegisterCategory(categoryName),
773                       armnn::InvalidArgumentException);
774     BOOST_CHECK(counterDirectory.GetCategoryCount() == 1);
775     BOOST_CHECK(!anotherCategory);
776
777     // Register a device for testing
778     const std::string deviceName = "some_device";
779     const Device* device         = nullptr;
780     BOOST_CHECK_NO_THROW(device = counterDirectory.RegisterDevice(deviceName));
781     BOOST_CHECK(counterDirectory.GetDeviceCount() == 1);
782     BOOST_CHECK(device);
783     BOOST_CHECK(device->m_Uid >= 1);
784     BOOST_CHECK(device->m_Name == deviceName);
785     BOOST_CHECK(device->m_Cores == 0);
786
787     // Register a new category not associated to any device
788     const std::string categoryWoDeviceName = "some_category_without_device";
789     const Category* categoryWoDevice       = nullptr;
790     BOOST_CHECK_NO_THROW(categoryWoDevice = counterDirectory.RegisterCategory(categoryWoDeviceName, 0));
791     BOOST_CHECK(counterDirectory.GetCategoryCount() == 2);
792     BOOST_CHECK(categoryWoDevice);
793     BOOST_CHECK(categoryWoDevice->m_Name == categoryWoDeviceName);
794     BOOST_CHECK(categoryWoDevice->m_Counters.empty());
795     BOOST_CHECK(categoryWoDevice->m_DeviceUid == 0);
796     BOOST_CHECK(categoryWoDevice->m_CounterSetUid == 0);
797
798     // Register a new category associated to an invalid device
799     const std::string categoryWInvalidDeviceName = "some_category_with_invalid_device";
800
801     ARMNN_NO_CONVERSION_WARN_BEGIN
802     uint16_t invalidDeviceUid = device->m_Uid + 10;
803     ARMNN_NO_CONVERSION_WARN_END
804
805     const Category* categoryWInvalidDevice = nullptr;
806     BOOST_CHECK_THROW(categoryWInvalidDevice =
807                           counterDirectory.RegisterCategory(categoryWInvalidDeviceName, invalidDeviceUid),
808                       armnn::InvalidArgumentException);
809     BOOST_CHECK(counterDirectory.GetCategoryCount() == 2);
810     BOOST_CHECK(!categoryWInvalidDevice);
811
812     // Register a new category associated to a valid device
813     const std::string categoryWValidDeviceName = "some_category_with_valid_device";
814     const Category* categoryWValidDevice       = nullptr;
815     BOOST_CHECK_NO_THROW(categoryWValidDevice =
816                              counterDirectory.RegisterCategory(categoryWValidDeviceName, device->m_Uid));
817     BOOST_CHECK(counterDirectory.GetCategoryCount() == 3);
818     BOOST_CHECK(categoryWValidDevice);
819     BOOST_CHECK(categoryWValidDevice != category);
820     BOOST_CHECK(categoryWValidDevice->m_Name == categoryWValidDeviceName);
821     BOOST_CHECK(categoryWValidDevice->m_DeviceUid == device->m_Uid);
822     BOOST_CHECK(categoryWValidDevice->m_CounterSetUid == 0);
823
824     // Register a counter set for testing
825     const std::string counterSetName = "some_counter_set";
826     const CounterSet* counterSet     = nullptr;
827     BOOST_CHECK_NO_THROW(counterSet = counterDirectory.RegisterCounterSet(counterSetName));
828     BOOST_CHECK(counterDirectory.GetCounterSetCount() == 1);
829     BOOST_CHECK(counterSet);
830     BOOST_CHECK(counterSet->m_Uid >= 1);
831     BOOST_CHECK(counterSet->m_Name == counterSetName);
832     BOOST_CHECK(counterSet->m_Count == 0);
833
834     // Register a new category not associated to any counter set
835     const std::string categoryWoCounterSetName = "some_category_without_counter_set";
836     const Category* categoryWoCounterSet       = nullptr;
837     BOOST_CHECK_NO_THROW(categoryWoCounterSet =
838                              counterDirectory.RegisterCategory(categoryWoCounterSetName, armnn::EmptyOptional(), 0));
839     BOOST_CHECK(counterDirectory.GetCategoryCount() == 4);
840     BOOST_CHECK(categoryWoCounterSet);
841     BOOST_CHECK(categoryWoCounterSet->m_Name == categoryWoCounterSetName);
842     BOOST_CHECK(categoryWoCounterSet->m_DeviceUid == 0);
843     BOOST_CHECK(categoryWoCounterSet->m_CounterSetUid == 0);
844
845     // Register a new category associated to an invalid counter set
846     const std::string categoryWInvalidCounterSetName = "some_category_with_invalid_counter_set";
847
848     ARMNN_NO_CONVERSION_WARN_BEGIN
849     uint16_t invalidCunterSetUid = counterSet->m_Uid + 10;
850     ARMNN_NO_CONVERSION_WARN_END
851
852     const Category* categoryWInvalidCounterSet = nullptr;
853     BOOST_CHECK_THROW(categoryWInvalidCounterSet = counterDirectory.RegisterCategory(
854                           categoryWInvalidCounterSetName, armnn::EmptyOptional(), invalidCunterSetUid),
855                       armnn::InvalidArgumentException);
856     BOOST_CHECK(counterDirectory.GetCategoryCount() == 4);
857     BOOST_CHECK(!categoryWInvalidCounterSet);
858
859     // Register a new category associated to a valid counter set
860     const std::string categoryWValidCounterSetName = "some_category_with_valid_counter_set";
861     const Category* categoryWValidCounterSet       = nullptr;
862     BOOST_CHECK_NO_THROW(categoryWValidCounterSet = counterDirectory.RegisterCategory(
863                              categoryWValidCounterSetName, armnn::EmptyOptional(), counterSet->m_Uid));
864     BOOST_CHECK(counterDirectory.GetCategoryCount() == 5);
865     BOOST_CHECK(categoryWValidCounterSet);
866     BOOST_CHECK(categoryWValidCounterSet != category);
867     BOOST_CHECK(categoryWValidCounterSet->m_Name == categoryWValidCounterSetName);
868     BOOST_CHECK(categoryWValidCounterSet->m_DeviceUid == 0);
869     BOOST_CHECK(categoryWValidCounterSet->m_CounterSetUid == counterSet->m_Uid);
870
871     // Register a new category associated to a valid device and counter set
872     const std::string categoryWValidDeviceAndValidCounterSetName = "some_category_with_valid_device_and_counter_set";
873     const Category* categoryWValidDeviceAndValidCounterSet       = nullptr;
874     BOOST_CHECK_NO_THROW(categoryWValidDeviceAndValidCounterSet = counterDirectory.RegisterCategory(
875                              categoryWValidDeviceAndValidCounterSetName, device->m_Uid, counterSet->m_Uid));
876     BOOST_CHECK(counterDirectory.GetCategoryCount() == 6);
877     BOOST_CHECK(categoryWValidDeviceAndValidCounterSet);
878     BOOST_CHECK(categoryWValidDeviceAndValidCounterSet != category);
879     BOOST_CHECK(categoryWValidDeviceAndValidCounterSet->m_Name == categoryWValidDeviceAndValidCounterSetName);
880     BOOST_CHECK(categoryWValidDeviceAndValidCounterSet->m_DeviceUid == device->m_Uid);
881     BOOST_CHECK(categoryWValidDeviceAndValidCounterSet->m_CounterSetUid == counterSet->m_Uid);
882 }
883
884 BOOST_AUTO_TEST_CASE(CheckCounterDirectoryRegisterDevice)
885 {
886     CounterDirectory counterDirectory;
887     BOOST_CHECK(counterDirectory.GetCategoryCount() == 0);
888     BOOST_CHECK(counterDirectory.GetDeviceCount() == 0);
889     BOOST_CHECK(counterDirectory.GetCounterSetCount() == 0);
890     BOOST_CHECK(counterDirectory.GetCounterCount() == 0);
891
892     // Register a device with an invalid name
893     const Device* noDevice = nullptr;
894     BOOST_CHECK_THROW(noDevice = counterDirectory.RegisterDevice(""), armnn::InvalidArgumentException);
895     BOOST_CHECK(counterDirectory.GetDeviceCount() == 0);
896     BOOST_CHECK(!noDevice);
897
898     // Register a device with an invalid name
899     BOOST_CHECK_THROW(noDevice = counterDirectory.RegisterDevice("inv@lid nam€"), armnn::InvalidArgumentException);
900     BOOST_CHECK(counterDirectory.GetDeviceCount() == 0);
901     BOOST_CHECK(!noDevice);
902
903     // Register a new device with no cores or parent category
904     const std::string deviceName = "some_device";
905     const Device* device         = nullptr;
906     BOOST_CHECK_NO_THROW(device = counterDirectory.RegisterDevice(deviceName));
907     BOOST_CHECK(counterDirectory.GetDeviceCount() == 1);
908     BOOST_CHECK(device);
909     BOOST_CHECK(device->m_Name == deviceName);
910     BOOST_CHECK(device->m_Uid >= 1);
911     BOOST_CHECK(device->m_Cores == 0);
912
913     // Try getting an unregistered device
914     const Device* unregisteredDevice = counterDirectory.GetDevice(9999);
915     BOOST_CHECK(!unregisteredDevice);
916
917     // Get the registered device
918     const Device* registeredDevice = counterDirectory.GetDevice(device->m_Uid);
919     BOOST_CHECK(counterDirectory.GetDeviceCount() == 1);
920     BOOST_CHECK(registeredDevice);
921     BOOST_CHECK(registeredDevice == device);
922
923     // Register a device with the name of a device already registered
924     const Device* deviceSameName = nullptr;
925     BOOST_CHECK_THROW(deviceSameName = counterDirectory.RegisterDevice(deviceName), armnn::InvalidArgumentException);
926     BOOST_CHECK(counterDirectory.GetDeviceCount() == 1);
927     BOOST_CHECK(!deviceSameName);
928
929     // Register a new device with cores and no parent category
930     const std::string deviceWCoresName = "some_device_with_cores";
931     const Device* deviceWCores         = nullptr;
932     BOOST_CHECK_NO_THROW(deviceWCores = counterDirectory.RegisterDevice(deviceWCoresName, 2));
933     BOOST_CHECK(counterDirectory.GetDeviceCount() == 2);
934     BOOST_CHECK(deviceWCores);
935     BOOST_CHECK(deviceWCores->m_Name == deviceWCoresName);
936     BOOST_CHECK(deviceWCores->m_Uid >= 1);
937     BOOST_CHECK(deviceWCores->m_Uid > device->m_Uid);
938     BOOST_CHECK(deviceWCores->m_Cores == 2);
939
940     // Get the registered device
941     const Device* registeredDeviceWCores = counterDirectory.GetDevice(deviceWCores->m_Uid);
942     BOOST_CHECK(counterDirectory.GetDeviceCount() == 2);
943     BOOST_CHECK(registeredDeviceWCores);
944     BOOST_CHECK(registeredDeviceWCores == deviceWCores);
945     BOOST_CHECK(registeredDeviceWCores != device);
946
947     // Register a new device with cores and invalid parent category
948     const std::string deviceWCoresWInvalidParentCategoryName = "some_device_with_cores_with_invalid_parent_category";
949     const Device* deviceWCoresWInvalidParentCategory         = nullptr;
950     BOOST_CHECK_THROW(deviceWCoresWInvalidParentCategory =
951                           counterDirectory.RegisterDevice(deviceWCoresWInvalidParentCategoryName, 3, std::string("")),
952                       armnn::InvalidArgumentException);
953     BOOST_CHECK(counterDirectory.GetDeviceCount() == 2);
954     BOOST_CHECK(!deviceWCoresWInvalidParentCategory);
955
956     // Register a new device with cores and invalid parent category
957     const std::string deviceWCoresWInvalidParentCategoryName2 = "some_device_with_cores_with_invalid_parent_category2";
958     const Device* deviceWCoresWInvalidParentCategory2         = nullptr;
959     BOOST_CHECK_THROW(deviceWCoresWInvalidParentCategory2 = counterDirectory.RegisterDevice(
960                           deviceWCoresWInvalidParentCategoryName2, 3, std::string("invalid_parent_category")),
961                       armnn::InvalidArgumentException);
962     BOOST_CHECK(counterDirectory.GetDeviceCount() == 2);
963     BOOST_CHECK(!deviceWCoresWInvalidParentCategory2);
964
965     // Register a category for testing
966     const std::string categoryName = "some_category";
967     const Category* category       = nullptr;
968     BOOST_CHECK_NO_THROW(category = counterDirectory.RegisterCategory(categoryName));
969     BOOST_CHECK(counterDirectory.GetCategoryCount() == 1);
970     BOOST_CHECK(category);
971     BOOST_CHECK(category->m_Name == categoryName);
972     BOOST_CHECK(category->m_Counters.empty());
973     BOOST_CHECK(category->m_DeviceUid == 0);
974     BOOST_CHECK(category->m_CounterSetUid == 0);
975
976     // Register a new device with cores and valid parent category
977     const std::string deviceWCoresWValidParentCategoryName = "some_device_with_cores_with_valid_parent_category";
978     const Device* deviceWCoresWValidParentCategory         = nullptr;
979     BOOST_CHECK_NO_THROW(deviceWCoresWValidParentCategory =
980                              counterDirectory.RegisterDevice(deviceWCoresWValidParentCategoryName, 4, categoryName));
981     BOOST_CHECK(counterDirectory.GetDeviceCount() == 3);
982     BOOST_CHECK(deviceWCoresWValidParentCategory);
983     BOOST_CHECK(deviceWCoresWValidParentCategory->m_Name == deviceWCoresWValidParentCategoryName);
984     BOOST_CHECK(deviceWCoresWValidParentCategory->m_Uid >= 1);
985     BOOST_CHECK(deviceWCoresWValidParentCategory->m_Uid > device->m_Uid);
986     BOOST_CHECK(deviceWCoresWValidParentCategory->m_Uid > deviceWCores->m_Uid);
987     BOOST_CHECK(deviceWCoresWValidParentCategory->m_Cores == 4);
988     BOOST_CHECK(category->m_DeviceUid == deviceWCoresWValidParentCategory->m_Uid);
989
990     // Register a device associated to a category already associated to a different device
991     const std::string deviceSameCategoryName = "some_device_with_invalid_parent_category";
992     const Device* deviceSameCategory         = nullptr;
993     BOOST_CHECK_THROW(deviceSameCategory = counterDirectory.RegisterDevice(deviceSameCategoryName, 0, categoryName),
994                       armnn::InvalidArgumentException);
995     BOOST_CHECK(counterDirectory.GetDeviceCount() == 3);
996     BOOST_CHECK(!deviceSameCategory);
997 }
998
999 BOOST_AUTO_TEST_CASE(CheckCounterDirectoryRegisterCounterSet)
1000 {
1001     CounterDirectory counterDirectory;
1002     BOOST_CHECK(counterDirectory.GetCategoryCount() == 0);
1003     BOOST_CHECK(counterDirectory.GetDeviceCount() == 0);
1004     BOOST_CHECK(counterDirectory.GetCounterSetCount() == 0);
1005     BOOST_CHECK(counterDirectory.GetCounterCount() == 0);
1006
1007     // Register a counter set with an invalid name
1008     const CounterSet* noCounterSet = nullptr;
1009     BOOST_CHECK_THROW(noCounterSet = counterDirectory.RegisterCounterSet(""), armnn::InvalidArgumentException);
1010     BOOST_CHECK(counterDirectory.GetCounterSetCount() == 0);
1011     BOOST_CHECK(!noCounterSet);
1012
1013     // Register a counter set with an invalid name
1014     BOOST_CHECK_THROW(noCounterSet = counterDirectory.RegisterCounterSet("invalid name"),
1015                       armnn::InvalidArgumentException);
1016     BOOST_CHECK(counterDirectory.GetCounterSetCount() == 0);
1017     BOOST_CHECK(!noCounterSet);
1018
1019     // Register a new counter set with no count or parent category
1020     const std::string counterSetName = "some_counter_set";
1021     const CounterSet* counterSet     = nullptr;
1022     BOOST_CHECK_NO_THROW(counterSet = counterDirectory.RegisterCounterSet(counterSetName));
1023     BOOST_CHECK(counterDirectory.GetCounterSetCount() == 1);
1024     BOOST_CHECK(counterSet);
1025     BOOST_CHECK(counterSet->m_Name == counterSetName);
1026     BOOST_CHECK(counterSet->m_Uid >= 1);
1027     BOOST_CHECK(counterSet->m_Count == 0);
1028
1029     // Try getting an unregistered counter set
1030     const CounterSet* unregisteredCounterSet = counterDirectory.GetCounterSet(9999);
1031     BOOST_CHECK(!unregisteredCounterSet);
1032
1033     // Get the registered counter set
1034     const CounterSet* registeredCounterSet = counterDirectory.GetCounterSet(counterSet->m_Uid);
1035     BOOST_CHECK(counterDirectory.GetCounterSetCount() == 1);
1036     BOOST_CHECK(registeredCounterSet);
1037     BOOST_CHECK(registeredCounterSet == counterSet);
1038
1039     // Register a counter set with the name of a counter set already registered
1040     const CounterSet* counterSetSameName = nullptr;
1041     BOOST_CHECK_THROW(counterSetSameName = counterDirectory.RegisterCounterSet(counterSetName),
1042                       armnn::InvalidArgumentException);
1043     BOOST_CHECK(counterDirectory.GetCounterSetCount() == 1);
1044     BOOST_CHECK(!counterSetSameName);
1045
1046     // Register a new counter set with count and no parent category
1047     const std::string counterSetWCountName = "some_counter_set_with_count";
1048     const CounterSet* counterSetWCount     = nullptr;
1049     BOOST_CHECK_NO_THROW(counterSetWCount = counterDirectory.RegisterCounterSet(counterSetWCountName, 37));
1050     BOOST_CHECK(counterDirectory.GetCounterSetCount() == 2);
1051     BOOST_CHECK(counterSetWCount);
1052     BOOST_CHECK(counterSetWCount->m_Name == counterSetWCountName);
1053     BOOST_CHECK(counterSetWCount->m_Uid >= 1);
1054     BOOST_CHECK(counterSetWCount->m_Uid > counterSet->m_Uid);
1055     BOOST_CHECK(counterSetWCount->m_Count == 37);
1056
1057     // Get the registered counter set
1058     const CounterSet* registeredCounterSetWCount = counterDirectory.GetCounterSet(counterSetWCount->m_Uid);
1059     BOOST_CHECK(counterDirectory.GetCounterSetCount() == 2);
1060     BOOST_CHECK(registeredCounterSetWCount);
1061     BOOST_CHECK(registeredCounterSetWCount == counterSetWCount);
1062     BOOST_CHECK(registeredCounterSetWCount != counterSet);
1063
1064     // Register a new counter set with count and invalid parent category
1065     const std::string counterSetWCountWInvalidParentCategoryName = "some_counter_set_with_count_"
1066                                                                    "with_invalid_parent_category";
1067     const CounterSet* counterSetWCountWInvalidParentCategory = nullptr;
1068     BOOST_CHECK_THROW(counterSetWCountWInvalidParentCategory = counterDirectory.RegisterCounterSet(
1069                           counterSetWCountWInvalidParentCategoryName, 42, std::string("")),
1070                       armnn::InvalidArgumentException);
1071     BOOST_CHECK(counterDirectory.GetCounterSetCount() == 2);
1072     BOOST_CHECK(!counterSetWCountWInvalidParentCategory);
1073
1074     // Register a new counter set with count and invalid parent category
1075     const std::string counterSetWCountWInvalidParentCategoryName2 = "some_counter_set_with_count_"
1076                                                                     "with_invalid_parent_category2";
1077     const CounterSet* counterSetWCountWInvalidParentCategory2 = nullptr;
1078     BOOST_CHECK_THROW(counterSetWCountWInvalidParentCategory2 = counterDirectory.RegisterCounterSet(
1079                           counterSetWCountWInvalidParentCategoryName2, 42, std::string("invalid_parent_category")),
1080                       armnn::InvalidArgumentException);
1081     BOOST_CHECK(counterDirectory.GetCounterSetCount() == 2);
1082     BOOST_CHECK(!counterSetWCountWInvalidParentCategory2);
1083
1084     // Register a category for testing
1085     const std::string categoryName = "some_category";
1086     const Category* category       = nullptr;
1087     BOOST_CHECK_NO_THROW(category = counterDirectory.RegisterCategory(categoryName));
1088     BOOST_CHECK(counterDirectory.GetCategoryCount() == 1);
1089     BOOST_CHECK(category);
1090     BOOST_CHECK(category->m_Name == categoryName);
1091     BOOST_CHECK(category->m_Counters.empty());
1092     BOOST_CHECK(category->m_DeviceUid == 0);
1093     BOOST_CHECK(category->m_CounterSetUid == 0);
1094
1095     // Register a new counter set with count and valid parent category
1096     const std::string counterSetWCountWValidParentCategoryName = "some_counter_set_with_count_"
1097                                                                  "with_valid_parent_category";
1098     const CounterSet* counterSetWCountWValidParentCategory = nullptr;
1099     BOOST_CHECK_NO_THROW(counterSetWCountWValidParentCategory = counterDirectory.RegisterCounterSet(
1100                              counterSetWCountWValidParentCategoryName, 42, categoryName));
1101     BOOST_CHECK(counterDirectory.GetCounterSetCount() == 3);
1102     BOOST_CHECK(counterSetWCountWValidParentCategory);
1103     BOOST_CHECK(counterSetWCountWValidParentCategory->m_Name == counterSetWCountWValidParentCategoryName);
1104     BOOST_CHECK(counterSetWCountWValidParentCategory->m_Uid >= 1);
1105     BOOST_CHECK(counterSetWCountWValidParentCategory->m_Uid > counterSet->m_Uid);
1106     BOOST_CHECK(counterSetWCountWValidParentCategory->m_Uid > counterSetWCount->m_Uid);
1107     BOOST_CHECK(counterSetWCountWValidParentCategory->m_Count == 42);
1108     BOOST_CHECK(category->m_CounterSetUid == counterSetWCountWValidParentCategory->m_Uid);
1109
1110     // Register a counter set associated to a category already associated to a different counter set
1111     const std::string counterSetSameCategoryName = "some_counter_set_with_invalid_parent_category";
1112     const CounterSet* counterSetSameCategory     = nullptr;
1113     BOOST_CHECK_THROW(counterSetSameCategory =
1114                           counterDirectory.RegisterCounterSet(counterSetSameCategoryName, 0, categoryName),
1115                       armnn::InvalidArgumentException);
1116     BOOST_CHECK(counterDirectory.GetCounterSetCount() == 3);
1117     BOOST_CHECK(!counterSetSameCategory);
1118 }
1119
1120 BOOST_AUTO_TEST_CASE(CheckCounterDirectoryRegisterCounter)
1121 {
1122     CounterDirectory counterDirectory;
1123     BOOST_CHECK(counterDirectory.GetCategoryCount() == 0);
1124     BOOST_CHECK(counterDirectory.GetDeviceCount() == 0);
1125     BOOST_CHECK(counterDirectory.GetCounterSetCount() == 0);
1126     BOOST_CHECK(counterDirectory.GetCounterCount() == 0);
1127
1128     // Register a counter with an invalid parent category name
1129     const Counter* noCounter = nullptr;
1130     BOOST_CHECK_THROW(noCounter =
1131                           counterDirectory.RegisterCounter("", 0, 1, 123.45f, "valid name", "valid description"),
1132                       armnn::InvalidArgumentException);
1133     BOOST_CHECK(counterDirectory.GetCounterCount() == 0);
1134     BOOST_CHECK(!noCounter);
1135
1136     // Register a counter with an invalid parent category name
1137     BOOST_CHECK_THROW(noCounter = counterDirectory.RegisterCounter("invalid parent category", 0, 1, 123.45f,
1138                                                                    "valid name", "valid description"),
1139                       armnn::InvalidArgumentException);
1140     BOOST_CHECK(counterDirectory.GetCounterCount() == 0);
1141     BOOST_CHECK(!noCounter);
1142
1143     // Register a counter with an invalid class
1144     BOOST_CHECK_THROW(noCounter = counterDirectory.RegisterCounter("valid_parent_category", 2, 1, 123.45f, "valid name",
1145                                                                    "valid description"),
1146                       armnn::InvalidArgumentException);
1147     BOOST_CHECK(counterDirectory.GetCounterCount() == 0);
1148     BOOST_CHECK(!noCounter);
1149
1150     // Register a counter with an invalid interpolation
1151     BOOST_CHECK_THROW(noCounter = counterDirectory.RegisterCounter("valid_parent_category", 0, 3, 123.45f, "valid name",
1152                                                                    "valid description"),
1153                       armnn::InvalidArgumentException);
1154     BOOST_CHECK(counterDirectory.GetCounterCount() == 0);
1155     BOOST_CHECK(!noCounter);
1156
1157     // Register a counter with an invalid multiplier
1158     BOOST_CHECK_THROW(noCounter = counterDirectory.RegisterCounter("valid_parent_category", 0, 1, .0f, "valid name",
1159                                                                    "valid description"),
1160                       armnn::InvalidArgumentException);
1161     BOOST_CHECK(counterDirectory.GetCounterCount() == 0);
1162     BOOST_CHECK(!noCounter);
1163
1164     // Register a counter with an invalid name
1165     BOOST_CHECK_THROW(
1166         noCounter = counterDirectory.RegisterCounter("valid_parent_category", 0, 1, 123.45f, "", "valid description"),
1167         armnn::InvalidArgumentException);
1168     BOOST_CHECK(counterDirectory.GetCounterCount() == 0);
1169     BOOST_CHECK(!noCounter);
1170
1171     // Register a counter with an invalid name
1172     BOOST_CHECK_THROW(noCounter = counterDirectory.RegisterCounter("valid_parent_category", 0, 1, 123.45f,
1173                                                                    "invalid nam€", "valid description"),
1174                       armnn::InvalidArgumentException);
1175     BOOST_CHECK(counterDirectory.GetCounterCount() == 0);
1176     BOOST_CHECK(!noCounter);
1177
1178     // Register a counter with an invalid description
1179     BOOST_CHECK_THROW(noCounter =
1180                           counterDirectory.RegisterCounter("valid_parent_category", 0, 1, 123.45f, "valid name", ""),
1181                       armnn::InvalidArgumentException);
1182     BOOST_CHECK(counterDirectory.GetCounterCount() == 0);
1183     BOOST_CHECK(!noCounter);
1184
1185     // Register a counter with an invalid description
1186     BOOST_CHECK_THROW(noCounter = counterDirectory.RegisterCounter("valid_parent_category", 0, 1, 123.45f, "valid name",
1187                                                                    "inv@lid description"),
1188                       armnn::InvalidArgumentException);
1189     BOOST_CHECK(counterDirectory.GetCounterCount() == 0);
1190     BOOST_CHECK(!noCounter);
1191
1192     // Register a counter with an invalid unit2
1193     BOOST_CHECK_THROW(noCounter = counterDirectory.RegisterCounter("valid_parent_category", 0, 1, 123.45f, "valid name",
1194                                                                    "valid description", std::string("Mb/s2")),
1195                       armnn::InvalidArgumentException);
1196     BOOST_CHECK(counterDirectory.GetCounterCount() == 0);
1197     BOOST_CHECK(!noCounter);
1198
1199     // Register a counter with a non-existing parent category name
1200     BOOST_CHECK_THROW(noCounter = counterDirectory.RegisterCounter("invalid_parent_category", 0, 1, 123.45f,
1201                                                                    "valid name", "valid description"),
1202                       armnn::InvalidArgumentException);
1203     BOOST_CHECK(counterDirectory.GetCounterCount() == 0);
1204     BOOST_CHECK(!noCounter);
1205
1206     // Try getting an unregistered counter
1207     const Counter* unregisteredCounter = counterDirectory.GetCounter(9999);
1208     BOOST_CHECK(!unregisteredCounter);
1209
1210     // Register a category for testing
1211     const std::string categoryName = "some_category";
1212     const Category* category       = nullptr;
1213     BOOST_CHECK_NO_THROW(category = counterDirectory.RegisterCategory(categoryName));
1214     BOOST_CHECK(counterDirectory.GetCategoryCount() == 1);
1215     BOOST_CHECK(category);
1216     BOOST_CHECK(category->m_Name == categoryName);
1217     BOOST_CHECK(category->m_Counters.empty());
1218     BOOST_CHECK(category->m_DeviceUid == 0);
1219     BOOST_CHECK(category->m_CounterSetUid == 0);
1220
1221     // Register a counter with a valid parent category name
1222     const Counter* counter = nullptr;
1223     BOOST_CHECK_NO_THROW(
1224         counter = counterDirectory.RegisterCounter(categoryName, 0, 1, 123.45f, "valid name", "valid description"));
1225     BOOST_CHECK(counterDirectory.GetCounterCount() == 1);
1226     BOOST_CHECK(counter);
1227     BOOST_CHECK(counter->m_Uid >= 0);
1228     BOOST_CHECK(counter->m_MaxCounterUid == counter->m_Uid);
1229     BOOST_CHECK(counter->m_Class == 0);
1230     BOOST_CHECK(counter->m_Interpolation == 1);
1231     BOOST_CHECK(counter->m_Multiplier == 123.45f);
1232     BOOST_CHECK(counter->m_Name == "valid name");
1233     BOOST_CHECK(counter->m_Description == "valid description");
1234     BOOST_CHECK(counter->m_Units == "");
1235     BOOST_CHECK(counter->m_DeviceUid == 0);
1236     BOOST_CHECK(counter->m_CounterSetUid == 0);
1237     BOOST_CHECK(category->m_Counters.size() == 1);
1238     BOOST_CHECK(category->m_Counters.back() == counter->m_Uid);
1239
1240     // Register a counter with a name of a counter already registered for the given parent category name
1241     const Counter* counterSameName = nullptr;
1242     BOOST_CHECK_THROW(counterSameName =
1243                           counterDirectory.RegisterCounter(categoryName, 0, 0, 1.0f, "valid name", "valid description"),
1244                       armnn::InvalidArgumentException);
1245     BOOST_CHECK(counterDirectory.GetCounterCount() == 1);
1246     BOOST_CHECK(!counterSameName);
1247
1248     // Register a counter with a valid parent category name and units
1249     const Counter* counterWUnits = nullptr;
1250     BOOST_CHECK_NO_THROW(counterWUnits = counterDirectory.RegisterCounter(categoryName, 0, 1, 123.45f, "valid name 2",
1251                                                                           "valid description",
1252                                                                           std::string("Mnnsq2")));    // Units
1253     BOOST_CHECK(counterDirectory.GetCounterCount() == 2);
1254     BOOST_CHECK(counterWUnits);
1255     BOOST_CHECK(counterWUnits->m_Uid >= 0);
1256     BOOST_CHECK(counterWUnits->m_Uid > counter->m_Uid);
1257     BOOST_CHECK(counterWUnits->m_MaxCounterUid == counterWUnits->m_Uid);
1258     BOOST_CHECK(counterWUnits->m_Class == 0);
1259     BOOST_CHECK(counterWUnits->m_Interpolation == 1);
1260     BOOST_CHECK(counterWUnits->m_Multiplier == 123.45f);
1261     BOOST_CHECK(counterWUnits->m_Name == "valid name 2");
1262     BOOST_CHECK(counterWUnits->m_Description == "valid description");
1263     BOOST_CHECK(counterWUnits->m_Units == "Mnnsq2");
1264     BOOST_CHECK(counterWUnits->m_DeviceUid == 0);
1265     BOOST_CHECK(counterWUnits->m_CounterSetUid == 0);
1266     BOOST_CHECK(category->m_Counters.size() == 2);
1267     BOOST_CHECK(category->m_Counters.back() == counterWUnits->m_Uid);
1268
1269     // Register a counter with a valid parent category name and not associated with a device
1270     const Counter* counterWoDevice = nullptr;
1271     BOOST_CHECK_NO_THROW(counterWoDevice = counterDirectory.RegisterCounter(
1272                              categoryName, 0, 1, 123.45f, "valid name 3", "valid description",
1273                              armnn::EmptyOptional(),    // Units
1274                              armnn::EmptyOptional(),    // Number of cores
1275                              0));                       // Device UID
1276     BOOST_CHECK(counterDirectory.GetCounterCount() == 3);
1277     BOOST_CHECK(counterWoDevice);
1278     BOOST_CHECK(counterWoDevice->m_Uid >= 0);
1279     BOOST_CHECK(counterWoDevice->m_Uid > counter->m_Uid);
1280     BOOST_CHECK(counterWoDevice->m_MaxCounterUid == counterWoDevice->m_Uid);
1281     BOOST_CHECK(counterWoDevice->m_Class == 0);
1282     BOOST_CHECK(counterWoDevice->m_Interpolation == 1);
1283     BOOST_CHECK(counterWoDevice->m_Multiplier == 123.45f);
1284     BOOST_CHECK(counterWoDevice->m_Name == "valid name 3");
1285     BOOST_CHECK(counterWoDevice->m_Description == "valid description");
1286     BOOST_CHECK(counterWoDevice->m_Units == "");
1287     BOOST_CHECK(counterWoDevice->m_DeviceUid == 0);
1288     BOOST_CHECK(counterWoDevice->m_CounterSetUid == 0);
1289     BOOST_CHECK(category->m_Counters.size() == 3);
1290     BOOST_CHECK(category->m_Counters.back() == counterWoDevice->m_Uid);
1291
1292     // Register a counter with a valid parent category name and associated to an invalid device
1293     BOOST_CHECK_THROW(noCounter = counterDirectory.RegisterCounter(categoryName, 0, 1, 123.45f, "valid name 4",
1294                                                                    "valid description",
1295                                                                    armnn::EmptyOptional(),    // Units
1296                                                                    armnn::EmptyOptional(),    // Number of cores
1297                                                                    100),                      // Device UID
1298                       armnn::InvalidArgumentException);
1299     BOOST_CHECK(counterDirectory.GetCounterCount() == 3);
1300     BOOST_CHECK(!noCounter);
1301
1302     // Register a device for testing
1303     const std::string deviceName = "some_device";
1304     const Device* device         = nullptr;
1305     BOOST_CHECK_NO_THROW(device = counterDirectory.RegisterDevice(deviceName));
1306     BOOST_CHECK(counterDirectory.GetDeviceCount() == 1);
1307     BOOST_CHECK(device);
1308     BOOST_CHECK(device->m_Name == deviceName);
1309     BOOST_CHECK(device->m_Uid >= 1);
1310     BOOST_CHECK(device->m_Cores == 0);
1311
1312     // Register a counter with a valid parent category name and associated to a device
1313     const Counter* counterWDevice = nullptr;
1314     BOOST_CHECK_NO_THROW(counterWDevice = counterDirectory.RegisterCounter(categoryName, 0, 1, 123.45f, "valid name 5",
1315                                                                            "valid description",
1316                                                                            armnn::EmptyOptional(),    // Units
1317                                                                            armnn::EmptyOptional(),    // Number of cores
1318                                                                            device->m_Uid));           // Device UID
1319     BOOST_CHECK(counterDirectory.GetCounterCount() == 4);
1320     BOOST_CHECK(counterWDevice);
1321     BOOST_CHECK(counterWDevice->m_Uid >= 0);
1322     BOOST_CHECK(counterWDevice->m_Uid > counter->m_Uid);
1323     BOOST_CHECK(counterWDevice->m_MaxCounterUid == counterWDevice->m_Uid);
1324     BOOST_CHECK(counterWDevice->m_Class == 0);
1325     BOOST_CHECK(counterWDevice->m_Interpolation == 1);
1326     BOOST_CHECK(counterWDevice->m_Multiplier == 123.45f);
1327     BOOST_CHECK(counterWDevice->m_Name == "valid name 5");
1328     BOOST_CHECK(counterWDevice->m_Description == "valid description");
1329     BOOST_CHECK(counterWDevice->m_Units == "");
1330     BOOST_CHECK(counterWDevice->m_DeviceUid == device->m_Uid);
1331     BOOST_CHECK(counterWDevice->m_CounterSetUid == 0);
1332     BOOST_CHECK(category->m_Counters.size() == 4);
1333     BOOST_CHECK(category->m_Counters.back() == counterWDevice->m_Uid);
1334
1335     // Register a counter with a valid parent category name and not associated with a counter set
1336     const Counter* counterWoCounterSet = nullptr;
1337     BOOST_CHECK_NO_THROW(counterWoCounterSet = counterDirectory.RegisterCounter(
1338                              categoryName, 0, 1, 123.45f, "valid name 6", "valid description",
1339                              armnn::EmptyOptional(),    // Units
1340                              armnn::EmptyOptional(),    // Number of cores
1341                              armnn::EmptyOptional(),    // Device UID
1342                              0));                       // Counter set UID
1343     BOOST_CHECK(counterDirectory.GetCounterCount() == 5);
1344     BOOST_CHECK(counterWoCounterSet);
1345     BOOST_CHECK(counterWoCounterSet->m_Uid >= 0);
1346     BOOST_CHECK(counterWoCounterSet->m_Uid > counter->m_Uid);
1347     BOOST_CHECK(counterWoCounterSet->m_MaxCounterUid == counterWoCounterSet->m_Uid);
1348     BOOST_CHECK(counterWoCounterSet->m_Class == 0);
1349     BOOST_CHECK(counterWoCounterSet->m_Interpolation == 1);
1350     BOOST_CHECK(counterWoCounterSet->m_Multiplier == 123.45f);
1351     BOOST_CHECK(counterWoCounterSet->m_Name == "valid name 6");
1352     BOOST_CHECK(counterWoCounterSet->m_Description == "valid description");
1353     BOOST_CHECK(counterWoCounterSet->m_Units == "");
1354     BOOST_CHECK(counterWoCounterSet->m_DeviceUid == 0);
1355     BOOST_CHECK(counterWoCounterSet->m_CounterSetUid == 0);
1356     BOOST_CHECK(category->m_Counters.size() == 5);
1357     BOOST_CHECK(category->m_Counters.back() == counterWoCounterSet->m_Uid);
1358
1359     // Register a counter with a valid parent category name and associated to an invalid counter set
1360     BOOST_CHECK_THROW(noCounter = counterDirectory.RegisterCounter(categoryName, 0, 1, 123.45f, "valid name 7",
1361                                                                    "valid description",
1362                                                                    armnn::EmptyOptional(),    // Units
1363                                                                    armnn::EmptyOptional(),    // Number of cores
1364                                                                    armnn::EmptyOptional(),    // Device UID
1365                                                                    100),                      // Counter set UID
1366                       armnn::InvalidArgumentException);
1367     BOOST_CHECK(counterDirectory.GetCounterCount() == 5);
1368     BOOST_CHECK(!noCounter);
1369
1370     // Register a counter with a valid parent category name and with a given number of cores
1371     const Counter* counterWNumberOfCores = nullptr;
1372     uint16_t numberOfCores               = 15;
1373     BOOST_CHECK_NO_THROW(counterWNumberOfCores = counterDirectory.RegisterCounter(
1374                              categoryName, 0, 1, 123.45f, "valid name 8", "valid description",
1375                              armnn::EmptyOptional(),      // Units
1376                              numberOfCores,               // Number of cores
1377                              armnn::EmptyOptional(),      // Device UID
1378                              armnn::EmptyOptional()));    // Counter set UID
1379     BOOST_CHECK(counterDirectory.GetCounterCount() == 20);
1380     BOOST_CHECK(counterWNumberOfCores);
1381     BOOST_CHECK(counterWNumberOfCores->m_Uid >= 0);
1382     BOOST_CHECK(counterWNumberOfCores->m_Uid > counter->m_Uid);
1383     BOOST_CHECK(counterWNumberOfCores->m_MaxCounterUid == counterWNumberOfCores->m_Uid + numberOfCores - 1);
1384     BOOST_CHECK(counterWNumberOfCores->m_Class == 0);
1385     BOOST_CHECK(counterWNumberOfCores->m_Interpolation == 1);
1386     BOOST_CHECK(counterWNumberOfCores->m_Multiplier == 123.45f);
1387     BOOST_CHECK(counterWNumberOfCores->m_Name == "valid name 8");
1388     BOOST_CHECK(counterWNumberOfCores->m_Description == "valid description");
1389     BOOST_CHECK(counterWNumberOfCores->m_Units == "");
1390     BOOST_CHECK(counterWNumberOfCores->m_DeviceUid == 0);
1391     BOOST_CHECK(counterWNumberOfCores->m_CounterSetUid == 0);
1392     BOOST_CHECK(category->m_Counters.size() == 20);
1393     for (size_t i = 0; i < numberOfCores; i++)
1394     {
1395         BOOST_CHECK(category->m_Counters[category->m_Counters.size() - numberOfCores + i] ==
1396                     counterWNumberOfCores->m_Uid + i);
1397     }
1398
1399     // Register a multi-core device for testing
1400     const std::string multiCoreDeviceName = "some_multi_core_device";
1401     const Device* multiCoreDevice         = nullptr;
1402     BOOST_CHECK_NO_THROW(multiCoreDevice = counterDirectory.RegisterDevice(multiCoreDeviceName, 4));
1403     BOOST_CHECK(counterDirectory.GetDeviceCount() == 2);
1404     BOOST_CHECK(multiCoreDevice);
1405     BOOST_CHECK(multiCoreDevice->m_Name == multiCoreDeviceName);
1406     BOOST_CHECK(multiCoreDevice->m_Uid >= 1);
1407     BOOST_CHECK(multiCoreDevice->m_Cores == 4);
1408
1409     // Register a counter with a valid parent category name and associated to the multi-core device
1410     const Counter* counterWMultiCoreDevice = nullptr;
1411     BOOST_CHECK_NO_THROW(counterWMultiCoreDevice = counterDirectory.RegisterCounter(
1412                              categoryName, 0, 1, 123.45f, "valid name 9", "valid description",
1413                              armnn::EmptyOptional(),      // Units
1414                              armnn::EmptyOptional(),      // Number of cores
1415                              multiCoreDevice->m_Uid,      // Device UID
1416                              armnn::EmptyOptional()));    // Counter set UID
1417     BOOST_CHECK(counterDirectory.GetCounterCount() == 24);
1418     BOOST_CHECK(counterWMultiCoreDevice);
1419     BOOST_CHECK(counterWMultiCoreDevice->m_Uid >= 0);
1420     BOOST_CHECK(counterWMultiCoreDevice->m_Uid > counter->m_Uid);
1421     BOOST_CHECK(counterWMultiCoreDevice->m_MaxCounterUid ==
1422                 counterWMultiCoreDevice->m_Uid + multiCoreDevice->m_Cores - 1);
1423     BOOST_CHECK(counterWMultiCoreDevice->m_Class == 0);
1424     BOOST_CHECK(counterWMultiCoreDevice->m_Interpolation == 1);
1425     BOOST_CHECK(counterWMultiCoreDevice->m_Multiplier == 123.45f);
1426     BOOST_CHECK(counterWMultiCoreDevice->m_Name == "valid name 9");
1427     BOOST_CHECK(counterWMultiCoreDevice->m_Description == "valid description");
1428     BOOST_CHECK(counterWMultiCoreDevice->m_Units == "");
1429     BOOST_CHECK(counterWMultiCoreDevice->m_DeviceUid == multiCoreDevice->m_Uid);
1430     BOOST_CHECK(counterWMultiCoreDevice->m_CounterSetUid == 0);
1431     BOOST_CHECK(category->m_Counters.size() == 24);
1432     for (size_t i = 0; i < 4; i++)
1433     {
1434         BOOST_CHECK(category->m_Counters[category->m_Counters.size() - 4 + i] == counterWMultiCoreDevice->m_Uid + i);
1435     }
1436
1437     // Register a multi-core device associate to a parent category for testing
1438     const std::string multiCoreDeviceNameWParentCategory = "some_multi_core_device_with_parent_category";
1439     const Device* multiCoreDeviceWParentCategory         = nullptr;
1440     BOOST_CHECK_NO_THROW(multiCoreDeviceWParentCategory =
1441                              counterDirectory.RegisterDevice(multiCoreDeviceNameWParentCategory, 2, categoryName));
1442     BOOST_CHECK(counterDirectory.GetDeviceCount() == 3);
1443     BOOST_CHECK(multiCoreDeviceWParentCategory);
1444     BOOST_CHECK(multiCoreDeviceWParentCategory->m_Name == multiCoreDeviceNameWParentCategory);
1445     BOOST_CHECK(multiCoreDeviceWParentCategory->m_Uid >= 1);
1446     BOOST_CHECK(multiCoreDeviceWParentCategory->m_Cores == 2);
1447
1448     // Register a counter with a valid parent category name and getting the number of cores of the multi-core device
1449     // associated to that category
1450     const Counter* counterWMultiCoreDeviceWParentCategory = nullptr;
1451     BOOST_CHECK_NO_THROW(counterWMultiCoreDeviceWParentCategory = counterDirectory.RegisterCounter(
1452                              categoryName, 0, 1, 123.45f, "valid name 10", "valid description",
1453                              armnn::EmptyOptional(),      // Units
1454                              armnn::EmptyOptional(),      // Number of cores
1455                              armnn::EmptyOptional(),      // Device UID
1456                              armnn::EmptyOptional()));    // Counter set UID
1457     BOOST_CHECK(counterDirectory.GetCounterCount() == 26);
1458     BOOST_CHECK(counterWMultiCoreDeviceWParentCategory);
1459     BOOST_CHECK(counterWMultiCoreDeviceWParentCategory->m_Uid >= 0);
1460     BOOST_CHECK(counterWMultiCoreDeviceWParentCategory->m_Uid > counter->m_Uid);
1461     BOOST_CHECK(counterWMultiCoreDeviceWParentCategory->m_MaxCounterUid ==
1462                 counterWMultiCoreDeviceWParentCategory->m_Uid + multiCoreDeviceWParentCategory->m_Cores - 1);
1463     BOOST_CHECK(counterWMultiCoreDeviceWParentCategory->m_Class == 0);
1464     BOOST_CHECK(counterWMultiCoreDeviceWParentCategory->m_Interpolation == 1);
1465     BOOST_CHECK(counterWMultiCoreDeviceWParentCategory->m_Multiplier == 123.45f);
1466     BOOST_CHECK(counterWMultiCoreDeviceWParentCategory->m_Name == "valid name 10");
1467     BOOST_CHECK(counterWMultiCoreDeviceWParentCategory->m_Description == "valid description");
1468     BOOST_CHECK(counterWMultiCoreDeviceWParentCategory->m_Units == "");
1469     BOOST_CHECK(counterWMultiCoreDeviceWParentCategory->m_DeviceUid == 0);
1470     BOOST_CHECK(counterWMultiCoreDeviceWParentCategory->m_CounterSetUid == 0);
1471     BOOST_CHECK(category->m_Counters.size() == 26);
1472     for (size_t i = 0; i < 2; i++)
1473     {
1474         BOOST_CHECK(category->m_Counters[category->m_Counters.size() - 2 + i] ==
1475                     counterWMultiCoreDeviceWParentCategory->m_Uid + i);
1476     }
1477
1478     // Register a counter set for testing
1479     const std::string counterSetName = "some_counter_set";
1480     const CounterSet* counterSet     = nullptr;
1481     BOOST_CHECK_NO_THROW(counterSet = counterDirectory.RegisterCounterSet(counterSetName));
1482     BOOST_CHECK(counterDirectory.GetCounterSetCount() == 1);
1483     BOOST_CHECK(counterSet);
1484     BOOST_CHECK(counterSet->m_Name == counterSetName);
1485     BOOST_CHECK(counterSet->m_Uid >= 1);
1486     BOOST_CHECK(counterSet->m_Count == 0);
1487
1488     // Register a counter with a valid parent category name and associated to a counter set
1489     const Counter* counterWCounterSet = nullptr;
1490     BOOST_CHECK_NO_THROW(counterWCounterSet = counterDirectory.RegisterCounter(
1491                              categoryName, 0, 1, 123.45f, "valid name 11", "valid description",
1492                              armnn::EmptyOptional(),    // Units
1493                              0,                         // Number of cores
1494                              armnn::EmptyOptional(),    // Device UID
1495                              counterSet->m_Uid));       // Counter set UID
1496     BOOST_CHECK(counterDirectory.GetCounterCount() == 27);
1497     BOOST_CHECK(counterWCounterSet);
1498     BOOST_CHECK(counterWCounterSet->m_Uid >= 0);
1499     BOOST_CHECK(counterWCounterSet->m_Uid > counter->m_Uid);
1500     BOOST_CHECK(counterWCounterSet->m_MaxCounterUid == counterWCounterSet->m_Uid);
1501     BOOST_CHECK(counterWCounterSet->m_Class == 0);
1502     BOOST_CHECK(counterWCounterSet->m_Interpolation == 1);
1503     BOOST_CHECK(counterWCounterSet->m_Multiplier == 123.45f);
1504     BOOST_CHECK(counterWCounterSet->m_Name == "valid name 11");
1505     BOOST_CHECK(counterWCounterSet->m_Description == "valid description");
1506     BOOST_CHECK(counterWCounterSet->m_Units == "");
1507     BOOST_CHECK(counterWCounterSet->m_DeviceUid == 0);
1508     BOOST_CHECK(counterWCounterSet->m_CounterSetUid == counterSet->m_Uid);
1509     BOOST_CHECK(category->m_Counters.size() == 27);
1510     BOOST_CHECK(category->m_Counters.back() == counterWCounterSet->m_Uid);
1511
1512     // Register a counter with a valid parent category name and associated to a device and a counter set
1513     const Counter* counterWDeviceWCounterSet = nullptr;
1514     BOOST_CHECK_NO_THROW(counterWDeviceWCounterSet = counterDirectory.RegisterCounter(
1515                              categoryName, 0, 1, 123.45f, "valid name 12", "valid description",
1516                              armnn::EmptyOptional(),    // Units
1517                              1,                         // Number of cores
1518                              device->m_Uid,             // Device UID
1519                              counterSet->m_Uid));       // Counter set UID
1520     BOOST_CHECK(counterDirectory.GetCounterCount() == 28);
1521     BOOST_CHECK(counterWDeviceWCounterSet);
1522     BOOST_CHECK(counterWDeviceWCounterSet->m_Uid >= 0);
1523     BOOST_CHECK(counterWDeviceWCounterSet->m_Uid > counter->m_Uid);
1524     BOOST_CHECK(counterWDeviceWCounterSet->m_MaxCounterUid == counterWDeviceWCounterSet->m_Uid);
1525     BOOST_CHECK(counterWDeviceWCounterSet->m_Class == 0);
1526     BOOST_CHECK(counterWDeviceWCounterSet->m_Interpolation == 1);
1527     BOOST_CHECK(counterWDeviceWCounterSet->m_Multiplier == 123.45f);
1528     BOOST_CHECK(counterWDeviceWCounterSet->m_Name == "valid name 12");
1529     BOOST_CHECK(counterWDeviceWCounterSet->m_Description == "valid description");
1530     BOOST_CHECK(counterWDeviceWCounterSet->m_Units == "");
1531     BOOST_CHECK(counterWDeviceWCounterSet->m_DeviceUid == device->m_Uid);
1532     BOOST_CHECK(counterWDeviceWCounterSet->m_CounterSetUid == counterSet->m_Uid);
1533     BOOST_CHECK(category->m_Counters.size() == 28);
1534     BOOST_CHECK(category->m_Counters.back() == counterWDeviceWCounterSet->m_Uid);
1535
1536     // Register another category for testing
1537     const std::string anotherCategoryName = "some_other_category";
1538     const Category* anotherCategory       = nullptr;
1539     BOOST_CHECK_NO_THROW(anotherCategory = counterDirectory.RegisterCategory(anotherCategoryName));
1540     BOOST_CHECK(counterDirectory.GetCategoryCount() == 2);
1541     BOOST_CHECK(anotherCategory);
1542     BOOST_CHECK(anotherCategory != category);
1543     BOOST_CHECK(anotherCategory->m_Name == anotherCategoryName);
1544     BOOST_CHECK(anotherCategory->m_Counters.empty());
1545     BOOST_CHECK(anotherCategory->m_DeviceUid == 0);
1546     BOOST_CHECK(anotherCategory->m_CounterSetUid == 0);
1547
1548     // Register a counter to the other category
1549     const Counter* anotherCounter = nullptr;
1550     BOOST_CHECK_NO_THROW(anotherCounter = counterDirectory.RegisterCounter(anotherCategoryName, 1, 0, .00043f,
1551                                                                            "valid name", "valid description",
1552                                                                            armnn::EmptyOptional(),    // Units
1553                                                                            armnn::EmptyOptional(),    // Number of cores
1554                                                                            device->m_Uid,             // Device UID
1555                                                                            counterSet->m_Uid));       // Counter set UID
1556     BOOST_CHECK(counterDirectory.GetCounterCount() == 29);
1557     BOOST_CHECK(anotherCounter);
1558     BOOST_CHECK(anotherCounter->m_Uid >= 0);
1559     BOOST_CHECK(anotherCounter->m_MaxCounterUid == anotherCounter->m_Uid);
1560     BOOST_CHECK(anotherCounter->m_Class == 1);
1561     BOOST_CHECK(anotherCounter->m_Interpolation == 0);
1562     BOOST_CHECK(anotherCounter->m_Multiplier == .00043f);
1563     BOOST_CHECK(anotherCounter->m_Name == "valid name");
1564     BOOST_CHECK(anotherCounter->m_Description == "valid description");
1565     BOOST_CHECK(anotherCounter->m_Units == "");
1566     BOOST_CHECK(anotherCounter->m_DeviceUid == device->m_Uid);
1567     BOOST_CHECK(anotherCounter->m_CounterSetUid == counterSet->m_Uid);
1568     BOOST_CHECK(anotherCategory->m_Counters.size() == 1);
1569     BOOST_CHECK(anotherCategory->m_Counters.back() == anotherCounter->m_Uid);
1570 }
1571
1572 BOOST_AUTO_TEST_CASE(CounterSelectionCommandHandlerParseData)
1573 {
1574     using boost::numeric_cast;
1575
1576     ProfilingStateMachine profilingStateMachine;
1577
1578     class TestCaptureThread : public IPeriodicCounterCapture
1579     {
1580         void Start() override
1581         {}
1582         void Stop() override
1583         {}
1584     };
1585
1586     class TestReadCounterValues : public IReadCounterValues
1587     {
1588         bool IsCounterRegistered(uint16_t counterUid) const override
1589         {
1590             return true;
1591         }
1592         uint16_t GetCounterCount() const override
1593         {
1594             return 0;
1595         }
1596         uint32_t GetCounterValue(uint16_t counterUid) const override
1597         {
1598             return 0;
1599         }
1600     };
1601     const uint32_t familyId = 0;
1602     const uint32_t packetId = 0x40000;
1603
1604     uint32_t version = 1;
1605     Holder holder;
1606     TestCaptureThread captureThread;
1607     TestReadCounterValues readCounterValues;
1608     MockBufferManager mockBuffer(512);
1609     SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer);
1610
1611     uint32_t sizeOfUint32 = numeric_cast<uint32_t>(sizeof(uint32_t));
1612     uint32_t sizeOfUint16 = numeric_cast<uint32_t>(sizeof(uint16_t));
1613
1614     // Data with period and counters
1615     uint32_t period1     = 10;
1616     uint32_t dataLength1 = 8;
1617     uint32_t offset      = 0;
1618
1619     std::unique_ptr<unsigned char[]> uniqueData1 = std::make_unique<unsigned char[]>(dataLength1);
1620     unsigned char* data1                         = reinterpret_cast<unsigned char*>(uniqueData1.get());
1621
1622     WriteUint32(data1, offset, period1);
1623     offset += sizeOfUint32;
1624     WriteUint16(data1, offset, 4000);
1625     offset += sizeOfUint16;
1626     WriteUint16(data1, offset, 5000);
1627
1628     Packet packetA(packetId, dataLength1, uniqueData1);
1629
1630     PeriodicCounterSelectionCommandHandler commandHandler(familyId, packetId, version, holder, captureThread,
1631                                                           readCounterValues, sendCounterPacket, profilingStateMachine);
1632
1633     profilingStateMachine.TransitionToState(ProfilingState::Uninitialised);
1634     BOOST_CHECK_THROW(commandHandler(packetA), armnn::RuntimeException);
1635     profilingStateMachine.TransitionToState(ProfilingState::NotConnected);
1636     BOOST_CHECK_THROW(commandHandler(packetA), armnn::RuntimeException);
1637     profilingStateMachine.TransitionToState(ProfilingState::WaitingForAck);
1638     BOOST_CHECK_THROW(commandHandler(packetA), armnn::RuntimeException);
1639     profilingStateMachine.TransitionToState(ProfilingState::Active);
1640     BOOST_CHECK_NO_THROW(commandHandler(packetA));
1641
1642     const std::vector<uint16_t> counterIdsA = holder.GetCaptureData().GetCounterIds();
1643
1644     BOOST_TEST(holder.GetCaptureData().GetCapturePeriod() == period1);
1645     BOOST_TEST(counterIdsA.size() == 2);
1646     BOOST_TEST(counterIdsA[0] == 4000);
1647     BOOST_TEST(counterIdsA[1] == 5000);
1648
1649     auto readBuffer = mockBuffer.GetReadableBuffer();
1650
1651     offset = 0;
1652
1653     uint32_t headerWord0 = ReadUint32(readBuffer, offset);
1654     offset += sizeOfUint32;
1655     uint32_t headerWord1 = ReadUint32(readBuffer, offset);
1656     offset += sizeOfUint32;
1657     uint32_t period = ReadUint32(readBuffer, offset);
1658
1659     BOOST_TEST(((headerWord0 >> 26) & 0x3F) == 0);     // packet family
1660     BOOST_TEST(((headerWord0 >> 16) & 0x3FF) == 4);    // packet id
1661     BOOST_TEST(headerWord1 == 8);                      // data lenght
1662     BOOST_TEST(period == 10);                          // capture period
1663
1664     uint16_t counterId = 0;
1665     offset += sizeOfUint32;
1666     counterId = ReadUint16(readBuffer, offset);
1667     BOOST_TEST(counterId == 4000);
1668     offset += sizeOfUint16;
1669     counterId = ReadUint16(readBuffer, offset);
1670     BOOST_TEST(counterId == 5000);
1671
1672     mockBuffer.MarkRead(readBuffer);
1673
1674     // Data with period only
1675     uint32_t period2     = 11;
1676     uint32_t dataLength2 = 4;
1677
1678     std::unique_ptr<unsigned char[]> uniqueData2 = std::make_unique<unsigned char[]>(dataLength2);
1679
1680     WriteUint32(reinterpret_cast<unsigned char*>(uniqueData2.get()), 0, period2);
1681
1682     Packet packetB(packetId, dataLength2, uniqueData2);
1683
1684     commandHandler(packetB);
1685
1686     const std::vector<uint16_t> counterIdsB = holder.GetCaptureData().GetCounterIds();
1687
1688     BOOST_TEST(holder.GetCaptureData().GetCapturePeriod() == period2);
1689     BOOST_TEST(counterIdsB.size() == 0);
1690
1691     readBuffer = mockBuffer.GetReadableBuffer();
1692
1693     offset = 0;
1694
1695     headerWord0 = ReadUint32(readBuffer, offset);
1696     offset += sizeOfUint32;
1697     headerWord1 = ReadUint32(readBuffer, offset);
1698     offset += sizeOfUint32;
1699     period = ReadUint32(readBuffer, offset);
1700
1701     BOOST_TEST(((headerWord0 >> 26) & 0x3F) == 0);     // packet family
1702     BOOST_TEST(((headerWord0 >> 16) & 0x3FF) == 4);    // packet id
1703     BOOST_TEST(headerWord1 == 4);                      // data length
1704     BOOST_TEST(period == 11);                          // capture period
1705 }
1706
1707 BOOST_AUTO_TEST_CASE(CheckConnectionAcknowledged)
1708 {
1709     using boost::numeric_cast;
1710
1711     const uint32_t packetFamilyId     = 0;
1712     const uint32_t connectionPacketId = 0x10000;
1713     const uint32_t version            = 1;
1714
1715     uint32_t sizeOfUint32 = numeric_cast<uint32_t>(sizeof(uint32_t));
1716     uint32_t sizeOfUint16 = numeric_cast<uint32_t>(sizeof(uint16_t));
1717
1718     // Data with period and counters
1719     uint32_t period1     = 10;
1720     uint32_t dataLength1 = 8;
1721     uint32_t offset      = 0;
1722
1723     std::unique_ptr<unsigned char[]> uniqueData1 = std::make_unique<unsigned char[]>(dataLength1);
1724     unsigned char* data1                         = reinterpret_cast<unsigned char*>(uniqueData1.get());
1725
1726     WriteUint32(data1, offset, period1);
1727     offset += sizeOfUint32;
1728     WriteUint16(data1, offset, 4000);
1729     offset += sizeOfUint16;
1730     WriteUint16(data1, offset, 5000);
1731
1732     Packet packetA(connectionPacketId, dataLength1, uniqueData1);
1733
1734     ProfilingStateMachine profilingState(ProfilingState::Uninitialised);
1735     BOOST_CHECK(profilingState.GetCurrentState() == ProfilingState::Uninitialised);
1736     CounterDirectory counterDirectory;
1737     MockBufferManager mockBuffer(1024);
1738     SendCounterPacket sendCounterPacket(profilingState, mockBuffer);
1739     SendTimelinePacket sendTimelinePacket(mockBuffer);
1740
1741     ConnectionAcknowledgedCommandHandler commandHandler(packetFamilyId, connectionPacketId, version, counterDirectory,
1742                                                         sendCounterPacket, sendTimelinePacket, profilingState);
1743
1744     // command handler received packet on ProfilingState::Uninitialised
1745     BOOST_CHECK_THROW(commandHandler(packetA), armnn::Exception);
1746
1747     profilingState.TransitionToState(ProfilingState::NotConnected);
1748     BOOST_CHECK(profilingState.GetCurrentState() == ProfilingState::NotConnected);
1749     // command handler received packet on ProfilingState::NotConnected
1750     BOOST_CHECK_THROW(commandHandler(packetA), armnn::Exception);
1751
1752     profilingState.TransitionToState(ProfilingState::WaitingForAck);
1753     BOOST_CHECK(profilingState.GetCurrentState() == ProfilingState::WaitingForAck);
1754     // command handler received packet on ProfilingState::WaitingForAck
1755     BOOST_CHECK_NO_THROW(commandHandler(packetA));
1756     BOOST_CHECK(profilingState.GetCurrentState() == ProfilingState::Active);
1757
1758     // command handler received packet on ProfilingState::Active
1759     BOOST_CHECK_NO_THROW(commandHandler(packetA));
1760     BOOST_CHECK(profilingState.GetCurrentState() == ProfilingState::Active);
1761
1762     // command handler received different packet
1763     const uint32_t differentPacketId = 0x40000;
1764     Packet packetB(differentPacketId, dataLength1, uniqueData1);
1765     profilingState.TransitionToState(ProfilingState::NotConnected);
1766     profilingState.TransitionToState(ProfilingState::WaitingForAck);
1767     ConnectionAcknowledgedCommandHandler differentCommandHandler(packetFamilyId, differentPacketId, version,
1768                                                                  counterDirectory, sendCounterPacket,
1769                                                                  sendTimelinePacket, profilingState);
1770     BOOST_CHECK_THROW(differentCommandHandler(packetB), armnn::Exception);
1771 }
1772
1773 BOOST_AUTO_TEST_CASE(CheckSocketProfilingConnection)
1774 {
1775     // Check that creating a SocketProfilingConnection results in an exception as the Gator UDS doesn't exist.
1776     BOOST_CHECK_THROW(new SocketProfilingConnection(), armnn::Exception);
1777 }
1778
1779 BOOST_AUTO_TEST_CASE(SwTraceIsValidCharTest)
1780 {
1781     // Only ASCII 7-bit encoding supported
1782     for (unsigned char c = 0; c < 128; c++)
1783     {
1784         BOOST_CHECK(SwTraceCharPolicy::IsValidChar(c));
1785     }
1786
1787     // Not ASCII
1788     for (unsigned char c = 255; c >= 128; c++)
1789     {
1790         BOOST_CHECK(!SwTraceCharPolicy::IsValidChar(c));
1791     }
1792 }
1793
1794 BOOST_AUTO_TEST_CASE(SwTraceIsValidNameCharTest)
1795 {
1796     // Only alpha-numeric and underscore ASCII 7-bit encoding supported
1797     const unsigned char validChars[] = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_";
1798     for (unsigned char i = 0; i < sizeof(validChars) / sizeof(validChars[0]) - 1; i++)
1799     {
1800         BOOST_CHECK(SwTraceNameCharPolicy::IsValidChar(validChars[i]));
1801     }
1802
1803     // Non alpha-numeric chars
1804     for (unsigned char c = 0; c < 48; c++)
1805     {
1806         BOOST_CHECK(!SwTraceNameCharPolicy::IsValidChar(c));
1807     }
1808     for (unsigned char c = 58; c < 65; c++)
1809     {
1810         BOOST_CHECK(!SwTraceNameCharPolicy::IsValidChar(c));
1811     }
1812     for (unsigned char c = 91; c < 95; c++)
1813     {
1814         BOOST_CHECK(!SwTraceNameCharPolicy::IsValidChar(c));
1815     }
1816     for (unsigned char c = 96; c < 97; c++)
1817     {
1818         BOOST_CHECK(!SwTraceNameCharPolicy::IsValidChar(c));
1819     }
1820     for (unsigned char c = 123; c < 128; c++)
1821     {
1822         BOOST_CHECK(!SwTraceNameCharPolicy::IsValidChar(c));
1823     }
1824
1825     // Not ASCII
1826     for (unsigned char c = 255; c >= 128; c++)
1827     {
1828         BOOST_CHECK(!SwTraceNameCharPolicy::IsValidChar(c));
1829     }
1830 }
1831
1832 BOOST_AUTO_TEST_CASE(IsValidSwTraceStringTest)
1833 {
1834     // Valid SWTrace strings
1835     BOOST_CHECK(IsValidSwTraceString<SwTraceCharPolicy>(""));
1836     BOOST_CHECK(IsValidSwTraceString<SwTraceCharPolicy>("_"));
1837     BOOST_CHECK(IsValidSwTraceString<SwTraceCharPolicy>("0123"));
1838     BOOST_CHECK(IsValidSwTraceString<SwTraceCharPolicy>("valid_string"));
1839     BOOST_CHECK(IsValidSwTraceString<SwTraceCharPolicy>("VALID_string_456"));
1840     BOOST_CHECK(IsValidSwTraceString<SwTraceCharPolicy>(" "));
1841     BOOST_CHECK(IsValidSwTraceString<SwTraceCharPolicy>("valid string"));
1842     BOOST_CHECK(IsValidSwTraceString<SwTraceCharPolicy>("!$%"));
1843     BOOST_CHECK(IsValidSwTraceString<SwTraceCharPolicy>("valid|\\~string#123"));
1844
1845     // Invalid SWTrace strings
1846     BOOST_CHECK(!IsValidSwTraceString<SwTraceCharPolicy>("€£"));
1847     BOOST_CHECK(!IsValidSwTraceString<SwTraceCharPolicy>("invalid‡string"));
1848     BOOST_CHECK(!IsValidSwTraceString<SwTraceCharPolicy>("12Ž34"));
1849 }
1850
1851 BOOST_AUTO_TEST_CASE(IsValidSwTraceNameStringTest)
1852 {
1853     // Valid SWTrace name strings
1854     BOOST_CHECK(IsValidSwTraceString<SwTraceNameCharPolicy>(""));
1855     BOOST_CHECK(IsValidSwTraceString<SwTraceNameCharPolicy>("_"));
1856     BOOST_CHECK(IsValidSwTraceString<SwTraceNameCharPolicy>("0123"));
1857     BOOST_CHECK(IsValidSwTraceString<SwTraceNameCharPolicy>("valid_string"));
1858     BOOST_CHECK(IsValidSwTraceString<SwTraceNameCharPolicy>("VALID_string_456"));
1859
1860     // Invalid SWTrace name strings
1861     BOOST_CHECK(!IsValidSwTraceString<SwTraceNameCharPolicy>(" "));
1862     BOOST_CHECK(!IsValidSwTraceString<SwTraceNameCharPolicy>("invalid string"));
1863     BOOST_CHECK(!IsValidSwTraceString<SwTraceNameCharPolicy>("!$%"));
1864     BOOST_CHECK(!IsValidSwTraceString<SwTraceNameCharPolicy>("invalid|\\~string#123"));
1865     BOOST_CHECK(!IsValidSwTraceString<SwTraceNameCharPolicy>("€£"));
1866     BOOST_CHECK(!IsValidSwTraceString<SwTraceNameCharPolicy>("invalid‡string"));
1867     BOOST_CHECK(!IsValidSwTraceString<SwTraceNameCharPolicy>("12Ž34"));
1868 }
1869
1870 template <typename SwTracePolicy>
1871 void StringToSwTraceStringTestHelper(const std::string& testString, std::vector<uint32_t> buffer, size_t expectedSize)
1872 {
1873     // Convert the test string to a SWTrace string
1874     BOOST_CHECK(StringToSwTraceString<SwTracePolicy>(testString, buffer));
1875
1876     // The buffer must contain at least the length of the string
1877     BOOST_CHECK(!buffer.empty());
1878
1879     // The buffer must be of the expected size (in words)
1880     BOOST_CHECK(buffer.size() == expectedSize);
1881
1882     // The first word of the byte must be the length of the string including the null-terminator
1883     BOOST_CHECK(buffer[0] == testString.size() + 1);
1884
1885     // The contents of the buffer must match the test string
1886     BOOST_CHECK(std::memcmp(testString.data(), buffer.data() + 1, testString.size()) == 0);
1887
1888     // The buffer must include the null-terminator at the end of the string
1889     size_t nullTerminatorIndex = sizeof(uint32_t) + testString.size();
1890     BOOST_CHECK(reinterpret_cast<unsigned char*>(buffer.data())[nullTerminatorIndex] == '\0');
1891 }
1892
1893 BOOST_AUTO_TEST_CASE(StringToSwTraceStringTest)
1894 {
1895     std::vector<uint32_t> buffer;
1896
1897     // Valid SWTrace strings (expected size in words)
1898     StringToSwTraceStringTestHelper<SwTraceCharPolicy>("", buffer, 2);
1899     StringToSwTraceStringTestHelper<SwTraceCharPolicy>("_", buffer, 2);
1900     StringToSwTraceStringTestHelper<SwTraceCharPolicy>("0123", buffer, 3);
1901     StringToSwTraceStringTestHelper<SwTraceCharPolicy>("valid_string", buffer, 5);
1902     StringToSwTraceStringTestHelper<SwTraceCharPolicy>("VALID_string_456", buffer, 6);
1903     StringToSwTraceStringTestHelper<SwTraceCharPolicy>(" ", buffer, 2);
1904     StringToSwTraceStringTestHelper<SwTraceCharPolicy>("valid string", buffer, 5);
1905     StringToSwTraceStringTestHelper<SwTraceCharPolicy>("!$%", buffer, 2);
1906     StringToSwTraceStringTestHelper<SwTraceCharPolicy>("valid|\\~string#123", buffer, 6);
1907
1908     // Invalid SWTrace strings
1909     BOOST_CHECK(!StringToSwTraceString<SwTraceCharPolicy>("€£", buffer));
1910     BOOST_CHECK(buffer.empty());
1911     BOOST_CHECK(!StringToSwTraceString<SwTraceCharPolicy>("invalid‡string", buffer));
1912     BOOST_CHECK(buffer.empty());
1913     BOOST_CHECK(!StringToSwTraceString<SwTraceCharPolicy>("12Ž34", buffer));
1914     BOOST_CHECK(buffer.empty());
1915 }
1916
1917 BOOST_AUTO_TEST_CASE(StringToSwTraceNameStringTest)
1918 {
1919     std::vector<uint32_t> buffer;
1920
1921     // Valid SWTrace namestrings (expected size in words)
1922     StringToSwTraceStringTestHelper<SwTraceNameCharPolicy>("", buffer, 2);
1923     StringToSwTraceStringTestHelper<SwTraceNameCharPolicy>("_", buffer, 2);
1924     StringToSwTraceStringTestHelper<SwTraceNameCharPolicy>("0123", buffer, 3);
1925     StringToSwTraceStringTestHelper<SwTraceNameCharPolicy>("valid_string", buffer, 5);
1926     StringToSwTraceStringTestHelper<SwTraceNameCharPolicy>("VALID_string_456", buffer, 6);
1927
1928     // Invalid SWTrace namestrings
1929     BOOST_CHECK(!StringToSwTraceString<SwTraceNameCharPolicy>(" ", buffer));
1930     BOOST_CHECK(buffer.empty());
1931     BOOST_CHECK(!StringToSwTraceString<SwTraceNameCharPolicy>("invalid string", buffer));
1932     BOOST_CHECK(buffer.empty());
1933     BOOST_CHECK(!StringToSwTraceString<SwTraceNameCharPolicy>("!$%", buffer));
1934     BOOST_CHECK(buffer.empty());
1935     BOOST_CHECK(!StringToSwTraceString<SwTraceNameCharPolicy>("invalid|\\~string#123", buffer));
1936     BOOST_CHECK(buffer.empty());
1937     BOOST_CHECK(!StringToSwTraceString<SwTraceNameCharPolicy>("€£", buffer));
1938     BOOST_CHECK(buffer.empty());
1939     BOOST_CHECK(!StringToSwTraceString<SwTraceNameCharPolicy>("invalid‡string", buffer));
1940     BOOST_CHECK(buffer.empty());
1941     BOOST_CHECK(!StringToSwTraceString<SwTraceNameCharPolicy>("12Ž34", buffer));
1942     BOOST_CHECK(buffer.empty());
1943 }
1944
1945 BOOST_AUTO_TEST_CASE(CheckPeriodicCounterCaptureThread)
1946 {
1947     class CaptureReader : public IReadCounterValues
1948     {
1949     public:
1950         CaptureReader(uint16_t counterSize)
1951         {
1952             for (uint16_t i = 0; i < counterSize; ++i)
1953             {
1954                 m_Data[i] = 0;
1955             }
1956             m_CounterSize = counterSize;
1957         }
1958         //not used
1959         bool IsCounterRegistered(uint16_t counterUid) const override
1960         {
1961             return false;
1962         }
1963
1964         uint16_t GetCounterCount() const override
1965         {
1966             return m_CounterSize;
1967         }
1968
1969         uint32_t GetCounterValue(uint16_t counterUid) const override
1970         {
1971             if (counterUid > m_CounterSize)
1972             {
1973                 BOOST_FAIL("Invalid counter Uid");
1974             }
1975             return m_Data.at(counterUid).load();
1976         }
1977
1978         void SetCounterValue(uint16_t counterUid, uint32_t value)
1979         {
1980             if (counterUid > m_CounterSize)
1981             {
1982                 BOOST_FAIL("Invalid counter Uid");
1983             }
1984             m_Data.at(counterUid).store(value);
1985         }
1986
1987     private:
1988         std::unordered_map<uint16_t, std::atomic<uint32_t>> m_Data;
1989         uint16_t m_CounterSize;
1990     };
1991
1992     ProfilingStateMachine profilingStateMachine;
1993
1994     Holder data;
1995     std::vector<uint16_t> captureIds1 = { 0, 1 };
1996     std::vector<uint16_t> captureIds2;
1997
1998     MockBufferManager mockBuffer(512);
1999     SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer);
2000
2001     std::vector<uint16_t> counterIds;
2002     CaptureReader captureReader(2);
2003
2004     unsigned int valueA   = 10;
2005     unsigned int valueB   = 15;
2006     unsigned int numSteps = 5;
2007
2008     PeriodicCounterCapture periodicCounterCapture(std::ref(data), std::ref(sendCounterPacket), captureReader);
2009
2010     for (unsigned int i = 0; i < numSteps; ++i)
2011     {
2012         data.SetCaptureData(1, captureIds1);
2013         captureReader.SetCounterValue(0, valueA * (i + 1));
2014         captureReader.SetCounterValue(1, valueB * (i + 1));
2015
2016         periodicCounterCapture.Start();
2017         periodicCounterCapture.Stop();
2018     }
2019
2020     auto buffer = mockBuffer.GetReadableBuffer();
2021
2022     uint32_t headerWord0 = ReadUint32(buffer, 0);
2023     uint32_t headerWord1 = ReadUint32(buffer, 4);
2024
2025     BOOST_TEST(((headerWord0 >> 26) & 0x0000003F) == 1);    // packet family
2026     BOOST_TEST(((headerWord0 >> 19) & 0x0000007F) == 0);    // packet class
2027     BOOST_TEST(((headerWord0 >> 16) & 0x00000007) == 0);    // packet type
2028     BOOST_TEST(headerWord1 == 20);
2029
2030     uint32_t offset    = 16;
2031     uint16_t readIndex = ReadUint16(buffer, offset);
2032     BOOST_TEST(0 == readIndex);
2033
2034     offset += 2;
2035     uint32_t readValue = ReadUint32(buffer, offset);
2036     BOOST_TEST((valueA * numSteps) == readValue);
2037
2038     offset += 4;
2039     readIndex = ReadUint16(buffer, offset);
2040     BOOST_TEST(1 == readIndex);
2041
2042     offset += 2;
2043     readValue = ReadUint32(buffer, offset);
2044     BOOST_TEST((valueB * numSteps) == readValue);
2045 }
2046
2047 BOOST_AUTO_TEST_CASE(RequestCounterDirectoryCommandHandlerTest1)
2048 {
2049     using boost::numeric_cast;
2050
2051     const uint32_t familyId = 0;
2052     const uint32_t packetId = 3;
2053     const uint32_t version  = 1;
2054     ProfilingStateMachine profilingStateMachine;
2055     CounterDirectory counterDirectory;
2056     MockBufferManager mockBuffer(1024);
2057     SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer);
2058     RequestCounterDirectoryCommandHandler commandHandler(familyId, packetId, version, counterDirectory,
2059                                                          sendCounterPacket, profilingStateMachine);
2060
2061     const uint32_t wrongPacketId = 47;
2062     const uint32_t wrongHeader   = (wrongPacketId & 0x000003FF) << 16;
2063
2064     Packet wrongPacket(wrongHeader);
2065
2066     profilingStateMachine.TransitionToState(ProfilingState::Uninitialised);
2067     BOOST_CHECK_THROW(commandHandler(wrongPacket), armnn::RuntimeException);    // Wrong profiling state
2068     profilingStateMachine.TransitionToState(ProfilingState::NotConnected);
2069     BOOST_CHECK_THROW(commandHandler(wrongPacket), armnn::RuntimeException);    // Wrong profiling state
2070     profilingStateMachine.TransitionToState(ProfilingState::WaitingForAck);
2071     BOOST_CHECK_THROW(commandHandler(wrongPacket), armnn::RuntimeException);    // Wrong profiling state
2072     profilingStateMachine.TransitionToState(ProfilingState::Active);
2073     BOOST_CHECK_THROW(commandHandler(wrongPacket), armnn::InvalidArgumentException);    // Wrong packet
2074
2075     const uint32_t rightHeader = (packetId & 0x000003FF) << 16;
2076
2077     Packet rightPacket(rightHeader);
2078
2079     BOOST_CHECK_NO_THROW(commandHandler(rightPacket));    // Right packet
2080
2081     auto readBuffer = mockBuffer.GetReadableBuffer();
2082
2083     uint32_t headerWord0 = ReadUint32(readBuffer, 0);
2084     uint32_t headerWord1 = ReadUint32(readBuffer, 4);
2085
2086     BOOST_TEST(((headerWord0 >> 26) & 0x0000003F) == 0);    // packet family
2087     BOOST_TEST(((headerWord0 >> 16) & 0x000003FF) == 2);    // packet id
2088     BOOST_TEST(headerWord1 == 24);                          // data length
2089
2090     uint32_t bodyHeaderWord0   = ReadUint32(readBuffer, 8);
2091     uint16_t deviceRecordCount = numeric_cast<uint16_t>(bodyHeaderWord0 >> 16);
2092     BOOST_TEST(deviceRecordCount == 0);    // device_records_count
2093 }
2094
2095 BOOST_AUTO_TEST_CASE(RequestCounterDirectoryCommandHandlerTest2)
2096 {
2097     using boost::numeric_cast;
2098
2099     const uint32_t familyId = 0;
2100     const uint32_t packetId = 3;
2101     const uint32_t version  = 1;
2102     ProfilingStateMachine profilingStateMachine;
2103     CounterDirectory counterDirectory;
2104     MockBufferManager mockBuffer(1024);
2105     SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer);
2106     RequestCounterDirectoryCommandHandler commandHandler(familyId, packetId, version, counterDirectory,
2107                                                          sendCounterPacket, profilingStateMachine);
2108     const uint32_t header = (packetId & 0x000003FF) << 16;
2109     Packet packet(header);
2110
2111     const Device* device         = counterDirectory.RegisterDevice("deviceA", 1);
2112     const CounterSet* counterSet = counterDirectory.RegisterCounterSet("countersetA");
2113     counterDirectory.RegisterCategory("categoryA", device->m_Uid, counterSet->m_Uid);
2114     counterDirectory.RegisterCounter("categoryA", 0, 1, 2.0f, "counterA", "descA");
2115     counterDirectory.RegisterCounter("categoryA", 1, 1, 3.0f, "counterB", "descB");
2116
2117     profilingStateMachine.TransitionToState(ProfilingState::Uninitialised);
2118     BOOST_CHECK_THROW(commandHandler(packet), armnn::RuntimeException);    // Wrong profiling state
2119     profilingStateMachine.TransitionToState(ProfilingState::NotConnected);
2120     BOOST_CHECK_THROW(commandHandler(packet), armnn::RuntimeException);    // Wrong profiling state
2121     profilingStateMachine.TransitionToState(ProfilingState::WaitingForAck);
2122     BOOST_CHECK_THROW(commandHandler(packet), armnn::RuntimeException);    // Wrong profiling state
2123     profilingStateMachine.TransitionToState(ProfilingState::Active);
2124     BOOST_CHECK_NO_THROW(commandHandler(packet));
2125
2126     auto readBuffer = mockBuffer.GetReadableBuffer();
2127
2128     uint32_t headerWord0 = ReadUint32(readBuffer, 0);
2129     uint32_t headerWord1 = ReadUint32(readBuffer, 4);
2130
2131     BOOST_TEST(((headerWord0 >> 26) & 0x0000003F) == 0);    // packet family
2132     BOOST_TEST(((headerWord0 >> 16) & 0x000003FF) == 2);    // packet id
2133     BOOST_TEST(headerWord1 == 240);                         // data length
2134
2135     uint32_t bodyHeaderWord0       = ReadUint32(readBuffer, 8);
2136     uint32_t bodyHeaderWord1       = ReadUint32(readBuffer, 12);
2137     uint32_t bodyHeaderWord2       = ReadUint32(readBuffer, 16);
2138     uint32_t bodyHeaderWord3       = ReadUint32(readBuffer, 20);
2139     uint32_t bodyHeaderWord4       = ReadUint32(readBuffer, 24);
2140     uint32_t bodyHeaderWord5       = ReadUint32(readBuffer, 28);
2141     uint16_t deviceRecordCount     = numeric_cast<uint16_t>(bodyHeaderWord0 >> 16);
2142     uint16_t counterSetRecordCount = numeric_cast<uint16_t>(bodyHeaderWord2 >> 16);
2143     uint16_t categoryRecordCount   = numeric_cast<uint16_t>(bodyHeaderWord4 >> 16);
2144     BOOST_TEST(deviceRecordCount == 1);        // device_records_count
2145     BOOST_TEST(bodyHeaderWord1 == 0);          // device_records_pointer_table_offset
2146     BOOST_TEST(counterSetRecordCount == 1);    // counter_set_count
2147     BOOST_TEST(bodyHeaderWord3 == 4);          // counter_set_pointer_table_offset
2148     BOOST_TEST(categoryRecordCount == 1);      // categories_count
2149     BOOST_TEST(bodyHeaderWord5 == 8);          // categories_pointer_table_offset
2150
2151     uint32_t deviceRecordOffset = ReadUint32(readBuffer, 32);
2152     BOOST_TEST(deviceRecordOffset == 0);
2153
2154     uint32_t counterSetRecordOffset = ReadUint32(readBuffer, 36);
2155     BOOST_TEST(counterSetRecordOffset == 20);
2156
2157     uint32_t categoryRecordOffset = ReadUint32(readBuffer, 40);
2158     BOOST_TEST(categoryRecordOffset == 44);
2159 }
2160
2161 BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadConnectionAcknowledgedPacket)
2162 {
2163     // Locally reduce log level to "Warning", as this test needs to parse a warning message from the standard output
2164     LogLevelSwapper logLevelSwapper(armnn::LogSeverity::Warning);
2165
2166     // Swap the profiling connection factory in the profiling service instance with our mock one
2167     SwapProfilingConnectionFactoryHelper helper;
2168
2169     // Redirect the standard output to a local stream so that we can parse the warning message
2170     std::stringstream ss;
2171     StreamRedirector streamRedirector(std::cout, ss.rdbuf());
2172
2173     // Calculate the size of a Stream Metadata packet
2174     std::string processName      = GetProcessName().substr(0, 60);
2175     unsigned int processNameSize = processName.empty() ? 0 : boost::numeric_cast<unsigned int>(processName.size()) + 1;
2176     unsigned int streamMetadataPacketsize = 118 + processNameSize;
2177
2178     // Reset the profiling service to the uninitialized state
2179     armnn::Runtime::CreationOptions::ExternalProfilingOptions options;
2180     options.m_EnableProfiling          = true;
2181     ProfilingService& profilingService = ProfilingService::Instance();
2182     profilingService.ResetExternalProfilingOptions(options, true);
2183
2184     // Bring the profiling service to the "WaitingForAck" state
2185     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
2186     profilingService.Update();    // Initialize the counter directory
2187     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
2188     profilingService.Update();    // Create the profiling connection
2189
2190     // Get the mock profiling connection
2191     MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
2192     BOOST_CHECK(mockProfilingConnection);
2193
2194     // Remove the packets received so far
2195     mockProfilingConnection->Clear();
2196
2197     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
2198     profilingService.Update();
2199
2200     // Wait for the Stream Metadata packet to be sent
2201     helper.WaitForProfilingPacketsSent();
2202
2203     // Check that the mock profiling connection contains one Stream Metadata packet
2204     const std::vector<uint32_t> writtenData = mockProfilingConnection->GetWrittenData();
2205     BOOST_TEST(writtenData.size() == 1);
2206     BOOST_TEST(writtenData[0] == streamMetadataPacketsize);
2207
2208     // Write a valid "Connection Acknowledged" packet into the mock profiling connection, to simulate a valid
2209     // reply from an external profiling service
2210
2211     // Connection Acknowledged Packet header (word 0, word 1 is always zero):
2212     // 26:31 [6]  packet_family: Control Packet Family, value 0b000000
2213     // 16:25 [10] packet_id: Packet identifier, value 0b0000000001
2214     // 8:15  [8]  reserved: Reserved, value 0b00000000
2215     // 0:7   [8]  reserved: Reserved, value 0b00000000
2216     uint32_t packetFamily = 0;
2217     uint32_t packetId     = 37;    // Wrong packet id!!!
2218     uint32_t header       = ((packetFamily & 0x0000003F) << 26) | ((packetId & 0x000003FF) << 16);
2219
2220     // Create the Connection Acknowledged Packet
2221     Packet connectionAcknowledgedPacket(header);
2222
2223     // Write the packet to the mock profiling connection
2224     mockProfilingConnection->WritePacket(std::move(connectionAcknowledgedPacket));
2225
2226     // Wait for a bit (must at least be the delay value of the mock profiling connection) to make sure that
2227     // the Connection Acknowledged packet gets processed by the profiling service
2228     std::this_thread::sleep_for(std::chrono::seconds(2));
2229
2230     // Check that the expected error has occurred and logged to the standard output
2231     BOOST_CHECK(boost::contains(ss.str(), "Functor with requested PacketId=37 and Version=4194304 does not exist"));
2232
2233     // The Connection Acknowledged Command Handler should not have updated the profiling state
2234     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
2235
2236     // Reset the profiling service to stop any running thread
2237     options.m_EnableProfiling = false;
2238     profilingService.ResetExternalProfilingOptions(options, true);
2239 }
2240
2241 BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodConnectionAcknowledgedPacket)
2242 {
2243     // Swap the profiling connection factory in the profiling service instance with our mock one
2244     SwapProfilingConnectionFactoryHelper helper;
2245
2246     // Calculate the size of a Stream Metadata packet
2247     std::string processName      = GetProcessName().substr(0, 60);
2248     unsigned int processNameSize = processName.empty() ? 0 : boost::numeric_cast<unsigned int>(processName.size()) + 1;
2249     unsigned int streamMetadataPacketsize = 118 + processNameSize;
2250
2251     // Reset the profiling service to the uninitialized state
2252     armnn::Runtime::CreationOptions::ExternalProfilingOptions options;
2253     options.m_EnableProfiling          = true;
2254     ProfilingService& profilingService = ProfilingService::Instance();
2255     profilingService.ResetExternalProfilingOptions(options, true);
2256
2257     // Bring the profiling service to the "WaitingForAck" state
2258     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
2259     profilingService.Update();    // Initialize the counter directory
2260     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
2261     profilingService.Update();    // Create the profiling connection
2262
2263     // Get the mock profiling connection
2264     MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
2265     BOOST_CHECK(mockProfilingConnection);
2266
2267     // Remove the packets received so far
2268     mockProfilingConnection->Clear();
2269
2270     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
2271     profilingService.Update();    // Start the command handler and the send thread
2272
2273     // Wait for the Stream Metadata packet to be sent
2274     helper.WaitForProfilingPacketsSent();
2275
2276     // Check that the mock profiling connection contains one Stream Metadata packet
2277     const std::vector<uint32_t> writtenData = mockProfilingConnection->GetWrittenData();
2278     BOOST_TEST(writtenData.size() == 1);
2279     BOOST_TEST(writtenData[0] == streamMetadataPacketsize);
2280
2281     // Write a valid "Connection Acknowledged" packet into the mock profiling connection, to simulate a valid
2282     // reply from an external profiling service
2283
2284     // Connection Acknowledged Packet header (word 0, word 1 is always zero):
2285     // 26:31 [6]  packet_family: Control Packet Family, value 0b000000
2286     // 16:25 [10] packet_id: Packet identifier, value 0b0000000001
2287     // 8:15  [8]  reserved: Reserved, value 0b00000000
2288     // 0:7   [8]  reserved: Reserved, value 0b00000000
2289     uint32_t packetFamily = 0;
2290     uint32_t packetId     = 1;
2291     uint32_t header       = ((packetFamily & 0x0000003F) << 26) | ((packetId & 0x000003FF) << 16);
2292
2293     // Create the Connection Acknowledged Packet
2294     Packet connectionAcknowledgedPacket(header);
2295
2296     // Write the packet to the mock profiling connection
2297     mockProfilingConnection->WritePacket(std::move(connectionAcknowledgedPacket));
2298
2299     // Wait for a bit (must at least be the delay value of the mock profiling connection) to make sure that
2300     // the Connection Acknowledged packet gets processed by the profiling service
2301     std::this_thread::sleep_for(std::chrono::seconds(2));
2302
2303     // The Connection Acknowledged Command Handler should have updated the profiling state accordingly
2304     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
2305
2306     // Reset the profiling service to stop any running thread
2307     options.m_EnableProfiling = false;
2308     profilingService.ResetExternalProfilingOptions(options, true);
2309 }
2310
2311 BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadRequestCounterDirectoryPacket)
2312 {
2313     // Locally reduce log level to "Warning", as this test needs to parse a warning message from the standard output
2314     LogLevelSwapper logLevelSwapper(armnn::LogSeverity::Warning);
2315
2316     // Swap the profiling connection factory in the profiling service instance with our mock one
2317     SwapProfilingConnectionFactoryHelper helper;
2318
2319     // Redirect the standard output to a local stream so that we can parse the warning message
2320     std::stringstream ss;
2321     StreamRedirector streamRedirector(std::cout, ss.rdbuf());
2322
2323     // Reset the profiling service to the uninitialized state
2324     armnn::Runtime::CreationOptions::ExternalProfilingOptions options;
2325     options.m_EnableProfiling          = true;
2326     ProfilingService& profilingService = ProfilingService::Instance();
2327     profilingService.ResetExternalProfilingOptions(options, true);
2328
2329     // Bring the profiling service to the "Active" state
2330     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
2331     helper.ForceTransitionToState(ProfilingState::NotConnected);
2332     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
2333     profilingService.Update();    // Create the profiling connection
2334     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
2335     profilingService.Update();    // Start the command handler and the send thread
2336
2337     // Wait for the Stream Metadata packet the be sent
2338     // (we are not testing the connection acknowledgement here so it will be ignored by this test)
2339     helper.WaitForProfilingPacketsSent();
2340
2341     // Force the profiling service to the "Active" state
2342     helper.ForceTransitionToState(ProfilingState::Active);
2343     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
2344
2345     // Get the mock profiling connection
2346     MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
2347     BOOST_CHECK(mockProfilingConnection);
2348
2349     // Remove the packets received so far
2350     mockProfilingConnection->Clear();
2351
2352     // Write a valid "Request Counter Directory" packet into the mock profiling connection, to simulate a valid
2353     // reply from an external profiling service
2354
2355     // Request Counter Directory packet header (word 0, word 1 is always zero):
2356     // 26:31 [6]  packet_family: Control Packet Family, value 0b000000
2357     // 16:25 [10] packet_id: Packet identifier, value 0b0000000011
2358     // 8:15  [8]  reserved: Reserved, value 0b00000000
2359     // 0:7   [8]  reserved: Reserved, value 0b00000000
2360     uint32_t packetFamily = 0;
2361     uint32_t packetId     = 123;    // Wrong packet id!!!
2362     uint32_t header       = ((packetFamily & 0x0000003F) << 26) | ((packetId & 0x000003FF) << 16);
2363
2364     // Create the Request Counter Directory packet
2365     Packet requestCounterDirectoryPacket(header);
2366
2367     // Write the packet to the mock profiling connection
2368     mockProfilingConnection->WritePacket(std::move(requestCounterDirectoryPacket));
2369
2370     // Wait for a bit (must at least be the delay value of the mock profiling connection) to make sure that
2371     // the Create the Request Counter packet gets processed by the profiling service
2372     std::this_thread::sleep_for(std::chrono::seconds(2));
2373
2374     // Check that the expected error has occurred and logged to the standard output
2375     BOOST_CHECK(boost::contains(ss.str(), "Functor with requested PacketId=123 and Version=4194304 does not exist"));
2376
2377     // The Request Counter Directory Command Handler should not have updated the profiling state
2378     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
2379
2380     // Reset the profiling service to stop any running thread
2381     options.m_EnableProfiling = false;
2382     profilingService.ResetExternalProfilingOptions(options, true);
2383 }
2384
2385 BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodRequestCounterDirectoryPacket)
2386 {
2387     // Swap the profiling connection factory in the profiling service instance with our mock one
2388     SwapProfilingConnectionFactoryHelper helper;
2389
2390     // Reset the profiling service to the uninitialized state
2391     armnn::Runtime::CreationOptions::ExternalProfilingOptions options;
2392     options.m_EnableProfiling          = true;
2393     ProfilingService& profilingService = ProfilingService::Instance();
2394     profilingService.ResetExternalProfilingOptions(options, true);
2395
2396     // Bring the profiling service to the "Active" state
2397     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
2398     profilingService.Update();    // Initialize the counter directory
2399     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
2400     profilingService.Update();    // Create the profiling connection
2401     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
2402     profilingService.Update();    // Start the command handler and the send thread
2403
2404     // Wait for the Stream Metadata packet the be sent
2405     // (we are not testing the connection acknowledgement here so it will be ignored by this test)
2406     helper.WaitForProfilingPacketsSent();
2407
2408     // Force the profiling service to the "Active" state
2409     helper.ForceTransitionToState(ProfilingState::Active);
2410     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
2411
2412     // Get the mock profiling connection
2413     MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
2414     BOOST_CHECK(mockProfilingConnection);
2415
2416     // Remove the packets received so far
2417     mockProfilingConnection->Clear();
2418
2419     // Write a valid "Request Counter Directory" packet into the mock profiling connection, to simulate a valid
2420     // reply from an external profiling service
2421
2422     // Request Counter Directory packet header (word 0, word 1 is always zero):
2423     // 26:31 [6]  packet_family: Control Packet Family, value 0b000000
2424     // 16:25 [10] packet_id: Packet identifier, value 0b0000000011
2425     // 8:15  [8]  reserved: Reserved, value 0b00000000
2426     // 0:7   [8]  reserved: Reserved, value 0b00000000
2427     uint32_t packetFamily = 0;
2428     uint32_t packetId     = 3;
2429     uint32_t header       = ((packetFamily & 0x0000003F) << 26) | ((packetId & 0x000003FF) << 16);
2430
2431     // Create the Request Counter Directory packet
2432     Packet requestCounterDirectoryPacket(header);
2433
2434     // Write the packet to the mock profiling connection
2435     mockProfilingConnection->WritePacket(std::move(requestCounterDirectoryPacket));
2436
2437     // Wait for the Counter Directory packet to be sent
2438     helper.WaitForProfilingPacketsSent();
2439
2440     // Check that the mock profiling connection contains one Counter Directory packet
2441     const std::vector<uint32_t> writtenData = mockProfilingConnection->GetWrittenData();
2442     BOOST_TEST(writtenData.size() == 1);
2443     BOOST_TEST(writtenData[0] == 416);    // The size of the expected Counter Directory packet
2444
2445     // The Request Counter Directory Command Handler should not have updated the profiling state
2446     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
2447
2448     // Reset the profiling service to stop any running thread
2449     options.m_EnableProfiling = false;
2450     profilingService.ResetExternalProfilingOptions(options, true);
2451 }
2452
2453 BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadPeriodicCounterSelectionPacket)
2454 {
2455     // Locally reduce log level to "Warning", as this test needs to parse a warning message from the standard output
2456     LogLevelSwapper logLevelSwapper(armnn::LogSeverity::Warning);
2457
2458     // Swap the profiling connection factory in the profiling service instance with our mock one
2459     SwapProfilingConnectionFactoryHelper helper;
2460
2461     // Redirect the standard output to a local stream so that we can parse the warning message
2462     std::stringstream ss;
2463     StreamRedirector streamRedirector(std::cout, ss.rdbuf());
2464
2465     // Reset the profiling service to the uninitialized state
2466     armnn::Runtime::CreationOptions::ExternalProfilingOptions options;
2467     options.m_EnableProfiling          = true;
2468     ProfilingService& profilingService = ProfilingService::Instance();
2469     profilingService.ResetExternalProfilingOptions(options, true);
2470
2471     // Bring the profiling service to the "Active" state
2472     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
2473     profilingService.Update();    // Initialize the counter directory
2474     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
2475     profilingService.Update();    // Create the profiling connection
2476     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
2477     profilingService.Update();    // Start the command handler and the send thread
2478
2479     // Wait for the Stream Metadata packet the be sent
2480     // (we are not testing the connection acknowledgement here so it will be ignored by this test)
2481     helper.WaitForProfilingPacketsSent();
2482
2483     // Force the profiling service to the "Active" state
2484     helper.ForceTransitionToState(ProfilingState::Active);
2485     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
2486
2487     // Get the mock profiling connection
2488     MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
2489     BOOST_CHECK(mockProfilingConnection);
2490
2491     // Remove the packets received so far
2492     mockProfilingConnection->Clear();
2493
2494     // Write a "Periodic Counter Selection" packet into the mock profiling connection, to simulate an input from an
2495     // external profiling service
2496
2497     // Periodic Counter Selection packet header:
2498     // 26:31 [6]  packet_family: Control Packet Family, value 0b000000
2499     // 16:25 [10] packet_id: Packet identifier, value 0b0000000100
2500     // 8:15  [8]  reserved: Reserved, value 0b00000000
2501     // 0:7   [8]  reserved: Reserved, value 0b00000000
2502     uint32_t packetFamily = 0;
2503     uint32_t packetId     = 999;    // Wrong packet id!!!
2504     uint32_t header       = ((packetFamily & 0x0000003F) << 26) | ((packetId & 0x000003FF) << 16);
2505
2506     // Create the Periodic Counter Selection packet
2507     Packet periodicCounterSelectionPacket(header);    // Length == 0, this will disable the collection of counters
2508
2509     // Write the packet to the mock profiling connection
2510     mockProfilingConnection->WritePacket(std::move(periodicCounterSelectionPacket));
2511
2512     // Wait for a bit (must at least be the delay value of the mock profiling connection) to make sure that
2513     // the Periodic Counter Selection packet gets processed by the profiling service
2514     std::this_thread::sleep_for(std::chrono::seconds(2));
2515
2516     // Check that the expected error has occurred and logged to the standard output
2517     BOOST_CHECK(boost::contains(ss.str(), "Functor with requested PacketId=999 and Version=4194304 does not exist"));
2518
2519     // The Periodic Counter Selection Handler should not have updated the profiling state
2520     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
2521
2522     // Reset the profiling service to stop any running thread
2523     options.m_EnableProfiling = false;
2524     profilingService.ResetExternalProfilingOptions(options, true);
2525 }
2526
2527 BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadPeriodicCounterSelectionPacketInvalidCounterUid)
2528 {
2529     // Locally reduce log level to "Warning", as this test needs to parse a warning message from the standard output
2530     LogLevelSwapper logLevelSwapper(armnn::LogSeverity::Warning);
2531
2532     // Swap the profiling connection factory in the profiling service instance with our mock one
2533     SwapProfilingConnectionFactoryHelper helper;
2534
2535     // Reset the profiling service to the uninitialized state
2536     armnn::Runtime::CreationOptions::ExternalProfilingOptions options;
2537     options.m_EnableProfiling          = true;
2538     ProfilingService& profilingService = ProfilingService::Instance();
2539     profilingService.ResetExternalProfilingOptions(options, true);
2540
2541     // Bring the profiling service to the "Active" state
2542     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
2543     profilingService.Update();    // Initialize the counter directory
2544     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
2545     profilingService.Update();    // Create the profiling connection
2546     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
2547     profilingService.Update();    // Start the command handler and the send thread
2548
2549     // Wait for the Stream Metadata packet the be sent
2550     // (we are not testing the connection acknowledgement here so it will be ignored by this test)
2551     helper.WaitForProfilingPacketsSent();
2552
2553     // Force the profiling service to the "Active" state
2554     helper.ForceTransitionToState(ProfilingState::Active);
2555     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
2556
2557     // Get the mock profiling connection
2558     MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
2559     BOOST_CHECK(mockProfilingConnection);
2560
2561     // Remove the packets received so far
2562     mockProfilingConnection->Clear();
2563
2564     // Write a "Periodic Counter Selection" packet into the mock profiling connection, to simulate an input from an
2565     // external profiling service
2566
2567     // Periodic Counter Selection packet header:
2568     // 26:31 [6]  packet_family: Control Packet Family, value 0b000000
2569     // 16:25 [10] packet_id: Packet identifier, value 0b0000000100
2570     // 8:15  [8]  reserved: Reserved, value 0b00000000
2571     // 0:7   [8]  reserved: Reserved, value 0b00000000
2572     uint32_t packetFamily = 0;
2573     uint32_t packetId     = 4;
2574     uint32_t header       = ((packetFamily & 0x0000003F) << 26) | ((packetId & 0x000003FF) << 16);
2575
2576     uint32_t capturePeriod = 123456;    // Some capture period (microseconds)
2577
2578     // Get the first valid counter UID
2579     const ICounterDirectory& counterDirectory = profilingService.GetCounterDirectory();
2580     const Counters& counters                  = counterDirectory.GetCounters();
2581     BOOST_CHECK(counters.size() > 1);
2582     uint16_t counterUidA = counters.begin()->first;    // First valid counter UID
2583     uint16_t counterUidB = 9999;                       // Second invalid counter UID
2584
2585     uint32_t length = 8;
2586
2587     auto data = std::make_unique<unsigned char[]>(length);
2588     WriteUint32(data.get(), 0, capturePeriod);
2589     WriteUint16(data.get(), 4, counterUidA);
2590     WriteUint16(data.get(), 6, counterUidB);
2591
2592     // Create the Periodic Counter Selection packet
2593     Packet periodicCounterSelectionPacket(header, length, data);    // Length > 0, this will start the Period Counter
2594                                                                     // Capture thread
2595
2596     // Write the packet to the mock profiling connection
2597     mockProfilingConnection->WritePacket(std::move(periodicCounterSelectionPacket));
2598
2599     // Expecting one Periodic Counter Selection packet and at least one Periodic Counter Capture packet
2600     int expectedPackets = 2;
2601     std::vector<uint32_t> receivedPackets;
2602
2603     // Keep waiting until all the expected packets have been received
2604     do
2605     {
2606         helper.WaitForProfilingPacketsSent();
2607         const std::vector<uint32_t> writtenData = mockProfilingConnection->GetWrittenData();
2608         if (writtenData.empty())
2609         {
2610             BOOST_ERROR("Packets should be available for reading at this point");
2611             return;
2612         }
2613         receivedPackets.insert(receivedPackets.end(), writtenData.begin(), writtenData.end());
2614         expectedPackets -= boost::numeric_cast<int>(writtenData.size());
2615     } while (expectedPackets > 0);
2616     BOOST_TEST(!receivedPackets.empty());
2617
2618     // The size of the expected Periodic Counter Selection packet
2619     BOOST_TEST((std::find(receivedPackets.begin(), receivedPackets.end(), 14) != receivedPackets.end()));
2620     // The size of the expected Periodic Counter Capture packet
2621     BOOST_TEST((std::find(receivedPackets.begin(), receivedPackets.end(), 22) != receivedPackets.end()));
2622
2623     // The Periodic Counter Selection Handler should not have updated the profiling state
2624     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
2625
2626     // Reset the profiling service to stop any running thread
2627     options.m_EnableProfiling = false;
2628     profilingService.ResetExternalProfilingOptions(options, true);
2629 }
2630
2631 BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodPeriodicCounterSelectionPacketNoCounters)
2632 {
2633     // Swap the profiling connection factory in the profiling service instance with our mock one
2634     SwapProfilingConnectionFactoryHelper helper;
2635
2636     // Reset the profiling service to the uninitialized state
2637     armnn::Runtime::CreationOptions::ExternalProfilingOptions options;
2638     options.m_EnableProfiling          = true;
2639     ProfilingService& profilingService = ProfilingService::Instance();
2640     profilingService.ResetExternalProfilingOptions(options, true);
2641
2642     // Bring the profiling service to the "Active" state
2643     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
2644     profilingService.Update();    // Initialize the counter directory
2645     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
2646     profilingService.Update();    // Create the profiling connection
2647     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
2648     profilingService.Update();    // Start the command handler and the send thread
2649
2650     // Wait for the Stream Metadata packet the be sent
2651     // (we are not testing the connection acknowledgement here so it will be ignored by this test)
2652     helper.WaitForProfilingPacketsSent();
2653
2654     // Force the profiling service to the "Active" state
2655     helper.ForceTransitionToState(ProfilingState::Active);
2656     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
2657
2658     // Get the mock profiling connection
2659     MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
2660     BOOST_CHECK(mockProfilingConnection);
2661
2662     // Remove the packets received so far
2663     mockProfilingConnection->Clear();
2664
2665     // Write a "Periodic Counter Selection" packet into the mock profiling connection, to simulate an input from an
2666     // external profiling service
2667
2668     // Periodic Counter Selection packet header:
2669     // 26:31 [6]  packet_family: Control Packet Family, value 0b000000
2670     // 16:25 [10] packet_id: Packet identifier, value 0b0000000100
2671     // 8:15  [8]  reserved: Reserved, value 0b00000000
2672     // 0:7   [8]  reserved: Reserved, value 0b00000000
2673     uint32_t packetFamily = 0;
2674     uint32_t packetId     = 4;
2675     uint32_t header       = ((packetFamily & 0x0000003F) << 26) | ((packetId & 0x000003FF) << 16);
2676
2677     // Create the Periodic Counter Selection packet
2678     Packet periodicCounterSelectionPacket(header);    // Length == 0, this will disable the collection of counters
2679
2680     // Write the packet to the mock profiling connection
2681     mockProfilingConnection->WritePacket(std::move(periodicCounterSelectionPacket));
2682
2683     // Wait for the Periodic Counter Selection packet to be sent
2684     helper.WaitForProfilingPacketsSent();
2685
2686     // The Periodic Counter Selection Handler should not have updated the profiling state
2687     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
2688
2689     // Check that the mock profiling connection contains one Periodic Counter Selection
2690     const std::vector<uint32_t> writtenData = mockProfilingConnection->GetWrittenData();
2691     BOOST_TEST(writtenData.size() == 1);    // Only one packet is expected (no Periodic Counter packets)
2692     BOOST_TEST(writtenData[0] == 12);       // The size of the expected Periodic Counter Selection (echos the sent one)
2693
2694     // Reset the profiling service to stop any running thread
2695     options.m_EnableProfiling = false;
2696     profilingService.ResetExternalProfilingOptions(options, true);
2697 }
2698
2699 BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodPeriodicCounterSelectionPacketSingleCounter)
2700 {
2701     // Swap the profiling connection factory in the profiling service instance with our mock one
2702     SwapProfilingConnectionFactoryHelper helper;
2703
2704     // Reset the profiling service to the uninitialized state
2705     armnn::Runtime::CreationOptions::ExternalProfilingOptions options;
2706     options.m_EnableProfiling          = true;
2707     ProfilingService& profilingService = ProfilingService::Instance();
2708     profilingService.ResetExternalProfilingOptions(options, true);
2709
2710     // Bring the profiling service to the "Active" state
2711     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
2712     profilingService.Update();    // Initialize the counter directory
2713     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
2714     profilingService.Update();    // Create the profiling connection
2715     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
2716     profilingService.Update();    // Start the command handler and the send thread
2717
2718     // Wait for the Stream Metadata packet the be sent
2719     // (we are not testing the connection acknowledgement here so it will be ignored by this test)
2720     helper.WaitForProfilingPacketsSent();
2721
2722     // Force the profiling service to the "Active" state
2723     helper.ForceTransitionToState(ProfilingState::Active);
2724     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
2725
2726     // Get the mock profiling connection
2727     MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
2728     BOOST_CHECK(mockProfilingConnection);
2729
2730     // Remove the packets received so far
2731     mockProfilingConnection->Clear();
2732
2733     // Write a "Periodic Counter Selection" packet into the mock profiling connection, to simulate an input from an
2734     // external profiling service
2735
2736     // Periodic Counter Selection packet header:
2737     // 26:31 [6]  packet_family: Control Packet Family, value 0b000000
2738     // 16:25 [10] packet_id: Packet identifier, value 0b0000000100
2739     // 8:15  [8]  reserved: Reserved, value 0b00000000
2740     // 0:7   [8]  reserved: Reserved, value 0b00000000
2741     uint32_t packetFamily = 0;
2742     uint32_t packetId     = 4;
2743     uint32_t header       = ((packetFamily & 0x0000003F) << 26) | ((packetId & 0x000003FF) << 16);
2744
2745     uint32_t capturePeriod = 123456;    // Some capture period (microseconds)
2746
2747     // Get the first valid counter UID
2748     const ICounterDirectory& counterDirectory = profilingService.GetCounterDirectory();
2749     const Counters& counters                  = counterDirectory.GetCounters();
2750     BOOST_CHECK(!counters.empty());
2751     uint16_t counterUid = counters.begin()->first;    // Valid counter UID
2752
2753     uint32_t length = 6;
2754
2755     auto data = std::make_unique<unsigned char[]>(length);
2756     WriteUint32(data.get(), 0, capturePeriod);
2757     WriteUint16(data.get(), 4, counterUid);
2758
2759     // Create the Periodic Counter Selection packet
2760     Packet periodicCounterSelectionPacket(header, length, data);    // Length > 0, this will start the Period Counter
2761                                                                     // Capture thread
2762
2763     // Write the packet to the mock profiling connection
2764     mockProfilingConnection->WritePacket(std::move(periodicCounterSelectionPacket));
2765
2766     // Expecting one Periodic Counter Selection packet and at least one Periodic Counter Capture packet
2767     int expectedPackets = 2;
2768     std::vector<uint32_t> receivedPackets;
2769
2770     // Keep waiting until all the expected packets have been received
2771     do
2772     {
2773         helper.WaitForProfilingPacketsSent();
2774         const std::vector<uint32_t> writtenData = mockProfilingConnection->GetWrittenData();
2775         if (writtenData.empty())
2776         {
2777             BOOST_ERROR("Packets should be available for reading at this point");
2778             return;
2779         }
2780         receivedPackets.insert(receivedPackets.end(), writtenData.begin(), writtenData.end());
2781         expectedPackets -= boost::numeric_cast<int>(writtenData.size());
2782     } while (expectedPackets > 0);
2783     BOOST_TEST(!receivedPackets.empty());
2784
2785     // The size of the expected Periodic Counter Selection packet (echos the sent one)
2786     BOOST_TEST((std::find(receivedPackets.begin(), receivedPackets.end(), 14) != receivedPackets.end()));
2787     // The size of the expected Periodic Counter Capture packet
2788     BOOST_TEST((std::find(receivedPackets.begin(), receivedPackets.end(), 22) != receivedPackets.end()));
2789
2790     // The Periodic Counter Selection Handler should not have updated the profiling state
2791     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
2792
2793     // Reset the profiling service to stop any running thread
2794     options.m_EnableProfiling = false;
2795     profilingService.ResetExternalProfilingOptions(options, true);
2796 }
2797
2798 BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodPeriodicCounterSelectionPacketMultipleCounters)
2799 {
2800     // Swap the profiling connection factory in the profiling service instance with our mock one
2801     SwapProfilingConnectionFactoryHelper helper;
2802
2803     // Reset the profiling service to the uninitialized state
2804     armnn::Runtime::CreationOptions::ExternalProfilingOptions options;
2805     options.m_EnableProfiling          = true;
2806     ProfilingService& profilingService = ProfilingService::Instance();
2807     profilingService.ResetExternalProfilingOptions(options, true);
2808
2809     // Bring the profiling service to the "Active" state
2810     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
2811     profilingService.Update();    // Initialize the counter directory
2812     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
2813     profilingService.Update();    // Create the profiling connection
2814     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
2815     profilingService.Update();    // Start the command handler and the send thread
2816
2817     // Wait for the Stream Metadata packet the be sent
2818     // (we are not testing the connection acknowledgement here so it will be ignored by this test)
2819     helper.WaitForProfilingPacketsSent();
2820
2821     // Force the profiling service to the "Active" state
2822     helper.ForceTransitionToState(ProfilingState::Active);
2823     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
2824
2825     // Get the mock profiling connection
2826     MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
2827     BOOST_CHECK(mockProfilingConnection);
2828
2829     // Remove the packets received so far
2830     mockProfilingConnection->Clear();
2831
2832     // Write a "Periodic Counter Selection" packet into the mock profiling connection, to simulate an input from an
2833     // external profiling service
2834
2835     // Periodic Counter Selection packet header:
2836     // 26:31 [6]  packet_family: Control Packet Family, value 0b000000
2837     // 16:25 [10] packet_id: Packet identifier, value 0b0000000100
2838     // 8:15  [8]  reserved: Reserved, value 0b00000000
2839     // 0:7   [8]  reserved: Reserved, value 0b00000000
2840     uint32_t packetFamily = 0;
2841     uint32_t packetId     = 4;
2842     uint32_t header       = ((packetFamily & 0x0000003F) << 26) | ((packetId & 0x000003FF) << 16);
2843
2844     uint32_t capturePeriod = 123456;    // Some capture period (microseconds)
2845
2846     // Get the first valid counter UID
2847     const ICounterDirectory& counterDirectory = profilingService.GetCounterDirectory();
2848     const Counters& counters                  = counterDirectory.GetCounters();
2849     BOOST_CHECK(counters.size() > 1);
2850     uint16_t counterUidA = counters.begin()->first;        // First valid counter UID
2851     uint16_t counterUidB = (counters.begin()++)->first;    // Second valid counter UID
2852
2853     uint32_t length = 8;
2854
2855     auto data = std::make_unique<unsigned char[]>(length);
2856     WriteUint32(data.get(), 0, capturePeriod);
2857     WriteUint16(data.get(), 4, counterUidA);
2858     WriteUint16(data.get(), 6, counterUidB);
2859
2860     // Create the Periodic Counter Selection packet
2861     Packet periodicCounterSelectionPacket(header, length, data);    // Length > 0, this will start the Period Counter
2862                                                                     // Capture thread
2863
2864     // Write the packet to the mock profiling connection
2865     mockProfilingConnection->WritePacket(std::move(periodicCounterSelectionPacket));
2866
2867     // Expecting one Periodic Counter Selection packet and at least one Periodic Counter Capture packet
2868     int expectedPackets = 2;
2869     std::vector<uint32_t> receivedPackets;
2870
2871     // Keep waiting until all the expected packets have been received
2872     do
2873     {
2874         helper.WaitForProfilingPacketsSent();
2875         const std::vector<uint32_t> writtenData = mockProfilingConnection->GetWrittenData();
2876         if (writtenData.empty())
2877         {
2878             BOOST_ERROR("Packets should be available for reading at this point");
2879             return;
2880         }
2881         receivedPackets.insert(receivedPackets.end(), writtenData.begin(), writtenData.end());
2882         expectedPackets -= boost::numeric_cast<int>(writtenData.size());
2883     } while (expectedPackets > 0);
2884     BOOST_TEST(!receivedPackets.empty());
2885
2886     // The size of the expected Periodic Counter Selection packet (echos the sent one)
2887     BOOST_TEST((std::find(receivedPackets.begin(), receivedPackets.end(), 16) != receivedPackets.end()));
2888     // The size of the expected Periodic Counter Capture packet
2889     BOOST_TEST((std::find(receivedPackets.begin(), receivedPackets.end(), 28) != receivedPackets.end()));
2890
2891     // The Periodic Counter Selection Handler should not have updated the profiling state
2892     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
2893
2894     // Reset the profiling service to stop any running thread
2895     options.m_EnableProfiling = false;
2896     profilingService.ResetExternalProfilingOptions(options, true);
2897 }
2898
2899 BOOST_AUTO_TEST_CASE(CheckProfilingServiceDisconnect)
2900 {
2901     // Swap the profiling connection factory in the profiling service instance with our mock one
2902     SwapProfilingConnectionFactoryHelper helper;
2903
2904     // Reset the profiling service to the uninitialized state
2905     armnn::Runtime::CreationOptions::ExternalProfilingOptions options;
2906     options.m_EnableProfiling          = true;
2907     ProfilingService& profilingService = ProfilingService::Instance();
2908     profilingService.ResetExternalProfilingOptions(options, true);
2909
2910     // Try to disconnect the profiling service while in the "Uninitialised" state
2911     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
2912     profilingService.Disconnect();
2913     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);    // The state should not change
2914
2915     // Try to disconnect the profiling service while in the "NotConnected" state
2916     profilingService.Update();    // Initialize the counter directory
2917     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
2918     profilingService.Disconnect();
2919     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);    // The state should not change
2920
2921     // Try to disconnect the profiling service while in the "WaitingForAck" state
2922     profilingService.Update();    // Create the profiling connection
2923     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
2924     profilingService.Disconnect();
2925     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);    // The state should not change
2926
2927     // Try to disconnect the profiling service while in the "Active" state
2928     profilingService.Update();    // Start the command handler and the send thread
2929
2930     // Wait for the Stream Metadata packet the be sent
2931     // (we are not testing the connection acknowledgement here so it will be ignored by this test)
2932     helper.WaitForProfilingPacketsSent();
2933
2934     // Force the profiling service to the "Active" state
2935     helper.ForceTransitionToState(ProfilingState::Active);
2936     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
2937
2938     // Get the mock profiling connection
2939     MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
2940     BOOST_CHECK(mockProfilingConnection);
2941
2942     // Check that the profiling connection is open
2943     BOOST_CHECK(mockProfilingConnection->IsOpen());
2944
2945     profilingService.Disconnect();
2946     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);    // The state should have changed
2947
2948     // Check that the profiling connection has been reset
2949     mockProfilingConnection = helper.GetMockProfilingConnection();
2950     BOOST_CHECK(mockProfilingConnection == nullptr);
2951
2952     // Reset the profiling service to stop any running thread
2953     options.m_EnableProfiling = false;
2954     profilingService.ResetExternalProfilingOptions(options, true);
2955 }
2956
2957 BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodPerJobCounterSelectionPacket)
2958 {
2959     // Swap the profiling connection factory in the profiling service instance with our mock one
2960     SwapProfilingConnectionFactoryHelper helper;
2961
2962     // Reset the profiling service to the uninitialized state
2963     armnn::Runtime::CreationOptions::ExternalProfilingOptions options;
2964     options.m_EnableProfiling          = true;
2965     ProfilingService& profilingService = ProfilingService::Instance();
2966     profilingService.ResetExternalProfilingOptions(options, true);
2967
2968     // Bring the profiling service to the "Active" state
2969     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
2970     profilingService.Update();    // Initialize the counter directory
2971     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
2972     profilingService.Update();    // Create the profiling connection
2973     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
2974     profilingService.Update();    // Start the command handler and the send thread
2975
2976     // Wait for the Stream Metadata packet the be sent
2977     // (we are not testing the connection acknowledgement here so it will be ignored by this test)
2978     helper.WaitForProfilingPacketsSent();
2979
2980     // Force the profiling service to the "Active" state
2981     helper.ForceTransitionToState(ProfilingState::Active);
2982     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
2983
2984     // Get the mock profiling connection
2985     MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
2986     BOOST_CHECK(mockProfilingConnection);
2987
2988     // Remove the packets received so far
2989     mockProfilingConnection->Clear();
2990
2991     // Write a "Per-Job Counter Selection" packet into the mock profiling connection, to simulate an input from an
2992     // external profiling service
2993
2994     // Per-Job Counter Selection packet header:
2995     // 26:31 [6]  packet_family: Control Packet Family, value 0b000000
2996     // 16:25 [10] packet_id: Packet identifier, value 0b0000000100
2997     // 8:15  [8]  reserved: Reserved, value 0b00000000
2998     // 0:7   [8]  reserved: Reserved, value 0b00000000
2999     uint32_t packetFamily = 0;
3000     uint32_t packetId     = 5;
3001     uint32_t header       = ((packetFamily & 0x0000003F) << 26) | ((packetId & 0x000003FF) << 16);
3002
3003     // Create the Per-Job Counter Selection packet
3004     Packet periodicCounterSelectionPacket(header);    // Length == 0, this will disable the collection of counters
3005
3006     // Write the packet to the mock profiling connection
3007     mockProfilingConnection->WritePacket(std::move(periodicCounterSelectionPacket));
3008
3009     // Wait for a bit (must at least be the delay value of the mock profiling connection) to make sure that
3010     // the Per-Job Counter Selection packet gets processed by the profiling service
3011     std::this_thread::sleep_for(std::chrono::seconds(2));
3012
3013     // The Per-Job Counter Selection packets are dropped silently, so there should be no reply coming
3014     // from the profiling service
3015     const std::vector<uint32_t> writtenData = mockProfilingConnection->GetWrittenData();
3016     BOOST_TEST(writtenData.empty());
3017
3018     // The Per-Job Counter Selection Command Handler should not have updated the profiling state
3019     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
3020
3021     // Reset the profiling service to stop any running thread
3022     options.m_EnableProfiling = false;
3023     profilingService.ResetExternalProfilingOptions(options, true);
3024 }
3025
3026 BOOST_AUTO_TEST_CASE(CheckConfigureProfilingServiceOn)
3027 {
3028     armnn::Runtime::CreationOptions::ExternalProfilingOptions options;
3029     options.m_EnableProfiling          = true;
3030     ProfilingService& profilingService = ProfilingService::Instance();
3031     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
3032     profilingService.ConfigureProfilingService(options);
3033     // should get as far as NOT_CONNECTED
3034     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
3035     // Reset the profiling service to stop any running thread
3036     options.m_EnableProfiling = false;
3037     profilingService.ResetExternalProfilingOptions(options, true);
3038 }
3039
3040 BOOST_AUTO_TEST_CASE(CheckConfigureProfilingServiceOff)
3041 {
3042     armnn::Runtime::CreationOptions::ExternalProfilingOptions options;
3043     ProfilingService& profilingService = ProfilingService::Instance();
3044     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
3045     profilingService.ConfigureProfilingService(options);
3046     // should not move from Uninitialised
3047     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
3048     // Reset the profiling service to stop any running thread
3049     options.m_EnableProfiling = false;
3050     profilingService.ResetExternalProfilingOptions(options, true);
3051 }
3052
3053 BOOST_AUTO_TEST_SUITE_END()