Merge remote-tracking branch 'remotes/kevin/tags/for-upstream' into staging
[sdk/emulator/qemu.git] / io / channel-tls.c
1 /*
2  * QEMU I/O channels TLS driver
3  *
4  * Copyright (c) 2015 Red Hat, Inc.
5  *
6  * This library is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU Lesser General Public
8  * License as published by the Free Software Foundation; either
9  * version 2 of the License, or (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14  * Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public
17  * License along with this library; if not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20
21 #include "qemu/osdep.h"
22 #include "qapi/error.h"
23 #include "io/channel-tls.h"
24 #include "trace.h"
25
26
27 static ssize_t qio_channel_tls_write_handler(const char *buf,
28                                              size_t len,
29                                              void *opaque)
30 {
31     QIOChannelTLS *tioc = QIO_CHANNEL_TLS(opaque);
32     ssize_t ret;
33
34     ret = qio_channel_write(tioc->master, buf, len, NULL);
35     if (ret == QIO_CHANNEL_ERR_BLOCK) {
36         errno = EAGAIN;
37         return -1;
38     } else if (ret < 0) {
39         errno = EIO;
40         return -1;
41     }
42     return ret;
43 }
44
45 static ssize_t qio_channel_tls_read_handler(char *buf,
46                                             size_t len,
47                                             void *opaque)
48 {
49     QIOChannelTLS *tioc = QIO_CHANNEL_TLS(opaque);
50     ssize_t ret;
51
52     ret = qio_channel_read(tioc->master, buf, len, NULL);
53     if (ret == QIO_CHANNEL_ERR_BLOCK) {
54         errno = EAGAIN;
55         return -1;
56     } else if (ret < 0) {
57         errno = EIO;
58         return -1;
59     }
60     return ret;
61 }
62
63
64 QIOChannelTLS *
65 qio_channel_tls_new_server(QIOChannel *master,
66                            QCryptoTLSCreds *creds,
67                            const char *aclname,
68                            Error **errp)
69 {
70     QIOChannelTLS *ioc;
71
72     ioc = QIO_CHANNEL_TLS(object_new(TYPE_QIO_CHANNEL_TLS));
73
74     ioc->master = master;
75     object_ref(OBJECT(master));
76
77     ioc->session = qcrypto_tls_session_new(
78         creds,
79         NULL,
80         aclname,
81         QCRYPTO_TLS_CREDS_ENDPOINT_SERVER,
82         errp);
83     if (!ioc->session) {
84         goto error;
85     }
86
87     qcrypto_tls_session_set_callbacks(
88         ioc->session,
89         qio_channel_tls_write_handler,
90         qio_channel_tls_read_handler,
91         ioc);
92
93     trace_qio_channel_tls_new_server(ioc, master, creds, aclname);
94     return ioc;
95
96  error:
97     object_unref(OBJECT(ioc));
98     return NULL;
99 }
100
101 QIOChannelTLS *
102 qio_channel_tls_new_client(QIOChannel *master,
103                            QCryptoTLSCreds *creds,
104                            const char *hostname,
105                            Error **errp)
106 {
107     QIOChannelTLS *tioc;
108     QIOChannel *ioc;
109
110     tioc = QIO_CHANNEL_TLS(object_new(TYPE_QIO_CHANNEL_TLS));
111     ioc = QIO_CHANNEL(tioc);
112
113     tioc->master = master;
114     if (master->features & (1 << QIO_CHANNEL_FEATURE_SHUTDOWN)) {
115         ioc->features |= (1 << QIO_CHANNEL_FEATURE_SHUTDOWN);
116     }
117     object_ref(OBJECT(master));
118
119     tioc->session = qcrypto_tls_session_new(
120         creds,
121         hostname,
122         NULL,
123         QCRYPTO_TLS_CREDS_ENDPOINT_CLIENT,
124         errp);
125     if (!tioc->session) {
126         goto error;
127     }
128
129     qcrypto_tls_session_set_callbacks(
130         tioc->session,
131         qio_channel_tls_write_handler,
132         qio_channel_tls_read_handler,
133         tioc);
134
135     trace_qio_channel_tls_new_client(tioc, master, creds, hostname);
136     return tioc;
137
138  error:
139     object_unref(OBJECT(tioc));
140     return NULL;
141 }
142
143
144 static gboolean qio_channel_tls_handshake_io(QIOChannel *ioc,
145                                              GIOCondition condition,
146                                              gpointer user_data);
147
148 static void qio_channel_tls_handshake_task(QIOChannelTLS *ioc,
149                                            QIOTask *task)
150 {
151     Error *err = NULL;
152     QCryptoTLSSessionHandshakeStatus status;
153
154     if (qcrypto_tls_session_handshake(ioc->session, &err) < 0) {
155         trace_qio_channel_tls_handshake_fail(ioc);
156         qio_task_abort(task, err);
157         goto cleanup;
158     }
159
160     status = qcrypto_tls_session_get_handshake_status(ioc->session);
161     if (status == QCRYPTO_TLS_HANDSHAKE_COMPLETE) {
162         trace_qio_channel_tls_handshake_complete(ioc);
163         if (qcrypto_tls_session_check_credentials(ioc->session,
164                                                   &err) < 0) {
165             trace_qio_channel_tls_credentials_deny(ioc);
166             qio_task_abort(task, err);
167             goto cleanup;
168         }
169         trace_qio_channel_tls_credentials_allow(ioc);
170         qio_task_complete(task);
171     } else {
172         GIOCondition condition;
173         if (status == QCRYPTO_TLS_HANDSHAKE_SENDING) {
174             condition = G_IO_OUT;
175         } else {
176             condition = G_IO_IN;
177         }
178
179         trace_qio_channel_tls_handshake_pending(ioc, status);
180         qio_channel_add_watch(ioc->master,
181                               condition,
182                               qio_channel_tls_handshake_io,
183                               task,
184                               NULL);
185     }
186
187  cleanup:
188     error_free(err);
189 }
190
191
192 static gboolean qio_channel_tls_handshake_io(QIOChannel *ioc,
193                                              GIOCondition condition,
194                                              gpointer user_data)
195 {
196     QIOTask *task = user_data;
197     QIOChannelTLS *tioc = QIO_CHANNEL_TLS(
198         qio_task_get_source(task));
199
200     qio_channel_tls_handshake_task(
201        tioc, task);
202
203     object_unref(OBJECT(tioc));
204
205     return FALSE;
206 }
207
208 void qio_channel_tls_handshake(QIOChannelTLS *ioc,
209                                QIOTaskFunc func,
210                                gpointer opaque,
211                                GDestroyNotify destroy)
212 {
213     QIOTask *task;
214
215     task = qio_task_new(OBJECT(ioc),
216                         func, opaque, destroy);
217
218     trace_qio_channel_tls_handshake_start(ioc);
219     qio_channel_tls_handshake_task(ioc, task);
220 }
221
222
223 static void qio_channel_tls_init(Object *obj G_GNUC_UNUSED)
224 {
225 }
226
227
228 static void qio_channel_tls_finalize(Object *obj)
229 {
230     QIOChannelTLS *ioc = QIO_CHANNEL_TLS(obj);
231
232     object_unref(OBJECT(ioc->master));
233     qcrypto_tls_session_free(ioc->session);
234 }
235
236
237 static ssize_t qio_channel_tls_readv(QIOChannel *ioc,
238                                      const struct iovec *iov,
239                                      size_t niov,
240                                      int **fds,
241                                      size_t *nfds,
242                                      Error **errp)
243 {
244     QIOChannelTLS *tioc = QIO_CHANNEL_TLS(ioc);
245     size_t i;
246     ssize_t got = 0;
247
248     for (i = 0 ; i < niov ; i++) {
249         ssize_t ret = qcrypto_tls_session_read(tioc->session,
250                                                iov[i].iov_base,
251                                                iov[i].iov_len);
252         if (ret < 0) {
253             if (errno == EAGAIN) {
254                 if (got) {
255                     return got;
256                 } else {
257                     return QIO_CHANNEL_ERR_BLOCK;
258                 }
259             }
260
261             error_setg_errno(errp, errno,
262                              "Cannot read from TLS channel");
263             return -1;
264         }
265         got += ret;
266         if (ret < iov[i].iov_len) {
267             break;
268         }
269     }
270     return got;
271 }
272
273
274 static ssize_t qio_channel_tls_writev(QIOChannel *ioc,
275                                       const struct iovec *iov,
276                                       size_t niov,
277                                       int *fds,
278                                       size_t nfds,
279                                       Error **errp)
280 {
281     QIOChannelTLS *tioc = QIO_CHANNEL_TLS(ioc);
282     size_t i;
283     ssize_t done = 0;
284
285     for (i = 0 ; i < niov ; i++) {
286         ssize_t ret = qcrypto_tls_session_write(tioc->session,
287                                                 iov[i].iov_base,
288                                                 iov[i].iov_len);
289         if (ret <= 0) {
290             if (errno == EAGAIN) {
291                 if (done) {
292                     return done;
293                 } else {
294                     return QIO_CHANNEL_ERR_BLOCK;
295                 }
296             }
297
298             error_setg_errno(errp, errno,
299                              "Cannot write to TLS channel");
300             return -1;
301         }
302         done += ret;
303         if (ret < iov[i].iov_len) {
304             break;
305         }
306     }
307     return done;
308 }
309
310 static int qio_channel_tls_set_blocking(QIOChannel *ioc,
311                                         bool enabled,
312                                         Error **errp)
313 {
314     QIOChannelTLS *tioc = QIO_CHANNEL_TLS(ioc);
315
316     return qio_channel_set_blocking(tioc->master, enabled, errp);
317 }
318
319 static void qio_channel_tls_set_delay(QIOChannel *ioc,
320                                       bool enabled)
321 {
322     QIOChannelTLS *tioc = QIO_CHANNEL_TLS(ioc);
323
324     qio_channel_set_delay(tioc->master, enabled);
325 }
326
327 static void qio_channel_tls_set_cork(QIOChannel *ioc,
328                                      bool enabled)
329 {
330     QIOChannelTLS *tioc = QIO_CHANNEL_TLS(ioc);
331
332     qio_channel_set_cork(tioc->master, enabled);
333 }
334
335 static int qio_channel_tls_shutdown(QIOChannel *ioc,
336                                     QIOChannelShutdown how,
337                                     Error **errp)
338 {
339     QIOChannelTLS *tioc = QIO_CHANNEL_TLS(ioc);
340
341     return qio_channel_shutdown(tioc->master, how, errp);
342 }
343
344 static int qio_channel_tls_close(QIOChannel *ioc,
345                                  Error **errp)
346 {
347     QIOChannelTLS *tioc = QIO_CHANNEL_TLS(ioc);
348
349     return qio_channel_close(tioc->master, errp);
350 }
351
352 static GSource *qio_channel_tls_create_watch(QIOChannel *ioc,
353                                              GIOCondition condition)
354 {
355     QIOChannelTLS *tioc = QIO_CHANNEL_TLS(ioc);
356
357     return qio_channel_create_watch(tioc->master, condition);
358 }
359
360 QCryptoTLSSession *
361 qio_channel_tls_get_session(QIOChannelTLS *ioc)
362 {
363     return ioc->session;
364 }
365
366 static void qio_channel_tls_class_init(ObjectClass *klass,
367                                        void *class_data G_GNUC_UNUSED)
368 {
369     QIOChannelClass *ioc_klass = QIO_CHANNEL_CLASS(klass);
370
371     ioc_klass->io_writev = qio_channel_tls_writev;
372     ioc_klass->io_readv = qio_channel_tls_readv;
373     ioc_klass->io_set_blocking = qio_channel_tls_set_blocking;
374     ioc_klass->io_set_delay = qio_channel_tls_set_delay;
375     ioc_klass->io_set_cork = qio_channel_tls_set_cork;
376     ioc_klass->io_close = qio_channel_tls_close;
377     ioc_klass->io_shutdown = qio_channel_tls_shutdown;
378     ioc_klass->io_create_watch = qio_channel_tls_create_watch;
379 }
380
381 static const TypeInfo qio_channel_tls_info = {
382     .parent = TYPE_QIO_CHANNEL,
383     .name = TYPE_QIO_CHANNEL_TLS,
384     .instance_size = sizeof(QIOChannelTLS),
385     .instance_init = qio_channel_tls_init,
386     .instance_finalize = qio_channel_tls_finalize,
387     .class_init = qio_channel_tls_class_init,
388 };
389
390 static void qio_channel_tls_register_types(void)
391 {
392     type_register_static(&qio_channel_tls_info);
393 }
394
395 type_init(qio_channel_tls_register_types);