Add support for g_io_channel_gnutls_new() function
[framework/connectivity/connman.git] / gweb / giognutls.c
1 /*
2  *
3  *  Web service library with GLib integration
4  *
5  *  Copyright (C) 2009-2010  Intel Corporation. All rights reserved.
6  *
7  *  This program is free software; you can redistribute it and/or modify
8  *  it under the terms of the GNU General Public License version 2 as
9  *  published by the Free Software Foundation.
10  *
11  *  This program is distributed in the hope that it will be useful,
12  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
13  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  *  GNU General Public License for more details.
15  *
16  *  You should have received a copy of the GNU General Public License
17  *  along with this program; if not, write to the Free Software
18  *  Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
19  *
20  */
21
22 #ifdef HAVE_CONFIG_H
23 #include <config.h>
24 #endif
25
26 #include <stdio.h>
27 #include <errno.h>
28 #include <ctype.h>
29 #include <unistd.h>
30
31 #include <gnutls/gnutls.h>
32
33 #include "giognutls.h"
34
35 //#define DBG(fmt, arg...)  printf("%s: " fmt "\n" , __func__ , ## arg)
36 #define DBG(fmt, arg...)
37
38 typedef struct _GIOGnuTLSChannel GIOGnuTLSChannel;
39 typedef struct _GIOGnuTLSWatch GIOGnuTLSWatch;
40
41 struct _GIOGnuTLSChannel {
42         GIOChannel channel;
43         GIOChannel *transport;
44         gnutls_certificate_credentials_t cred;
45         gnutls_session session;
46         gboolean established;
47 };
48
49 struct _GIOGnuTLSWatch {
50         GSource source;
51         GPollFD pollfd;
52         GIOChannel *channel;
53         GIOCondition condition;
54 };
55
56 static volatile gint global_init_done = 0;
57
58 static inline void g_io_gnutls_global_init(void)
59 {
60         if (g_atomic_int_compare_and_exchange(&global_init_done, 0, 1) == TRUE)
61                 gnutls_global_init();
62 }
63
64 static GIOStatus check_handshake(GIOChannel *channel, GError **error)
65 {
66         GIOGnuTLSChannel *gnutls_channel = (GIOGnuTLSChannel *) channel;
67         int err;
68
69         DBG("channel %p", channel);
70
71         if (gnutls_channel->established == TRUE)
72                 return G_IO_STATUS_NORMAL;
73
74         err = gnutls_handshake(gnutls_channel->session);
75         if (err < 0)
76                 return G_IO_STATUS_AGAIN;
77
78         gnutls_channel->established = TRUE;
79
80         DBG("handshake done");
81
82         return G_IO_STATUS_NORMAL;
83 }
84
85 static GIOStatus g_io_gnutls_read(GIOChannel *channel, gchar *buf,
86                                 gsize count, gsize *bytes_read, GError **err)
87 {
88         GIOGnuTLSChannel *gnutls_channel = (GIOGnuTLSChannel *) channel;
89         GIOStatus status;
90         ssize_t result;
91
92         DBG("channel %p count %zu", channel, count);
93
94         *bytes_read = 0;
95
96 again:
97         status = check_handshake(channel, err);
98         if (status != G_IO_STATUS_NORMAL)
99                 return status;
100
101         result = gnutls_record_recv(gnutls_channel->session, buf, count);
102
103         DBG("result %zd", result);
104
105         if (result == GNUTLS_E_REHANDSHAKE) {
106                 gnutls_channel->established = FALSE;
107                 goto again;
108         }
109
110         if (result == GNUTLS_E_INTERRUPTED || result == GNUTLS_E_AGAIN)
111                 goto again;
112
113         if (result == GNUTLS_E_UNEXPECTED_PACKET_LENGTH)
114                 return G_IO_STATUS_EOF;
115
116         if (result < 0) {
117                 g_set_error(err, G_IO_CHANNEL_ERROR,
118                                 G_IO_CHANNEL_ERROR_FAILED, "Stream corrupted");
119                 return G_IO_STATUS_ERROR;
120         }
121
122         *bytes_read = result;
123
124         return (result > 0) ? G_IO_STATUS_NORMAL : G_IO_STATUS_EOF;
125 }
126
127 static GIOStatus g_io_gnutls_write(GIOChannel *channel, const gchar *buf,
128                                 gsize count, gsize *bytes_written, GError **err)
129 {
130         GIOGnuTLSChannel *gnutls_channel = (GIOGnuTLSChannel *) channel;
131         GIOStatus status;
132         ssize_t result;
133
134         DBG("channel %p count %zu", channel, count);
135
136         *bytes_written = 0;
137
138 again:
139         status = check_handshake(channel, err);
140         if (status != G_IO_STATUS_NORMAL)
141                 return status;
142
143         result = gnutls_record_send(gnutls_channel->session, buf, count);
144
145         DBG("result %zd", result);
146
147         if (result == GNUTLS_E_REHANDSHAKE) {
148                 gnutls_channel->established = FALSE;
149                 goto again;
150         }
151
152         if (result == GNUTLS_E_INTERRUPTED || result == GNUTLS_E_AGAIN)
153                 goto again;
154
155         if (result < 0) {
156                 g_set_error(err, G_IO_CHANNEL_ERROR,
157                                 G_IO_CHANNEL_ERROR_FAILED, "Stream corrupted");
158                 return G_IO_STATUS_ERROR;
159         }
160
161         *bytes_written = result;
162
163         return (result > 0) ? G_IO_STATUS_NORMAL : G_IO_STATUS_EOF;
164 }
165
166 static GIOStatus g_io_gnutls_seek(GIOChannel *channel, gint64 offset,
167                                                 GSeekType type, GError **err)
168 {
169         GIOGnuTLSChannel *gnutls_channel = (GIOGnuTLSChannel *) channel;
170         GIOChannel *transport = gnutls_channel->transport;
171
172         DBG("channel %p", channel);
173
174         return transport->funcs->io_seek(transport, offset, type, err);
175 }
176
177 static GIOStatus g_io_gnutls_close(GIOChannel *channel, GError **err)
178 {
179         GIOGnuTLSChannel *gnutls_channel = (GIOGnuTLSChannel *) channel;
180         GIOChannel *transport = gnutls_channel->transport;
181
182         DBG("channel %p", channel);
183
184         if (gnutls_channel->established == TRUE)
185                 gnutls_bye(gnutls_channel->session, GNUTLS_SHUT_RDWR);
186
187         return transport->funcs->io_close(transport, err);
188 }
189
190 static void g_io_gnutls_free(GIOChannel *channel)
191 {
192         GIOGnuTLSChannel *gnutls_channel = (GIOGnuTLSChannel *) channel;
193
194         DBG("channel %p", channel);
195
196         g_io_channel_unref(gnutls_channel->transport);
197
198         gnutls_deinit(gnutls_channel->session);
199
200         gnutls_certificate_free_credentials(gnutls_channel->cred);
201
202         g_free(gnutls_channel);
203 }
204
205 static GIOStatus g_io_gnutls_set_flags(GIOChannel *channel,
206                                                 GIOFlags flags, GError **err)
207 {
208         GIOGnuTLSChannel *gnutls_channel = (GIOGnuTLSChannel *) channel;
209         GIOChannel *transport = gnutls_channel->transport;
210
211         DBG("channel %p flags %u", channel, flags);
212
213         return transport->funcs->io_set_flags(transport, flags, err);
214 }
215
216 static GIOFlags g_io_gnutls_get_flags(GIOChannel *channel)
217 {
218         GIOGnuTLSChannel *gnutls_channel = (GIOGnuTLSChannel *) channel;
219         GIOChannel *transport = gnutls_channel->transport;
220
221         DBG("channel %p", channel);
222
223         return transport->funcs->io_get_flags(transport);
224 }
225
226 static gboolean g_io_gnutls_prepare(GSource *source, gint *timeout)
227 {
228         DBG("source %p", source);
229
230         *timeout = -1;
231
232         return FALSE;
233 }
234
235 static gboolean g_io_gnutls_check(GSource *source)
236 {
237         GIOGnuTLSWatch *watch = (GIOGnuTLSWatch *) source;
238         GIOCondition condition = watch->pollfd.revents;
239
240         DBG("source %p condition %u", source, condition);
241
242         if (condition & watch->condition)
243                 return TRUE;
244
245         return FALSE;
246 }
247
248 static gboolean g_io_gnutls_dispatch(GSource *source, GSourceFunc callback,
249                                                         gpointer user_data)
250 {
251         GIOGnuTLSWatch *watch = (GIOGnuTLSWatch *) source;
252         GIOFunc func = (GIOFunc) callback;
253         GIOCondition condition = watch->pollfd.revents;
254
255         DBG("source %p condition %u", source, condition);
256
257         if (func == NULL)
258                 return FALSE;
259
260         return func(watch->channel, condition & watch->condition, user_data);
261 }
262
263 static void g_io_gnutls_finalize(GSource *source)
264 {
265         GIOGnuTLSWatch *watch = (GIOGnuTLSWatch *) source;
266
267         DBG("source %p", source);
268
269         g_io_channel_unref(watch->channel);
270 }
271
272 static GSourceFuncs gnutls_watch_funcs = {
273         g_io_gnutls_prepare,
274         g_io_gnutls_check,
275         g_io_gnutls_dispatch,
276         g_io_gnutls_finalize,
277 };
278
279 static GSource *g_io_gnutls_create_watch(GIOChannel *channel,
280                                                 GIOCondition condition)
281 {
282         GIOGnuTLSChannel *gnutls_channel = (GIOGnuTLSChannel *) channel;
283         GIOGnuTLSWatch *watch;
284         GSource *source;
285
286         DBG("channel %p condition %u", channel, condition);
287
288         source = g_source_new(&gnutls_watch_funcs, sizeof(GIOGnuTLSWatch));
289
290         watch = (GIOGnuTLSWatch *) source;
291
292         watch->channel = channel;
293         g_io_channel_ref(channel);
294
295         watch->condition = condition;
296
297         watch->pollfd.fd = g_io_channel_unix_get_fd(gnutls_channel->transport);
298         watch->pollfd.events = condition;
299
300         g_source_add_poll(source, &watch->pollfd);
301
302         return source;
303 }
304
305 static GIOFuncs gnutls_channel_funcs = {
306         g_io_gnutls_read,
307         g_io_gnutls_write,
308         g_io_gnutls_seek,
309         g_io_gnutls_close,
310         g_io_gnutls_create_watch,
311         g_io_gnutls_free,
312         g_io_gnutls_set_flags,
313         g_io_gnutls_get_flags,
314 };
315
316 static ssize_t g_io_gnutls_push_func(gnutls_transport_ptr_t transport_data,
317                                                 const void *buf, size_t count)
318 {
319         GIOGnuTLSChannel *gnutls_channel = transport_data;
320         ssize_t result;
321         int fd;
322
323         DBG("transport %p count %zu", gnutls_channel->transport, count);
324
325         fd = g_io_channel_unix_get_fd(gnutls_channel->transport);
326
327         result = write(fd, buf, count);
328
329         DBG("result %zd", result);
330
331         return result;
332 }
333
334 static ssize_t g_io_gnutls_pull_func(gnutls_transport_ptr_t transport_data,
335                                                 void *buf, size_t count)
336 {
337         GIOGnuTLSChannel *gnutls_channel = transport_data;
338         ssize_t result;
339         int fd;
340
341         DBG("transport %p count %zu", gnutls_channel->transport, count);
342
343         fd = g_io_channel_unix_get_fd(gnutls_channel->transport);
344
345         result = read(fd, buf, count);
346
347         DBG("result %zd", result);
348
349         return result;
350 }
351
352 GIOChannel *g_io_channel_gnutls_new(int fd)
353 {
354         GIOGnuTLSChannel *gnutls_channel;
355         GIOChannel *channel;
356         int err;
357
358         DBG("");
359
360         gnutls_channel = g_new(GIOGnuTLSChannel, 1);
361
362         channel = (GIOChannel *) gnutls_channel;
363
364         g_io_channel_init(channel);
365         channel->funcs = &gnutls_channel_funcs;
366
367         gnutls_channel->transport = g_io_channel_unix_new(fd);
368
369         g_io_channel_set_encoding(gnutls_channel->transport, NULL, NULL);
370         g_io_channel_set_buffered(gnutls_channel->transport, FALSE);
371
372         channel->is_seekable = FALSE;
373         channel->is_readable = TRUE;
374         channel->is_writeable = TRUE;
375
376         g_io_channel_set_encoding(channel, NULL, NULL);
377         g_io_channel_set_buffered(channel, FALSE);
378
379         g_io_gnutls_global_init();
380
381         err = gnutls_init(&gnutls_channel->session, GNUTLS_CLIENT);
382         if (err < 0) {
383                 g_free(gnutls_channel);
384                 return NULL;
385         }
386
387         gnutls_transport_set_ptr(gnutls_channel->session, gnutls_channel);
388         gnutls_transport_set_push_function(gnutls_channel->session,
389                                                 g_io_gnutls_push_func);
390         gnutls_transport_set_pull_function(gnutls_channel->session,
391                                                 g_io_gnutls_pull_func);
392         gnutls_transport_set_lowat(gnutls_channel->session, 0);
393
394         gnutls_priority_set_direct(gnutls_channel->session,
395                                 "NORMAL:!VERS-TLS1.1:!VERS-TLS1.0", NULL);
396
397         gnutls_certificate_allocate_credentials(&gnutls_channel->cred);
398         gnutls_credentials_set(gnutls_channel->session,
399                                 GNUTLS_CRD_CERTIFICATE, gnutls_channel->cred);
400
401         DBG("channel %p transport %p", channel, gnutls_channel->transport);
402
403         return channel;
404 }