Merge git://git.kernel.org/pub/scm/linux/kernel/git/netdev/net-next
[platform/kernel/linux-starfive.git] / net / handshake / tlshd.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Establish a TLS session for a kernel socket consumer
4  * using the tlshd user space handler.
5  *
6  * Author: Chuck Lever <chuck.lever@oracle.com>
7  *
8  * Copyright (c) 2021-2023, Oracle and/or its affiliates.
9  */
10
11 #include <linux/types.h>
12 #include <linux/socket.h>
13 #include <linux/kernel.h>
14 #include <linux/module.h>
15 #include <linux/slab.h>
16 #include <linux/key.h>
17
18 #include <net/sock.h>
19 #include <net/handshake.h>
20 #include <net/genetlink.h>
21
22 #include <uapi/linux/keyctl.h>
23 #include <uapi/linux/handshake.h>
24 #include "handshake.h"
25
26 struct tls_handshake_req {
27         void                    (*th_consumer_done)(void *data, int status,
28                                                     key_serial_t peerid);
29         void                    *th_consumer_data;
30
31         int                     th_type;
32         unsigned int            th_timeout_ms;
33         int                     th_auth_mode;
34         key_serial_t            th_keyring;
35         key_serial_t            th_certificate;
36         key_serial_t            th_privkey;
37
38         unsigned int            th_num_peerids;
39         key_serial_t            th_peerid[5];
40 };
41
42 static struct tls_handshake_req *
43 tls_handshake_req_init(struct handshake_req *req,
44                        const struct tls_handshake_args *args)
45 {
46         struct tls_handshake_req *treq = handshake_req_private(req);
47
48         treq->th_timeout_ms = args->ta_timeout_ms;
49         treq->th_consumer_done = args->ta_done;
50         treq->th_consumer_data = args->ta_data;
51         treq->th_keyring = args->ta_keyring;
52         treq->th_num_peerids = 0;
53         treq->th_certificate = TLS_NO_CERT;
54         treq->th_privkey = TLS_NO_PRIVKEY;
55         return treq;
56 }
57
58 static void tls_handshake_remote_peerids(struct tls_handshake_req *treq,
59                                          struct genl_info *info)
60 {
61         struct nlattr *head = nlmsg_attrdata(info->nlhdr, GENL_HDRLEN);
62         int rem, len = nlmsg_attrlen(info->nlhdr, GENL_HDRLEN);
63         struct nlattr *nla;
64         unsigned int i;
65
66         i = 0;
67         nla_for_each_attr(nla, head, len, rem) {
68                 if (nla_type(nla) == HANDSHAKE_A_DONE_REMOTE_AUTH)
69                         i++;
70         }
71         if (!i)
72                 return;
73         treq->th_num_peerids = min_t(unsigned int, i,
74                                      ARRAY_SIZE(treq->th_peerid));
75
76         i = 0;
77         nla_for_each_attr(nla, head, len, rem) {
78                 if (nla_type(nla) == HANDSHAKE_A_DONE_REMOTE_AUTH)
79                         treq->th_peerid[i++] = nla_get_u32(nla);
80                 if (i >= treq->th_num_peerids)
81                         break;
82         }
83 }
84
85 /**
86  * tls_handshake_done - callback to handle a CMD_DONE request
87  * @req: socket on which the handshake was performed
88  * @status: session status code
89  * @info: full results of session establishment
90  *
91  */
92 static void tls_handshake_done(struct handshake_req *req,
93                                unsigned int status, struct genl_info *info)
94 {
95         struct tls_handshake_req *treq = handshake_req_private(req);
96
97         treq->th_peerid[0] = TLS_NO_PEERID;
98         if (info)
99                 tls_handshake_remote_peerids(treq, info);
100
101         treq->th_consumer_done(treq->th_consumer_data, -status,
102                                treq->th_peerid[0]);
103 }
104
105 #if IS_ENABLED(CONFIG_KEYS)
106 static int tls_handshake_private_keyring(struct tls_handshake_req *treq)
107 {
108         key_ref_t process_keyring_ref, keyring_ref;
109         int ret;
110
111         if (treq->th_keyring == TLS_NO_KEYRING)
112                 return 0;
113
114         process_keyring_ref = lookup_user_key(KEY_SPEC_PROCESS_KEYRING,
115                                               KEY_LOOKUP_CREATE,
116                                               KEY_NEED_WRITE);
117         if (IS_ERR(process_keyring_ref)) {
118                 ret = PTR_ERR(process_keyring_ref);
119                 goto out;
120         }
121
122         keyring_ref = lookup_user_key(treq->th_keyring, KEY_LOOKUP_CREATE,
123                                       KEY_NEED_LINK);
124         if (IS_ERR(keyring_ref)) {
125                 ret = PTR_ERR(keyring_ref);
126                 goto out_put_key;
127         }
128
129         ret = key_link(key_ref_to_ptr(process_keyring_ref),
130                        key_ref_to_ptr(keyring_ref));
131
132         key_ref_put(keyring_ref);
133 out_put_key:
134         key_ref_put(process_keyring_ref);
135 out:
136         return ret;
137 }
138 #else
139 static int tls_handshake_private_keyring(struct tls_handshake_req *treq)
140 {
141         return 0;
142 }
143 #endif
144
145 static int tls_handshake_put_peer_identity(struct sk_buff *msg,
146                                            struct tls_handshake_req *treq)
147 {
148         unsigned int i;
149
150         for (i = 0; i < treq->th_num_peerids; i++)
151                 if (nla_put_u32(msg, HANDSHAKE_A_ACCEPT_PEER_IDENTITY,
152                                 treq->th_peerid[i]) < 0)
153                         return -EMSGSIZE;
154         return 0;
155 }
156
157 static int tls_handshake_put_certificate(struct sk_buff *msg,
158                                          struct tls_handshake_req *treq)
159 {
160         struct nlattr *entry_attr;
161
162         if (treq->th_certificate == TLS_NO_CERT &&
163             treq->th_privkey == TLS_NO_PRIVKEY)
164                 return 0;
165
166         entry_attr = nla_nest_start(msg, HANDSHAKE_A_ACCEPT_CERTIFICATE);
167         if (!entry_attr)
168                 return -EMSGSIZE;
169
170         if (nla_put_u32(msg, HANDSHAKE_A_X509_CERT,
171                         treq->th_certificate) ||
172             nla_put_u32(msg, HANDSHAKE_A_X509_PRIVKEY,
173                         treq->th_privkey)) {
174                 nla_nest_cancel(msg, entry_attr);
175                 return -EMSGSIZE;
176         }
177
178         nla_nest_end(msg, entry_attr);
179         return 0;
180 }
181
182 /**
183  * tls_handshake_accept - callback to construct a CMD_ACCEPT response
184  * @req: handshake parameters to return
185  * @info: generic netlink message context
186  * @fd: file descriptor to be returned
187  *
188  * Returns zero on success, or a negative errno on failure.
189  */
190 static int tls_handshake_accept(struct handshake_req *req,
191                                 struct genl_info *info, int fd)
192 {
193         struct tls_handshake_req *treq = handshake_req_private(req);
194         struct nlmsghdr *hdr;
195         struct sk_buff *msg;
196         int ret;
197
198         ret = tls_handshake_private_keyring(treq);
199         if (ret < 0)
200                 goto out;
201
202         ret = -ENOMEM;
203         msg = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL);
204         if (!msg)
205                 goto out;
206         hdr = handshake_genl_put(msg, info);
207         if (!hdr)
208                 goto out_cancel;
209
210         ret = -EMSGSIZE;
211         ret = nla_put_u32(msg, HANDSHAKE_A_ACCEPT_SOCKFD, fd);
212         if (ret < 0)
213                 goto out_cancel;
214         ret = nla_put_u32(msg, HANDSHAKE_A_ACCEPT_MESSAGE_TYPE, treq->th_type);
215         if (ret < 0)
216                 goto out_cancel;
217         if (treq->th_timeout_ms) {
218                 ret = nla_put_u32(msg, HANDSHAKE_A_ACCEPT_TIMEOUT, treq->th_timeout_ms);
219                 if (ret < 0)
220                         goto out_cancel;
221         }
222
223         ret = nla_put_u32(msg, HANDSHAKE_A_ACCEPT_AUTH_MODE,
224                           treq->th_auth_mode);
225         if (ret < 0)
226                 goto out_cancel;
227         switch (treq->th_auth_mode) {
228         case HANDSHAKE_AUTH_PSK:
229                 ret = tls_handshake_put_peer_identity(msg, treq);
230                 if (ret < 0)
231                         goto out_cancel;
232                 break;
233         case HANDSHAKE_AUTH_X509:
234                 ret = tls_handshake_put_certificate(msg, treq);
235                 if (ret < 0)
236                         goto out_cancel;
237                 break;
238         }
239
240         genlmsg_end(msg, hdr);
241         return genlmsg_reply(msg, info);
242
243 out_cancel:
244         genlmsg_cancel(msg, hdr);
245 out:
246         return ret;
247 }
248
249 static const struct handshake_proto tls_handshake_proto = {
250         .hp_handler_class       = HANDSHAKE_HANDLER_CLASS_TLSHD,
251         .hp_privsize            = sizeof(struct tls_handshake_req),
252         .hp_flags               = BIT(HANDSHAKE_F_PROTO_NOTIFY),
253
254         .hp_accept              = tls_handshake_accept,
255         .hp_done                = tls_handshake_done,
256 };
257
258 /**
259  * tls_client_hello_anon - request an anonymous TLS handshake on a socket
260  * @args: socket and handshake parameters for this request
261  * @flags: memory allocation control flags
262  *
263  * Return values:
264  *   %0: Handshake request enqueue; ->done will be called when complete
265  *   %-ESRCH: No user agent is available
266  *   %-ENOMEM: Memory allocation failed
267  */
268 int tls_client_hello_anon(const struct tls_handshake_args *args, gfp_t flags)
269 {
270         struct tls_handshake_req *treq;
271         struct handshake_req *req;
272
273         req = handshake_req_alloc(&tls_handshake_proto, flags);
274         if (!req)
275                 return -ENOMEM;
276         treq = tls_handshake_req_init(req, args);
277         treq->th_type = HANDSHAKE_MSG_TYPE_CLIENTHELLO;
278         treq->th_auth_mode = HANDSHAKE_AUTH_UNAUTH;
279
280         return handshake_req_submit(args->ta_sock, req, flags);
281 }
282 EXPORT_SYMBOL(tls_client_hello_anon);
283
284 /**
285  * tls_client_hello_x509 - request an x.509-based TLS handshake on a socket
286  * @args: socket and handshake parameters for this request
287  * @flags: memory allocation control flags
288  *
289  * Return values:
290  *   %0: Handshake request enqueue; ->done will be called when complete
291  *   %-ESRCH: No user agent is available
292  *   %-ENOMEM: Memory allocation failed
293  */
294 int tls_client_hello_x509(const struct tls_handshake_args *args, gfp_t flags)
295 {
296         struct tls_handshake_req *treq;
297         struct handshake_req *req;
298
299         req = handshake_req_alloc(&tls_handshake_proto, flags);
300         if (!req)
301                 return -ENOMEM;
302         treq = tls_handshake_req_init(req, args);
303         treq->th_type = HANDSHAKE_MSG_TYPE_CLIENTHELLO;
304         treq->th_auth_mode = HANDSHAKE_AUTH_X509;
305         treq->th_certificate = args->ta_my_cert;
306         treq->th_privkey = args->ta_my_privkey;
307
308         return handshake_req_submit(args->ta_sock, req, flags);
309 }
310 EXPORT_SYMBOL(tls_client_hello_x509);
311
312 /**
313  * tls_client_hello_psk - request a PSK-based TLS handshake on a socket
314  * @args: socket and handshake parameters for this request
315  * @flags: memory allocation control flags
316  *
317  * Return values:
318  *   %0: Handshake request enqueue; ->done will be called when complete
319  *   %-EINVAL: Wrong number of local peer IDs
320  *   %-ESRCH: No user agent is available
321  *   %-ENOMEM: Memory allocation failed
322  */
323 int tls_client_hello_psk(const struct tls_handshake_args *args, gfp_t flags)
324 {
325         struct tls_handshake_req *treq;
326         struct handshake_req *req;
327         unsigned int i;
328
329         if (!args->ta_num_peerids ||
330             args->ta_num_peerids > ARRAY_SIZE(treq->th_peerid))
331                 return -EINVAL;
332
333         req = handshake_req_alloc(&tls_handshake_proto, flags);
334         if (!req)
335                 return -ENOMEM;
336         treq = tls_handshake_req_init(req, args);
337         treq->th_type = HANDSHAKE_MSG_TYPE_CLIENTHELLO;
338         treq->th_auth_mode = HANDSHAKE_AUTH_PSK;
339         treq->th_num_peerids = args->ta_num_peerids;
340         for (i = 0; i < args->ta_num_peerids; i++)
341                 treq->th_peerid[i] = args->ta_my_peerids[i];
342
343         return handshake_req_submit(args->ta_sock, req, flags);
344 }
345 EXPORT_SYMBOL(tls_client_hello_psk);
346
347 /**
348  * tls_server_hello_x509 - request a server TLS handshake on a socket
349  * @args: socket and handshake parameters for this request
350  * @flags: memory allocation control flags
351  *
352  * Return values:
353  *   %0: Handshake request enqueue; ->done will be called when complete
354  *   %-ESRCH: No user agent is available
355  *   %-ENOMEM: Memory allocation failed
356  */
357 int tls_server_hello_x509(const struct tls_handshake_args *args, gfp_t flags)
358 {
359         struct tls_handshake_req *treq;
360         struct handshake_req *req;
361
362         req = handshake_req_alloc(&tls_handshake_proto, flags);
363         if (!req)
364                 return -ENOMEM;
365         treq = tls_handshake_req_init(req, args);
366         treq->th_type = HANDSHAKE_MSG_TYPE_SERVERHELLO;
367         treq->th_auth_mode = HANDSHAKE_AUTH_X509;
368         treq->th_certificate = args->ta_my_cert;
369         treq->th_privkey = args->ta_my_privkey;
370
371         return handshake_req_submit(args->ta_sock, req, flags);
372 }
373 EXPORT_SYMBOL(tls_server_hello_x509);
374
375 /**
376  * tls_server_hello_psk - request a server TLS handshake on a socket
377  * @args: socket and handshake parameters for this request
378  * @flags: memory allocation control flags
379  *
380  * Return values:
381  *   %0: Handshake request enqueue; ->done will be called when complete
382  *   %-ESRCH: No user agent is available
383  *   %-ENOMEM: Memory allocation failed
384  */
385 int tls_server_hello_psk(const struct tls_handshake_args *args, gfp_t flags)
386 {
387         struct tls_handshake_req *treq;
388         struct handshake_req *req;
389
390         req = handshake_req_alloc(&tls_handshake_proto, flags);
391         if (!req)
392                 return -ENOMEM;
393         treq = tls_handshake_req_init(req, args);
394         treq->th_type = HANDSHAKE_MSG_TYPE_SERVERHELLO;
395         treq->th_auth_mode = HANDSHAKE_AUTH_PSK;
396         treq->th_num_peerids = 1;
397         treq->th_peerid[0] = args->ta_my_peerids[0];
398
399         return handshake_req_submit(args->ta_sock, req, flags);
400 }
401 EXPORT_SYMBOL(tls_server_hello_psk);
402
403 /**
404  * tls_handshake_cancel - cancel a pending handshake
405  * @sk: socket on which there is an ongoing handshake
406  *
407  * Request cancellation races with request completion. To determine
408  * who won, callers examine the return value from this function.
409  *
410  * Return values:
411  *   %true - Uncompleted handshake request was canceled
412  *   %false - Handshake request already completed or not found
413  */
414 bool tls_handshake_cancel(struct sock *sk)
415 {
416         return handshake_req_cancel(sk);
417 }
418 EXPORT_SYMBOL(tls_handshake_cancel);