Use file descriptor directly instead of GIOChannel
[platform/upstream/connman.git] / gweb / giognutls.c
1 /*
2  *
3  *  Web service library with GLib integration
4  *
5  *  Copyright (C) 2009-2010  Intel Corporation. All rights reserved.
6  *
7  *  This program is free software; you can redistribute it and/or modify
8  *  it under the terms of the GNU General Public License version 2 as
9  *  published by the Free Software Foundation.
10  *
11  *  This program is distributed in the hope that it will be useful,
12  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
13  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  *  GNU General Public License for more details.
15  *
16  *  You should have received a copy of the GNU General Public License
17  *  along with this program; if not, write to the Free Software
18  *  Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
19  *
20  */
21
22 #ifdef HAVE_CONFIG_H
23 #include <config.h>
24 #endif
25
26 #include <stdio.h>
27 #include <errno.h>
28 #include <fcntl.h>
29 #include <unistd.h>
30
31 #include <gnutls/gnutls.h>
32
33 #include "giognutls.h"
34
35 //#define DBG(fmt, arg...)  printf("%s: " fmt "\n" , __func__ , ## arg)
36 #define DBG(fmt, arg...)
37
38 typedef struct _GIOGnuTLSChannel GIOGnuTLSChannel;
39 typedef struct _GIOGnuTLSWatch GIOGnuTLSWatch;
40
41 struct _GIOGnuTLSChannel {
42         GIOChannel channel;
43         gint fd;
44         gnutls_certificate_credentials_t cred;
45         gnutls_session session;
46         gboolean established;
47         gboolean again;
48 };
49
50 struct _GIOGnuTLSWatch {
51         GSource source;
52         GPollFD pollfd;
53         GIOChannel *channel;
54         GIOCondition condition;
55 };
56
57 static volatile gint global_init_done = 0;
58
59 static inline void g_io_gnutls_global_init(void)
60 {
61         if (g_atomic_int_compare_and_exchange(&global_init_done, 0, 1) == TRUE)
62                 gnutls_global_init();
63 }
64
65 static GIOStatus check_handshake(GIOChannel *channel, GError **err)
66 {
67         GIOGnuTLSChannel *gnutls_channel = (GIOGnuTLSChannel *) channel;
68         int result;
69
70         DBG("channel %p", channel);
71
72         if (gnutls_channel->established == TRUE)
73                 return G_IO_STATUS_NORMAL;
74
75 again:
76         result = gnutls_handshake(gnutls_channel->session);
77
78         if (result == GNUTLS_E_INTERRUPTED || result == GNUTLS_E_AGAIN) {
79                 GIOFlags flags = g_io_channel_get_flags(channel);
80
81                 if (gnutls_channel->again == TRUE)
82                         return G_IO_STATUS_AGAIN;
83
84                 if (flags & G_IO_FLAG_NONBLOCK)
85                         return G_IO_STATUS_AGAIN;
86
87                 goto again;
88         }
89
90         if (result < 0) {
91                 g_set_error(err, G_IO_CHANNEL_ERROR,
92                                 G_IO_CHANNEL_ERROR_FAILED, "Handshake failed");
93                 return G_IO_STATUS_ERROR;
94         }
95
96         gnutls_channel->established = TRUE;
97
98         DBG("handshake done");
99
100         return G_IO_STATUS_NORMAL;
101 }
102
103 static GIOStatus g_io_gnutls_read(GIOChannel *channel, gchar *buf,
104                                 gsize count, gsize *bytes_read, GError **err)
105 {
106         GIOGnuTLSChannel *gnutls_channel = (GIOGnuTLSChannel *) channel;
107         GIOStatus status;
108         ssize_t result;
109
110         DBG("channel %p count %zu", channel, count);
111
112         *bytes_read = 0;
113
114 again:
115         status = check_handshake(channel, err);
116         if (status != G_IO_STATUS_NORMAL)
117                 return status;
118
119         result = gnutls_record_recv(gnutls_channel->session, buf, count);
120
121         DBG("result %zd", result);
122
123         if (result == GNUTLS_E_REHANDSHAKE) {
124                 gnutls_channel->established = FALSE;
125                 goto again;
126         }
127
128         if (result == GNUTLS_E_INTERRUPTED || result == GNUTLS_E_AGAIN) {
129                 GIOFlags flags = g_io_channel_get_flags(channel);
130
131                 if (gnutls_channel->again == TRUE)
132                         return G_IO_STATUS_AGAIN;
133
134                 if (flags & G_IO_FLAG_NONBLOCK)
135                         return G_IO_STATUS_AGAIN;
136
137                 goto again;
138         }
139
140         if (result == GNUTLS_E_UNEXPECTED_PACKET_LENGTH)
141                 return G_IO_STATUS_EOF;
142
143         if (result < 0) {
144                 g_set_error(err, G_IO_CHANNEL_ERROR,
145                                 G_IO_CHANNEL_ERROR_FAILED, "Stream corrupted");
146                 return G_IO_STATUS_ERROR;
147         }
148
149         *bytes_read = result;
150
151         return (result > 0) ? G_IO_STATUS_NORMAL : G_IO_STATUS_EOF;
152 }
153
154 static GIOStatus g_io_gnutls_write(GIOChannel *channel, const gchar *buf,
155                                 gsize count, gsize *bytes_written, GError **err)
156 {
157         GIOGnuTLSChannel *gnutls_channel = (GIOGnuTLSChannel *) channel;
158         GIOStatus status;
159         ssize_t result;
160
161         DBG("channel %p count %zu", channel, count);
162
163         *bytes_written = 0;
164
165 again:
166         status = check_handshake(channel, err);
167         if (status != G_IO_STATUS_NORMAL)
168                 return status;
169
170         result = gnutls_record_send(gnutls_channel->session, buf, count);
171
172         DBG("result %zd", result);
173
174         if (result == GNUTLS_E_REHANDSHAKE) {
175                 gnutls_channel->established = FALSE;
176                 goto again;
177         }
178
179         if (result == GNUTLS_E_INTERRUPTED || result == GNUTLS_E_AGAIN) {
180                 GIOFlags flags = g_io_channel_get_flags(channel);
181
182                 if (gnutls_channel->again == TRUE)
183                         return G_IO_STATUS_AGAIN;
184
185                 if (flags & G_IO_FLAG_NONBLOCK)
186                         return G_IO_STATUS_AGAIN;
187
188                 goto again;
189         }
190
191         if (result < 0) {
192                 g_set_error(err, G_IO_CHANNEL_ERROR,
193                                 G_IO_CHANNEL_ERROR_FAILED, "Stream corrupted");
194                 return G_IO_STATUS_ERROR;
195         }
196
197         *bytes_written = result;
198
199         return (result > 0) ? G_IO_STATUS_NORMAL : G_IO_STATUS_EOF;
200 }
201
202 static GIOStatus g_io_gnutls_seek(GIOChannel *channel, gint64 offset,
203                                                 GSeekType type, GError **err)
204 {
205         DBG("channel %p", channel);
206
207         g_set_error_literal(err, G_IO_CHANNEL_ERROR,
208                                 G_IO_CHANNEL_ERROR_FAILED, "Not supported");
209         return G_IO_STATUS_ERROR;
210 }
211
212 static GIOStatus g_io_gnutls_close(GIOChannel *channel, GError **err)
213 {
214         GIOGnuTLSChannel *gnutls_channel = (GIOGnuTLSChannel *) channel;
215
216         DBG("channel %p", channel);
217
218         if (gnutls_channel->established == TRUE)
219                 gnutls_bye(gnutls_channel->session, GNUTLS_SHUT_RDWR);
220
221         if (close(gnutls_channel->fd) < 0) {
222                 g_set_error_literal(err, G_IO_CHANNEL_ERROR,
223                                 G_IO_CHANNEL_ERROR_FAILED, "Closing failed");
224                 return G_IO_STATUS_ERROR;
225         }
226
227         return G_IO_STATUS_NORMAL;
228 }
229
230 static void g_io_gnutls_free(GIOChannel *channel)
231 {
232         GIOGnuTLSChannel *gnutls_channel = (GIOGnuTLSChannel *) channel;
233
234         DBG("channel %p", channel);
235
236         gnutls_deinit(gnutls_channel->session);
237
238         gnutls_certificate_free_credentials(gnutls_channel->cred);
239
240         g_free(gnutls_channel);
241 }
242
243 static GIOStatus g_io_gnutls_set_flags(GIOChannel *channel,
244                                                 GIOFlags flags, GError **err)
245 {
246         GIOGnuTLSChannel *gnutls_channel = (GIOGnuTLSChannel *) channel;
247         glong fcntl_flags = 0;
248
249         DBG("channel %p flags %u", channel, flags);
250
251         if (flags & G_IO_FLAG_NONBLOCK)
252                 fcntl_flags |= O_NONBLOCK;
253
254         if (fcntl(gnutls_channel->fd, F_SETFL, fcntl_flags) < 0) {
255                 g_set_error_literal(err, G_IO_CHANNEL_ERROR,
256                         G_IO_CHANNEL_ERROR_FAILED, "Setting flags failed");
257                 return G_IO_STATUS_ERROR;
258         }
259
260         return G_IO_STATUS_NORMAL;
261 }
262
263 static GIOFlags g_io_gnutls_get_flags(GIOChannel *channel)
264 {
265         GIOGnuTLSChannel *gnutls_channel = (GIOGnuTLSChannel *) channel;
266         GIOFlags flags = 0;
267         glong fcntl_flags;
268
269         DBG("channel %p", channel);
270
271         fcntl_flags = fcntl(gnutls_channel->fd, F_GETFL);
272         if (fcntl_flags < 0)
273                 return 0;
274
275         if (fcntl_flags & O_NONBLOCK)
276                 flags |= G_IO_FLAG_NONBLOCK;
277
278         return flags;
279 }
280
281 static gboolean g_io_gnutls_prepare(GSource *source, gint *timeout)
282 {
283         DBG("source %p", source);
284
285         *timeout = -1;
286
287         return FALSE;
288 }
289
290 static gboolean g_io_gnutls_check(GSource *source)
291 {
292         GIOGnuTLSWatch *watch = (GIOGnuTLSWatch *) source;
293         GIOCondition condition = watch->pollfd.revents;
294
295         DBG("source %p condition %u", source, condition);
296
297         if (condition & watch->condition)
298                 return TRUE;
299
300         return FALSE;
301 }
302
303 static gboolean g_io_gnutls_dispatch(GSource *source, GSourceFunc callback,
304                                                         gpointer user_data)
305 {
306         GIOGnuTLSWatch *watch = (GIOGnuTLSWatch *) source;
307         GIOFunc func = (GIOFunc) callback;
308         GIOCondition condition = watch->pollfd.revents;
309
310         DBG("source %p condition %u", source, condition);
311
312         if (func == NULL)
313                 return FALSE;
314
315         return func(watch->channel, condition & watch->condition, user_data);
316 }
317
318 static void g_io_gnutls_finalize(GSource *source)
319 {
320         GIOGnuTLSWatch *watch = (GIOGnuTLSWatch *) source;
321
322         DBG("source %p", source);
323
324         g_io_channel_unref(watch->channel);
325 }
326
327 static GSourceFuncs gnutls_watch_funcs = {
328         g_io_gnutls_prepare,
329         g_io_gnutls_check,
330         g_io_gnutls_dispatch,
331         g_io_gnutls_finalize,
332 };
333
334 static GSource *g_io_gnutls_create_watch(GIOChannel *channel,
335                                                 GIOCondition condition)
336 {
337         GIOGnuTLSChannel *gnutls_channel = (GIOGnuTLSChannel *) channel;
338         GIOGnuTLSWatch *watch;
339         GSource *source;
340
341         DBG("channel %p condition %u", channel, condition);
342
343         source = g_source_new(&gnutls_watch_funcs, sizeof(GIOGnuTLSWatch));
344
345         watch = (GIOGnuTLSWatch *) source;
346
347         watch->channel = channel;
348         g_io_channel_ref(channel);
349
350         watch->condition = condition;
351
352         watch->pollfd.fd = gnutls_channel->fd;
353         watch->pollfd.events = condition;
354
355         g_source_add_poll(source, &watch->pollfd);
356
357         return source;
358 }
359
360 static GIOFuncs gnutls_channel_funcs = {
361         g_io_gnutls_read,
362         g_io_gnutls_write,
363         g_io_gnutls_seek,
364         g_io_gnutls_close,
365         g_io_gnutls_create_watch,
366         g_io_gnutls_free,
367         g_io_gnutls_set_flags,
368         g_io_gnutls_get_flags,
369 };
370
371 static ssize_t g_io_gnutls_push_func(gnutls_transport_ptr_t transport_data,
372                                                 const void *buf, size_t count)
373 {
374         GIOGnuTLSChannel *gnutls_channel = transport_data;
375         ssize_t result;
376
377         DBG("count %zu", count);
378
379         result = write(gnutls_channel->fd, buf, count);
380
381         if (result < 0 && errno == EAGAIN)
382                 gnutls_channel->again = TRUE;
383         else
384                 gnutls_channel->again = FALSE;
385
386         DBG("result %zd", result);
387
388         return result;
389 }
390
391 static ssize_t g_io_gnutls_pull_func(gnutls_transport_ptr_t transport_data,
392                                                 void *buf, size_t count)
393 {
394         GIOGnuTLSChannel *gnutls_channel = transport_data;
395         ssize_t result;
396
397         DBG("count %zu", count);
398
399         result = read(gnutls_channel->fd, buf, count);
400
401         if (result < 0 && errno == EAGAIN)
402                 gnutls_channel->again = TRUE;
403         else
404                 gnutls_channel->again = FALSE;
405
406         DBG("result %zd", result);
407
408         return result;
409 }
410
411 GIOChannel *g_io_channel_gnutls_new(int fd)
412 {
413         GIOGnuTLSChannel *gnutls_channel;
414         GIOChannel *channel;
415         int err;
416
417         DBG("");
418
419         gnutls_channel = g_new(GIOGnuTLSChannel, 1);
420
421         channel = (GIOChannel *) gnutls_channel;
422
423         g_io_channel_init(channel);
424         channel->funcs = &gnutls_channel_funcs;
425
426         gnutls_channel->fd = fd;
427
428         channel->is_seekable = FALSE;
429         channel->is_readable = TRUE;
430         channel->is_writeable = TRUE;
431
432         channel->do_encode = FALSE;
433
434         g_io_gnutls_global_init();
435
436         err = gnutls_init(&gnutls_channel->session, GNUTLS_CLIENT);
437         if (err < 0) {
438                 g_free(gnutls_channel);
439                 return NULL;
440         }
441
442         gnutls_transport_set_ptr(gnutls_channel->session, gnutls_channel);
443         gnutls_transport_set_push_function(gnutls_channel->session,
444                                                 g_io_gnutls_push_func);
445         gnutls_transport_set_pull_function(gnutls_channel->session,
446                                                 g_io_gnutls_pull_func);
447         gnutls_transport_set_lowat(gnutls_channel->session, 0);
448
449         gnutls_priority_set_direct(gnutls_channel->session,
450                                 "NORMAL:!VERS-TLS1.1:!VERS-TLS1.0", NULL);
451
452         gnutls_certificate_allocate_credentials(&gnutls_channel->cred);
453         gnutls_credentials_set(gnutls_channel->session,
454                                 GNUTLS_CRD_CERTIFICATE, gnutls_channel->cred);
455
456         DBG("channel %p", channel);
457
458         return channel;
459 }