Imported Upstream version 3.0.21
[platform/upstream/gnutls.git] / src / udp-serv.c
1 /*
2  * Copyright (C) 2011-2012 Free Software Foundation, Inc.
3  *
4  * This file is part of GnuTLS.
5  *
6  * GnuTLS is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * GnuTLS is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  * GNU General Public License for more details.
15  *
16  * You should have received a copy of the GNU General Public License
17  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
18  */
19
20 #include <config.h>
21
22 #include <stdio.h>
23 #if HAVE_SYS_SOCKET_H
24 # include <sys/socket.h>
25 #elif HAVE_WS2TCPIP_H
26 # include <ws2tcpip.h>
27 #endif
28 #include <arpa/inet.h>
29 #ifndef _WIN32
30 # include <netinet/in.h>
31 #endif
32 #include <sys/select.h>
33 #include <stdlib.h>
34 #include <string.h>
35 #include <unistd.h>
36 #include <errno.h>
37 #include <common.h>
38 #include "udp-serv.h"
39 #include "list.h"
40
41 typedef struct {
42   gnutls_session_t session;
43   int fd;
44   struct sockaddr * cli_addr;
45   socklen_t cli_addr_size;
46 } priv_data_st;
47
48 static int pull_timeout_func(gnutls_transport_ptr_t ptr, unsigned int ms);
49 static ssize_t push_func (gnutls_transport_ptr_t p, const void * data, size_t size);
50 static ssize_t pull_func(gnutls_transport_ptr_t p, void * data, size_t size);
51
52 #define MAX_BUFFER 255     /* Longest string to echo */
53
54 void udp_server(const char* name, int port, int mtu)
55 {
56     int sock, ret;
57     struct sockaddr_in cli_addr;
58     socklen_t cli_addr_size;
59     char buffer[MAX_BUFFER];
60     priv_data_st priv;
61     gnutls_session_t session;
62     gnutls_datum_t cookie_key;
63     gnutls_dtls_prestate_st prestate;
64     unsigned char sequence[8];
65
66     ret = gnutls_key_generate(&cookie_key, GNUTLS_COOKIE_KEY_SIZE);
67     if (ret < 0)
68       {
69         fprintf(stderr, "Cannot generate key\n");
70         exit(1);
71       }
72
73     ret = listen_socket (name, port, SOCK_DGRAM);
74     if (ret < 0)
75       {
76         fprintf(stderr, "Cannot listen\n");
77         exit (1);
78       }
79
80     for (;;)
81       {
82         printf("Waiting for connection...\n");
83         sock = wait_for_connection();
84         if (sock < 0)
85           continue;
86
87         cli_addr_size = sizeof(cli_addr);
88         ret = recvfrom(sock, buffer, sizeof(buffer), MSG_PEEK, (struct sockaddr*)&cli_addr, &cli_addr_size);
89         if (ret > 0)
90           {
91             memset(&prestate, 0, sizeof(prestate));
92             ret = gnutls_dtls_cookie_verify(&cookie_key, &cli_addr, sizeof(cli_addr), buffer, ret, &prestate);
93             if (ret < 0) /* cookie not valid */
94               {
95                 priv_data_st s;
96                 
97                 memset(&s,0,sizeof(s));
98                 s.fd = sock;
99                 s.cli_addr = (void*)&cli_addr;
100                 s.cli_addr_size = sizeof(cli_addr);
101                 
102                 printf("Sending hello verify request to %s\n", human_addr ((struct sockaddr *)
103                   &cli_addr, sizeof(cli_addr), buffer, sizeof(buffer)));
104                 gnutls_dtls_cookie_send(&cookie_key, &cli_addr, sizeof(cli_addr), &prestate, (gnutls_transport_ptr_t)&s, push_func);
105
106                 /* discard peeked data*/
107                 recvfrom(sock, buffer, sizeof(buffer), 0, (struct sockaddr*)&cli_addr, &cli_addr_size);
108                 continue;
109               }
110             printf ("Accepted connection from %s\n",
111                             human_addr ((struct sockaddr *)
112                                         &cli_addr, sizeof(cli_addr), buffer,
113                                         sizeof (buffer)));
114           }
115         else
116           continue;
117
118         session = initialize_session(1);
119         gnutls_dtls_prestate_set(session, &prestate);
120         if (mtu) gnutls_dtls_set_mtu(session, mtu);
121
122         priv.session = session;
123         priv.fd = sock;
124         priv.cli_addr = (struct sockaddr *)&cli_addr;
125         priv.cli_addr_size = sizeof(cli_addr);
126
127         gnutls_transport_set_ptr (session, &priv);
128         gnutls_transport_set_push_function (session, push_func);
129         gnutls_transport_set_pull_function (session, pull_func);
130         gnutls_transport_set_pull_timeout_function (session, pull_timeout_func);
131
132         do
133           {
134             ret = gnutls_handshake(session);
135           }
136         while(ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED);
137
138         if (ret < 0)
139           {
140             fprintf(stderr, "Error in handshake(): %s\n", gnutls_strerror(ret));
141             gnutls_deinit(session);
142             continue;
143           }
144
145         for(;;)
146           {
147             do {
148               ret = gnutls_record_recv_seq(session, buffer, MAX_BUFFER, sequence);
149             } while(ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED);
150
151             if (ret == GNUTLS_E_REHANDSHAKE)
152               {
153                 fprintf (stderr, "*** Received hello message\n");
154                 do
155                   {
156                     ret = gnutls_handshake (session);
157                   }
158                 while (ret == GNUTLS_E_INTERRUPTED ||
159                        ret == GNUTLS_E_AGAIN);
160                 
161                 if (ret == 0) continue;
162               }
163             if (ret < 0)
164               {
165                 fprintf(stderr, "Error in recv(): %s\n", gnutls_strerror(ret));
166                 break;
167               }
168             if (ret == 0)
169               {
170                 printf("EOF\n\n");
171                 break;
172               }
173               
174             buffer[ret] = 0;
175             printf("received[%.2x%.2x%.2x%.2x%.2x%.2x%.2x%.2x]: %s\n", sequence[0], sequence[1], sequence[2],
176                    sequence[3], sequence[4], sequence[5], sequence[6], sequence[7], buffer);
177
178             if (check_command(session, buffer) == 0)
179               {
180                 /* reply back */
181                 ret = gnutls_record_send(session, buffer, ret);
182                 if (ret < 0)
183                   {
184                     fprintf(stderr, "Error in send(): %s\n", gnutls_strerror(ret));
185                     break;
186                   }
187               }
188           }
189       }
190     gnutls_deinit(session);
191 }
192
193 /* Wait for data to be received within a timeout period in milliseconds
194  */
195 static int pull_timeout_func(gnutls_transport_ptr_t ptr, unsigned int ms)
196 {
197 fd_set rfds;
198 struct timeval tv;
199 priv_data_st *priv = ptr;
200 struct sockaddr_in cli_addr;
201 socklen_t cli_addr_size;
202 int ret;
203 char c;
204
205   FD_ZERO(&rfds);
206   FD_SET(priv->fd, &rfds);
207   
208   tv.tv_sec = 0;
209   tv.tv_usec = ms * 1000;
210
211   while(tv.tv_usec >= 1000000)
212     {
213       tv.tv_usec -= 1000000;
214       tv.tv_sec++;
215     }
216
217   ret = select(priv->fd+1, &rfds, NULL, NULL, &tv);
218
219   if (ret <= 0)
220     return ret;
221
222   /* only report ok if the next message is from the peer we expect
223    * from 
224    */
225   cli_addr_size = sizeof(cli_addr);
226   ret = recvfrom(priv->fd, &c, 1, MSG_PEEK, (struct sockaddr*)&cli_addr, &cli_addr_size);
227   if (ret > 0)
228     {
229       if (cli_addr_size == priv->cli_addr_size && memcmp(&cli_addr, priv->cli_addr, sizeof(cli_addr))==0)
230         return 1;
231     }
232
233   return 0;
234 }
235
236 static ssize_t push_func (gnutls_transport_ptr_t p, const void * data, size_t size)
237 {
238 priv_data_st *priv = p;
239
240   return sendto(priv->fd, data, size, 0, priv->cli_addr, priv->cli_addr_size);
241 }
242
243 static ssize_t pull_func(gnutls_transport_ptr_t p, void * data, size_t size)
244 {
245 priv_data_st *priv = p;
246 struct sockaddr_in cli_addr;
247 socklen_t cli_addr_size;
248 char buffer[64];
249 int ret;
250
251   cli_addr_size = sizeof(cli_addr);
252   ret = recvfrom(priv->fd, data, size, 0, (struct sockaddr*)&cli_addr, &cli_addr_size);
253   if (ret == -1)
254     return ret;
255
256   if (cli_addr_size == priv->cli_addr_size && memcmp(&cli_addr, priv->cli_addr, sizeof(cli_addr))==0)
257     return ret;
258
259   printf ("Denied connection from %s\n",
260                 human_addr ((struct sockaddr *)
261                             &cli_addr, sizeof(cli_addr), buffer,
262                             sizeof (buffer)));
263   
264   gnutls_transport_set_errno(priv->session, EAGAIN);
265   return -1;
266 }