1 // SPDX-License-Identifier: GPL-2.0-only
3 * Establish a TLS session for a kernel socket consumer
4 * using the tlshd user space handler.
6 * Author: Chuck Lever <chuck.lever@oracle.com>
8 * Copyright (c) 2021-2023, Oracle and/or its affiliates.
11 #include <linux/types.h>
12 #include <linux/socket.h>
13 #include <linux/kernel.h>
14 #include <linux/module.h>
15 #include <linux/slab.h>
16 #include <linux/key.h>
19 #include <net/handshake.h>
20 #include <net/genetlink.h>
22 #include <uapi/linux/keyctl.h>
23 #include <uapi/linux/handshake.h>
24 #include "handshake.h"
26 struct tls_handshake_req {
27 void (*th_consumer_done)(void *data, int status,
29 void *th_consumer_data;
32 unsigned int th_timeout_ms;
34 key_serial_t th_keyring;
35 key_serial_t th_certificate;
36 key_serial_t th_privkey;
38 unsigned int th_num_peerids;
39 key_serial_t th_peerid[5];
42 static struct tls_handshake_req *
43 tls_handshake_req_init(struct handshake_req *req,
44 const struct tls_handshake_args *args)
46 struct tls_handshake_req *treq = handshake_req_private(req);
48 treq->th_timeout_ms = args->ta_timeout_ms;
49 treq->th_consumer_done = args->ta_done;
50 treq->th_consumer_data = args->ta_data;
51 treq->th_keyring = args->ta_keyring;
52 treq->th_num_peerids = 0;
53 treq->th_certificate = TLS_NO_CERT;
54 treq->th_privkey = TLS_NO_PRIVKEY;
58 static void tls_handshake_remote_peerids(struct tls_handshake_req *treq,
59 struct genl_info *info)
61 struct nlattr *head = nlmsg_attrdata(info->nlhdr, GENL_HDRLEN);
62 int rem, len = nlmsg_attrlen(info->nlhdr, GENL_HDRLEN);
67 nla_for_each_attr(nla, head, len, rem) {
68 if (nla_type(nla) == HANDSHAKE_A_DONE_REMOTE_AUTH)
73 treq->th_num_peerids = min_t(unsigned int, i,
74 ARRAY_SIZE(treq->th_peerid));
77 nla_for_each_attr(nla, head, len, rem) {
78 if (nla_type(nla) == HANDSHAKE_A_DONE_REMOTE_AUTH)
79 treq->th_peerid[i++] = nla_get_u32(nla);
80 if (i >= treq->th_num_peerids)
86 * tls_handshake_done - callback to handle a CMD_DONE request
87 * @req: socket on which the handshake was performed
88 * @status: session status code
89 * @info: full results of session establishment
92 static void tls_handshake_done(struct handshake_req *req,
93 unsigned int status, struct genl_info *info)
95 struct tls_handshake_req *treq = handshake_req_private(req);
97 treq->th_peerid[0] = TLS_NO_PEERID;
99 tls_handshake_remote_peerids(treq, info);
101 treq->th_consumer_done(treq->th_consumer_data, -status,
105 #if IS_ENABLED(CONFIG_KEYS)
106 static int tls_handshake_private_keyring(struct tls_handshake_req *treq)
108 key_ref_t process_keyring_ref, keyring_ref;
111 if (treq->th_keyring == TLS_NO_KEYRING)
114 process_keyring_ref = lookup_user_key(KEY_SPEC_PROCESS_KEYRING,
117 if (IS_ERR(process_keyring_ref)) {
118 ret = PTR_ERR(process_keyring_ref);
122 keyring_ref = lookup_user_key(treq->th_keyring, KEY_LOOKUP_CREATE,
124 if (IS_ERR(keyring_ref)) {
125 ret = PTR_ERR(keyring_ref);
129 ret = key_link(key_ref_to_ptr(process_keyring_ref),
130 key_ref_to_ptr(keyring_ref));
132 key_ref_put(keyring_ref);
134 key_ref_put(process_keyring_ref);
139 static int tls_handshake_private_keyring(struct tls_handshake_req *treq)
145 static int tls_handshake_put_peer_identity(struct sk_buff *msg,
146 struct tls_handshake_req *treq)
150 for (i = 0; i < treq->th_num_peerids; i++)
151 if (nla_put_u32(msg, HANDSHAKE_A_ACCEPT_PEER_IDENTITY,
152 treq->th_peerid[i]) < 0)
157 static int tls_handshake_put_certificate(struct sk_buff *msg,
158 struct tls_handshake_req *treq)
160 struct nlattr *entry_attr;
162 if (treq->th_certificate == TLS_NO_CERT &&
163 treq->th_privkey == TLS_NO_PRIVKEY)
166 entry_attr = nla_nest_start(msg, HANDSHAKE_A_ACCEPT_CERTIFICATE);
170 if (nla_put_u32(msg, HANDSHAKE_A_X509_CERT,
171 treq->th_certificate) ||
172 nla_put_u32(msg, HANDSHAKE_A_X509_PRIVKEY,
174 nla_nest_cancel(msg, entry_attr);
178 nla_nest_end(msg, entry_attr);
183 * tls_handshake_accept - callback to construct a CMD_ACCEPT response
184 * @req: handshake parameters to return
185 * @info: generic netlink message context
186 * @fd: file descriptor to be returned
188 * Returns zero on success, or a negative errno on failure.
190 static int tls_handshake_accept(struct handshake_req *req,
191 struct genl_info *info, int fd)
193 struct tls_handshake_req *treq = handshake_req_private(req);
194 struct nlmsghdr *hdr;
198 ret = tls_handshake_private_keyring(treq);
203 msg = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL);
206 hdr = handshake_genl_put(msg, info);
211 ret = nla_put_u32(msg, HANDSHAKE_A_ACCEPT_SOCKFD, fd);
214 ret = nla_put_u32(msg, HANDSHAKE_A_ACCEPT_MESSAGE_TYPE, treq->th_type);
217 if (treq->th_timeout_ms) {
218 ret = nla_put_u32(msg, HANDSHAKE_A_ACCEPT_TIMEOUT, treq->th_timeout_ms);
223 ret = nla_put_u32(msg, HANDSHAKE_A_ACCEPT_AUTH_MODE,
227 switch (treq->th_auth_mode) {
228 case HANDSHAKE_AUTH_PSK:
229 ret = tls_handshake_put_peer_identity(msg, treq);
233 case HANDSHAKE_AUTH_X509:
234 ret = tls_handshake_put_certificate(msg, treq);
240 genlmsg_end(msg, hdr);
241 return genlmsg_reply(msg, info);
244 genlmsg_cancel(msg, hdr);
249 static const struct handshake_proto tls_handshake_proto = {
250 .hp_handler_class = HANDSHAKE_HANDLER_CLASS_TLSHD,
251 .hp_privsize = sizeof(struct tls_handshake_req),
252 .hp_flags = BIT(HANDSHAKE_F_PROTO_NOTIFY),
254 .hp_accept = tls_handshake_accept,
255 .hp_done = tls_handshake_done,
259 * tls_client_hello_anon - request an anonymous TLS handshake on a socket
260 * @args: socket and handshake parameters for this request
261 * @flags: memory allocation control flags
264 * %0: Handshake request enqueue; ->done will be called when complete
265 * %-ESRCH: No user agent is available
266 * %-ENOMEM: Memory allocation failed
268 int tls_client_hello_anon(const struct tls_handshake_args *args, gfp_t flags)
270 struct tls_handshake_req *treq;
271 struct handshake_req *req;
273 req = handshake_req_alloc(&tls_handshake_proto, flags);
276 treq = tls_handshake_req_init(req, args);
277 treq->th_type = HANDSHAKE_MSG_TYPE_CLIENTHELLO;
278 treq->th_auth_mode = HANDSHAKE_AUTH_UNAUTH;
280 return handshake_req_submit(args->ta_sock, req, flags);
282 EXPORT_SYMBOL(tls_client_hello_anon);
285 * tls_client_hello_x509 - request an x.509-based TLS handshake on a socket
286 * @args: socket and handshake parameters for this request
287 * @flags: memory allocation control flags
290 * %0: Handshake request enqueue; ->done will be called when complete
291 * %-ESRCH: No user agent is available
292 * %-ENOMEM: Memory allocation failed
294 int tls_client_hello_x509(const struct tls_handshake_args *args, gfp_t flags)
296 struct tls_handshake_req *treq;
297 struct handshake_req *req;
299 req = handshake_req_alloc(&tls_handshake_proto, flags);
302 treq = tls_handshake_req_init(req, args);
303 treq->th_type = HANDSHAKE_MSG_TYPE_CLIENTHELLO;
304 treq->th_auth_mode = HANDSHAKE_AUTH_X509;
305 treq->th_certificate = args->ta_my_cert;
306 treq->th_privkey = args->ta_my_privkey;
308 return handshake_req_submit(args->ta_sock, req, flags);
310 EXPORT_SYMBOL(tls_client_hello_x509);
313 * tls_client_hello_psk - request a PSK-based TLS handshake on a socket
314 * @args: socket and handshake parameters for this request
315 * @flags: memory allocation control flags
318 * %0: Handshake request enqueue; ->done will be called when complete
319 * %-EINVAL: Wrong number of local peer IDs
320 * %-ESRCH: No user agent is available
321 * %-ENOMEM: Memory allocation failed
323 int tls_client_hello_psk(const struct tls_handshake_args *args, gfp_t flags)
325 struct tls_handshake_req *treq;
326 struct handshake_req *req;
329 if (!args->ta_num_peerids ||
330 args->ta_num_peerids > ARRAY_SIZE(treq->th_peerid))
333 req = handshake_req_alloc(&tls_handshake_proto, flags);
336 treq = tls_handshake_req_init(req, args);
337 treq->th_type = HANDSHAKE_MSG_TYPE_CLIENTHELLO;
338 treq->th_auth_mode = HANDSHAKE_AUTH_PSK;
339 treq->th_num_peerids = args->ta_num_peerids;
340 for (i = 0; i < args->ta_num_peerids; i++)
341 treq->th_peerid[i] = args->ta_my_peerids[i];
343 return handshake_req_submit(args->ta_sock, req, flags);
345 EXPORT_SYMBOL(tls_client_hello_psk);
348 * tls_server_hello_x509 - request a server TLS handshake on a socket
349 * @args: socket and handshake parameters for this request
350 * @flags: memory allocation control flags
353 * %0: Handshake request enqueue; ->done will be called when complete
354 * %-ESRCH: No user agent is available
355 * %-ENOMEM: Memory allocation failed
357 int tls_server_hello_x509(const struct tls_handshake_args *args, gfp_t flags)
359 struct tls_handshake_req *treq;
360 struct handshake_req *req;
362 req = handshake_req_alloc(&tls_handshake_proto, flags);
365 treq = tls_handshake_req_init(req, args);
366 treq->th_type = HANDSHAKE_MSG_TYPE_SERVERHELLO;
367 treq->th_auth_mode = HANDSHAKE_AUTH_X509;
368 treq->th_certificate = args->ta_my_cert;
369 treq->th_privkey = args->ta_my_privkey;
371 return handshake_req_submit(args->ta_sock, req, flags);
373 EXPORT_SYMBOL(tls_server_hello_x509);
376 * tls_server_hello_psk - request a server TLS handshake on a socket
377 * @args: socket and handshake parameters for this request
378 * @flags: memory allocation control flags
381 * %0: Handshake request enqueue; ->done will be called when complete
382 * %-ESRCH: No user agent is available
383 * %-ENOMEM: Memory allocation failed
385 int tls_server_hello_psk(const struct tls_handshake_args *args, gfp_t flags)
387 struct tls_handshake_req *treq;
388 struct handshake_req *req;
390 req = handshake_req_alloc(&tls_handshake_proto, flags);
393 treq = tls_handshake_req_init(req, args);
394 treq->th_type = HANDSHAKE_MSG_TYPE_SERVERHELLO;
395 treq->th_auth_mode = HANDSHAKE_AUTH_PSK;
396 treq->th_num_peerids = 1;
397 treq->th_peerid[0] = args->ta_my_peerids[0];
399 return handshake_req_submit(args->ta_sock, req, flags);
401 EXPORT_SYMBOL(tls_server_hello_psk);
404 * tls_handshake_cancel - cancel a pending handshake
405 * @sk: socket on which there is an ongoing handshake
407 * Request cancellation races with request completion. To determine
408 * who won, callers examine the return value from this function.
411 * %true - Uncompleted handshake request was canceled
412 * %false - Handshake request already completed or not found
414 bool tls_handshake_cancel(struct sock *sk)
416 return handshake_req_cancel(sk);
418 EXPORT_SYMBOL(tls_handshake_cancel);