Fix CVE-2017-6891 in minitasn1 code
[platform/upstream/gnutls.git] / src / socket.c
1 /*
2  * Copyright (C) 2000-2012 Free Software Foundation, Inc.
3  *
4  * This file is part of GnuTLS.
5  *
6  * GnuTLS is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * GnuTLS is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  * GNU General Public License for more details.
15  *
16  * You should have received a copy of the GNU General Public License
17  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
18  *
19  * In addition, as a special exception, the copyright holders give
20  * permission to link the code of portions of this program with the
21  * OpenSSL library under certain conditions as described in each
22  * individual source file, and distribute linked combinations including
23  * the two.
24  * 
25  * You must obey the GNU General Public License in all respects for all
26  * of the code used other than OpenSSL. If you modify file(s) with this
27  * exception, you may extend this exception to your version of the
28  * file(s), but you are not obligated to do so. If you do not wish to do
29  * so, delete this exception statement from your version. If you delete
30  * this exception statement from all source files in the program, then
31  * also delete it here.
32  */
33
34 #include <config.h>
35
36 #if HAVE_SYS_SOCKET_H
37 #include <sys/socket.h>
38 #elif HAVE_WS2TCPIP_H
39 #include <ws2tcpip.h>
40 #endif
41 #include <netdb.h>
42 #include <string.h>
43 #include <errno.h>
44 #include <sys/select.h>
45 #include <sys/types.h>
46 #include <stdio.h>
47 #include <stdlib.h>
48 #include <unistd.h>
49 #ifndef _WIN32
50 #include <arpa/inet.h>
51 #include <signal.h>
52 #endif
53 #include <socket.h>
54 #include <c-ctype.h>
55 #include "sockets.h"
56
57 #define MAX_BUF 4096
58
59 /* Functions to manipulate sockets
60  */
61
62 ssize_t
63 socket_recv(const socket_st * socket, void *buffer, int buffer_size)
64 {
65         int ret;
66
67         if (socket->secure) {
68                 do {
69                         ret =
70                             gnutls_record_recv(socket->session, buffer,
71                                                buffer_size);
72                         if (ret == GNUTLS_E_HEARTBEAT_PING_RECEIVED)
73                                 gnutls_heartbeat_pong(socket->session, 0);
74                 }
75                 while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN
76                        || ret == GNUTLS_E_HEARTBEAT_PING_RECEIVED);
77
78         } else
79                 do {
80                         ret = recv(socket->fd, buffer, buffer_size, 0);
81                 }
82                 while (ret == -1 && errno == EINTR);
83
84         return ret;
85 }
86
87 ssize_t
88 socket_recv_timeout(const socket_st * socket, void *buffer, int buffer_size, unsigned ms)
89 {
90         int ret;
91
92         if (socket->secure)
93                 gnutls_record_set_timeout(socket->session, ms);
94         ret = socket_recv(socket, buffer, buffer_size);
95
96         if (socket->secure)
97                 gnutls_record_set_timeout(socket->session, 0);
98
99         return ret;
100 }
101
102 ssize_t
103 socket_send(const socket_st * socket, const void *buffer, int buffer_size)
104 {
105         return socket_send_range(socket, buffer, buffer_size, NULL);
106 }
107
108
109 ssize_t
110 socket_send_range(const socket_st * socket, const void *buffer,
111                   int buffer_size, gnutls_range_st * range)
112 {
113         int ret;
114
115         if (socket->secure)
116                 do {
117                         if (range == NULL)
118                                 ret =
119                                     gnutls_record_send(socket->session,
120                                                        buffer,
121                                                        buffer_size);
122                         else
123                                 ret =
124                                     gnutls_record_send_range(socket->
125                                                              session,
126                                                              buffer,
127                                                              buffer_size,
128                                                              range);
129                 }
130                 while (ret == GNUTLS_E_AGAIN
131                        || ret == GNUTLS_E_INTERRUPTED);
132         else
133                 do {
134                         ret = send(socket->fd, buffer, buffer_size, 0);
135                 }
136                 while (ret == -1 && errno == EINTR);
137
138         if (ret > 0 && ret != buffer_size && socket->verbose)
139                 fprintf(stderr,
140                         "*** Only sent %d bytes instead of %d.\n", ret,
141                         buffer_size);
142
143         return ret;
144 }
145
146 static
147 ssize_t send_line(int fd, const char *txt)
148 {
149         int len = strlen(txt);
150         int ret;
151
152         ret = send(fd, txt, len, 0);
153
154         if (ret == -1) {
155                 fprintf(stderr, "error sending %s\n", txt);
156                 exit(1);
157         }
158
159         return ret;
160 }
161
162 static
163 ssize_t wait_for_text(int fd, const char *txt, unsigned txt_size)
164 {
165         char buf[512];
166         char *p;
167         int ret;
168         fd_set read_fds;
169         struct timeval tv;
170
171         do {
172                 FD_ZERO(&read_fds);
173                 FD_SET(fd, &read_fds);
174                 tv.tv_sec = 10;
175                 tv.tv_usec = 0;
176                 ret = select(fd + 1, &read_fds, NULL, NULL, &tv);
177                 if (ret <= 0)
178                         ret = -1;
179                 else
180                         ret = recv(fd, buf, sizeof(buf)-1, 0);
181                 if (ret == -1) {
182                         fprintf(stderr, "error receiving %s\n", txt);
183                         exit(1);
184                 }
185                 buf[ret] = 0;
186
187                 p = memmem(buf, ret, txt, txt_size);
188                 if (p != NULL && p != buf) {
189                         p--;
190                         if (*p == '\n')
191                                 break;
192                 }
193         } while(ret < (int)txt_size || strncmp(buf, txt, txt_size) != 0);
194
195         return ret;
196 }
197
198 void
199 socket_starttls(socket_st * socket, const char *app_proto)
200 {
201         if (socket->secure)
202                 return;
203
204         if (app_proto == NULL || strcasecmp(app_proto, "https") == 0)
205                 return;
206
207         if (strcasecmp(app_proto, "smtp") == 0 || strcasecmp(app_proto, "submission") == 0) {
208                 if (socket->verbose)
209                         printf("Negotiating SMTP STARTTLS\n");
210
211                 wait_for_text(socket->fd, "220 ", 4);
212                 send_line(socket->fd, "EHLO mail.example.com\n");
213                 wait_for_text(socket->fd, "250 ", 4);
214                 send_line(socket->fd, "STARTTLS\n");
215                 wait_for_text(socket->fd, "220 ", 4);
216         } else if (strcasecmp(app_proto, "imap") == 0 || strcasecmp(app_proto, "imap2") == 0) {
217                 if (socket->verbose)
218                         printf("Negotiating IMAP STARTTLS\n");
219
220                 send_line(socket->fd, "a CAPABILITY\r\n");
221                 wait_for_text(socket->fd, "a OK", 4);
222                 send_line(socket->fd, "a STARTTLS\r\n");
223                 wait_for_text(socket->fd, "a OK", 4);
224         } else if (strcasecmp(app_proto, "ftp") == 0 || strcasecmp(app_proto, "ftps") == 0) {
225                 if (socket->verbose)
226                         printf("Negotiating FTP STARTTLS\n");
227
228                 send_line(socket->fd, "FEAT\n");
229                 wait_for_text(socket->fd, "211 End", 7);
230                 send_line(socket->fd, "AUTH TLS\n");
231                 wait_for_text(socket->fd, "234", 3);
232         } else {
233                 if (!c_isdigit(app_proto[0])) {
234                         static int warned = 0;
235                         if (warned == 0) {
236                                 fprintf(stderr, "unknown protocol %s\n", app_proto);
237                                 warned = 1;
238                         }
239                 }
240         }
241
242         return;
243 }
244
245 void socket_bye(socket_st * socket)
246 {
247         int ret;
248         if (socket->secure) {
249                 do
250                         ret = gnutls_bye(socket->session, GNUTLS_SHUT_WR);
251                 while (ret == GNUTLS_E_INTERRUPTED
252                        || ret == GNUTLS_E_AGAIN);
253                 if (ret < 0)
254                         fprintf(stderr, "*** gnutls_bye() error: %s\n",
255                                 gnutls_strerror(ret));
256                 gnutls_deinit(socket->session);
257                 socket->session = NULL;
258         }
259
260         freeaddrinfo(socket->addr_info);
261         socket->addr_info = socket->ptr = NULL;
262
263         free(socket->ip);
264         free(socket->hostname);
265         free(socket->service);
266
267         shutdown(socket->fd, SHUT_RDWR);        /* no more receptions */
268         close(socket->fd);
269
270         socket->fd = -1;
271         socket->secure = 0;
272 }
273
274 void
275 socket_open(socket_st * hd, const char *hostname, const char *service,
276             int udp, const char *msg)
277 {
278         struct addrinfo hints, *res, *ptr;
279         int sd, err;
280         char buffer[MAX_BUF + 1];
281         char portname[16] = { 0 };
282
283         if (msg != NULL)
284                 printf("Resolving '%s'...\n", hostname);
285
286         /* get server name */
287         memset(&hints, 0, sizeof(hints));
288
289 #ifdef AI_IDN
290         hints.ai_flags = AI_IDN|AI_IDN_ALLOW_UNASSIGNED;
291 #endif
292
293         hints.ai_socktype = udp ? SOCK_DGRAM : SOCK_STREAM;
294         if ((err = getaddrinfo(hostname, service, &hints, &res))) {
295                 fprintf(stderr, "Cannot resolve %s:%s: %s\n", hostname,
296                         service, gai_strerror(err));
297                 exit(1);
298         }
299
300         sd = -1;
301         for (ptr = res; ptr != NULL; ptr = ptr->ai_next) {
302                 sd = socket(ptr->ai_family, ptr->ai_socktype,
303                             ptr->ai_protocol);
304                 if (sd == -1)
305                         continue;
306
307                 if ((err =
308                      getnameinfo(ptr->ai_addr, ptr->ai_addrlen, buffer,
309                                  MAX_BUF, portname, sizeof(portname),
310                                  NI_NUMERICHOST | NI_NUMERICSERV)) != 0) {
311                         fprintf(stderr, "getnameinfo(): %s\n",
312                                 gai_strerror(err));
313                         continue;
314                 }
315
316                 if (hints.ai_socktype == SOCK_DGRAM) {
317 #if defined(IP_DONTFRAG)
318                         int yes = 1;
319                         if (setsockopt(sd, IPPROTO_IP, IP_DONTFRAG,
320                                        (const void *) &yes,
321                                        sizeof(yes)) < 0)
322                                 perror("setsockopt(IP_DF) failed");
323 #elif defined(IP_MTU_DISCOVER)
324                         int yes = IP_PMTUDISC_DO;
325                         if (setsockopt(sd, IPPROTO_IP, IP_MTU_DISCOVER,
326                                        (const void *) &yes,
327                                        sizeof(yes)) < 0)
328                                 perror("setsockopt(IP_DF) failed");
329 #endif
330                 }
331
332
333                 if (msg)
334                         printf("%s '%s:%s'...\n", msg, buffer, portname);
335
336                 err = connect(sd, ptr->ai_addr, ptr->ai_addrlen);
337                 if (err < 0) {
338                         int e = errno;
339                         fprintf(stderr, "Cannot connect to %s:%s: %s\n",
340                                 buffer, portname, strerror(e));
341                         continue;
342                 }
343                 break;
344         }
345
346         if (err != 0)
347                 exit(1);
348
349         if (sd == -1) {
350                 fprintf(stderr, "Could not find a supported socket\n");
351                 exit(1);
352         }
353
354         hd->secure = 0;
355         hd->fd = sd;
356         hd->hostname = strdup(hostname);
357         hd->ip = strdup(buffer);
358         hd->service = strdup(portname);
359         hd->ptr = ptr;
360         hd->addr_info = res;
361
362         return;
363 }
364
365 void sockets_init(void)
366 {
367 #ifdef _WIN32
368         WORD wVersionRequested;
369         WSADATA wsaData;
370
371         wVersionRequested = MAKEWORD(1, 1);
372         if (WSAStartup(wVersionRequested, &wsaData) != 0) {
373                 perror("WSA_STARTUP_ERROR");
374         }
375 #else
376         signal(SIGPIPE, SIG_IGN);
377 #endif
378
379 }
380
381 /* converts a textual service or port to
382  * a service.
383  */
384 const char *port_to_service(const char *sport, const char *proto)
385 {
386         unsigned int port;
387         struct servent *sr;
388
389         if (!c_isdigit(sport[0]))
390                 return sport;
391
392         port = atoi(sport);
393         if (port == 0)
394                 return sport;
395
396         port = htons(port);
397
398         sr = getservbyport(port, proto);
399         if (sr == NULL) {
400                 fprintf(stderr,
401                         "Warning: getservbyport(%s) failed. Using port number as service.\n", sport);
402                 return sport;
403         }
404
405         return sr->s_name;
406 }
407
408 int service_to_port(const char *service, const char *proto)
409 {
410         unsigned int port;
411         struct servent *sr;
412
413         port = atoi(service);
414         if (port != 0)
415                 return port;
416
417         sr = getservbyname(service, proto);
418         if (sr == NULL) {
419                 fprintf(stderr, "Warning: getservbyname() failed.\n");
420                 exit(1);
421         }
422
423         return ntohs(sr->s_port);
424 }