IVGCVSW-4129 Fix thread starvation due to low capture periods
[platform/upstream/armnn.git] / tests / profiling / gatordmock / GatordMockService.cpp
1 //
2 // Copyright © 2019 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "GatordMockService.hpp"
7
8 #include <CommandHandlerRegistry.hpp>
9 #include <PacketVersionResolver.hpp>
10 #include <ProfilingUtils.hpp>
11
12 #include <cerrno>
13 #include <fcntl.h>
14 #include <iomanip>
15 #include <iostream>
16 #include <poll.h>
17 #include <string>
18 #include <sys/ioctl.h>
19 #include <sys/socket.h>
20 #include <sys/un.h>
21 #include <unistd.h>
22
23 namespace armnn
24 {
25
26 namespace gatordmock
27 {
28
29 bool GatordMockService::OpenListeningSocket(std::string udsNamespace)
30 {
31     m_ListeningSocket = socket(PF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0);
32     if (-1 == m_ListeningSocket)
33     {
34         std::cerr << ": Socket construction failed: " << strerror(errno) << std::endl;
35         return false;
36     }
37
38     sockaddr_un udsAddress;
39     memset(&udsAddress, 0, sizeof(sockaddr_un));
40     // We've set the first element of sun_path to be 0, skip over it and copy the namespace after it.
41     memcpy(udsAddress.sun_path + 1, udsNamespace.c_str(), strlen(udsNamespace.c_str()));
42     udsAddress.sun_family = AF_UNIX;
43
44     // Bind the socket to the UDS namespace.
45     if (-1 == bind(m_ListeningSocket, reinterpret_cast<const sockaddr*>(&udsAddress), sizeof(sockaddr_un)))
46     {
47         std::cerr << ": Binding on socket failed: " << strerror(errno) << std::endl;
48         return false;
49     }
50     // Listen for 1 connection.
51     if (-1 == listen(m_ListeningSocket, 1))
52     {
53         std::cerr << ": Listen call on socket failed: " << strerror(errno) << std::endl;
54         return false;
55     }
56     return true;
57 }
58
59 int GatordMockService::BlockForOneClient()
60 {
61     m_ClientConnection = accept4(m_ListeningSocket, nullptr, nullptr, SOCK_CLOEXEC);
62     if (-1 == m_ClientConnection)
63     {
64         std::cerr << ": Failure when waiting for a client connection: " << strerror(errno) << std::endl;
65         return -1;
66     }
67     return m_ClientConnection;
68 }
69
70 bool GatordMockService::WaitForStreamMetaData()
71 {
72     if (m_EchoPackets)
73     {
74         std::cout << "Waiting for stream meta data..." << std::endl;
75     }
76     // The start of the stream metadata is 2x32bit words, 0 and packet length.
77     uint8_t header[8];
78     if (!ReadFromSocket(header, 8))
79     {
80         return false;
81     }
82     EchoPacket(PacketDirection::ReceivedHeader, header, 8);
83     // The first word, stream_metadata_identifer, should always be 0.
84     if (ToUint32(&header[0], TargetEndianness::BeWire) != 0)
85     {
86         std::cerr << ": Protocol error. The stream_metadata_identifer was not 0." << std::endl;
87         return false;
88     }
89
90     uint8_t pipeMagic[4];
91     if (!ReadFromSocket(pipeMagic, 4))
92     {
93         return false;
94     }
95     EchoPacket(PacketDirection::ReceivedData, pipeMagic, 4);
96
97     // Before we interpret the length we need to read the pipe_magic word to determine endianness.
98     if (ToUint32(&pipeMagic[0], TargetEndianness::BeWire) == PIPE_MAGIC)
99     {
100         m_Endianness = TargetEndianness::BeWire;
101     }
102     else if (ToUint32(&pipeMagic[0], TargetEndianness::LeWire) == PIPE_MAGIC)
103     {
104         m_Endianness = TargetEndianness::LeWire;
105     }
106     else
107     {
108         std::cerr << ": Protocol read error. Unable to read PIPE_MAGIC value." << std::endl;
109         return false;
110     }
111     // Now we know the endianness we can get the length from the header.
112     // Remember we already read the pipe magic 4 bytes.
113     uint32_t metaDataLength = ToUint32(&header[4], m_Endianness) - 4;
114     // Read the entire packet.
115     uint8_t packetData[metaDataLength];
116     if (metaDataLength != boost::numeric_cast<uint32_t>(read(m_ClientConnection, &packetData, metaDataLength)))
117     {
118         std::cerr << ": Protocol read error. Data length mismatch." << std::endl;
119         return false;
120     }
121     EchoPacket(PacketDirection::ReceivedData, packetData, metaDataLength);
122     m_StreamMetaDataVersion    = ToUint32(&packetData[0], m_Endianness);
123     m_StreamMetaDataMaxDataLen = ToUint32(&packetData[4], m_Endianness);
124     m_StreamMetaDataPid        = ToUint32(&packetData[8], m_Endianness);
125
126     return true;
127 }
128
129 void GatordMockService::SendConnectionAck()
130 {
131     if (m_EchoPackets)
132     {
133         std::cout << "Sending connection acknowledgement." << std::endl;
134     }
135     // The connection ack packet is an empty data packet with packetId == 1.
136     SendPacket(0, 1, nullptr, 0);
137 }
138
139 void GatordMockService::SendRequestCounterDir()
140 {
141     if (m_EchoPackets)
142     {
143         std::cout << "Sending connection acknowledgement." << std::endl;
144     }
145     // The connection ack packet is an empty data packet with packetId == 1.
146     SendPacket(0, 3, nullptr, 0);
147 }
148
149 bool GatordMockService::LaunchReceivingThread()
150 {
151     if (m_EchoPackets)
152     {
153         std::cout << "Launching receiving thread." << std::endl;
154     }
155     // At this point we want to make the socket non blocking.
156     const int currentFlags = fcntl(m_ClientConnection, F_GETFL);
157     if (0 != fcntl(m_ClientConnection, F_SETFL, currentFlags | O_NONBLOCK))
158     {
159         close(m_ClientConnection);
160         std::cerr << "Failed to set socket as non blocking: " << strerror(errno) << std::endl;
161         return false;
162     }
163     m_ListeningThread = std::thread(&GatordMockService::ReceiveLoop, this, std::ref(*this));
164     return true;
165 }
166
167 void GatordMockService::WaitForReceivingThread()
168 {
169     // The receiving thread may already have died.
170     if (m_CloseReceivingThread != true)
171     {
172         m_CloseReceivingThread.store(true);
173     }
174     // Check that the receiving thread is running
175     if (m_ListeningThread.joinable())
176     {
177         // Wait for the receiving thread to complete operations
178         m_ListeningThread.join();
179     }
180 }
181
182 void GatordMockService::SendPeriodicCounterSelectionList(uint32_t period, std::vector<uint16_t> counters)
183 {
184     // The packet body consists of a UINT32 representing the period following by zero or more
185     // UINT16's representing counter UID's. If the list is empty it implies all counters are to
186     // be disabled.
187
188     if (m_EchoPackets)
189     {
190         std::cout << "SendPeriodicCounterSelectionList: Period=" << std::dec << period << "uSec" << std::endl;
191         std::cout << "List length=" << counters.size() << std::endl;
192         ;
193     }
194     // Start by calculating the length of the packet body in bytes. This will be at least 4.
195     uint32_t dataLength = static_cast<uint32_t>(4 + (counters.size() * 2));
196
197     std::unique_ptr<unsigned char[]> uniqueData = std::make_unique<unsigned char[]>(dataLength);
198     unsigned char* data                         = reinterpret_cast<unsigned char*>(uniqueData.get());
199
200     uint32_t offset = 0;
201     profiling::WriteUint32(data, offset, period);
202     offset += 4;
203     for (std::vector<uint16_t>::iterator it = counters.begin(); it != counters.end(); ++it)
204     {
205         profiling::WriteUint16(data, offset, *it);
206         offset += 2;
207     }
208
209     // Send the packet.
210     SendPacket(0, 4, data, dataLength);
211     // There will be an echo response packet sitting in the receive thread. PeriodicCounterSelectionResponseHandler
212     // should deal with it.
213 }
214
215 void GatordMockService::WaitCommand(uint timeout)
216 {
217     // Wait for a maximum of timeout microseconds or if the receive thread has closed.
218     // There is a certain level of rounding involved in this timing.
219     uint iterations = timeout / 1000;
220     std::cout << std::dec << "Wait command with timeout of " << timeout << " iterations =  " << iterations << std::endl;
221     uint count = 0;
222     while ((this->ReceiveThreadRunning() && (count < iterations)))
223     {
224         std::this_thread::sleep_for(std::chrono::microseconds(1000));
225         ++count;
226     }
227     if (m_EchoPackets)
228     {
229         std::cout << std::dec << "Wait command with timeout of " << timeout << " microseconds completed. " << std::endl;
230     }
231 }
232
233 void GatordMockService::ReceiveLoop(GatordMockService& mockService)
234 {
235     m_CloseReceivingThread.store(false);
236     while (!m_CloseReceivingThread.load())
237     {
238         try
239         {
240             armnn::profiling::Packet packet = mockService.WaitForPacket(500);
241         }
242         catch (const armnn::TimeoutException&)
243         {
244             // In this case we ignore timeouts and and keep trying to receive.
245         }
246         catch (const armnn::InvalidArgumentException& e)
247         {
248             // We couldn't find a functor to handle the packet?
249             std::cerr << "Packet received that could not be processed: " << e.what() << std::endl;
250         }
251         catch (const armnn::RuntimeException& e)
252         {
253             // A runtime exception occurred which means we must exit the loop.
254             std::cerr << "Receive thread closing: " << e.what() << std::endl;
255             m_CloseReceivingThread.store(true);
256         }
257     }
258 }
259
260 armnn::profiling::Packet GatordMockService::WaitForPacket(uint32_t timeoutMs)
261 {
262     // Is there currently more than a headers worth of data waiting to be read?
263     int bytes_available;
264     ioctl(m_ClientConnection, FIONREAD, &bytes_available);
265     if (bytes_available > 8)
266     {
267         // Yes there is. Read it:
268         return ReceivePacket();
269     }
270     else
271     {
272         // No there's not. Poll for more data.
273         struct pollfd pollingFd[1]{};
274         pollingFd[0].fd = m_ClientConnection;
275         int pollResult  = poll(pollingFd, 1, static_cast<int>(timeoutMs));
276
277         switch (pollResult)
278         {
279             // Error
280             case -1:
281                 throw armnn::RuntimeException(std::string("File descriptor reported an error during polling: ") +
282                                               strerror(errno));
283
284             // Timeout
285             case 0:
286                 throw armnn::TimeoutException("Timeout while waiting to receive packet.");
287
288             // Normal poll return. It could still contain an error signal
289             default:
290                 // Check if the socket reported an error
291                 if (pollingFd[0].revents & (POLLNVAL | POLLERR | POLLHUP))
292                 {
293                     if (pollingFd[0].revents == POLLNVAL)
294                     {
295                         throw armnn::RuntimeException(std::string("Error while polling receiving socket: POLLNVAL"));
296                     }
297                     if (pollingFd[0].revents == POLLERR)
298                     {
299                         throw armnn::RuntimeException(std::string("Error while polling receiving socket: POLLERR: ") +
300                                                       strerror(errno));
301                     }
302                     if (pollingFd[0].revents == POLLHUP)
303                     {
304                         throw armnn::RuntimeException(std::string("Connection closed by remote client: POLLHUP"));
305                     }
306                 }
307
308                 // Check if there is data to read
309                 if (!(pollingFd[0].revents & (POLLIN)))
310                 {
311                     // This is a corner case. The socket as been woken up but not with any data.
312                     // We'll throw a timeout exception to loop around again.
313                     throw armnn::TimeoutException("File descriptor was polled but no data was available to receive.");
314                 }
315                 return ReceivePacket();
316         }
317     }
318 }
319
320 armnn::profiling::Packet GatordMockService::ReceivePacket()
321 {
322     uint32_t header[2];
323     if (!ReadHeader(header))
324     {
325         return armnn::profiling::Packet();
326     }
327     // Read data_length bytes from the socket.
328     std::unique_ptr<unsigned char[]> uniquePacketData = std::make_unique<unsigned char[]>(header[1]);
329     unsigned char* packetData                         = reinterpret_cast<unsigned char*>(uniquePacketData.get());
330
331     if (!ReadFromSocket(packetData, header[1]))
332     {
333         return armnn::profiling::Packet();
334     }
335
336     EchoPacket(PacketDirection::ReceivedData, packetData, header[1]);
337
338     // Construct received packet
339     armnn::profiling::PacketVersionResolver packetVersionResolver;
340     armnn::profiling::Packet packetRx = armnn::profiling::Packet(header[0], header[1], uniquePacketData);
341     if (m_EchoPackets)
342     {
343         std::cout << "Processing packet ID= " << packetRx.GetPacketId() << " Length=" << packetRx.GetLength()
344                   << std::endl;
345     }
346
347     profiling::Version version =
348         packetVersionResolver.ResolvePacketVersion(packetRx.GetPacketFamily(), packetRx.GetPacketId());
349
350     profiling::CommandHandlerFunctor* commandHandlerFunctor =
351         m_HandlerRegistry.GetFunctor(packetRx.GetPacketFamily(), packetRx.GetPacketId(), version.GetEncodedValue());
352     BOOST_ASSERT(commandHandlerFunctor);
353     commandHandlerFunctor->operator()(packetRx);
354     return packetRx;
355 }
356
357 bool GatordMockService::SendPacket(uint32_t packetFamily, uint32_t packetId, const uint8_t* data, uint32_t dataLength)
358 {
359     // Construct a packet from the id and data given and send it to the client.
360     // Encode the header.
361     uint32_t header[2];
362     header[0] = packetFamily << 26 | packetId << 16;
363     header[1] = dataLength;
364     // Add the header to the packet.
365     uint8_t packet[8 + dataLength];
366     InsertU32(header[0], packet, m_Endianness);
367     InsertU32(header[1], packet + 4, m_Endianness);
368     // And the rest of the data if there is any.
369     if (dataLength > 0)
370     {
371         memcpy((packet + 8), data, dataLength);
372     }
373     EchoPacket(PacketDirection::Sending, packet, sizeof(packet));
374     if (-1 == write(m_ClientConnection, packet, sizeof(packet)))
375     {
376         std::cerr << ": Failure when writing to client socket: " << strerror(errno) << std::endl;
377         return false;
378     }
379     return true;
380 }
381
382 bool GatordMockService::ReadHeader(uint32_t headerAsWords[2])
383 {
384     // The header will always be 2x32bit words.
385     uint8_t header[8];
386     if (!ReadFromSocket(header, 8))
387     {
388         return false;
389     }
390     EchoPacket(PacketDirection::ReceivedHeader, header, 8);
391     headerAsWords[0] = ToUint32(&header[0], m_Endianness);
392     headerAsWords[1] = ToUint32(&header[4], m_Endianness);
393     return true;
394 }
395
396 bool GatordMockService::ReadFromSocket(uint8_t* packetData, uint32_t expectedLength)
397 {
398     // This is a blocking read until either expectedLength has been received or an error is detected.
399     ssize_t totalBytesRead = 0;
400     while (boost::numeric_cast<uint32_t>(totalBytesRead) < expectedLength)
401     {
402         ssize_t bytesRead = recv(m_ClientConnection, packetData, expectedLength, 0);
403         if (bytesRead < 0)
404         {
405             std::cerr << ": Failure when reading from client socket: " << strerror(errno) << std::endl;
406             return false;
407         }
408         if (bytesRead == 0)
409         {
410             std::cerr << ": EOF while reading from client socket." << std::endl;
411             return false;
412         }
413         totalBytesRead += bytesRead;
414     }
415     return true;
416 };
417
418 void GatordMockService::EchoPacket(PacketDirection direction, uint8_t* packet, size_t lengthInBytes)
419 {
420     // If enabled print the contents of the data packet to the console.
421     if (m_EchoPackets)
422     {
423         if (direction == PacketDirection::Sending)
424         {
425             std::cout << "TX " << std::dec << lengthInBytes << " bytes : ";
426         }
427         else if (direction == PacketDirection::ReceivedHeader)
428         {
429             std::cout << "RX Header " << std::dec << lengthInBytes << " bytes : ";
430         }
431         else
432         {
433             std::cout << "RX Data " << std::dec << lengthInBytes << " bytes : ";
434         }
435         for (unsigned int i = 0; i < lengthInBytes; i++)
436         {
437             if ((i % 10) == 0)
438             {
439                 std::cout << std::endl;
440             }
441             std::cout << "0x" << std::setfill('0') << std::setw(2) << std::hex << static_cast<unsigned int>(packet[i])
442                       << " ";
443         }
444         std::cout << std::endl;
445     }
446 }
447
448 uint32_t GatordMockService::ToUint32(uint8_t* data, TargetEndianness endianness)
449 {
450     // Extract the first 4 bytes starting at data and push them into a 32bit integer based on the
451     // specified endianness.
452     if (endianness == TargetEndianness::BeWire)
453     {
454         return static_cast<uint32_t>(data[0]) << 24 | static_cast<uint32_t>(data[1]) << 16 |
455                static_cast<uint32_t>(data[2]) << 8 | static_cast<uint32_t>(data[3]);
456     }
457     else
458     {
459         return static_cast<uint32_t>(data[3]) << 24 | static_cast<uint32_t>(data[2]) << 16 |
460                static_cast<uint32_t>(data[1]) << 8 | static_cast<uint32_t>(data[0]);
461     }
462 }
463
464 void GatordMockService::InsertU32(uint32_t value, uint8_t* data, TargetEndianness endianness)
465 {
466     // Take the bytes of a 32bit integer and copy them into char array starting at data considering
467     // the endianness value.
468     if (endianness == TargetEndianness::BeWire)
469     {
470         *data       = static_cast<uint8_t>((value >> 24) & 0xFF);
471         *(data + 1) = static_cast<uint8_t>((value >> 16) & 0xFF);
472         *(data + 2) = static_cast<uint8_t>((value >> 8) & 0xFF);
473         *(data + 3) = static_cast<uint8_t>(value & 0xFF);
474     }
475     else
476     {
477         *(data + 3) = static_cast<uint8_t>((value >> 24) & 0xFF);
478         *(data + 2) = static_cast<uint8_t>((value >> 16) & 0xFF);
479         *(data + 1) = static_cast<uint8_t>((value >> 8) & 0xFF);
480         *data       = static_cast<uint8_t>(value & 0xFF);
481     }
482 }
483
484 }    // namespace gatordmock
485
486 }    // namespace armnn