IVGCVSW-3431 Create Profiling Service State Machine
[platform/upstream/armnn.git] / src / profiling / test / ProfilingTests.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "../CommandHandlerKey.hpp"
7 #include "../CommandHandlerFunctor.hpp"
8 #include "../CommandHandlerRegistry.hpp"
9 #include "../EncodeVersion.hpp"
10 #include "../Packet.hpp"
11 #include "../PacketVersionResolver.hpp"
12 #include "../ProfilingStateMachine.hpp"
13
14 #include <boost/test/unit_test.hpp>
15
16 #include <cstdint>
17 #include <cstring>
18 #include <limits>
19 #include <map>
20 #include <random>
21 #include <thread>
22
23 BOOST_AUTO_TEST_SUITE(ExternalProfiling)
24
25 using namespace armnn::profiling;
26
27 BOOST_AUTO_TEST_CASE(CheckCommandHandlerKeyComparisons)
28 {
29     CommandHandlerKey testKey0(1, 1);
30     CommandHandlerKey testKey1(1, 1);
31     CommandHandlerKey testKey2(1, 1);
32     CommandHandlerKey testKey3(0, 0);
33     CommandHandlerKey testKey4(2, 2);
34     CommandHandlerKey testKey5(0, 2);
35
36     BOOST_CHECK(testKey1<testKey4);
37     BOOST_CHECK(testKey1>testKey3);
38     BOOST_CHECK(testKey1<=testKey4);
39     BOOST_CHECK(testKey1>=testKey3);
40     BOOST_CHECK(testKey1<=testKey2);
41     BOOST_CHECK(testKey1>=testKey2);
42     BOOST_CHECK(testKey1==testKey2);
43     BOOST_CHECK(testKey1==testKey1);
44
45     BOOST_CHECK(!(testKey1==testKey5));
46     BOOST_CHECK(!(testKey1!=testKey1));
47     BOOST_CHECK(testKey1!=testKey5);
48
49     BOOST_CHECK(testKey1==testKey2 && testKey2==testKey1);
50     BOOST_CHECK(testKey0==testKey1 && testKey1==testKey2 && testKey0==testKey2);
51
52     BOOST_CHECK(testKey1.GetPacketId()==1);
53     BOOST_CHECK(testKey1.GetVersion()==1);
54
55     std::vector<CommandHandlerKey> vect =
56     {
57         CommandHandlerKey(0,1), CommandHandlerKey(2,0), CommandHandlerKey(1,0),
58         CommandHandlerKey(2,1), CommandHandlerKey(1,1), CommandHandlerKey(0,1),
59         CommandHandlerKey(2,0), CommandHandlerKey(0,0)
60     };
61
62     std::sort(vect.begin(), vect.end());
63
64     std::vector<CommandHandlerKey> expectedVect =
65     {
66         CommandHandlerKey(0,0), CommandHandlerKey(0,1), CommandHandlerKey(0,1),
67         CommandHandlerKey(1,0), CommandHandlerKey(1,1), CommandHandlerKey(2,0),
68         CommandHandlerKey(2,0), CommandHandlerKey(2,1)
69     };
70
71     BOOST_CHECK(vect == expectedVect);
72 }
73
74 BOOST_AUTO_TEST_CASE(CheckEncodeVersion)
75 {
76     Version version1(12);
77
78     BOOST_CHECK(version1.GetMajor() == 0);
79     BOOST_CHECK(version1.GetMinor() == 0);
80     BOOST_CHECK(version1.GetPatch() == 12);
81
82     Version version2(4108);
83
84     BOOST_CHECK(version2.GetMajor() == 0);
85     BOOST_CHECK(version2.GetMinor() == 1);
86     BOOST_CHECK(version2.GetPatch() == 12);
87
88     Version version3(4198412);
89
90     BOOST_CHECK(version3.GetMajor() == 1);
91     BOOST_CHECK(version3.GetMinor() == 1);
92     BOOST_CHECK(version3.GetPatch() == 12);
93
94     Version version4(0);
95
96     BOOST_CHECK(version4.GetMajor() == 0);
97     BOOST_CHECK(version4.GetMinor() == 0);
98     BOOST_CHECK(version4.GetPatch() == 0);
99
100     Version version5(1, 0, 0);
101     BOOST_CHECK(version5.GetEncodedValue() == 4194304);
102 }
103
104 BOOST_AUTO_TEST_CASE(CheckPacketClass)
105 {
106     const char* data = "test";
107     unsigned int length = static_cast<unsigned int>(std::strlen(data));
108
109     Packet packetTest1(472580096,length,data);
110     BOOST_CHECK_THROW(Packet packetTest2(472580096,0,""), armnn::Exception);
111
112     Packet packetTest3(472580096,0, nullptr);
113
114     BOOST_CHECK(packetTest1.GetLength() == length);
115     BOOST_CHECK(packetTest1.GetData() == data);
116
117     BOOST_CHECK(packetTest1.GetPacketFamily() == 7);
118     BOOST_CHECK(packetTest1.GetPacketId() == 43);
119     BOOST_CHECK(packetTest1.GetPacketType() == 3);
120     BOOST_CHECK(packetTest1.GetPacketClass() == 5);
121 }
122
123 // Create Derived Classes
124 class TestFunctorA : public CommandHandlerFunctor
125 {
126 public:
127     using CommandHandlerFunctor::CommandHandlerFunctor;
128
129     int GetCount() { return m_Count; }
130
131     void operator()(const Packet& packet) override
132     {
133         m_Count++;
134     }
135
136 private:
137     int m_Count = 0;
138 };
139
140 class TestFunctorB : public TestFunctorA
141 {
142     using TestFunctorA::TestFunctorA;
143 };
144
145 class TestFunctorC : public TestFunctorA
146 {
147     using TestFunctorA::TestFunctorA;
148 };
149
150 BOOST_AUTO_TEST_CASE(CheckCommandHandlerFunctor)
151 {
152     // Hard code the version as it will be the same during a single profiling session
153     uint32_t version = 1;
154
155     TestFunctorA testFunctorA(461, version);
156     TestFunctorB testFunctorB(963, version);
157     TestFunctorC testFunctorC(983, version);
158
159     CommandHandlerKey keyA(testFunctorA.GetPacketId(), testFunctorA.GetVersion());
160     CommandHandlerKey keyB(testFunctorB.GetPacketId(), testFunctorB.GetVersion());
161     CommandHandlerKey keyC(testFunctorC.GetPacketId(), testFunctorC.GetVersion());
162
163     // Create the unwrapped map to simulate the Command Handler Registry
164     std::map<CommandHandlerKey, CommandHandlerFunctor*> registry;
165
166     registry.insert(std::make_pair(keyB, &testFunctorB));
167     registry.insert(std::make_pair(keyA, &testFunctorA));
168     registry.insert(std::make_pair(keyC, &testFunctorC));
169
170     // Check the order of the map is correct
171     auto it = registry.begin();
172     BOOST_CHECK(it->first==keyA);
173     it++;
174     BOOST_CHECK(it->first==keyB);
175     it++;
176     BOOST_CHECK(it->first==keyC);
177
178     Packet packetA(500000000, 0, nullptr);
179     Packet packetB(600000000, 0, nullptr);
180     Packet packetC(400000000, 0, nullptr);
181
182     // Check the correct operator of derived class is called
183     registry.at(CommandHandlerKey(packetA.GetPacketId(), version))->operator()(packetA);
184     BOOST_CHECK(testFunctorA.GetCount() == 1);
185     BOOST_CHECK(testFunctorB.GetCount() == 0);
186     BOOST_CHECK(testFunctorC.GetCount() == 0);
187
188     registry.at(CommandHandlerKey(packetB.GetPacketId(), version))->operator()(packetB);
189     BOOST_CHECK(testFunctorA.GetCount() == 1);
190     BOOST_CHECK(testFunctorB.GetCount() == 1);
191     BOOST_CHECK(testFunctorC.GetCount() == 0);
192
193     registry.at(CommandHandlerKey(packetC.GetPacketId(), version))->operator()(packetC);
194     BOOST_CHECK(testFunctorA.GetCount() == 1);
195     BOOST_CHECK(testFunctorB.GetCount() == 1);
196     BOOST_CHECK(testFunctorC.GetCount() == 1);
197 }
198
199 BOOST_AUTO_TEST_CASE(CheckCommandHandlerRegistry)
200 {
201     // Hard code the version as it will be the same during a single profiling session
202     uint32_t version = 1;
203
204     TestFunctorA testFunctorA(461, version);
205     TestFunctorB testFunctorB(963, version);
206     TestFunctorC testFunctorC(983, version);
207
208     // Create the Command Handler Registry
209     CommandHandlerRegistry registry;
210
211     // Register multiple different derived classes
212     registry.RegisterFunctor(&testFunctorA, testFunctorA.GetPacketId(), testFunctorA.GetVersion());
213     registry.RegisterFunctor(&testFunctorB, testFunctorB.GetPacketId(), testFunctorB.GetVersion());
214     registry.RegisterFunctor(&testFunctorC, testFunctorC.GetPacketId(), testFunctorC.GetVersion());
215
216     Packet packetA(500000000, 0, nullptr);
217     Packet packetB(600000000, 0, nullptr);
218     Packet packetC(400000000, 0, nullptr);
219
220     // Check the correct operator of derived class is called
221     registry.GetFunctor(packetA.GetPacketId(), version)->operator()(packetA);
222     BOOST_CHECK(testFunctorA.GetCount() == 1);
223     BOOST_CHECK(testFunctorB.GetCount() == 0);
224     BOOST_CHECK(testFunctorC.GetCount() == 0);
225
226     registry.GetFunctor(packetB.GetPacketId(), version)->operator()(packetB);
227     BOOST_CHECK(testFunctorA.GetCount() == 1);
228     BOOST_CHECK(testFunctorB.GetCount() == 1);
229     BOOST_CHECK(testFunctorC.GetCount() == 0);
230
231     registry.GetFunctor(packetC.GetPacketId(), version)->operator()(packetC);
232     BOOST_CHECK(testFunctorA.GetCount() == 1);
233     BOOST_CHECK(testFunctorB.GetCount() == 1);
234     BOOST_CHECK(testFunctorC.GetCount() == 1);
235
236     // Re-register an existing key with a new function
237     registry.RegisterFunctor(&testFunctorC, testFunctorA.GetPacketId(), version);
238     registry.GetFunctor(packetA.GetPacketId(), version)->operator()(packetC);
239     BOOST_CHECK(testFunctorA.GetCount() == 1);
240     BOOST_CHECK(testFunctorB.GetCount() == 1);
241     BOOST_CHECK(testFunctorC.GetCount() == 2);
242
243     // Check that non-existent key returns nullptr for its functor
244     BOOST_CHECK_THROW(registry.GetFunctor(0, 0), armnn::Exception);
245 }
246
247 BOOST_AUTO_TEST_CASE(CheckPacketVersionResolver)
248 {
249     // Set up random number generator for generating packetId values
250     std::random_device device;
251     std::mt19937 generator(device());
252     std::uniform_int_distribution<uint32_t> distribution(std::numeric_limits<uint32_t>::min(),
253                                                          std::numeric_limits<uint32_t>::max());
254
255     // NOTE: Expected version is always 1.0.0, regardless of packetId
256     const Version expectedVersion(1, 0, 0);
257
258     PacketVersionResolver packetVersionResolver;
259
260     constexpr unsigned int numTests = 10u;
261
262     for (unsigned int i = 0u; i < numTests; ++i)
263     {
264         const uint32_t packetId = distribution(generator);
265         Version resolvedVersion = packetVersionResolver.ResolvePacketVersion(packetId);
266
267         BOOST_TEST(resolvedVersion == expectedVersion);
268     }
269 }
270 void ProfilingCurrentStateThreadImpl(ProfilingStateMachine& states)
271 {
272     ProfilingState newState = ProfilingState::NotConnected;
273     states.GetCurrentState();
274     states.TransitionToState(newState);
275 }
276
277 BOOST_AUTO_TEST_CASE(CheckProfilingStateMachine)
278 {
279     ProfilingStateMachine profilingState1(ProfilingState::Uninitialised);
280     profilingState1.TransitionToState(ProfilingState::Uninitialised);
281     BOOST_CHECK(profilingState1.GetCurrentState() ==  ProfilingState::Uninitialised);
282
283     ProfilingStateMachine profilingState2(ProfilingState::Uninitialised);
284     profilingState2.TransitionToState(ProfilingState::NotConnected);
285     BOOST_CHECK(profilingState2.GetCurrentState() == ProfilingState::NotConnected);
286
287     ProfilingStateMachine profilingState3(ProfilingState::NotConnected);
288     profilingState3.TransitionToState(ProfilingState::NotConnected);
289     BOOST_CHECK(profilingState3.GetCurrentState() == ProfilingState::NotConnected);
290
291     ProfilingStateMachine profilingState4(ProfilingState::NotConnected);
292     profilingState4.TransitionToState(ProfilingState::WaitingForAck);
293     BOOST_CHECK(profilingState4.GetCurrentState() == ProfilingState::WaitingForAck);
294
295     ProfilingStateMachine profilingState5(ProfilingState::WaitingForAck);
296     profilingState5.TransitionToState(ProfilingState::WaitingForAck);
297     BOOST_CHECK(profilingState5.GetCurrentState() == ProfilingState::WaitingForAck);
298
299     ProfilingStateMachine profilingState6(ProfilingState::WaitingForAck);
300     profilingState6.TransitionToState(ProfilingState::Active);
301     BOOST_CHECK(profilingState6.GetCurrentState() == ProfilingState::Active);
302
303     ProfilingStateMachine profilingState7(ProfilingState::Active);
304     profilingState7.TransitionToState(ProfilingState::NotConnected);
305     BOOST_CHECK(profilingState7.GetCurrentState() == ProfilingState::NotConnected);
306
307     ProfilingStateMachine profilingState8(ProfilingState::Active);
308     profilingState8.TransitionToState(ProfilingState::Active);
309     BOOST_CHECK(profilingState8.GetCurrentState() == ProfilingState::Active);
310
311     ProfilingStateMachine profilingState9(ProfilingState::Uninitialised);
312     BOOST_CHECK_THROW(profilingState9.TransitionToState(ProfilingState::WaitingForAck),
313                       armnn::Exception);
314
315     ProfilingStateMachine profilingState10(ProfilingState::Uninitialised);
316     BOOST_CHECK_THROW(profilingState10.TransitionToState(ProfilingState::Active),
317                       armnn::Exception);
318
319     ProfilingStateMachine profilingState11(ProfilingState::NotConnected);
320     BOOST_CHECK_THROW(profilingState11.TransitionToState(ProfilingState::Uninitialised),
321                       armnn::Exception);
322
323     ProfilingStateMachine profilingState12(ProfilingState::NotConnected);
324     BOOST_CHECK_THROW(profilingState12.TransitionToState(ProfilingState::Active),
325                       armnn::Exception);
326
327     ProfilingStateMachine profilingState13(ProfilingState::WaitingForAck);
328     BOOST_CHECK_THROW(profilingState13.TransitionToState(ProfilingState::Uninitialised),
329                       armnn::Exception);
330
331     ProfilingStateMachine profilingState14(ProfilingState::WaitingForAck);
332     BOOST_CHECK_THROW(profilingState14.TransitionToState(ProfilingState::NotConnected),
333                       armnn::Exception);
334
335     ProfilingStateMachine profilingState15(ProfilingState::Active);
336     BOOST_CHECK_THROW(profilingState15.TransitionToState(ProfilingState::Uninitialised),
337                       armnn::Exception);
338
339     ProfilingStateMachine profilingState16(armnn::profiling::ProfilingState::Active);
340     BOOST_CHECK_THROW(profilingState16.TransitionToState(ProfilingState::WaitingForAck),
341                       armnn::Exception);
342
343     ProfilingStateMachine profilingState17(ProfilingState::Uninitialised);
344
345     std::thread thread1 (ProfilingCurrentStateThreadImpl,std::ref(profilingState17));
346     std::thread thread2 (ProfilingCurrentStateThreadImpl,std::ref(profilingState17));
347     std::thread thread3 (ProfilingCurrentStateThreadImpl,std::ref(profilingState17));
348     std::thread thread4 (ProfilingCurrentStateThreadImpl,std::ref(profilingState17));
349     std::thread thread5 (ProfilingCurrentStateThreadImpl,std::ref(profilingState17));
350
351     thread1.join();
352     thread2.join();
353     thread3.join();
354     thread4.join();
355     thread5.join();
356
357     BOOST_TEST((profilingState17.GetCurrentState() == ProfilingState::NotConnected));
358 }
359
360 BOOST_AUTO_TEST_SUITE_END()