Add support for handling non-blocking GnuTLS channels
[framework/connectivity/connman.git] / gweb / giognutls.c
1 /*
2  *
3  *  Web service library with GLib integration
4  *
5  *  Copyright (C) 2009-2010  Intel Corporation. All rights reserved.
6  *
7  *  This program is free software; you can redistribute it and/or modify
8  *  it under the terms of the GNU General Public License version 2 as
9  *  published by the Free Software Foundation.
10  *
11  *  This program is distributed in the hope that it will be useful,
12  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
13  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  *  GNU General Public License for more details.
15  *
16  *  You should have received a copy of the GNU General Public License
17  *  along with this program; if not, write to the Free Software
18  *  Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
19  *
20  */
21
22 #ifdef HAVE_CONFIG_H
23 #include <config.h>
24 #endif
25
26 #include <stdio.h>
27 #include <errno.h>
28 #include <ctype.h>
29 #include <unistd.h>
30
31 #include <gnutls/gnutls.h>
32
33 #include "giognutls.h"
34
35 //#define DBG(fmt, arg...)  printf("%s: " fmt "\n" , __func__ , ## arg)
36 #define DBG(fmt, arg...)
37
38 typedef struct _GIOGnuTLSChannel GIOGnuTLSChannel;
39 typedef struct _GIOGnuTLSWatch GIOGnuTLSWatch;
40
41 struct _GIOGnuTLSChannel {
42         GIOChannel channel;
43         GIOChannel *transport;
44         gnutls_certificate_credentials_t cred;
45         gnutls_session session;
46         gboolean established;
47 };
48
49 struct _GIOGnuTLSWatch {
50         GSource source;
51         GPollFD pollfd;
52         GIOChannel *channel;
53         GIOCondition condition;
54 };
55
56 static volatile gint global_init_done = 0;
57
58 static inline void g_io_gnutls_global_init(void)
59 {
60         if (g_atomic_int_compare_and_exchange(&global_init_done, 0, 1) == TRUE)
61                 gnutls_global_init();
62 }
63
64 static GIOStatus check_handshake(GIOChannel *channel, GError **err)
65 {
66         GIOGnuTLSChannel *gnutls_channel = (GIOGnuTLSChannel *) channel;
67         int result;
68
69         DBG("channel %p", channel);
70
71 again:
72         if (gnutls_channel->established == TRUE)
73                 return G_IO_STATUS_NORMAL;
74
75         result = gnutls_handshake(gnutls_channel->session);
76
77         if (result == GNUTLS_E_INTERRUPTED || result == GNUTLS_E_AGAIN) {
78                 GIOFlags flags = g_io_channel_get_flags(channel);
79
80                 if (flags & G_IO_FLAG_NONBLOCK)
81                         return G_IO_STATUS_AGAIN;
82
83                 goto again;
84         }
85
86         if (result < 0) {
87                 g_set_error(err, G_IO_CHANNEL_ERROR,
88                                 G_IO_CHANNEL_ERROR_FAILED, "Handshake failed");
89                 return G_IO_STATUS_ERROR;
90         }
91
92         gnutls_channel->established = TRUE;
93
94         DBG("handshake done");
95
96         return G_IO_STATUS_NORMAL;
97 }
98
99 static GIOStatus g_io_gnutls_read(GIOChannel *channel, gchar *buf,
100                                 gsize count, gsize *bytes_read, GError **err)
101 {
102         GIOGnuTLSChannel *gnutls_channel = (GIOGnuTLSChannel *) channel;
103         GIOStatus status;
104         ssize_t result;
105
106         DBG("channel %p count %zu", channel, count);
107
108         *bytes_read = 0;
109
110 again:
111         status = check_handshake(channel, err);
112         if (status != G_IO_STATUS_NORMAL)
113                 return status;
114
115         result = gnutls_record_recv(gnutls_channel->session, buf, count);
116
117         DBG("result %zd", result);
118
119         if (result == GNUTLS_E_REHANDSHAKE) {
120                 gnutls_channel->established = FALSE;
121                 goto again;
122         }
123
124         if (result == GNUTLS_E_INTERRUPTED || result == GNUTLS_E_AGAIN) {
125                 GIOFlags flags = g_io_channel_get_flags(channel);
126
127                 if (flags & G_IO_FLAG_NONBLOCK)
128                         return G_IO_STATUS_AGAIN;
129
130                 goto again;
131         }
132
133         if (result == GNUTLS_E_UNEXPECTED_PACKET_LENGTH)
134                 return G_IO_STATUS_EOF;
135
136         if (result < 0) {
137                 g_set_error(err, G_IO_CHANNEL_ERROR,
138                                 G_IO_CHANNEL_ERROR_FAILED, "Stream corrupted");
139                 return G_IO_STATUS_ERROR;
140         }
141
142         *bytes_read = result;
143
144         return (result > 0) ? G_IO_STATUS_NORMAL : G_IO_STATUS_EOF;
145 }
146
147 static GIOStatus g_io_gnutls_write(GIOChannel *channel, const gchar *buf,
148                                 gsize count, gsize *bytes_written, GError **err)
149 {
150         GIOGnuTLSChannel *gnutls_channel = (GIOGnuTLSChannel *) channel;
151         GIOStatus status;
152         ssize_t result;
153
154         DBG("channel %p count %zu", channel, count);
155
156         *bytes_written = 0;
157
158 again:
159         status = check_handshake(channel, err);
160         if (status != G_IO_STATUS_NORMAL)
161                 return status;
162
163         result = gnutls_record_send(gnutls_channel->session, buf, count);
164
165         DBG("result %zd", result);
166
167         if (result == GNUTLS_E_REHANDSHAKE) {
168                 gnutls_channel->established = FALSE;
169                 goto again;
170         }
171
172         if (result == GNUTLS_E_INTERRUPTED || result == GNUTLS_E_AGAIN) {
173                 GIOFlags flags = g_io_channel_get_flags(channel);
174
175                 if (flags & G_IO_FLAG_NONBLOCK)
176                         return G_IO_STATUS_AGAIN;
177
178                 goto again;
179         }
180
181         if (result < 0) {
182                 g_set_error(err, G_IO_CHANNEL_ERROR,
183                                 G_IO_CHANNEL_ERROR_FAILED, "Stream corrupted");
184                 return G_IO_STATUS_ERROR;
185         }
186
187         *bytes_written = result;
188
189         return (result > 0) ? G_IO_STATUS_NORMAL : G_IO_STATUS_EOF;
190 }
191
192 static GIOStatus g_io_gnutls_seek(GIOChannel *channel, gint64 offset,
193                                                 GSeekType type, GError **err)
194 {
195         GIOGnuTLSChannel *gnutls_channel = (GIOGnuTLSChannel *) channel;
196         GIOChannel *transport = gnutls_channel->transport;
197
198         DBG("channel %p", channel);
199
200         return transport->funcs->io_seek(transport, offset, type, err);
201 }
202
203 static GIOStatus g_io_gnutls_close(GIOChannel *channel, GError **err)
204 {
205         GIOGnuTLSChannel *gnutls_channel = (GIOGnuTLSChannel *) channel;
206         GIOChannel *transport = gnutls_channel->transport;
207
208         DBG("channel %p", channel);
209
210         if (gnutls_channel->established == TRUE)
211                 gnutls_bye(gnutls_channel->session, GNUTLS_SHUT_RDWR);
212
213         return transport->funcs->io_close(transport, err);
214 }
215
216 static void g_io_gnutls_free(GIOChannel *channel)
217 {
218         GIOGnuTLSChannel *gnutls_channel = (GIOGnuTLSChannel *) channel;
219
220         DBG("channel %p", channel);
221
222         g_io_channel_unref(gnutls_channel->transport);
223
224         gnutls_deinit(gnutls_channel->session);
225
226         gnutls_certificate_free_credentials(gnutls_channel->cred);
227
228         g_free(gnutls_channel);
229 }
230
231 static GIOStatus g_io_gnutls_set_flags(GIOChannel *channel,
232                                                 GIOFlags flags, GError **err)
233 {
234         GIOGnuTLSChannel *gnutls_channel = (GIOGnuTLSChannel *) channel;
235         GIOChannel *transport = gnutls_channel->transport;
236
237         DBG("channel %p flags %u", channel, flags);
238
239         return transport->funcs->io_set_flags(transport, flags, err);
240 }
241
242 static GIOFlags g_io_gnutls_get_flags(GIOChannel *channel)
243 {
244         GIOGnuTLSChannel *gnutls_channel = (GIOGnuTLSChannel *) channel;
245         GIOChannel *transport = gnutls_channel->transport;
246
247         DBG("channel %p", channel);
248
249         return transport->funcs->io_get_flags(transport);
250 }
251
252 static gboolean g_io_gnutls_prepare(GSource *source, gint *timeout)
253 {
254         DBG("source %p", source);
255
256         *timeout = -1;
257
258         return FALSE;
259 }
260
261 static gboolean g_io_gnutls_check(GSource *source)
262 {
263         GIOGnuTLSWatch *watch = (GIOGnuTLSWatch *) source;
264         GIOCondition condition = watch->pollfd.revents;
265
266         DBG("source %p condition %u", source, condition);
267
268         if (condition & watch->condition)
269                 return TRUE;
270
271         return FALSE;
272 }
273
274 static gboolean g_io_gnutls_dispatch(GSource *source, GSourceFunc callback,
275                                                         gpointer user_data)
276 {
277         GIOGnuTLSWatch *watch = (GIOGnuTLSWatch *) source;
278         GIOFunc func = (GIOFunc) callback;
279         GIOCondition condition = watch->pollfd.revents;
280
281         DBG("source %p condition %u", source, condition);
282
283         if (func == NULL)
284                 return FALSE;
285
286         return func(watch->channel, condition & watch->condition, user_data);
287 }
288
289 static void g_io_gnutls_finalize(GSource *source)
290 {
291         GIOGnuTLSWatch *watch = (GIOGnuTLSWatch *) source;
292
293         DBG("source %p", source);
294
295         g_io_channel_unref(watch->channel);
296 }
297
298 static GSourceFuncs gnutls_watch_funcs = {
299         g_io_gnutls_prepare,
300         g_io_gnutls_check,
301         g_io_gnutls_dispatch,
302         g_io_gnutls_finalize,
303 };
304
305 static GSource *g_io_gnutls_create_watch(GIOChannel *channel,
306                                                 GIOCondition condition)
307 {
308         GIOGnuTLSChannel *gnutls_channel = (GIOGnuTLSChannel *) channel;
309         GIOGnuTLSWatch *watch;
310         GSource *source;
311
312         DBG("channel %p condition %u", channel, condition);
313
314         source = g_source_new(&gnutls_watch_funcs, sizeof(GIOGnuTLSWatch));
315
316         watch = (GIOGnuTLSWatch *) source;
317
318         watch->channel = channel;
319         g_io_channel_ref(channel);
320
321         watch->condition = condition;
322
323         watch->pollfd.fd = g_io_channel_unix_get_fd(gnutls_channel->transport);
324         watch->pollfd.events = condition;
325
326         g_source_add_poll(source, &watch->pollfd);
327
328         return source;
329 }
330
331 static GIOFuncs gnutls_channel_funcs = {
332         g_io_gnutls_read,
333         g_io_gnutls_write,
334         g_io_gnutls_seek,
335         g_io_gnutls_close,
336         g_io_gnutls_create_watch,
337         g_io_gnutls_free,
338         g_io_gnutls_set_flags,
339         g_io_gnutls_get_flags,
340 };
341
342 static ssize_t g_io_gnutls_push_func(gnutls_transport_ptr_t transport_data,
343                                                 const void *buf, size_t count)
344 {
345         GIOGnuTLSChannel *gnutls_channel = transport_data;
346         ssize_t result;
347         int fd;
348
349         DBG("transport %p count %zu", gnutls_channel->transport, count);
350
351         fd = g_io_channel_unix_get_fd(gnutls_channel->transport);
352
353         result = write(fd, buf, count);
354
355         DBG("result %zd", result);
356
357         return result;
358 }
359
360 static ssize_t g_io_gnutls_pull_func(gnutls_transport_ptr_t transport_data,
361                                                 void *buf, size_t count)
362 {
363         GIOGnuTLSChannel *gnutls_channel = transport_data;
364         ssize_t result;
365         int fd;
366
367         DBG("transport %p count %zu", gnutls_channel->transport, count);
368
369         fd = g_io_channel_unix_get_fd(gnutls_channel->transport);
370
371         result = read(fd, buf, count);
372
373         DBG("result %zd", result);
374
375         return result;
376 }
377
378 GIOChannel *g_io_channel_gnutls_new(int fd)
379 {
380         GIOGnuTLSChannel *gnutls_channel;
381         GIOChannel *channel;
382         int err;
383
384         DBG("");
385
386         gnutls_channel = g_new(GIOGnuTLSChannel, 1);
387
388         channel = (GIOChannel *) gnutls_channel;
389
390         g_io_channel_init(channel);
391         channel->funcs = &gnutls_channel_funcs;
392
393         gnutls_channel->transport = g_io_channel_unix_new(fd);
394
395         g_io_channel_set_encoding(gnutls_channel->transport, NULL, NULL);
396         g_io_channel_set_buffered(gnutls_channel->transport, FALSE);
397
398         channel->is_seekable = FALSE;
399         channel->is_readable = TRUE;
400         channel->is_writeable = TRUE;
401
402         channel->do_encode = FALSE;
403
404         g_io_gnutls_global_init();
405
406         err = gnutls_init(&gnutls_channel->session, GNUTLS_CLIENT);
407         if (err < 0) {
408                 g_free(gnutls_channel);
409                 return NULL;
410         }
411
412         gnutls_transport_set_ptr(gnutls_channel->session, gnutls_channel);
413         gnutls_transport_set_push_function(gnutls_channel->session,
414                                                 g_io_gnutls_push_func);
415         gnutls_transport_set_pull_function(gnutls_channel->session,
416                                                 g_io_gnutls_pull_func);
417         gnutls_transport_set_lowat(gnutls_channel->session, 0);
418
419         gnutls_priority_set_direct(gnutls_channel->session,
420                                 "NORMAL:!VERS-TLS1.1:!VERS-TLS1.0", NULL);
421
422         gnutls_certificate_allocate_credentials(&gnutls_channel->cred);
423         gnutls_credentials_set(gnutls_channel->session,
424                                 GNUTLS_CRD_CERTIFICATE, gnutls_channel->cred);
425
426         DBG("channel %p transport %p", channel, gnutls_channel->transport);
427
428         return channel;
429 }