IVGCVSW-4760 Change the offsets in the counter directory body_header to be from the...
[platform/upstream/armnn.git] / src / profiling / SocketProfilingConnection.cpp
1 //
2 // Copyright © 2019 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "SocketProfilingConnection.hpp"
7
8 #include "common/include/SocketConnectionException.hpp"
9
10 #include <cerrno>
11 #include <fcntl.h>
12 #include <string>
13
14 using namespace armnnUtils;
15
16 namespace armnn
17 {
18 namespace profiling
19 {
20
21 SocketProfilingConnection::SocketProfilingConnection()
22 {
23     Sockets::Initialize();
24     memset(m_Socket, 0, sizeof(m_Socket));
25     // Note: we're using Linux specific SOCK_CLOEXEC flag.
26     m_Socket[0].fd = socket(PF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0);
27     if (m_Socket[0].fd == -1)
28     {
29         throw armnnProfiling::SocketConnectionException(
30             std::string("SocketProfilingConnection: Socket construction failed: ")  + strerror(errno),
31             m_Socket[0].fd,
32             errno);
33     }
34
35     // Connect to the named unix domain socket.
36     sockaddr_un server{};
37     memset(&server, 0, sizeof(sockaddr_un));
38     // As m_GatorNamespace begins with a null character we need to ignore that when getting its length.
39     memcpy(server.sun_path, m_GatorNamespace, strlen(m_GatorNamespace + 1) + 1);
40     server.sun_family = AF_UNIX;
41     if (0 != connect(m_Socket[0].fd, reinterpret_cast<const sockaddr*>(&server), sizeof(sockaddr_un)))
42     {
43         Close();
44         throw armnnProfiling::SocketConnectionException(
45             std::string("SocketProfilingConnection: Cannot connect to stream socket: ")  + strerror(errno),
46             m_Socket[0].fd,
47             errno);
48     }
49
50     // Our socket will only be interested in polling reads.
51     m_Socket[0].events = POLLIN;
52
53     // Make the socket non blocking.
54     if (!Sockets::SetNonBlocking(m_Socket[0].fd))
55     {
56         Close();
57         throw armnnProfiling::SocketConnectionException(
58             std::string("SocketProfilingConnection: Failed to set socket as non blocking: ")  + strerror(errno),
59             m_Socket[0].fd,
60             errno);
61     }
62 }
63
64 bool SocketProfilingConnection::IsOpen() const
65 {
66     return m_Socket[0].fd > 0;
67 }
68
69 void SocketProfilingConnection::Close()
70 {
71     if (Sockets::Close(m_Socket[0].fd) != 0)
72     {
73         throw armnnProfiling::SocketConnectionException(
74             std::string("SocketProfilingConnection: Cannot close stream socket: ")  + strerror(errno),
75             m_Socket[0].fd,
76             errno);
77     }
78
79     memset(m_Socket, 0, sizeof(m_Socket));
80 }
81
82 bool SocketProfilingConnection::WritePacket(const unsigned char* buffer, uint32_t length)
83 {
84     if (buffer == nullptr || length == 0)
85     {
86         return false;
87     }
88
89     return Sockets::Write(m_Socket[0].fd, buffer, length) != -1;
90 }
91
92 Packet SocketProfilingConnection::ReadPacket(uint32_t timeout)
93 {
94     // Is there currently at least a header worth of data waiting to be read?
95     int bytes_available = 0;
96     Sockets::Ioctl(m_Socket[0].fd, FIONREAD, &bytes_available);
97     if (bytes_available >= 8)
98     {
99         // Yes there is. Read it:
100         return ReceivePacket();
101     }
102
103     // Poll for data on the socket or until timeout occurs
104     int pollResult = Sockets::Poll(&m_Socket[0], 1, static_cast<int>(timeout));
105
106     switch (pollResult)
107     {
108     case -1: // Error
109         throw armnnProfiling::SocketConnectionException(
110             std::string("SocketProfilingConnection: Error occured while reading from socket: ") + strerror(errno),
111             m_Socket[0].fd,
112             errno);
113
114     case 0: // Timeout
115         throw TimeoutException("SocketProfilingConnection: Timeout while reading from socket");
116
117     default: // Normal poll return but it could still contain an error signal
118         // Check if the socket reported an error
119         if (m_Socket[0].revents & (POLLNVAL | POLLERR | POLLHUP))
120         {
121             if (m_Socket[0].revents == POLLNVAL)
122             {
123                 // This is an unrecoverable error.
124                 Close();
125                 throw armnnProfiling::SocketConnectionException(
126                     std::string("SocketProfilingConnection: Error occured while polling receiving socket: POLLNVAL."),
127                     m_Socket[0].fd);
128             }
129             if (m_Socket[0].revents == POLLERR)
130             {
131                 throw armnnProfiling::SocketConnectionException(
132                     std::string(
133                         "SocketProfilingConnection: Error occured while polling receiving socket: POLLERR: ")
134                         + strerror(errno),
135                     m_Socket[0].fd,
136                     errno);
137             }
138             if (m_Socket[0].revents == POLLHUP)
139             {
140                 // This is an unrecoverable error.
141                 Close();
142                 throw armnnProfiling::SocketConnectionException(
143                     std::string("SocketProfilingConnection: Connection closed by remote client: POLLHUP."),
144                     m_Socket[0].fd);
145             }
146         }
147
148         // Check if there is data to read
149         if (!(m_Socket[0].revents & (POLLIN)))
150         {
151             // This is a corner case. The socket as been woken up but not with any data.
152             // We'll throw a timeout exception to loop around again.
153             throw armnn::TimeoutException(
154                 "SocketProfilingConnection: File descriptor was polled but no data was available to receive.");
155         }
156
157         return ReceivePacket();
158     }
159 }
160
161 Packet SocketProfilingConnection::ReceivePacket()
162 {
163     char header[8] = {};
164     long receiveResult = Sockets::Read(m_Socket[0].fd, &header, sizeof(header));
165     // We expect 8 as the result here. 0 means EOF, socket is closed. -1 means there been some other kind of error.
166     switch( receiveResult )
167     {
168         case 0:
169             // Socket has closed.
170             Close();
171             throw armnnProfiling::SocketConnectionException(
172                 std::string("SocketProfilingConnection: Remote socket has closed the connection."),
173                 m_Socket[0].fd);
174         case -1:
175             // There's been a socket error. We will presume it's unrecoverable.
176             Close();
177             throw armnnProfiling::SocketConnectionException(
178                 std::string("SocketProfilingConnection: Error occured while reading the packet: ") + strerror(errno),
179                 m_Socket[0].fd,
180                 errno);
181         default:
182             if (receiveResult < 8)
183             {
184                  throw armnnProfiling::SocketConnectionException(
185                      std::string(
186                          "SocketProfilingConnection: The received packet did not contains a valid PIPE header."),
187                      m_Socket[0].fd);
188             }
189             break;
190     }
191
192     // stream_metadata_identifier is the first 4 bytes
193     uint32_t metadataIdentifier = 0;
194     std::memcpy(&metadataIdentifier, header, sizeof(metadataIdentifier));
195
196     // data_length is the next 4 bytes
197     uint32_t dataLength = 0;
198     std::memcpy(&dataLength, header + 4u, sizeof(dataLength));
199
200     std::unique_ptr<unsigned char[]> packetData;
201     if (dataLength > 0)
202     {
203         packetData = std::make_unique<unsigned char[]>(dataLength);
204         long receivedLength = Sockets::Read(m_Socket[0].fd, packetData.get(), dataLength);
205         if (receivedLength < 0)
206         {
207             throw armnnProfiling::SocketConnectionException(
208                 std::string("SocketProfilingConnection: Error occured while reading the packet: ")  + strerror(errno),
209                 m_Socket[0].fd,
210                 errno);
211         }
212         if (dataLength != static_cast<uint32_t>(receivedLength))
213         {
214             // What do we do here if we can't read in a full packet?
215             throw armnnProfiling::SocketConnectionException(
216                 std::string("SocketProfilingConnection: Invalid PIPE packet."),
217                 m_Socket[0].fd);
218         }
219     }
220
221     return Packet(metadataIdentifier, dataLength, packetData);
222 }
223
224 } // namespace profiling
225 } // namespace armnn