rxrpc: Drop rxrpc_conn_parameters from rxrpc_connection and rxrpc_bundle
[platform/kernel/linux-rpi.git] / net / rxrpc / rxkad.c
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /* Kerberos-based RxRPC security
3  *
4  * Copyright (C) 2007 Red Hat, Inc. All Rights Reserved.
5  * Written by David Howells (dhowells@redhat.com)
6  */
7
8 #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
9
10 #include <crypto/skcipher.h>
11 #include <linux/module.h>
12 #include <linux/net.h>
13 #include <linux/skbuff.h>
14 #include <linux/udp.h>
15 #include <linux/scatterlist.h>
16 #include <linux/ctype.h>
17 #include <linux/slab.h>
18 #include <linux/key-type.h>
19 #include <net/sock.h>
20 #include <net/af_rxrpc.h>
21 #include <keys/rxrpc-type.h>
22 #include "ar-internal.h"
23
24 #define RXKAD_VERSION                   2
25 #define MAXKRB5TICKETLEN                1024
26 #define RXKAD_TKT_TYPE_KERBEROS_V5      256
27 #define ANAME_SZ                        40      /* size of authentication name */
28 #define INST_SZ                         40      /* size of principal's instance */
29 #define REALM_SZ                        40      /* size of principal's auth domain */
30 #define SNAME_SZ                        40      /* size of service name */
31 #define RXKAD_ALIGN                     8
32
33 struct rxkad_level1_hdr {
34         __be32  data_size;      /* true data size (excluding padding) */
35 };
36
37 struct rxkad_level2_hdr {
38         __be32  data_size;      /* true data size (excluding padding) */
39         __be32  checksum;       /* decrypted data checksum */
40 };
41
42 static int rxkad_prime_packet_security(struct rxrpc_connection *conn,
43                                        struct crypto_sync_skcipher *ci);
44
45 /*
46  * this holds a pinned cipher so that keventd doesn't get called by the cipher
47  * alloc routine, but since we have it to hand, we use it to decrypt RESPONSE
48  * packets
49  */
50 static struct crypto_sync_skcipher *rxkad_ci;
51 static struct skcipher_request *rxkad_ci_req;
52 static DEFINE_MUTEX(rxkad_ci_mutex);
53
54 /*
55  * Parse the information from a server key
56  *
57  * The data should be the 8-byte secret key.
58  */
59 static int rxkad_preparse_server_key(struct key_preparsed_payload *prep)
60 {
61         struct crypto_skcipher *ci;
62
63         if (prep->datalen != 8)
64                 return -EINVAL;
65
66         memcpy(&prep->payload.data[2], prep->data, 8);
67
68         ci = crypto_alloc_skcipher("pcbc(des)", 0, CRYPTO_ALG_ASYNC);
69         if (IS_ERR(ci)) {
70                 _leave(" = %ld", PTR_ERR(ci));
71                 return PTR_ERR(ci);
72         }
73
74         if (crypto_skcipher_setkey(ci, prep->data, 8) < 0)
75                 BUG();
76
77         prep->payload.data[0] = ci;
78         _leave(" = 0");
79         return 0;
80 }
81
82 static void rxkad_free_preparse_server_key(struct key_preparsed_payload *prep)
83 {
84
85         if (prep->payload.data[0])
86                 crypto_free_skcipher(prep->payload.data[0]);
87 }
88
89 static void rxkad_destroy_server_key(struct key *key)
90 {
91         if (key->payload.data[0]) {
92                 crypto_free_skcipher(key->payload.data[0]);
93                 key->payload.data[0] = NULL;
94         }
95 }
96
97 /*
98  * initialise connection security
99  */
100 static int rxkad_init_connection_security(struct rxrpc_connection *conn,
101                                           struct rxrpc_key_token *token)
102 {
103         struct crypto_sync_skcipher *ci;
104         int ret;
105
106         _enter("{%d},{%x}", conn->debug_id, key_serial(conn->key));
107
108         conn->security_ix = token->security_index;
109
110         ci = crypto_alloc_sync_skcipher("pcbc(fcrypt)", 0, 0);
111         if (IS_ERR(ci)) {
112                 _debug("no cipher");
113                 ret = PTR_ERR(ci);
114                 goto error;
115         }
116
117         if (crypto_sync_skcipher_setkey(ci, token->kad->session_key,
118                                    sizeof(token->kad->session_key)) < 0)
119                 BUG();
120
121         switch (conn->security_level) {
122         case RXRPC_SECURITY_PLAIN:
123         case RXRPC_SECURITY_AUTH:
124         case RXRPC_SECURITY_ENCRYPT:
125                 break;
126         default:
127                 ret = -EKEYREJECTED;
128                 goto error;
129         }
130
131         ret = rxkad_prime_packet_security(conn, ci);
132         if (ret < 0)
133                 goto error_ci;
134
135         conn->rxkad.cipher = ci;
136         return 0;
137
138 error_ci:
139         crypto_free_sync_skcipher(ci);
140 error:
141         _leave(" = %d", ret);
142         return ret;
143 }
144
145 /*
146  * Work out how much data we can put in a packet.
147  */
148 static int rxkad_how_much_data(struct rxrpc_call *call, size_t remain,
149                                size_t *_buf_size, size_t *_data_size, size_t *_offset)
150 {
151         size_t shdr, buf_size, chunk;
152
153         switch (call->conn->security_level) {
154         default:
155                 buf_size = chunk = min_t(size_t, remain, RXRPC_JUMBO_DATALEN);
156                 shdr = 0;
157                 goto out;
158         case RXRPC_SECURITY_AUTH:
159                 shdr = sizeof(struct rxkad_level1_hdr);
160                 break;
161         case RXRPC_SECURITY_ENCRYPT:
162                 shdr = sizeof(struct rxkad_level2_hdr);
163                 break;
164         }
165
166         buf_size = round_down(RXRPC_JUMBO_DATALEN, RXKAD_ALIGN);
167
168         chunk = buf_size - shdr;
169         if (remain < chunk)
170                 buf_size = round_up(shdr + remain, RXKAD_ALIGN);
171
172 out:
173         *_buf_size = buf_size;
174         *_data_size = chunk;
175         *_offset = shdr;
176         return 0;
177 }
178
179 /*
180  * prime the encryption state with the invariant parts of a connection's
181  * description
182  */
183 static int rxkad_prime_packet_security(struct rxrpc_connection *conn,
184                                        struct crypto_sync_skcipher *ci)
185 {
186         struct skcipher_request *req;
187         struct rxrpc_key_token *token;
188         struct scatterlist sg;
189         struct rxrpc_crypt iv;
190         __be32 *tmpbuf;
191         size_t tmpsize = 4 * sizeof(__be32);
192
193         _enter("");
194
195         if (!conn->key)
196                 return 0;
197
198         tmpbuf = kmalloc(tmpsize, GFP_KERNEL);
199         if (!tmpbuf)
200                 return -ENOMEM;
201
202         req = skcipher_request_alloc(&ci->base, GFP_NOFS);
203         if (!req) {
204                 kfree(tmpbuf);
205                 return -ENOMEM;
206         }
207
208         token = conn->key->payload.data[0];
209         memcpy(&iv, token->kad->session_key, sizeof(iv));
210
211         tmpbuf[0] = htonl(conn->proto.epoch);
212         tmpbuf[1] = htonl(conn->proto.cid);
213         tmpbuf[2] = 0;
214         tmpbuf[3] = htonl(conn->security_ix);
215
216         sg_init_one(&sg, tmpbuf, tmpsize);
217         skcipher_request_set_sync_tfm(req, ci);
218         skcipher_request_set_callback(req, 0, NULL, NULL);
219         skcipher_request_set_crypt(req, &sg, &sg, tmpsize, iv.x);
220         crypto_skcipher_encrypt(req);
221         skcipher_request_free(req);
222
223         memcpy(&conn->rxkad.csum_iv, tmpbuf + 2, sizeof(conn->rxkad.csum_iv));
224         kfree(tmpbuf);
225         _leave(" = 0");
226         return 0;
227 }
228
229 /*
230  * Allocate and prepare the crypto request on a call.  For any particular call,
231  * this is called serially for the packets, so no lock should be necessary.
232  */
233 static struct skcipher_request *rxkad_get_call_crypto(struct rxrpc_call *call)
234 {
235         struct crypto_skcipher *tfm = &call->conn->rxkad.cipher->base;
236
237         return skcipher_request_alloc(tfm, GFP_NOFS);
238 }
239
240 /*
241  * Clean up the crypto on a call.
242  */
243 static void rxkad_free_call_crypto(struct rxrpc_call *call)
244 {
245 }
246
247 /*
248  * partially encrypt a packet (level 1 security)
249  */
250 static int rxkad_secure_packet_auth(const struct rxrpc_call *call,
251                                     struct rxrpc_txbuf *txb,
252                                     struct skcipher_request *req)
253 {
254         struct rxkad_level1_hdr *hdr = (void *)txb->data;
255         struct rxrpc_crypt iv;
256         struct scatterlist sg;
257         size_t pad;
258         u16 check;
259
260         _enter("");
261
262         check = txb->seq ^ ntohl(txb->wire.callNumber);
263         hdr->data_size = htonl((u32)check << 16 | txb->len);
264
265         txb->len += sizeof(struct rxkad_level1_hdr);
266         pad = txb->len;
267         pad = RXKAD_ALIGN - pad;
268         pad &= RXKAD_ALIGN - 1;
269         if (pad) {
270                 memset(txb->data + txb->offset, 0, pad);
271                 txb->len += pad;
272         }
273
274         /* start the encryption afresh */
275         memset(&iv, 0, sizeof(iv));
276
277         sg_init_one(&sg, txb->data, 8);
278         skcipher_request_set_sync_tfm(req, call->conn->rxkad.cipher);
279         skcipher_request_set_callback(req, 0, NULL, NULL);
280         skcipher_request_set_crypt(req, &sg, &sg, 8, iv.x);
281         crypto_skcipher_encrypt(req);
282         skcipher_request_zero(req);
283
284         _leave(" = 0");
285         return 0;
286 }
287
288 /*
289  * wholly encrypt a packet (level 2 security)
290  */
291 static int rxkad_secure_packet_encrypt(const struct rxrpc_call *call,
292                                        struct rxrpc_txbuf *txb,
293                                        struct skcipher_request *req)
294 {
295         const struct rxrpc_key_token *token;
296         struct rxkad_level2_hdr *rxkhdr = (void *)txb->data;
297         struct rxrpc_crypt iv;
298         struct scatterlist sg;
299         size_t pad;
300         u16 check;
301         int ret;
302
303         _enter("");
304
305         check = txb->seq ^ ntohl(txb->wire.callNumber);
306
307         rxkhdr->data_size = htonl(txb->len | (u32)check << 16);
308         rxkhdr->checksum = 0;
309
310         txb->len += sizeof(struct rxkad_level2_hdr);
311         pad = txb->len;
312         pad = RXKAD_ALIGN - pad;
313         pad &= RXKAD_ALIGN - 1;
314         if (pad) {
315                 memset(txb->data + txb->offset, 0, pad);
316                 txb->len += pad;
317         }
318
319         /* encrypt from the session key */
320         token = call->conn->key->payload.data[0];
321         memcpy(&iv, token->kad->session_key, sizeof(iv));
322
323         sg_init_one(&sg, txb->data, txb->len);
324         skcipher_request_set_sync_tfm(req, call->conn->rxkad.cipher);
325         skcipher_request_set_callback(req, 0, NULL, NULL);
326         skcipher_request_set_crypt(req, &sg, &sg, txb->len, iv.x);
327         ret = crypto_skcipher_encrypt(req);
328         skcipher_request_zero(req);
329         return ret;
330 }
331
332 /*
333  * checksum an RxRPC packet header
334  */
335 static int rxkad_secure_packet(struct rxrpc_call *call, struct rxrpc_txbuf *txb)
336 {
337         struct skcipher_request *req;
338         struct rxrpc_crypt iv;
339         struct scatterlist sg;
340         union {
341                 __be32 buf[2];
342         } crypto __aligned(8);
343         u32 x, y;
344         int ret;
345
346         _enter("{%d{%x}},{#%u},%u,",
347                call->debug_id, key_serial(call->conn->key),
348                txb->seq, txb->len);
349
350         if (!call->conn->rxkad.cipher)
351                 return 0;
352
353         ret = key_validate(call->conn->key);
354         if (ret < 0)
355                 return ret;
356
357         req = rxkad_get_call_crypto(call);
358         if (!req)
359                 return -ENOMEM;
360
361         /* continue encrypting from where we left off */
362         memcpy(&iv, call->conn->rxkad.csum_iv.x, sizeof(iv));
363
364         /* calculate the security checksum */
365         x = (ntohl(txb->wire.cid) & RXRPC_CHANNELMASK) << (32 - RXRPC_CIDSHIFT);
366         x |= txb->seq & 0x3fffffff;
367         crypto.buf[0] = txb->wire.callNumber;
368         crypto.buf[1] = htonl(x);
369
370         sg_init_one(&sg, crypto.buf, 8);
371         skcipher_request_set_sync_tfm(req, call->conn->rxkad.cipher);
372         skcipher_request_set_callback(req, 0, NULL, NULL);
373         skcipher_request_set_crypt(req, &sg, &sg, 8, iv.x);
374         crypto_skcipher_encrypt(req);
375         skcipher_request_zero(req);
376
377         y = ntohl(crypto.buf[1]);
378         y = (y >> 16) & 0xffff;
379         if (y == 0)
380                 y = 1; /* zero checksums are not permitted */
381         txb->wire.cksum = htons(y);
382
383         switch (call->conn->security_level) {
384         case RXRPC_SECURITY_PLAIN:
385                 ret = 0;
386                 break;
387         case RXRPC_SECURITY_AUTH:
388                 ret = rxkad_secure_packet_auth(call, txb, req);
389                 break;
390         case RXRPC_SECURITY_ENCRYPT:
391                 ret = rxkad_secure_packet_encrypt(call, txb, req);
392                 break;
393         default:
394                 ret = -EPERM;
395                 break;
396         }
397
398         skcipher_request_free(req);
399         _leave(" = %d [set %x]", ret, y);
400         return ret;
401 }
402
403 /*
404  * decrypt partial encryption on a packet (level 1 security)
405  */
406 static int rxkad_verify_packet_1(struct rxrpc_call *call, struct sk_buff *skb,
407                                  rxrpc_seq_t seq,
408                                  struct skcipher_request *req)
409 {
410         struct rxkad_level1_hdr sechdr;
411         struct rxrpc_skb_priv *sp = rxrpc_skb(skb);
412         struct rxrpc_crypt iv;
413         struct scatterlist sg[16];
414         bool aborted;
415         u32 data_size, buf;
416         u16 check;
417         int ret;
418
419         _enter("");
420
421         if (sp->len < 8) {
422                 aborted = rxrpc_abort_eproto(call, skb, "rxkad_1_hdr", "V1H",
423                                              RXKADSEALEDINCON);
424                 goto protocol_error;
425         }
426
427         /* Decrypt the skbuff in-place.  TODO: We really want to decrypt
428          * directly into the target buffer.
429          */
430         sg_init_table(sg, ARRAY_SIZE(sg));
431         ret = skb_to_sgvec(skb, sg, sp->offset, 8);
432         if (unlikely(ret < 0))
433                 return ret;
434
435         /* start the decryption afresh */
436         memset(&iv, 0, sizeof(iv));
437
438         skcipher_request_set_sync_tfm(req, call->conn->rxkad.cipher);
439         skcipher_request_set_callback(req, 0, NULL, NULL);
440         skcipher_request_set_crypt(req, sg, sg, 8, iv.x);
441         crypto_skcipher_decrypt(req);
442         skcipher_request_zero(req);
443
444         /* Extract the decrypted packet length */
445         if (skb_copy_bits(skb, sp->offset, &sechdr, sizeof(sechdr)) < 0) {
446                 aborted = rxrpc_abort_eproto(call, skb, "rxkad_1_len", "XV1",
447                                              RXKADDATALEN);
448                 goto protocol_error;
449         }
450         sp->offset += sizeof(sechdr);
451         sp->len    -= sizeof(sechdr);
452
453         buf = ntohl(sechdr.data_size);
454         data_size = buf & 0xffff;
455
456         check = buf >> 16;
457         check ^= seq ^ call->call_id;
458         check &= 0xffff;
459         if (check != 0) {
460                 aborted = rxrpc_abort_eproto(call, skb, "rxkad_1_check", "V1C",
461                                              RXKADSEALEDINCON);
462                 goto protocol_error;
463         }
464
465         if (data_size > sp->len) {
466                 aborted = rxrpc_abort_eproto(call, skb, "rxkad_1_datalen", "V1L",
467                                              RXKADDATALEN);
468                 goto protocol_error;
469         }
470         sp->len = data_size;
471
472         _leave(" = 0 [dlen=%x]", data_size);
473         return 0;
474
475 protocol_error:
476         if (aborted)
477                 rxrpc_send_abort_packet(call);
478         return -EPROTO;
479 }
480
481 /*
482  * wholly decrypt a packet (level 2 security)
483  */
484 static int rxkad_verify_packet_2(struct rxrpc_call *call, struct sk_buff *skb,
485                                  rxrpc_seq_t seq,
486                                  struct skcipher_request *req)
487 {
488         const struct rxrpc_key_token *token;
489         struct rxkad_level2_hdr sechdr;
490         struct rxrpc_skb_priv *sp = rxrpc_skb(skb);
491         struct rxrpc_crypt iv;
492         struct scatterlist _sg[4], *sg;
493         bool aborted;
494         u32 data_size, buf;
495         u16 check;
496         int nsg, ret;
497
498         _enter(",{%d}", sp->len);
499
500         if (sp->len < 8) {
501                 aborted = rxrpc_abort_eproto(call, skb, "rxkad_2_hdr", "V2H",
502                                              RXKADSEALEDINCON);
503                 goto protocol_error;
504         }
505
506         /* Decrypt the skbuff in-place.  TODO: We really want to decrypt
507          * directly into the target buffer.
508          */
509         sg = _sg;
510         nsg = skb_shinfo(skb)->nr_frags + 1;
511         if (nsg <= 4) {
512                 nsg = 4;
513         } else {
514                 sg = kmalloc_array(nsg, sizeof(*sg), GFP_NOIO);
515                 if (!sg)
516                         goto nomem;
517         }
518
519         sg_init_table(sg, nsg);
520         ret = skb_to_sgvec(skb, sg, sp->offset, sp->len);
521         if (unlikely(ret < 0)) {
522                 if (sg != _sg)
523                         kfree(sg);
524                 return ret;
525         }
526
527         /* decrypt from the session key */
528         token = call->conn->key->payload.data[0];
529         memcpy(&iv, token->kad->session_key, sizeof(iv));
530
531         skcipher_request_set_sync_tfm(req, call->conn->rxkad.cipher);
532         skcipher_request_set_callback(req, 0, NULL, NULL);
533         skcipher_request_set_crypt(req, sg, sg, sp->len, iv.x);
534         crypto_skcipher_decrypt(req);
535         skcipher_request_zero(req);
536         if (sg != _sg)
537                 kfree(sg);
538
539         /* Extract the decrypted packet length */
540         if (skb_copy_bits(skb, sp->offset, &sechdr, sizeof(sechdr)) < 0) {
541                 aborted = rxrpc_abort_eproto(call, skb, "rxkad_2_len", "XV2",
542                                              RXKADDATALEN);
543                 goto protocol_error;
544         }
545         sp->offset += sizeof(sechdr);
546         sp->len    -= sizeof(sechdr);
547
548         buf = ntohl(sechdr.data_size);
549         data_size = buf & 0xffff;
550
551         check = buf >> 16;
552         check ^= seq ^ call->call_id;
553         check &= 0xffff;
554         if (check != 0) {
555                 aborted = rxrpc_abort_eproto(call, skb, "rxkad_2_check", "V2C",
556                                              RXKADSEALEDINCON);
557                 goto protocol_error;
558         }
559
560         if (data_size > sp->len) {
561                 aborted = rxrpc_abort_eproto(call, skb, "rxkad_2_datalen", "V2L",
562                                              RXKADDATALEN);
563                 goto protocol_error;
564         }
565
566         sp->len = data_size;
567         _leave(" = 0 [dlen=%x]", data_size);
568         return 0;
569
570 protocol_error:
571         if (aborted)
572                 rxrpc_send_abort_packet(call);
573         return -EPROTO;
574
575 nomem:
576         _leave(" = -ENOMEM");
577         return -ENOMEM;
578 }
579
580 /*
581  * Verify the security on a received packet and the subpackets therein.
582  */
583 static int rxkad_verify_packet(struct rxrpc_call *call, struct sk_buff *skb)
584 {
585         struct rxrpc_skb_priv *sp = rxrpc_skb(skb);
586         struct skcipher_request *req;
587         struct rxrpc_crypt iv;
588         struct scatterlist sg;
589         union {
590                 __be32 buf[2];
591         } crypto __aligned(8);
592         rxrpc_seq_t seq = sp->hdr.seq;
593         bool aborted;
594         int ret;
595         u16 cksum;
596         u32 x, y;
597
598         _enter("{%d{%x}},{#%u}",
599                call->debug_id, key_serial(call->conn->key), seq);
600
601         if (!call->conn->rxkad.cipher)
602                 return 0;
603
604         req = rxkad_get_call_crypto(call);
605         if (!req)
606                 return -ENOMEM;
607
608         /* continue encrypting from where we left off */
609         memcpy(&iv, call->conn->rxkad.csum_iv.x, sizeof(iv));
610
611         /* validate the security checksum */
612         x = (call->cid & RXRPC_CHANNELMASK) << (32 - RXRPC_CIDSHIFT);
613         x |= seq & 0x3fffffff;
614         crypto.buf[0] = htonl(call->call_id);
615         crypto.buf[1] = htonl(x);
616
617         sg_init_one(&sg, crypto.buf, 8);
618         skcipher_request_set_sync_tfm(req, call->conn->rxkad.cipher);
619         skcipher_request_set_callback(req, 0, NULL, NULL);
620         skcipher_request_set_crypt(req, &sg, &sg, 8, iv.x);
621         crypto_skcipher_encrypt(req);
622         skcipher_request_zero(req);
623
624         y = ntohl(crypto.buf[1]);
625         cksum = (y >> 16) & 0xffff;
626         if (cksum == 0)
627                 cksum = 1; /* zero checksums are not permitted */
628
629         if (cksum != sp->hdr.cksum) {
630                 aborted = rxrpc_abort_eproto(call, skb, "rxkad_csum", "VCK",
631                                              RXKADSEALEDINCON);
632                 goto protocol_error;
633         }
634
635         switch (call->conn->security_level) {
636         case RXRPC_SECURITY_PLAIN:
637                 ret = 0;
638                 break;
639         case RXRPC_SECURITY_AUTH:
640                 ret = rxkad_verify_packet_1(call, skb, seq, req);
641                 break;
642         case RXRPC_SECURITY_ENCRYPT:
643                 ret = rxkad_verify_packet_2(call, skb, seq, req);
644                 break;
645         default:
646                 ret = -ENOANO;
647                 break;
648         }
649
650         skcipher_request_free(req);
651         return ret;
652
653 protocol_error:
654         if (aborted)
655                 rxrpc_send_abort_packet(call);
656         return -EPROTO;
657 }
658
659 /*
660  * issue a challenge
661  */
662 static int rxkad_issue_challenge(struct rxrpc_connection *conn)
663 {
664         struct rxkad_challenge challenge;
665         struct rxrpc_wire_header whdr;
666         struct msghdr msg;
667         struct kvec iov[2];
668         size_t len;
669         u32 serial;
670         int ret;
671
672         _enter("{%d}", conn->debug_id);
673
674         get_random_bytes(&conn->rxkad.nonce, sizeof(conn->rxkad.nonce));
675
676         challenge.version       = htonl(2);
677         challenge.nonce         = htonl(conn->rxkad.nonce);
678         challenge.min_level     = htonl(0);
679         challenge.__padding     = 0;
680
681         msg.msg_name    = &conn->peer->srx.transport;
682         msg.msg_namelen = conn->peer->srx.transport_len;
683         msg.msg_control = NULL;
684         msg.msg_controllen = 0;
685         msg.msg_flags   = 0;
686
687         whdr.epoch      = htonl(conn->proto.epoch);
688         whdr.cid        = htonl(conn->proto.cid);
689         whdr.callNumber = 0;
690         whdr.seq        = 0;
691         whdr.type       = RXRPC_PACKET_TYPE_CHALLENGE;
692         whdr.flags      = conn->out_clientflag;
693         whdr.userStatus = 0;
694         whdr.securityIndex = conn->security_ix;
695         whdr._rsvd      = 0;
696         whdr.serviceId  = htons(conn->service_id);
697
698         iov[0].iov_base = &whdr;
699         iov[0].iov_len  = sizeof(whdr);
700         iov[1].iov_base = &challenge;
701         iov[1].iov_len  = sizeof(challenge);
702
703         len = iov[0].iov_len + iov[1].iov_len;
704
705         serial = atomic_inc_return(&conn->serial);
706         whdr.serial = htonl(serial);
707
708         ret = kernel_sendmsg(conn->local->socket, &msg, iov, 2, len);
709         if (ret < 0) {
710                 trace_rxrpc_tx_fail(conn->debug_id, serial, ret,
711                                     rxrpc_tx_point_rxkad_challenge);
712                 return -EAGAIN;
713         }
714
715         conn->peer->last_tx_at = ktime_get_seconds();
716         trace_rxrpc_tx_packet(conn->debug_id, &whdr,
717                               rxrpc_tx_point_rxkad_challenge);
718         _leave(" = 0");
719         return 0;
720 }
721
722 /*
723  * send a Kerberos security response
724  */
725 static int rxkad_send_response(struct rxrpc_connection *conn,
726                                struct rxrpc_host_header *hdr,
727                                struct rxkad_response *resp,
728                                const struct rxkad_key *s2)
729 {
730         struct rxrpc_wire_header whdr;
731         struct msghdr msg;
732         struct kvec iov[3];
733         size_t len;
734         u32 serial;
735         int ret;
736
737         _enter("");
738
739         msg.msg_name    = &conn->peer->srx.transport;
740         msg.msg_namelen = conn->peer->srx.transport_len;
741         msg.msg_control = NULL;
742         msg.msg_controllen = 0;
743         msg.msg_flags   = 0;
744
745         memset(&whdr, 0, sizeof(whdr));
746         whdr.epoch      = htonl(hdr->epoch);
747         whdr.cid        = htonl(hdr->cid);
748         whdr.type       = RXRPC_PACKET_TYPE_RESPONSE;
749         whdr.flags      = conn->out_clientflag;
750         whdr.securityIndex = hdr->securityIndex;
751         whdr.serviceId  = htons(hdr->serviceId);
752
753         iov[0].iov_base = &whdr;
754         iov[0].iov_len  = sizeof(whdr);
755         iov[1].iov_base = resp;
756         iov[1].iov_len  = sizeof(*resp);
757         iov[2].iov_base = (void *)s2->ticket;
758         iov[2].iov_len  = s2->ticket_len;
759
760         len = iov[0].iov_len + iov[1].iov_len + iov[2].iov_len;
761
762         serial = atomic_inc_return(&conn->serial);
763         whdr.serial = htonl(serial);
764
765         ret = kernel_sendmsg(conn->local->socket, &msg, iov, 3, len);
766         if (ret < 0) {
767                 trace_rxrpc_tx_fail(conn->debug_id, serial, ret,
768                                     rxrpc_tx_point_rxkad_response);
769                 return -EAGAIN;
770         }
771
772         conn->peer->last_tx_at = ktime_get_seconds();
773         _leave(" = 0");
774         return 0;
775 }
776
777 /*
778  * calculate the response checksum
779  */
780 static void rxkad_calc_response_checksum(struct rxkad_response *response)
781 {
782         u32 csum = 1000003;
783         int loop;
784         u8 *p = (u8 *) response;
785
786         for (loop = sizeof(*response); loop > 0; loop--)
787                 csum = csum * 0x10204081 + *p++;
788
789         response->encrypted.checksum = htonl(csum);
790 }
791
792 /*
793  * encrypt the response packet
794  */
795 static int rxkad_encrypt_response(struct rxrpc_connection *conn,
796                                   struct rxkad_response *resp,
797                                   const struct rxkad_key *s2)
798 {
799         struct skcipher_request *req;
800         struct rxrpc_crypt iv;
801         struct scatterlist sg[1];
802
803         req = skcipher_request_alloc(&conn->rxkad.cipher->base, GFP_NOFS);
804         if (!req)
805                 return -ENOMEM;
806
807         /* continue encrypting from where we left off */
808         memcpy(&iv, s2->session_key, sizeof(iv));
809
810         sg_init_table(sg, 1);
811         sg_set_buf(sg, &resp->encrypted, sizeof(resp->encrypted));
812         skcipher_request_set_sync_tfm(req, conn->rxkad.cipher);
813         skcipher_request_set_callback(req, 0, NULL, NULL);
814         skcipher_request_set_crypt(req, sg, sg, sizeof(resp->encrypted), iv.x);
815         crypto_skcipher_encrypt(req);
816         skcipher_request_free(req);
817         return 0;
818 }
819
820 /*
821  * respond to a challenge packet
822  */
823 static int rxkad_respond_to_challenge(struct rxrpc_connection *conn,
824                                       struct sk_buff *skb,
825                                       u32 *_abort_code)
826 {
827         const struct rxrpc_key_token *token;
828         struct rxkad_challenge challenge;
829         struct rxkad_response *resp;
830         struct rxrpc_skb_priv *sp = rxrpc_skb(skb);
831         const char *eproto;
832         u32 version, nonce, min_level, abort_code;
833         int ret;
834
835         _enter("{%d,%x}", conn->debug_id, key_serial(conn->key));
836
837         eproto = tracepoint_string("chall_no_key");
838         abort_code = RX_PROTOCOL_ERROR;
839         if (!conn->key)
840                 goto protocol_error;
841
842         abort_code = RXKADEXPIRED;
843         ret = key_validate(conn->key);
844         if (ret < 0)
845                 goto other_error;
846
847         eproto = tracepoint_string("chall_short");
848         abort_code = RXKADPACKETSHORT;
849         if (skb_copy_bits(skb, sizeof(struct rxrpc_wire_header),
850                           &challenge, sizeof(challenge)) < 0)
851                 goto protocol_error;
852
853         version = ntohl(challenge.version);
854         nonce = ntohl(challenge.nonce);
855         min_level = ntohl(challenge.min_level);
856
857         trace_rxrpc_rx_challenge(conn, sp->hdr.serial, version, nonce, min_level);
858
859         eproto = tracepoint_string("chall_ver");
860         abort_code = RXKADINCONSISTENCY;
861         if (version != RXKAD_VERSION)
862                 goto protocol_error;
863
864         abort_code = RXKADLEVELFAIL;
865         ret = -EACCES;
866         if (conn->security_level < min_level)
867                 goto other_error;
868
869         token = conn->key->payload.data[0];
870
871         /* build the response packet */
872         resp = kzalloc(sizeof(struct rxkad_response), GFP_NOFS);
873         if (!resp)
874                 return -ENOMEM;
875
876         resp->version                   = htonl(RXKAD_VERSION);
877         resp->encrypted.epoch           = htonl(conn->proto.epoch);
878         resp->encrypted.cid             = htonl(conn->proto.cid);
879         resp->encrypted.securityIndex   = htonl(conn->security_ix);
880         resp->encrypted.inc_nonce       = htonl(nonce + 1);
881         resp->encrypted.level           = htonl(conn->security_level);
882         resp->kvno                      = htonl(token->kad->kvno);
883         resp->ticket_len                = htonl(token->kad->ticket_len);
884         resp->encrypted.call_id[0]      = htonl(conn->channels[0].call_counter);
885         resp->encrypted.call_id[1]      = htonl(conn->channels[1].call_counter);
886         resp->encrypted.call_id[2]      = htonl(conn->channels[2].call_counter);
887         resp->encrypted.call_id[3]      = htonl(conn->channels[3].call_counter);
888
889         /* calculate the response checksum and then do the encryption */
890         rxkad_calc_response_checksum(resp);
891         ret = rxkad_encrypt_response(conn, resp, token->kad);
892         if (ret == 0)
893                 ret = rxkad_send_response(conn, &sp->hdr, resp, token->kad);
894         kfree(resp);
895         return ret;
896
897 protocol_error:
898         trace_rxrpc_rx_eproto(NULL, sp->hdr.serial, eproto);
899         ret = -EPROTO;
900 other_error:
901         *_abort_code = abort_code;
902         return ret;
903 }
904
905 /*
906  * decrypt the kerberos IV ticket in the response
907  */
908 static int rxkad_decrypt_ticket(struct rxrpc_connection *conn,
909                                 struct key *server_key,
910                                 struct sk_buff *skb,
911                                 void *ticket, size_t ticket_len,
912                                 struct rxrpc_crypt *_session_key,
913                                 time64_t *_expiry,
914                                 u32 *_abort_code)
915 {
916         struct skcipher_request *req;
917         struct rxrpc_skb_priv *sp = rxrpc_skb(skb);
918         struct rxrpc_crypt iv, key;
919         struct scatterlist sg[1];
920         struct in_addr addr;
921         unsigned int life;
922         const char *eproto;
923         time64_t issue, now;
924         bool little_endian;
925         int ret;
926         u32 abort_code;
927         u8 *p, *q, *name, *end;
928
929         _enter("{%d},{%x}", conn->debug_id, key_serial(server_key));
930
931         *_expiry = 0;
932
933         ASSERT(server_key->payload.data[0] != NULL);
934         ASSERTCMP((unsigned long) ticket & 7UL, ==, 0);
935
936         memcpy(&iv, &server_key->payload.data[2], sizeof(iv));
937
938         ret = -ENOMEM;
939         req = skcipher_request_alloc(server_key->payload.data[0], GFP_NOFS);
940         if (!req)
941                 goto temporary_error;
942
943         sg_init_one(&sg[0], ticket, ticket_len);
944         skcipher_request_set_callback(req, 0, NULL, NULL);
945         skcipher_request_set_crypt(req, sg, sg, ticket_len, iv.x);
946         crypto_skcipher_decrypt(req);
947         skcipher_request_free(req);
948
949         p = ticket;
950         end = p + ticket_len;
951
952 #define Z(field)                                        \
953         ({                                              \
954                 u8 *__str = p;                          \
955                 eproto = tracepoint_string("rxkad_bad_"#field); \
956                 q = memchr(p, 0, end - p);              \
957                 if (!q || q - p > (field##_SZ))         \
958                         goto bad_ticket;                \
959                 for (; p < q; p++)                      \
960                         if (!isprint(*p))               \
961                                 goto bad_ticket;        \
962                 p++;                                    \
963                 __str;                                  \
964         })
965
966         /* extract the ticket flags */
967         _debug("KIV FLAGS: %x", *p);
968         little_endian = *p & 1;
969         p++;
970
971         /* extract the authentication name */
972         name = Z(ANAME);
973         _debug("KIV ANAME: %s", name);
974
975         /* extract the principal's instance */
976         name = Z(INST);
977         _debug("KIV INST : %s", name);
978
979         /* extract the principal's authentication domain */
980         name = Z(REALM);
981         _debug("KIV REALM: %s", name);
982
983         eproto = tracepoint_string("rxkad_bad_len");
984         if (end - p < 4 + 8 + 4 + 2)
985                 goto bad_ticket;
986
987         /* get the IPv4 address of the entity that requested the ticket */
988         memcpy(&addr, p, sizeof(addr));
989         p += 4;
990         _debug("KIV ADDR : %pI4", &addr);
991
992         /* get the session key from the ticket */
993         memcpy(&key, p, sizeof(key));
994         p += 8;
995         _debug("KIV KEY  : %08x %08x", ntohl(key.n[0]), ntohl(key.n[1]));
996         memcpy(_session_key, &key, sizeof(key));
997
998         /* get the ticket's lifetime */
999         life = *p++ * 5 * 60;
1000         _debug("KIV LIFE : %u", life);
1001
1002         /* get the issue time of the ticket */
1003         if (little_endian) {
1004                 __le32 stamp;
1005                 memcpy(&stamp, p, 4);
1006                 issue = rxrpc_u32_to_time64(le32_to_cpu(stamp));
1007         } else {
1008                 __be32 stamp;
1009                 memcpy(&stamp, p, 4);
1010                 issue = rxrpc_u32_to_time64(be32_to_cpu(stamp));
1011         }
1012         p += 4;
1013         now = ktime_get_real_seconds();
1014         _debug("KIV ISSUE: %llx [%llx]", issue, now);
1015
1016         /* check the ticket is in date */
1017         if (issue > now) {
1018                 abort_code = RXKADNOAUTH;
1019                 ret = -EKEYREJECTED;
1020                 goto other_error;
1021         }
1022
1023         if (issue < now - life) {
1024                 abort_code = RXKADEXPIRED;
1025                 ret = -EKEYEXPIRED;
1026                 goto other_error;
1027         }
1028
1029         *_expiry = issue + life;
1030
1031         /* get the service name */
1032         name = Z(SNAME);
1033         _debug("KIV SNAME: %s", name);
1034
1035         /* get the service instance name */
1036         name = Z(INST);
1037         _debug("KIV SINST: %s", name);
1038         return 0;
1039
1040 bad_ticket:
1041         trace_rxrpc_rx_eproto(NULL, sp->hdr.serial, eproto);
1042         abort_code = RXKADBADTICKET;
1043         ret = -EPROTO;
1044 other_error:
1045         *_abort_code = abort_code;
1046         return ret;
1047 temporary_error:
1048         return ret;
1049 }
1050
1051 /*
1052  * decrypt the response packet
1053  */
1054 static void rxkad_decrypt_response(struct rxrpc_connection *conn,
1055                                    struct rxkad_response *resp,
1056                                    const struct rxrpc_crypt *session_key)
1057 {
1058         struct skcipher_request *req = rxkad_ci_req;
1059         struct scatterlist sg[1];
1060         struct rxrpc_crypt iv;
1061
1062         _enter(",,%08x%08x",
1063                ntohl(session_key->n[0]), ntohl(session_key->n[1]));
1064
1065         mutex_lock(&rxkad_ci_mutex);
1066         if (crypto_sync_skcipher_setkey(rxkad_ci, session_key->x,
1067                                         sizeof(*session_key)) < 0)
1068                 BUG();
1069
1070         memcpy(&iv, session_key, sizeof(iv));
1071
1072         sg_init_table(sg, 1);
1073         sg_set_buf(sg, &resp->encrypted, sizeof(resp->encrypted));
1074         skcipher_request_set_sync_tfm(req, rxkad_ci);
1075         skcipher_request_set_callback(req, 0, NULL, NULL);
1076         skcipher_request_set_crypt(req, sg, sg, sizeof(resp->encrypted), iv.x);
1077         crypto_skcipher_decrypt(req);
1078         skcipher_request_zero(req);
1079
1080         mutex_unlock(&rxkad_ci_mutex);
1081
1082         _leave("");
1083 }
1084
1085 /*
1086  * verify a response
1087  */
1088 static int rxkad_verify_response(struct rxrpc_connection *conn,
1089                                  struct sk_buff *skb,
1090                                  u32 *_abort_code)
1091 {
1092         struct rxkad_response *response;
1093         struct rxrpc_skb_priv *sp = rxrpc_skb(skb);
1094         struct rxrpc_crypt session_key;
1095         struct key *server_key;
1096         const char *eproto;
1097         time64_t expiry;
1098         void *ticket;
1099         u32 abort_code, version, kvno, ticket_len, level;
1100         __be32 csum;
1101         int ret, i;
1102
1103         _enter("{%d}", conn->debug_id);
1104
1105         server_key = rxrpc_look_up_server_security(conn, skb, 0, 0);
1106         if (IS_ERR(server_key)) {
1107                 switch (PTR_ERR(server_key)) {
1108                 case -ENOKEY:
1109                         abort_code = RXKADUNKNOWNKEY;
1110                         break;
1111                 case -EKEYEXPIRED:
1112                         abort_code = RXKADEXPIRED;
1113                         break;
1114                 default:
1115                         abort_code = RXKADNOAUTH;
1116                         break;
1117                 }
1118                 trace_rxrpc_abort(0, "SVK",
1119                                   sp->hdr.cid, sp->hdr.callNumber, sp->hdr.seq,
1120                                   abort_code, PTR_ERR(server_key));
1121                 *_abort_code = abort_code;
1122                 return -EPROTO;
1123         }
1124
1125         ret = -ENOMEM;
1126         response = kzalloc(sizeof(struct rxkad_response), GFP_NOFS);
1127         if (!response)
1128                 goto temporary_error;
1129
1130         eproto = tracepoint_string("rxkad_rsp_short");
1131         abort_code = RXKADPACKETSHORT;
1132         if (skb_copy_bits(skb, sizeof(struct rxrpc_wire_header),
1133                           response, sizeof(*response)) < 0)
1134                 goto protocol_error;
1135
1136         version = ntohl(response->version);
1137         ticket_len = ntohl(response->ticket_len);
1138         kvno = ntohl(response->kvno);
1139
1140         trace_rxrpc_rx_response(conn, sp->hdr.serial, version, kvno, ticket_len);
1141
1142         eproto = tracepoint_string("rxkad_rsp_ver");
1143         abort_code = RXKADINCONSISTENCY;
1144         if (version != RXKAD_VERSION)
1145                 goto protocol_error;
1146
1147         eproto = tracepoint_string("rxkad_rsp_tktlen");
1148         abort_code = RXKADTICKETLEN;
1149         if (ticket_len < 4 || ticket_len > MAXKRB5TICKETLEN)
1150                 goto protocol_error;
1151
1152         eproto = tracepoint_string("rxkad_rsp_unkkey");
1153         abort_code = RXKADUNKNOWNKEY;
1154         if (kvno >= RXKAD_TKT_TYPE_KERBEROS_V5)
1155                 goto protocol_error;
1156
1157         /* extract the kerberos ticket and decrypt and decode it */
1158         ret = -ENOMEM;
1159         ticket = kmalloc(ticket_len, GFP_NOFS);
1160         if (!ticket)
1161                 goto temporary_error_free_resp;
1162
1163         eproto = tracepoint_string("rxkad_tkt_short");
1164         abort_code = RXKADPACKETSHORT;
1165         ret = skb_copy_bits(skb, sizeof(struct rxrpc_wire_header) + sizeof(*response),
1166                             ticket, ticket_len);
1167         if (ret < 0)
1168                 goto temporary_error_free_ticket;
1169
1170         ret = rxkad_decrypt_ticket(conn, server_key, skb, ticket, ticket_len,
1171                                    &session_key, &expiry, _abort_code);
1172         if (ret < 0)
1173                 goto temporary_error_free_ticket;
1174
1175         /* use the session key from inside the ticket to decrypt the
1176          * response */
1177         rxkad_decrypt_response(conn, response, &session_key);
1178
1179         eproto = tracepoint_string("rxkad_rsp_param");
1180         abort_code = RXKADSEALEDINCON;
1181         if (ntohl(response->encrypted.epoch) != conn->proto.epoch)
1182                 goto protocol_error_free;
1183         if (ntohl(response->encrypted.cid) != conn->proto.cid)
1184                 goto protocol_error_free;
1185         if (ntohl(response->encrypted.securityIndex) != conn->security_ix)
1186                 goto protocol_error_free;
1187         csum = response->encrypted.checksum;
1188         response->encrypted.checksum = 0;
1189         rxkad_calc_response_checksum(response);
1190         eproto = tracepoint_string("rxkad_rsp_csum");
1191         if (response->encrypted.checksum != csum)
1192                 goto protocol_error_free;
1193
1194         spin_lock(&conn->bundle->channel_lock);
1195         for (i = 0; i < RXRPC_MAXCALLS; i++) {
1196                 struct rxrpc_call *call;
1197                 u32 call_id = ntohl(response->encrypted.call_id[i]);
1198
1199                 eproto = tracepoint_string("rxkad_rsp_callid");
1200                 if (call_id > INT_MAX)
1201                         goto protocol_error_unlock;
1202
1203                 eproto = tracepoint_string("rxkad_rsp_callctr");
1204                 if (call_id < conn->channels[i].call_counter)
1205                         goto protocol_error_unlock;
1206
1207                 eproto = tracepoint_string("rxkad_rsp_callst");
1208                 if (call_id > conn->channels[i].call_counter) {
1209                         call = rcu_dereference_protected(
1210                                 conn->channels[i].call,
1211                                 lockdep_is_held(&conn->bundle->channel_lock));
1212                         if (call && call->state < RXRPC_CALL_COMPLETE)
1213                                 goto protocol_error_unlock;
1214                         conn->channels[i].call_counter = call_id;
1215                 }
1216         }
1217         spin_unlock(&conn->bundle->channel_lock);
1218
1219         eproto = tracepoint_string("rxkad_rsp_seq");
1220         abort_code = RXKADOUTOFSEQUENCE;
1221         if (ntohl(response->encrypted.inc_nonce) != conn->rxkad.nonce + 1)
1222                 goto protocol_error_free;
1223
1224         eproto = tracepoint_string("rxkad_rsp_level");
1225         abort_code = RXKADLEVELFAIL;
1226         level = ntohl(response->encrypted.level);
1227         if (level > RXRPC_SECURITY_ENCRYPT)
1228                 goto protocol_error_free;
1229         conn->security_level = level;
1230
1231         /* create a key to hold the security data and expiration time - after
1232          * this the connection security can be handled in exactly the same way
1233          * as for a client connection */
1234         ret = rxrpc_get_server_data_key(conn, &session_key, expiry, kvno);
1235         if (ret < 0)
1236                 goto temporary_error_free_ticket;
1237
1238         kfree(ticket);
1239         kfree(response);
1240         _leave(" = 0");
1241         return 0;
1242
1243 protocol_error_unlock:
1244         spin_unlock(&conn->bundle->channel_lock);
1245 protocol_error_free:
1246         kfree(ticket);
1247 protocol_error:
1248         kfree(response);
1249         trace_rxrpc_rx_eproto(NULL, sp->hdr.serial, eproto);
1250         key_put(server_key);
1251         *_abort_code = abort_code;
1252         return -EPROTO;
1253
1254 temporary_error_free_ticket:
1255         kfree(ticket);
1256 temporary_error_free_resp:
1257         kfree(response);
1258 temporary_error:
1259         /* Ignore the response packet if we got a temporary error such as
1260          * ENOMEM.  We just want to send the challenge again.  Note that we
1261          * also come out this way if the ticket decryption fails.
1262          */
1263         key_put(server_key);
1264         return ret;
1265 }
1266
1267 /*
1268  * clear the connection security
1269  */
1270 static void rxkad_clear(struct rxrpc_connection *conn)
1271 {
1272         _enter("");
1273
1274         if (conn->rxkad.cipher)
1275                 crypto_free_sync_skcipher(conn->rxkad.cipher);
1276 }
1277
1278 /*
1279  * Initialise the rxkad security service.
1280  */
1281 static int rxkad_init(void)
1282 {
1283         struct crypto_sync_skcipher *tfm;
1284         struct skcipher_request *req;
1285
1286         /* pin the cipher we need so that the crypto layer doesn't invoke
1287          * keventd to go get it */
1288         tfm = crypto_alloc_sync_skcipher("pcbc(fcrypt)", 0, 0);
1289         if (IS_ERR(tfm))
1290                 return PTR_ERR(tfm);
1291
1292         req = skcipher_request_alloc(&tfm->base, GFP_KERNEL);
1293         if (!req)
1294                 goto nomem_tfm;
1295
1296         rxkad_ci_req = req;
1297         rxkad_ci = tfm;
1298         return 0;
1299
1300 nomem_tfm:
1301         crypto_free_sync_skcipher(tfm);
1302         return -ENOMEM;
1303 }
1304
1305 /*
1306  * Clean up the rxkad security service.
1307  */
1308 static void rxkad_exit(void)
1309 {
1310         crypto_free_sync_skcipher(rxkad_ci);
1311         skcipher_request_free(rxkad_ci_req);
1312 }
1313
1314 /*
1315  * RxRPC Kerberos-based security
1316  */
1317 const struct rxrpc_security rxkad = {
1318         .name                           = "rxkad",
1319         .security_index                 = RXRPC_SECURITY_RXKAD,
1320         .no_key_abort                   = RXKADUNKNOWNKEY,
1321         .init                           = rxkad_init,
1322         .exit                           = rxkad_exit,
1323         .preparse_server_key            = rxkad_preparse_server_key,
1324         .free_preparse_server_key       = rxkad_free_preparse_server_key,
1325         .destroy_server_key             = rxkad_destroy_server_key,
1326         .init_connection_security       = rxkad_init_connection_security,
1327         .how_much_data                  = rxkad_how_much_data,
1328         .secure_packet                  = rxkad_secure_packet,
1329         .verify_packet                  = rxkad_verify_packet,
1330         .free_call_crypto               = rxkad_free_call_crypto,
1331         .issue_challenge                = rxkad_issue_challenge,
1332         .respond_to_challenge           = rxkad_respond_to_challenge,
1333         .verify_response                = rxkad_verify_response,
1334         .clear                          = rxkad_clear,
1335 };