53f580deafaae5781b6aa47ee5012ce59000b89d
[platform/upstream/armnn.git] / tests / profiling / gatordmock / tests / GatordMockTests.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "../GatordMockService.hpp"
7 #include "../PeriodicCounterCaptureCommandHandler.hpp"
8
9 #include <CommandHandlerRegistry.hpp>
10 #include <DirectoryCaptureCommandHandler.hpp>
11 #include <ProfilingService.hpp>
12
13 #include <test/SendCounterPacketTests.hpp>
14
15 #include <boost/cast.hpp>
16 #include <boost/test/test_tools.hpp>
17 #include <boost/test/unit_test_suite.hpp>
18
19 BOOST_AUTO_TEST_SUITE(GatordMockTests)
20
21 using namespace armnn;
22 using namespace std::this_thread;    // sleep_for, sleep_until
23 using namespace std::chrono_literals;
24
25 BOOST_AUTO_TEST_CASE(CounterCaptureHandlingTest)
26 {
27     using boost::numeric_cast;
28
29     profiling::PacketVersionResolver packetVersionResolver;
30
31     // Data with timestamp, counter idx & counter values
32     std::vector<std::pair<uint16_t, uint32_t>> indexValuePairs;
33     indexValuePairs.reserve(5);
34     indexValuePairs.emplace_back(std::make_pair<uint16_t, uint32_t>(0, 100));
35     indexValuePairs.emplace_back(std::make_pair<uint16_t, uint32_t>(1, 200));
36     indexValuePairs.emplace_back(std::make_pair<uint16_t, uint32_t>(2, 300));
37     indexValuePairs.emplace_back(std::make_pair<uint16_t, uint32_t>(3, 400));
38     indexValuePairs.emplace_back(std::make_pair<uint16_t, uint32_t>(4, 500));
39
40     // ((uint16_t (2 bytes) + uint32_t (4 bytes)) * 5) + word1 + word2
41     uint32_t dataLength = 38;
42
43     // Simulate two different packets incoming 500 ms apart
44     uint64_t time = static_cast<uint64_t>(
45         std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::steady_clock::now().time_since_epoch())
46             .count());
47
48     uint64_t time2 = time + 5000;
49
50     // UniqueData required for Packet class
51     std::unique_ptr<unsigned char[]> uniqueData1 = std::make_unique<unsigned char[]>(dataLength);
52     unsigned char* data1                         = reinterpret_cast<unsigned char*>(uniqueData1.get());
53
54     std::unique_ptr<unsigned char[]> uniqueData2 = std::make_unique<unsigned char[]>(dataLength);
55     unsigned char* data2                         = reinterpret_cast<unsigned char*>(uniqueData2.get());
56
57     uint32_t sizeOfUint64 = numeric_cast<uint32_t>(sizeof(uint64_t));
58     uint32_t sizeOfUint32 = numeric_cast<uint32_t>(sizeof(uint32_t));
59     uint32_t sizeOfUint16 = numeric_cast<uint32_t>(sizeof(uint16_t));
60     // Offset index to point to mem address
61     uint32_t offset = 0;
62
63     profiling::WriteUint64(data1, offset, time);
64     offset += sizeOfUint64;
65     for (const auto& pair : indexValuePairs)
66     {
67         profiling::WriteUint16(data1, offset, pair.first);
68         offset += sizeOfUint16;
69         profiling::WriteUint32(data1, offset, pair.second);
70         offset += sizeOfUint32;
71     }
72
73     offset = 0;
74
75     profiling::WriteUint64(data2, offset, time2);
76     offset += sizeOfUint64;
77     for (const auto& pair : indexValuePairs)
78     {
79         profiling::WriteUint16(data2, offset, pair.first);
80         offset += sizeOfUint16;
81         profiling::WriteUint32(data2, offset, pair.second);
82         offset += sizeOfUint32;
83     }
84
85     uint32_t headerWord1 = packetVersionResolver.ResolvePacketVersion(0, 4).GetEncodedValue();
86     // Create packet to send through to the command functor
87     profiling::Packet packet1(headerWord1, dataLength, uniqueData1);
88     profiling::Packet packet2(headerWord1, dataLength, uniqueData2);
89
90     gatordmock::PeriodicCounterCaptureCommandHandler commandHandler(0, 4, headerWord1, true);
91
92     // Simulate two separate packets coming in to calculate period
93     commandHandler(packet1);
94     commandHandler(packet2);
95
96     BOOST_ASSERT(commandHandler.m_CurrentPeriodValue == 5000);
97
98     for (size_t i = 0; i < commandHandler.m_CounterCaptureValues.m_Uids.size(); ++i)
99     {
100         BOOST_ASSERT(commandHandler.m_CounterCaptureValues.m_Uids[i] == i);
101     }
102 }
103
104 BOOST_AUTO_TEST_CASE(GatorDMockEndToEnd)
105 {
106     // The purpose of this test is to setup both sides of the profiling service and get to the point of receiving
107     // performance data.
108
109     //These variables are used to wait for the profiling service
110     uint32_t timeout   = 2000;
111     uint32_t sleepTime = 50;
112     uint32_t timeSlept = 0;
113
114     profiling::PacketVersionResolver packetVersionResolver;
115
116     // Create the Command Handler Registry
117     profiling::CommandHandlerRegistry registry;
118
119     // Update with derived functors
120     gatordmock::PeriodicCounterCaptureCommandHandler counterCaptureCommandHandler(
121         0, 4, packetVersionResolver.ResolvePacketVersion(0, 4).GetEncodedValue(), true);
122
123     profiling::DirectoryCaptureCommandHandler directoryCaptureCommandHandler(
124         0, 2, packetVersionResolver.ResolvePacketVersion(0, 2).GetEncodedValue(), true);
125
126     // Register different derived functors
127     registry.RegisterFunctor(&counterCaptureCommandHandler);
128     registry.RegisterFunctor(&directoryCaptureCommandHandler);
129     // Setup the mock service to bind to the UDS.
130     std::string udsNamespace = "gatord_namespace";
131     gatordmock::GatordMockService mockService(registry, false);
132     mockService.OpenListeningSocket(udsNamespace);
133
134     // Enable the profiling service.
135     armnn::Runtime::CreationOptions::ExternalProfilingOptions options;
136     options.m_EnableProfiling                     = true;
137     profiling::ProfilingService& profilingService = profiling::ProfilingService::Instance();
138     profilingService.ResetExternalProfilingOptions(options, true);
139
140     // Bring the profiling service to the "WaitingForAck" state
141     BOOST_CHECK(profilingService.GetCurrentState() == profiling::ProfilingState::Uninitialised);
142     profilingService.Update();
143     BOOST_CHECK(profilingService.GetCurrentState() == profiling::ProfilingState::NotConnected);
144     profilingService.Update();
145
146     // Connect the profiling service to the mock Gatord.
147     int clientFd = mockService.BlockForOneClient();
148     if (-1 == clientFd)
149     {
150         BOOST_FAIL("Failed to connect client");
151     }
152
153     // Give the profiling service sending thread time start executing and send the stream metadata.
154     while (profilingService.GetCurrentState() != profiling::ProfilingState::WaitingForAck)
155     {
156         if (timeSlept >= timeout)
157         {
158             BOOST_FAIL("Timeout: Profiling service did not switch to WaitingForAck state");
159         }
160         std::this_thread::sleep_for(std::chrono::milliseconds(sleepTime));
161         timeSlept += sleepTime;
162     }
163
164     profilingService.Update();
165     // Read the stream metadata on the mock side.
166     if (!mockService.WaitForStreamMetaData())
167     {
168         BOOST_FAIL("Failed to receive StreamMetaData");
169     }
170     // Send Ack from GatorD
171     mockService.SendConnectionAck();
172
173     timeSlept = 0;
174     while (profilingService.GetCurrentState() != profiling::ProfilingState::Active)
175     {
176         if (timeSlept >= timeout)
177         {
178             BOOST_FAIL("Timeout: Profiling service did not switch to Active state");
179         }
180         std::this_thread::sleep_for(std::chrono::milliseconds(sleepTime));
181         timeSlept += sleepTime;
182     }
183
184     mockService.LaunchReceivingThread();
185     // As part of the default startup of the profiling service a counter directory packet will be sent.
186     timeSlept = 0;
187     while (!directoryCaptureCommandHandler.ParsedCounterDirectory())
188     {
189         if (timeSlept >= timeout)
190         {
191             BOOST_FAIL("Timeout: MockGatord did not receive counter directory packet");
192         }
193         std::this_thread::sleep_for(std::chrono::milliseconds(sleepTime));
194         timeSlept += sleepTime;
195     }
196
197     const profiling::ICounterDirectory& serviceCounterDirectory  = profilingService.GetCounterDirectory();
198     const profiling::ICounterDirectory& receivedCounterDirectory = directoryCaptureCommandHandler.GetCounterDirectory();
199
200     // Compare thre basics of the counter directory from the service and the one we received over the wire.
201     BOOST_ASSERT(serviceCounterDirectory.GetDeviceCount() == receivedCounterDirectory.GetDeviceCount());
202     BOOST_ASSERT(serviceCounterDirectory.GetCounterSetCount() == receivedCounterDirectory.GetCounterSetCount());
203     BOOST_ASSERT(serviceCounterDirectory.GetCategoryCount() == receivedCounterDirectory.GetCategoryCount());
204     BOOST_ASSERT(serviceCounterDirectory.GetCounterCount() == receivedCounterDirectory.GetCounterCount());
205
206     receivedCounterDirectory.GetDeviceCount();
207     serviceCounterDirectory.GetDeviceCount();
208
209     const profiling::Devices& serviceDevices = serviceCounterDirectory.GetDevices();
210     for (auto& device : serviceDevices)
211     {
212         // Find the same device in the received counter directory.
213         auto foundDevice = receivedCounterDirectory.GetDevices().find(device.second->m_Uid);
214         BOOST_CHECK(foundDevice != receivedCounterDirectory.GetDevices().end());
215         BOOST_CHECK(device.second->m_Name.compare((*foundDevice).second->m_Name) == 0);
216         BOOST_CHECK(device.second->m_Cores == (*foundDevice).second->m_Cores);
217     }
218
219     const profiling::CounterSets& serviceCounterSets = serviceCounterDirectory.GetCounterSets();
220     for (auto& counterSet : serviceCounterSets)
221     {
222         // Find the same counter set in the received counter directory.
223         auto foundCounterSet = receivedCounterDirectory.GetCounterSets().find(counterSet.second->m_Uid);
224         BOOST_CHECK(foundCounterSet != receivedCounterDirectory.GetCounterSets().end());
225         BOOST_CHECK(counterSet.second->m_Name.compare((*foundCounterSet).second->m_Name) == 0);
226         BOOST_CHECK(counterSet.second->m_Count == (*foundCounterSet).second->m_Count);
227     }
228
229     const profiling::Categories& serviceCategories = serviceCounterDirectory.GetCategories();
230     for (auto& category : serviceCategories)
231     {
232         for (auto& receivedCategory : receivedCounterDirectory.GetCategories())
233         {
234             if (receivedCategory->m_Name.compare(category->m_Name) == 0)
235             {
236                 // We've found the matching category.
237                 BOOST_CHECK(category->m_DeviceUid == receivedCategory->m_DeviceUid);
238                 BOOST_CHECK(category->m_CounterSetUid == receivedCategory->m_CounterSetUid);
239                 // Now look at the interiors of the counters. Start by sorting them.
240                 std::sort(category->m_Counters.begin(), category->m_Counters.end());
241                 std::sort(receivedCategory->m_Counters.begin(), receivedCategory->m_Counters.end());
242                 // When comparing uid's here we need to translate them.
243                 std::function<bool(const uint16_t&, const uint16_t&)> comparator =
244                     [&directoryCaptureCommandHandler](const uint16_t& first, const uint16_t& second) {
245                         uint16_t translated = directoryCaptureCommandHandler.TranslateUIDCopyToOriginal(second);
246                         if (translated == first)
247                         {
248                             return true;
249                         }
250                         return false;
251                     };
252                 // Then let vector == do the work.
253                 BOOST_CHECK(std::equal(category->m_Counters.begin(), category->m_Counters.end(),
254                                        receivedCategory->m_Counters.begin(), comparator));
255                 break;
256             }
257         }
258     }
259
260     // Finally check the content of the counters.
261     const profiling::Counters& receivedCounters = receivedCounterDirectory.GetCounters();
262     for (auto& receivedCounter : receivedCounters)
263     {
264         // Translate the Uid and find the corresponding counter in the original counter directory.
265         // Note we can't check m_MaxCounterUid here as it will likely differ between the two counter directories.
266         uint16_t translated = directoryCaptureCommandHandler.TranslateUIDCopyToOriginal(receivedCounter.first);
267         const profiling::Counter* serviceCounter = serviceCounterDirectory.GetCounter(translated);
268         BOOST_CHECK(serviceCounter->m_DeviceUid == receivedCounter.second->m_DeviceUid);
269         BOOST_CHECK(serviceCounter->m_Name.compare(receivedCounter.second->m_Name) == 0);
270         BOOST_CHECK(serviceCounter->m_CounterSetUid == receivedCounter.second->m_CounterSetUid);
271         BOOST_CHECK(serviceCounter->m_Multiplier == receivedCounter.second->m_Multiplier);
272         BOOST_CHECK(serviceCounter->m_Interpolation == receivedCounter.second->m_Interpolation);
273         BOOST_CHECK(serviceCounter->m_Class == receivedCounter.second->m_Class);
274         BOOST_CHECK(serviceCounter->m_Units.compare(receivedCounter.second->m_Units) == 0);
275         BOOST_CHECK(serviceCounter->m_Description.compare(receivedCounter.second->m_Description) == 0);
276     }
277
278     mockService.WaitForReceivingThread();
279     options.m_EnableProfiling = false;
280     profilingService.ResetExternalProfilingOptions(options, true);
281
282     // Future tests here will add counters to the ProfilingService, increment values and examine
283     // PeriodicCounterCapture data received. These are yet to be integrated.
284 }
285
286 BOOST_AUTO_TEST_SUITE_END()