2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
6 #include "../CommandHandlerKey.hpp"
7 #include "../CommandHandlerFunctor.hpp"
8 #include "../CommandHandlerRegistry.hpp"
9 #include "../EncodeVersion.hpp"
10 #include "../Packet.hpp"
11 #include "../PacketVersionResolver.hpp"
12 #include "../ProfilingStateMachine.hpp"
14 #include <boost/test/unit_test.hpp>
23 BOOST_AUTO_TEST_SUITE(ExternalProfiling)
25 using namespace armnn::profiling;
27 BOOST_AUTO_TEST_CASE(CheckCommandHandlerKeyComparisons)
29 CommandHandlerKey testKey0(1, 1);
30 CommandHandlerKey testKey1(1, 1);
31 CommandHandlerKey testKey2(1, 1);
32 CommandHandlerKey testKey3(0, 0);
33 CommandHandlerKey testKey4(2, 2);
34 CommandHandlerKey testKey5(0, 2);
36 BOOST_CHECK(testKey1<testKey4);
37 BOOST_CHECK(testKey1>testKey3);
38 BOOST_CHECK(testKey1<=testKey4);
39 BOOST_CHECK(testKey1>=testKey3);
40 BOOST_CHECK(testKey1<=testKey2);
41 BOOST_CHECK(testKey1>=testKey2);
42 BOOST_CHECK(testKey1==testKey2);
43 BOOST_CHECK(testKey1==testKey1);
45 BOOST_CHECK(!(testKey1==testKey5));
46 BOOST_CHECK(!(testKey1!=testKey1));
47 BOOST_CHECK(testKey1!=testKey5);
49 BOOST_CHECK(testKey1==testKey2 && testKey2==testKey1);
50 BOOST_CHECK(testKey0==testKey1 && testKey1==testKey2 && testKey0==testKey2);
52 BOOST_CHECK(testKey1.GetPacketId()==1);
53 BOOST_CHECK(testKey1.GetVersion()==1);
55 std::vector<CommandHandlerKey> vect =
57 CommandHandlerKey(0,1), CommandHandlerKey(2,0), CommandHandlerKey(1,0),
58 CommandHandlerKey(2,1), CommandHandlerKey(1,1), CommandHandlerKey(0,1),
59 CommandHandlerKey(2,0), CommandHandlerKey(0,0)
62 std::sort(vect.begin(), vect.end());
64 std::vector<CommandHandlerKey> expectedVect =
66 CommandHandlerKey(0,0), CommandHandlerKey(0,1), CommandHandlerKey(0,1),
67 CommandHandlerKey(1,0), CommandHandlerKey(1,1), CommandHandlerKey(2,0),
68 CommandHandlerKey(2,0), CommandHandlerKey(2,1)
71 BOOST_CHECK(vect == expectedVect);
74 BOOST_AUTO_TEST_CASE(CheckEncodeVersion)
78 BOOST_CHECK(version1.GetMajor() == 0);
79 BOOST_CHECK(version1.GetMinor() == 0);
80 BOOST_CHECK(version1.GetPatch() == 12);
82 Version version2(4108);
84 BOOST_CHECK(version2.GetMajor() == 0);
85 BOOST_CHECK(version2.GetMinor() == 1);
86 BOOST_CHECK(version2.GetPatch() == 12);
88 Version version3(4198412);
90 BOOST_CHECK(version3.GetMajor() == 1);
91 BOOST_CHECK(version3.GetMinor() == 1);
92 BOOST_CHECK(version3.GetPatch() == 12);
96 BOOST_CHECK(version4.GetMajor() == 0);
97 BOOST_CHECK(version4.GetMinor() == 0);
98 BOOST_CHECK(version4.GetPatch() == 0);
100 Version version5(1, 0, 0);
101 BOOST_CHECK(version5.GetEncodedValue() == 4194304);
104 BOOST_AUTO_TEST_CASE(CheckPacketClass)
106 const char* data = "test";
107 unsigned int length = static_cast<unsigned int>(std::strlen(data));
109 Packet packetTest1(472580096,length,data);
110 BOOST_CHECK_THROW(Packet packetTest2(472580096,0,""), armnn::Exception);
112 Packet packetTest3(472580096,0, nullptr);
114 BOOST_CHECK(packetTest1.GetLength() == length);
115 BOOST_CHECK(packetTest1.GetData() == data);
117 BOOST_CHECK(packetTest1.GetPacketFamily() == 7);
118 BOOST_CHECK(packetTest1.GetPacketId() == 43);
119 BOOST_CHECK(packetTest1.GetPacketType() == 3);
120 BOOST_CHECK(packetTest1.GetPacketClass() == 5);
123 // Create Derived Classes
124 class TestFunctorA : public CommandHandlerFunctor
127 using CommandHandlerFunctor::CommandHandlerFunctor;
129 int GetCount() { return m_Count; }
131 void operator()(const Packet& packet) override
140 class TestFunctorB : public TestFunctorA
142 using TestFunctorA::TestFunctorA;
145 class TestFunctorC : public TestFunctorA
147 using TestFunctorA::TestFunctorA;
150 BOOST_AUTO_TEST_CASE(CheckCommandHandlerFunctor)
152 // Hard code the version as it will be the same during a single profiling session
153 uint32_t version = 1;
155 TestFunctorA testFunctorA(461, version);
156 TestFunctorB testFunctorB(963, version);
157 TestFunctorC testFunctorC(983, version);
159 CommandHandlerKey keyA(testFunctorA.GetPacketId(), testFunctorA.GetVersion());
160 CommandHandlerKey keyB(testFunctorB.GetPacketId(), testFunctorB.GetVersion());
161 CommandHandlerKey keyC(testFunctorC.GetPacketId(), testFunctorC.GetVersion());
163 // Create the unwrapped map to simulate the Command Handler Registry
164 std::map<CommandHandlerKey, CommandHandlerFunctor*> registry;
166 registry.insert(std::make_pair(keyB, &testFunctorB));
167 registry.insert(std::make_pair(keyA, &testFunctorA));
168 registry.insert(std::make_pair(keyC, &testFunctorC));
170 // Check the order of the map is correct
171 auto it = registry.begin();
172 BOOST_CHECK(it->first==keyA);
174 BOOST_CHECK(it->first==keyB);
176 BOOST_CHECK(it->first==keyC);
178 Packet packetA(500000000, 0, nullptr);
179 Packet packetB(600000000, 0, nullptr);
180 Packet packetC(400000000, 0, nullptr);
182 // Check the correct operator of derived class is called
183 registry.at(CommandHandlerKey(packetA.GetPacketId(), version))->operator()(packetA);
184 BOOST_CHECK(testFunctorA.GetCount() == 1);
185 BOOST_CHECK(testFunctorB.GetCount() == 0);
186 BOOST_CHECK(testFunctorC.GetCount() == 0);
188 registry.at(CommandHandlerKey(packetB.GetPacketId(), version))->operator()(packetB);
189 BOOST_CHECK(testFunctorA.GetCount() == 1);
190 BOOST_CHECK(testFunctorB.GetCount() == 1);
191 BOOST_CHECK(testFunctorC.GetCount() == 0);
193 registry.at(CommandHandlerKey(packetC.GetPacketId(), version))->operator()(packetC);
194 BOOST_CHECK(testFunctorA.GetCount() == 1);
195 BOOST_CHECK(testFunctorB.GetCount() == 1);
196 BOOST_CHECK(testFunctorC.GetCount() == 1);
199 BOOST_AUTO_TEST_CASE(CheckCommandHandlerRegistry)
201 // Hard code the version as it will be the same during a single profiling session
202 uint32_t version = 1;
204 TestFunctorA testFunctorA(461, version);
205 TestFunctorB testFunctorB(963, version);
206 TestFunctorC testFunctorC(983, version);
208 // Create the Command Handler Registry
209 CommandHandlerRegistry registry;
211 // Register multiple different derived classes
212 registry.RegisterFunctor(&testFunctorA, testFunctorA.GetPacketId(), testFunctorA.GetVersion());
213 registry.RegisterFunctor(&testFunctorB, testFunctorB.GetPacketId(), testFunctorB.GetVersion());
214 registry.RegisterFunctor(&testFunctorC, testFunctorC.GetPacketId(), testFunctorC.GetVersion());
216 Packet packetA(500000000, 0, nullptr);
217 Packet packetB(600000000, 0, nullptr);
218 Packet packetC(400000000, 0, nullptr);
220 // Check the correct operator of derived class is called
221 registry.GetFunctor(packetA.GetPacketId(), version)->operator()(packetA);
222 BOOST_CHECK(testFunctorA.GetCount() == 1);
223 BOOST_CHECK(testFunctorB.GetCount() == 0);
224 BOOST_CHECK(testFunctorC.GetCount() == 0);
226 registry.GetFunctor(packetB.GetPacketId(), version)->operator()(packetB);
227 BOOST_CHECK(testFunctorA.GetCount() == 1);
228 BOOST_CHECK(testFunctorB.GetCount() == 1);
229 BOOST_CHECK(testFunctorC.GetCount() == 0);
231 registry.GetFunctor(packetC.GetPacketId(), version)->operator()(packetC);
232 BOOST_CHECK(testFunctorA.GetCount() == 1);
233 BOOST_CHECK(testFunctorB.GetCount() == 1);
234 BOOST_CHECK(testFunctorC.GetCount() == 1);
236 // Re-register an existing key with a new function
237 registry.RegisterFunctor(&testFunctorC, testFunctorA.GetPacketId(), version);
238 registry.GetFunctor(packetA.GetPacketId(), version)->operator()(packetC);
239 BOOST_CHECK(testFunctorA.GetCount() == 1);
240 BOOST_CHECK(testFunctorB.GetCount() == 1);
241 BOOST_CHECK(testFunctorC.GetCount() == 2);
243 // Check that non-existent key returns nullptr for its functor
244 BOOST_CHECK_THROW(registry.GetFunctor(0, 0), armnn::Exception);
247 BOOST_AUTO_TEST_CASE(CheckPacketVersionResolver)
249 // Set up random number generator for generating packetId values
250 std::random_device device;
251 std::mt19937 generator(device());
252 std::uniform_int_distribution<uint32_t> distribution(std::numeric_limits<uint32_t>::min(),
253 std::numeric_limits<uint32_t>::max());
255 // NOTE: Expected version is always 1.0.0, regardless of packetId
256 const Version expectedVersion(1, 0, 0);
258 PacketVersionResolver packetVersionResolver;
260 constexpr unsigned int numTests = 10u;
262 for (unsigned int i = 0u; i < numTests; ++i)
264 const uint32_t packetId = distribution(generator);
265 Version resolvedVersion = packetVersionResolver.ResolvePacketVersion(packetId);
267 BOOST_TEST(resolvedVersion == expectedVersion);
270 void ProfilingCurrentStateThreadImpl(ProfilingStateMachine& states)
272 ProfilingState newState = ProfilingState::NotConnected;
273 states.GetCurrentState();
274 states.TransitionToState(newState);
277 BOOST_AUTO_TEST_CASE(CheckProfilingStateMachine)
279 ProfilingStateMachine profilingState1(ProfilingState::Uninitialised);
280 profilingState1.TransitionToState(ProfilingState::Uninitialised);
281 BOOST_CHECK(profilingState1.GetCurrentState() == ProfilingState::Uninitialised);
283 ProfilingStateMachine profilingState2(ProfilingState::Uninitialised);
284 profilingState2.TransitionToState(ProfilingState::NotConnected);
285 BOOST_CHECK(profilingState2.GetCurrentState() == ProfilingState::NotConnected);
287 ProfilingStateMachine profilingState3(ProfilingState::NotConnected);
288 profilingState3.TransitionToState(ProfilingState::NotConnected);
289 BOOST_CHECK(profilingState3.GetCurrentState() == ProfilingState::NotConnected);
291 ProfilingStateMachine profilingState4(ProfilingState::NotConnected);
292 profilingState4.TransitionToState(ProfilingState::WaitingForAck);
293 BOOST_CHECK(profilingState4.GetCurrentState() == ProfilingState::WaitingForAck);
295 ProfilingStateMachine profilingState5(ProfilingState::WaitingForAck);
296 profilingState5.TransitionToState(ProfilingState::WaitingForAck);
297 BOOST_CHECK(profilingState5.GetCurrentState() == ProfilingState::WaitingForAck);
299 ProfilingStateMachine profilingState6(ProfilingState::WaitingForAck);
300 profilingState6.TransitionToState(ProfilingState::Active);
301 BOOST_CHECK(profilingState6.GetCurrentState() == ProfilingState::Active);
303 ProfilingStateMachine profilingState7(ProfilingState::Active);
304 profilingState7.TransitionToState(ProfilingState::NotConnected);
305 BOOST_CHECK(profilingState7.GetCurrentState() == ProfilingState::NotConnected);
307 ProfilingStateMachine profilingState8(ProfilingState::Active);
308 profilingState8.TransitionToState(ProfilingState::Active);
309 BOOST_CHECK(profilingState8.GetCurrentState() == ProfilingState::Active);
311 ProfilingStateMachine profilingState9(ProfilingState::Uninitialised);
312 BOOST_CHECK_THROW(profilingState9.TransitionToState(ProfilingState::WaitingForAck),
315 ProfilingStateMachine profilingState10(ProfilingState::Uninitialised);
316 BOOST_CHECK_THROW(profilingState10.TransitionToState(ProfilingState::Active),
319 ProfilingStateMachine profilingState11(ProfilingState::NotConnected);
320 BOOST_CHECK_THROW(profilingState11.TransitionToState(ProfilingState::Uninitialised),
323 ProfilingStateMachine profilingState12(ProfilingState::NotConnected);
324 BOOST_CHECK_THROW(profilingState12.TransitionToState(ProfilingState::Active),
327 ProfilingStateMachine profilingState13(ProfilingState::WaitingForAck);
328 BOOST_CHECK_THROW(profilingState13.TransitionToState(ProfilingState::Uninitialised),
331 ProfilingStateMachine profilingState14(ProfilingState::WaitingForAck);
332 BOOST_CHECK_THROW(profilingState14.TransitionToState(ProfilingState::NotConnected),
335 ProfilingStateMachine profilingState15(ProfilingState::Active);
336 BOOST_CHECK_THROW(profilingState15.TransitionToState(ProfilingState::Uninitialised),
339 ProfilingStateMachine profilingState16(armnn::profiling::ProfilingState::Active);
340 BOOST_CHECK_THROW(profilingState16.TransitionToState(ProfilingState::WaitingForAck),
343 ProfilingStateMachine profilingState17(ProfilingState::Uninitialised);
345 std::thread thread1 (ProfilingCurrentStateThreadImpl,std::ref(profilingState17));
346 std::thread thread2 (ProfilingCurrentStateThreadImpl,std::ref(profilingState17));
347 std::thread thread3 (ProfilingCurrentStateThreadImpl,std::ref(profilingState17));
348 std::thread thread4 (ProfilingCurrentStateThreadImpl,std::ref(profilingState17));
349 std::thread thread5 (ProfilingCurrentStateThreadImpl,std::ref(profilingState17));
357 BOOST_TEST((profilingState17.GetCurrentState() == ProfilingState::NotConnected));
360 BOOST_AUTO_TEST_SUITE_END()