3 * Web service library with GLib integration
5 * Copyright (C) 2009-2010 Intel Corporation. All rights reserved.
7 * This program is free software; you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License version 2 as
9 * published by the Free Software Foundation.
11 * This program is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 * GNU General Public License for more details.
16 * You should have received a copy of the GNU General Public License
17 * along with this program; if not, write to the Free Software
18 * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
31 #include <gnutls/gnutls.h>
33 #include "giognutls.h"
35 //#define DBG(fmt, arg...) printf("%s: " fmt "\n" , __func__ , ## arg)
36 #define DBG(fmt, arg...)
38 typedef struct _GIOGnuTLSChannel GIOGnuTLSChannel;
39 typedef struct _GIOGnuTLSWatch GIOGnuTLSWatch;
41 struct _GIOGnuTLSChannel {
43 GIOChannel *transport;
44 gnutls_certificate_credentials_t cred;
45 gnutls_session session;
50 struct _GIOGnuTLSWatch {
54 GIOCondition condition;
57 static volatile gint global_init_done = 0;
59 static inline void g_io_gnutls_global_init(void)
61 if (g_atomic_int_compare_and_exchange(&global_init_done, 0, 1) == TRUE)
65 static GIOStatus check_handshake(GIOChannel *channel, GError **err)
67 GIOGnuTLSChannel *gnutls_channel = (GIOGnuTLSChannel *) channel;
70 DBG("channel %p", channel);
72 if (gnutls_channel->established == TRUE)
73 return G_IO_STATUS_NORMAL;
76 result = gnutls_handshake(gnutls_channel->session);
78 if (result == GNUTLS_E_INTERRUPTED || result == GNUTLS_E_AGAIN) {
79 GIOFlags flags = g_io_channel_get_flags(channel);
81 if (gnutls_channel->again == TRUE)
82 return G_IO_STATUS_AGAIN;
84 if (flags & G_IO_FLAG_NONBLOCK)
85 return G_IO_STATUS_AGAIN;
91 g_set_error(err, G_IO_CHANNEL_ERROR,
92 G_IO_CHANNEL_ERROR_FAILED, "Handshake failed");
93 return G_IO_STATUS_ERROR;
96 gnutls_channel->established = TRUE;
98 DBG("handshake done");
100 return G_IO_STATUS_NORMAL;
103 static GIOStatus g_io_gnutls_read(GIOChannel *channel, gchar *buf,
104 gsize count, gsize *bytes_read, GError **err)
106 GIOGnuTLSChannel *gnutls_channel = (GIOGnuTLSChannel *) channel;
110 DBG("channel %p count %zu", channel, count);
115 status = check_handshake(channel, err);
116 if (status != G_IO_STATUS_NORMAL)
119 result = gnutls_record_recv(gnutls_channel->session, buf, count);
121 DBG("result %zd", result);
123 if (result == GNUTLS_E_REHANDSHAKE) {
124 gnutls_channel->established = FALSE;
128 if (result == GNUTLS_E_INTERRUPTED || result == GNUTLS_E_AGAIN) {
129 GIOFlags flags = g_io_channel_get_flags(channel);
131 if (gnutls_channel->again == TRUE)
132 return G_IO_STATUS_AGAIN;
134 if (flags & G_IO_FLAG_NONBLOCK)
135 return G_IO_STATUS_AGAIN;
140 if (result == GNUTLS_E_UNEXPECTED_PACKET_LENGTH)
141 return G_IO_STATUS_EOF;
144 g_set_error(err, G_IO_CHANNEL_ERROR,
145 G_IO_CHANNEL_ERROR_FAILED, "Stream corrupted");
146 return G_IO_STATUS_ERROR;
149 *bytes_read = result;
151 return (result > 0) ? G_IO_STATUS_NORMAL : G_IO_STATUS_EOF;
154 static GIOStatus g_io_gnutls_write(GIOChannel *channel, const gchar *buf,
155 gsize count, gsize *bytes_written, GError **err)
157 GIOGnuTLSChannel *gnutls_channel = (GIOGnuTLSChannel *) channel;
161 DBG("channel %p count %zu", channel, count);
166 status = check_handshake(channel, err);
167 if (status != G_IO_STATUS_NORMAL)
170 result = gnutls_record_send(gnutls_channel->session, buf, count);
172 DBG("result %zd", result);
174 if (result == GNUTLS_E_REHANDSHAKE) {
175 gnutls_channel->established = FALSE;
179 if (result == GNUTLS_E_INTERRUPTED || result == GNUTLS_E_AGAIN) {
180 GIOFlags flags = g_io_channel_get_flags(channel);
182 if (gnutls_channel->again == TRUE)
183 return G_IO_STATUS_AGAIN;
185 if (flags & G_IO_FLAG_NONBLOCK)
186 return G_IO_STATUS_AGAIN;
192 g_set_error(err, G_IO_CHANNEL_ERROR,
193 G_IO_CHANNEL_ERROR_FAILED, "Stream corrupted");
194 return G_IO_STATUS_ERROR;
197 *bytes_written = result;
199 return (result > 0) ? G_IO_STATUS_NORMAL : G_IO_STATUS_EOF;
202 static GIOStatus g_io_gnutls_seek(GIOChannel *channel, gint64 offset,
203 GSeekType type, GError **err)
205 GIOGnuTLSChannel *gnutls_channel = (GIOGnuTLSChannel *) channel;
206 GIOChannel *transport = gnutls_channel->transport;
208 DBG("channel %p", channel);
210 return transport->funcs->io_seek(transport, offset, type, err);
213 static GIOStatus g_io_gnutls_close(GIOChannel *channel, GError **err)
215 GIOGnuTLSChannel *gnutls_channel = (GIOGnuTLSChannel *) channel;
216 GIOChannel *transport = gnutls_channel->transport;
218 DBG("channel %p", channel);
220 if (gnutls_channel->established == TRUE)
221 gnutls_bye(gnutls_channel->session, GNUTLS_SHUT_RDWR);
223 return transport->funcs->io_close(transport, err);
226 static void g_io_gnutls_free(GIOChannel *channel)
228 GIOGnuTLSChannel *gnutls_channel = (GIOGnuTLSChannel *) channel;
230 DBG("channel %p", channel);
232 g_io_channel_unref(gnutls_channel->transport);
234 gnutls_deinit(gnutls_channel->session);
236 gnutls_certificate_free_credentials(gnutls_channel->cred);
238 g_free(gnutls_channel);
241 static GIOStatus g_io_gnutls_set_flags(GIOChannel *channel,
242 GIOFlags flags, GError **err)
244 GIOGnuTLSChannel *gnutls_channel = (GIOGnuTLSChannel *) channel;
245 GIOChannel *transport = gnutls_channel->transport;
247 DBG("channel %p flags %u", channel, flags);
249 return transport->funcs->io_set_flags(transport, flags, err);
252 static GIOFlags g_io_gnutls_get_flags(GIOChannel *channel)
254 GIOGnuTLSChannel *gnutls_channel = (GIOGnuTLSChannel *) channel;
255 GIOChannel *transport = gnutls_channel->transport;
257 DBG("channel %p", channel);
259 return transport->funcs->io_get_flags(transport);
262 static gboolean g_io_gnutls_prepare(GSource *source, gint *timeout)
264 DBG("source %p", source);
271 static gboolean g_io_gnutls_check(GSource *source)
273 GIOGnuTLSWatch *watch = (GIOGnuTLSWatch *) source;
274 GIOCondition condition = watch->pollfd.revents;
276 DBG("source %p condition %u", source, condition);
278 if (condition & watch->condition)
284 static gboolean g_io_gnutls_dispatch(GSource *source, GSourceFunc callback,
287 GIOGnuTLSWatch *watch = (GIOGnuTLSWatch *) source;
288 GIOFunc func = (GIOFunc) callback;
289 GIOCondition condition = watch->pollfd.revents;
291 DBG("source %p condition %u", source, condition);
296 return func(watch->channel, condition & watch->condition, user_data);
299 static void g_io_gnutls_finalize(GSource *source)
301 GIOGnuTLSWatch *watch = (GIOGnuTLSWatch *) source;
303 DBG("source %p", source);
305 g_io_channel_unref(watch->channel);
308 static GSourceFuncs gnutls_watch_funcs = {
311 g_io_gnutls_dispatch,
312 g_io_gnutls_finalize,
315 static GSource *g_io_gnutls_create_watch(GIOChannel *channel,
316 GIOCondition condition)
318 GIOGnuTLSChannel *gnutls_channel = (GIOGnuTLSChannel *) channel;
319 GIOGnuTLSWatch *watch;
322 DBG("channel %p condition %u", channel, condition);
324 source = g_source_new(&gnutls_watch_funcs, sizeof(GIOGnuTLSWatch));
326 watch = (GIOGnuTLSWatch *) source;
328 watch->channel = channel;
329 g_io_channel_ref(channel);
331 watch->condition = condition;
333 watch->pollfd.fd = g_io_channel_unix_get_fd(gnutls_channel->transport);
334 watch->pollfd.events = condition;
336 g_source_add_poll(source, &watch->pollfd);
341 static GIOFuncs gnutls_channel_funcs = {
346 g_io_gnutls_create_watch,
348 g_io_gnutls_set_flags,
349 g_io_gnutls_get_flags,
352 static ssize_t g_io_gnutls_push_func(gnutls_transport_ptr_t transport_data,
353 const void *buf, size_t count)
355 GIOGnuTLSChannel *gnutls_channel = transport_data;
359 DBG("transport %p count %zu", gnutls_channel->transport, count);
361 fd = g_io_channel_unix_get_fd(gnutls_channel->transport);
363 result = write(fd, buf, count);
365 if (result < 0 && errno == EAGAIN)
366 gnutls_channel->again = TRUE;
368 gnutls_channel->again = FALSE;
370 DBG("result %zd", result);
375 static ssize_t g_io_gnutls_pull_func(gnutls_transport_ptr_t transport_data,
376 void *buf, size_t count)
378 GIOGnuTLSChannel *gnutls_channel = transport_data;
382 DBG("transport %p count %zu", gnutls_channel->transport, count);
384 fd = g_io_channel_unix_get_fd(gnutls_channel->transport);
386 result = read(fd, buf, count);
388 if (result < 0 && errno == EAGAIN)
389 gnutls_channel->again = TRUE;
391 gnutls_channel->again = FALSE;
393 DBG("result %zd", result);
398 GIOChannel *g_io_channel_gnutls_new(int fd)
400 GIOGnuTLSChannel *gnutls_channel;
406 gnutls_channel = g_new(GIOGnuTLSChannel, 1);
408 channel = (GIOChannel *) gnutls_channel;
410 g_io_channel_init(channel);
411 channel->funcs = &gnutls_channel_funcs;
413 gnutls_channel->transport = g_io_channel_unix_new(fd);
415 g_io_channel_set_encoding(gnutls_channel->transport, NULL, NULL);
416 g_io_channel_set_buffered(gnutls_channel->transport, FALSE);
418 channel->is_seekable = FALSE;
419 channel->is_readable = TRUE;
420 channel->is_writeable = TRUE;
422 channel->do_encode = FALSE;
424 g_io_gnutls_global_init();
426 err = gnutls_init(&gnutls_channel->session, GNUTLS_CLIENT);
428 g_free(gnutls_channel);
432 gnutls_transport_set_ptr(gnutls_channel->session, gnutls_channel);
433 gnutls_transport_set_push_function(gnutls_channel->session,
434 g_io_gnutls_push_func);
435 gnutls_transport_set_pull_function(gnutls_channel->session,
436 g_io_gnutls_pull_func);
437 gnutls_transport_set_lowat(gnutls_channel->session, 0);
439 gnutls_priority_set_direct(gnutls_channel->session,
440 "NORMAL:!VERS-TLS1.1:!VERS-TLS1.0", NULL);
442 gnutls_certificate_allocate_credentials(&gnutls_channel->cred);
443 gnutls_credentials_set(gnutls_channel->session,
444 GNUTLS_CRD_CERTIFICATE, gnutls_channel->cred);
446 DBG("channel %p transport %p", channel, gnutls_channel->transport);