Merge branch 'timers-urgent-for-linus' of git://git.kernel.org/pub/scm/linux/kernel...
[platform/kernel/linux-rpi.git] / crypto / algif_skcipher.c
1 /*
2  * algif_skcipher: User-space interface for skcipher algorithms
3  *
4  * This file provides the user-space API for symmetric key ciphers.
5  *
6  * Copyright (c) 2010 Herbert Xu <herbert@gondor.apana.org.au>
7  *
8  * This program is free software; you can redistribute it and/or modify it
9  * under the terms of the GNU General Public License as published by the Free
10  * Software Foundation; either version 2 of the License, or (at your option)
11  * any later version.
12  *
13  */
14
15 #include <crypto/scatterwalk.h>
16 #include <crypto/skcipher.h>
17 #include <crypto/if_alg.h>
18 #include <linux/init.h>
19 #include <linux/list.h>
20 #include <linux/kernel.h>
21 #include <linux/sched/signal.h>
22 #include <linux/mm.h>
23 #include <linux/module.h>
24 #include <linux/net.h>
25 #include <net/sock.h>
26
27 struct skcipher_sg_list {
28         struct list_head list;
29
30         int cur;
31
32         struct scatterlist sg[0];
33 };
34
35 struct skcipher_tfm {
36         struct crypto_skcipher *skcipher;
37         bool has_key;
38 };
39
40 struct skcipher_ctx {
41         struct list_head tsgl;
42         struct af_alg_sgl rsgl;
43
44         void *iv;
45
46         struct af_alg_completion completion;
47
48         atomic_t inflight;
49         size_t used;
50
51         unsigned int len;
52         bool more;
53         bool merge;
54         bool enc;
55
56         struct skcipher_request req;
57 };
58
59 struct skcipher_async_rsgl {
60         struct af_alg_sgl sgl;
61         struct list_head list;
62 };
63
64 struct skcipher_async_req {
65         struct kiocb *iocb;
66         struct skcipher_async_rsgl first_sgl;
67         struct list_head list;
68         struct scatterlist *tsg;
69         atomic_t *inflight;
70         struct skcipher_request req;
71 };
72
73 #define MAX_SGL_ENTS ((4096 - sizeof(struct skcipher_sg_list)) / \
74                       sizeof(struct scatterlist) - 1)
75
76 static void skcipher_free_async_sgls(struct skcipher_async_req *sreq)
77 {
78         struct skcipher_async_rsgl *rsgl, *tmp;
79         struct scatterlist *sgl;
80         struct scatterlist *sg;
81         int i, n;
82
83         list_for_each_entry_safe(rsgl, tmp, &sreq->list, list) {
84                 af_alg_free_sg(&rsgl->sgl);
85                 if (rsgl != &sreq->first_sgl)
86                         kfree(rsgl);
87         }
88         sgl = sreq->tsg;
89         n = sg_nents(sgl);
90         for_each_sg(sgl, sg, n, i)
91                 put_page(sg_page(sg));
92
93         kfree(sreq->tsg);
94 }
95
96 static void skcipher_async_cb(struct crypto_async_request *req, int err)
97 {
98         struct skcipher_async_req *sreq = req->data;
99         struct kiocb *iocb = sreq->iocb;
100
101         atomic_dec(sreq->inflight);
102         skcipher_free_async_sgls(sreq);
103         kzfree(sreq);
104         iocb->ki_complete(iocb, err, err);
105 }
106
107 static inline int skcipher_sndbuf(struct sock *sk)
108 {
109         struct alg_sock *ask = alg_sk(sk);
110         struct skcipher_ctx *ctx = ask->private;
111
112         return max_t(int, max_t(int, sk->sk_sndbuf & PAGE_MASK, PAGE_SIZE) -
113                           ctx->used, 0);
114 }
115
116 static inline bool skcipher_writable(struct sock *sk)
117 {
118         return PAGE_SIZE <= skcipher_sndbuf(sk);
119 }
120
121 static int skcipher_alloc_sgl(struct sock *sk)
122 {
123         struct alg_sock *ask = alg_sk(sk);
124         struct skcipher_ctx *ctx = ask->private;
125         struct skcipher_sg_list *sgl;
126         struct scatterlist *sg = NULL;
127
128         sgl = list_entry(ctx->tsgl.prev, struct skcipher_sg_list, list);
129         if (!list_empty(&ctx->tsgl))
130                 sg = sgl->sg;
131
132         if (!sg || sgl->cur >= MAX_SGL_ENTS) {
133                 sgl = sock_kmalloc(sk, sizeof(*sgl) +
134                                        sizeof(sgl->sg[0]) * (MAX_SGL_ENTS + 1),
135                                    GFP_KERNEL);
136                 if (!sgl)
137                         return -ENOMEM;
138
139                 sg_init_table(sgl->sg, MAX_SGL_ENTS + 1);
140                 sgl->cur = 0;
141
142                 if (sg)
143                         sg_chain(sg, MAX_SGL_ENTS + 1, sgl->sg);
144
145                 list_add_tail(&sgl->list, &ctx->tsgl);
146         }
147
148         return 0;
149 }
150
151 static void skcipher_pull_sgl(struct sock *sk, size_t used, int put)
152 {
153         struct alg_sock *ask = alg_sk(sk);
154         struct skcipher_ctx *ctx = ask->private;
155         struct skcipher_sg_list *sgl;
156         struct scatterlist *sg;
157         int i;
158
159         while (!list_empty(&ctx->tsgl)) {
160                 sgl = list_first_entry(&ctx->tsgl, struct skcipher_sg_list,
161                                        list);
162                 sg = sgl->sg;
163
164                 for (i = 0; i < sgl->cur; i++) {
165                         size_t plen = min_t(size_t, used, sg[i].length);
166
167                         if (!sg_page(sg + i))
168                                 continue;
169
170                         sg[i].length -= plen;
171                         sg[i].offset += plen;
172
173                         used -= plen;
174                         ctx->used -= plen;
175
176                         if (sg[i].length)
177                                 return;
178                         if (put)
179                                 put_page(sg_page(sg + i));
180                         sg_assign_page(sg + i, NULL);
181                 }
182
183                 list_del(&sgl->list);
184                 sock_kfree_s(sk, sgl,
185                              sizeof(*sgl) + sizeof(sgl->sg[0]) *
186                                             (MAX_SGL_ENTS + 1));
187         }
188
189         if (!ctx->used)
190                 ctx->merge = 0;
191 }
192
193 static void skcipher_free_sgl(struct sock *sk)
194 {
195         struct alg_sock *ask = alg_sk(sk);
196         struct skcipher_ctx *ctx = ask->private;
197
198         skcipher_pull_sgl(sk, ctx->used, 1);
199 }
200
201 static int skcipher_wait_for_wmem(struct sock *sk, unsigned flags)
202 {
203         DEFINE_WAIT_FUNC(wait, woken_wake_function);
204         int err = -ERESTARTSYS;
205         long timeout;
206
207         if (flags & MSG_DONTWAIT)
208                 return -EAGAIN;
209
210         sk_set_bit(SOCKWQ_ASYNC_NOSPACE, sk);
211
212         add_wait_queue(sk_sleep(sk), &wait);
213         for (;;) {
214                 if (signal_pending(current))
215                         break;
216                 timeout = MAX_SCHEDULE_TIMEOUT;
217                 if (sk_wait_event(sk, &timeout, skcipher_writable(sk), &wait)) {
218                         err = 0;
219                         break;
220                 }
221         }
222         remove_wait_queue(sk_sleep(sk), &wait);
223
224         return err;
225 }
226
227 static void skcipher_wmem_wakeup(struct sock *sk)
228 {
229         struct socket_wq *wq;
230
231         if (!skcipher_writable(sk))
232                 return;
233
234         rcu_read_lock();
235         wq = rcu_dereference(sk->sk_wq);
236         if (skwq_has_sleeper(wq))
237                 wake_up_interruptible_sync_poll(&wq->wait, POLLIN |
238                                                            POLLRDNORM |
239                                                            POLLRDBAND);
240         sk_wake_async(sk, SOCK_WAKE_WAITD, POLL_IN);
241         rcu_read_unlock();
242 }
243
244 static int skcipher_wait_for_data(struct sock *sk, unsigned flags)
245 {
246         DEFINE_WAIT_FUNC(wait, woken_wake_function);
247         struct alg_sock *ask = alg_sk(sk);
248         struct skcipher_ctx *ctx = ask->private;
249         long timeout;
250         int err = -ERESTARTSYS;
251
252         if (flags & MSG_DONTWAIT) {
253                 return -EAGAIN;
254         }
255
256         sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
257
258         add_wait_queue(sk_sleep(sk), &wait);
259         for (;;) {
260                 if (signal_pending(current))
261                         break;
262                 timeout = MAX_SCHEDULE_TIMEOUT;
263                 if (sk_wait_event(sk, &timeout, ctx->used, &wait)) {
264                         err = 0;
265                         break;
266                 }
267         }
268         remove_wait_queue(sk_sleep(sk), &wait);
269
270         sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
271
272         return err;
273 }
274
275 static void skcipher_data_wakeup(struct sock *sk)
276 {
277         struct alg_sock *ask = alg_sk(sk);
278         struct skcipher_ctx *ctx = ask->private;
279         struct socket_wq *wq;
280
281         if (!ctx->used)
282                 return;
283
284         rcu_read_lock();
285         wq = rcu_dereference(sk->sk_wq);
286         if (skwq_has_sleeper(wq))
287                 wake_up_interruptible_sync_poll(&wq->wait, POLLOUT |
288                                                            POLLRDNORM |
289                                                            POLLRDBAND);
290         sk_wake_async(sk, SOCK_WAKE_SPACE, POLL_OUT);
291         rcu_read_unlock();
292 }
293
294 static int skcipher_sendmsg(struct socket *sock, struct msghdr *msg,
295                             size_t size)
296 {
297         struct sock *sk = sock->sk;
298         struct alg_sock *ask = alg_sk(sk);
299         struct sock *psk = ask->parent;
300         struct alg_sock *pask = alg_sk(psk);
301         struct skcipher_ctx *ctx = ask->private;
302         struct skcipher_tfm *skc = pask->private;
303         struct crypto_skcipher *tfm = skc->skcipher;
304         unsigned ivsize = crypto_skcipher_ivsize(tfm);
305         struct skcipher_sg_list *sgl;
306         struct af_alg_control con = {};
307         long copied = 0;
308         bool enc = 0;
309         bool init = 0;
310         int err;
311         int i;
312
313         if (msg->msg_controllen) {
314                 err = af_alg_cmsg_send(msg, &con);
315                 if (err)
316                         return err;
317
318                 init = 1;
319                 switch (con.op) {
320                 case ALG_OP_ENCRYPT:
321                         enc = 1;
322                         break;
323                 case ALG_OP_DECRYPT:
324                         enc = 0;
325                         break;
326                 default:
327                         return -EINVAL;
328                 }
329
330                 if (con.iv && con.iv->ivlen != ivsize)
331                         return -EINVAL;
332         }
333
334         err = -EINVAL;
335
336         lock_sock(sk);
337         if (!ctx->more && ctx->used)
338                 goto unlock;
339
340         if (init) {
341                 ctx->enc = enc;
342                 if (con.iv)
343                         memcpy(ctx->iv, con.iv->iv, ivsize);
344         }
345
346         while (size) {
347                 struct scatterlist *sg;
348                 unsigned long len = size;
349                 size_t plen;
350
351                 if (ctx->merge) {
352                         sgl = list_entry(ctx->tsgl.prev,
353                                          struct skcipher_sg_list, list);
354                         sg = sgl->sg + sgl->cur - 1;
355                         len = min_t(unsigned long, len,
356                                     PAGE_SIZE - sg->offset - sg->length);
357
358                         err = memcpy_from_msg(page_address(sg_page(sg)) +
359                                               sg->offset + sg->length,
360                                               msg, len);
361                         if (err)
362                                 goto unlock;
363
364                         sg->length += len;
365                         ctx->merge = (sg->offset + sg->length) &
366                                      (PAGE_SIZE - 1);
367
368                         ctx->used += len;
369                         copied += len;
370                         size -= len;
371                         continue;
372                 }
373
374                 if (!skcipher_writable(sk)) {
375                         err = skcipher_wait_for_wmem(sk, msg->msg_flags);
376                         if (err)
377                                 goto unlock;
378                 }
379
380                 len = min_t(unsigned long, len, skcipher_sndbuf(sk));
381
382                 err = skcipher_alloc_sgl(sk);
383                 if (err)
384                         goto unlock;
385
386                 sgl = list_entry(ctx->tsgl.prev, struct skcipher_sg_list, list);
387                 sg = sgl->sg;
388                 if (sgl->cur)
389                         sg_unmark_end(sg + sgl->cur - 1);
390                 do {
391                         i = sgl->cur;
392                         plen = min_t(size_t, len, PAGE_SIZE);
393
394                         sg_assign_page(sg + i, alloc_page(GFP_KERNEL));
395                         err = -ENOMEM;
396                         if (!sg_page(sg + i))
397                                 goto unlock;
398
399                         err = memcpy_from_msg(page_address(sg_page(sg + i)),
400                                               msg, plen);
401                         if (err) {
402                                 __free_page(sg_page(sg + i));
403                                 sg_assign_page(sg + i, NULL);
404                                 goto unlock;
405                         }
406
407                         sg[i].length = plen;
408                         len -= plen;
409                         ctx->used += plen;
410                         copied += plen;
411                         size -= plen;
412                         sgl->cur++;
413                 } while (len && sgl->cur < MAX_SGL_ENTS);
414
415                 if (!size)
416                         sg_mark_end(sg + sgl->cur - 1);
417
418                 ctx->merge = plen & (PAGE_SIZE - 1);
419         }
420
421         err = 0;
422
423         ctx->more = msg->msg_flags & MSG_MORE;
424
425 unlock:
426         skcipher_data_wakeup(sk);
427         release_sock(sk);
428
429         return copied ?: err;
430 }
431
432 static ssize_t skcipher_sendpage(struct socket *sock, struct page *page,
433                                  int offset, size_t size, int flags)
434 {
435         struct sock *sk = sock->sk;
436         struct alg_sock *ask = alg_sk(sk);
437         struct skcipher_ctx *ctx = ask->private;
438         struct skcipher_sg_list *sgl;
439         int err = -EINVAL;
440
441         if (flags & MSG_SENDPAGE_NOTLAST)
442                 flags |= MSG_MORE;
443
444         lock_sock(sk);
445         if (!ctx->more && ctx->used)
446                 goto unlock;
447
448         if (!size)
449                 goto done;
450
451         if (!skcipher_writable(sk)) {
452                 err = skcipher_wait_for_wmem(sk, flags);
453                 if (err)
454                         goto unlock;
455         }
456
457         err = skcipher_alloc_sgl(sk);
458         if (err)
459                 goto unlock;
460
461         ctx->merge = 0;
462         sgl = list_entry(ctx->tsgl.prev, struct skcipher_sg_list, list);
463
464         if (sgl->cur)
465                 sg_unmark_end(sgl->sg + sgl->cur - 1);
466
467         sg_mark_end(sgl->sg + sgl->cur);
468         get_page(page);
469         sg_set_page(sgl->sg + sgl->cur, page, size, offset);
470         sgl->cur++;
471         ctx->used += size;
472
473 done:
474         ctx->more = flags & MSG_MORE;
475
476 unlock:
477         skcipher_data_wakeup(sk);
478         release_sock(sk);
479
480         return err ?: size;
481 }
482
483 static int skcipher_all_sg_nents(struct skcipher_ctx *ctx)
484 {
485         struct skcipher_sg_list *sgl;
486         struct scatterlist *sg;
487         int nents = 0;
488
489         list_for_each_entry(sgl, &ctx->tsgl, list) {
490                 sg = sgl->sg;
491
492                 while (!sg->length)
493                         sg++;
494
495                 nents += sg_nents(sg);
496         }
497         return nents;
498 }
499
500 static int skcipher_recvmsg_async(struct socket *sock, struct msghdr *msg,
501                                   int flags)
502 {
503         struct sock *sk = sock->sk;
504         struct alg_sock *ask = alg_sk(sk);
505         struct sock *psk = ask->parent;
506         struct alg_sock *pask = alg_sk(psk);
507         struct skcipher_ctx *ctx = ask->private;
508         struct skcipher_tfm *skc = pask->private;
509         struct crypto_skcipher *tfm = skc->skcipher;
510         struct skcipher_sg_list *sgl;
511         struct scatterlist *sg;
512         struct skcipher_async_req *sreq;
513         struct skcipher_request *req;
514         struct skcipher_async_rsgl *last_rsgl = NULL;
515         unsigned int txbufs = 0, len = 0, tx_nents;
516         unsigned int reqsize = crypto_skcipher_reqsize(tfm);
517         unsigned int ivsize = crypto_skcipher_ivsize(tfm);
518         int err = -ENOMEM;
519         bool mark = false;
520         char *iv;
521
522         sreq = kzalloc(sizeof(*sreq) + reqsize + ivsize, GFP_KERNEL);
523         if (unlikely(!sreq))
524                 goto out;
525
526         req = &sreq->req;
527         iv = (char *)(req + 1) + reqsize;
528         sreq->iocb = msg->msg_iocb;
529         INIT_LIST_HEAD(&sreq->list);
530         sreq->inflight = &ctx->inflight;
531
532         lock_sock(sk);
533         tx_nents = skcipher_all_sg_nents(ctx);
534         sreq->tsg = kcalloc(tx_nents, sizeof(*sg), GFP_KERNEL);
535         if (unlikely(!sreq->tsg))
536                 goto unlock;
537         sg_init_table(sreq->tsg, tx_nents);
538         memcpy(iv, ctx->iv, ivsize);
539         skcipher_request_set_tfm(req, tfm);
540         skcipher_request_set_callback(req, CRYPTO_TFM_REQ_MAY_SLEEP,
541                                       skcipher_async_cb, sreq);
542
543         while (iov_iter_count(&msg->msg_iter)) {
544                 struct skcipher_async_rsgl *rsgl;
545                 int used;
546
547                 if (!ctx->used) {
548                         err = skcipher_wait_for_data(sk, flags);
549                         if (err)
550                                 goto free;
551                 }
552                 sgl = list_first_entry(&ctx->tsgl,
553                                        struct skcipher_sg_list, list);
554                 sg = sgl->sg;
555
556                 while (!sg->length)
557                         sg++;
558
559                 used = min_t(unsigned long, ctx->used,
560                              iov_iter_count(&msg->msg_iter));
561                 used = min_t(unsigned long, used, sg->length);
562
563                 if (txbufs == tx_nents) {
564                         struct scatterlist *tmp;
565                         int x;
566                         /* Ran out of tx slots in async request
567                          * need to expand */
568                         tmp = kcalloc(tx_nents * 2, sizeof(*tmp),
569                                       GFP_KERNEL);
570                         if (!tmp) {
571                                 err = -ENOMEM;
572                                 goto free;
573                         }
574
575                         sg_init_table(tmp, tx_nents * 2);
576                         for (x = 0; x < tx_nents; x++)
577                                 sg_set_page(&tmp[x], sg_page(&sreq->tsg[x]),
578                                             sreq->tsg[x].length,
579                                             sreq->tsg[x].offset);
580                         kfree(sreq->tsg);
581                         sreq->tsg = tmp;
582                         tx_nents *= 2;
583                         mark = true;
584                 }
585                 /* Need to take over the tx sgl from ctx
586                  * to the asynch req - these sgls will be freed later */
587                 sg_set_page(sreq->tsg + txbufs++, sg_page(sg), sg->length,
588                             sg->offset);
589
590                 if (list_empty(&sreq->list)) {
591                         rsgl = &sreq->first_sgl;
592                         list_add_tail(&rsgl->list, &sreq->list);
593                 } else {
594                         rsgl = kmalloc(sizeof(*rsgl), GFP_KERNEL);
595                         if (!rsgl) {
596                                 err = -ENOMEM;
597                                 goto free;
598                         }
599                         list_add_tail(&rsgl->list, &sreq->list);
600                 }
601
602                 used = af_alg_make_sg(&rsgl->sgl, &msg->msg_iter, used);
603                 err = used;
604                 if (used < 0)
605                         goto free;
606                 if (last_rsgl)
607                         af_alg_link_sg(&last_rsgl->sgl, &rsgl->sgl);
608
609                 last_rsgl = rsgl;
610                 len += used;
611                 skcipher_pull_sgl(sk, used, 0);
612                 iov_iter_advance(&msg->msg_iter, used);
613         }
614
615         if (mark)
616                 sg_mark_end(sreq->tsg + txbufs - 1);
617
618         skcipher_request_set_crypt(req, sreq->tsg, sreq->first_sgl.sgl.sg,
619                                    len, iv);
620         err = ctx->enc ? crypto_skcipher_encrypt(req) :
621                          crypto_skcipher_decrypt(req);
622         if (err == -EINPROGRESS) {
623                 atomic_inc(&ctx->inflight);
624                 err = -EIOCBQUEUED;
625                 sreq = NULL;
626                 goto unlock;
627         }
628 free:
629         skcipher_free_async_sgls(sreq);
630 unlock:
631         skcipher_wmem_wakeup(sk);
632         release_sock(sk);
633         kzfree(sreq);
634 out:
635         return err;
636 }
637
638 static int skcipher_recvmsg_sync(struct socket *sock, struct msghdr *msg,
639                                  int flags)
640 {
641         struct sock *sk = sock->sk;
642         struct alg_sock *ask = alg_sk(sk);
643         struct sock *psk = ask->parent;
644         struct alg_sock *pask = alg_sk(psk);
645         struct skcipher_ctx *ctx = ask->private;
646         struct skcipher_tfm *skc = pask->private;
647         struct crypto_skcipher *tfm = skc->skcipher;
648         unsigned bs = crypto_skcipher_blocksize(tfm);
649         struct skcipher_sg_list *sgl;
650         struct scatterlist *sg;
651         int err = -EAGAIN;
652         int used;
653         long copied = 0;
654
655         lock_sock(sk);
656         while (msg_data_left(msg)) {
657                 if (!ctx->used) {
658                         err = skcipher_wait_for_data(sk, flags);
659                         if (err)
660                                 goto unlock;
661                 }
662
663                 used = min_t(unsigned long, ctx->used, msg_data_left(msg));
664
665                 used = af_alg_make_sg(&ctx->rsgl, &msg->msg_iter, used);
666                 err = used;
667                 if (err < 0)
668                         goto unlock;
669
670                 if (ctx->more || used < ctx->used)
671                         used -= used % bs;
672
673                 err = -EINVAL;
674                 if (!used)
675                         goto free;
676
677                 sgl = list_first_entry(&ctx->tsgl,
678                                        struct skcipher_sg_list, list);
679                 sg = sgl->sg;
680
681                 while (!sg->length)
682                         sg++;
683
684                 skcipher_request_set_crypt(&ctx->req, sg, ctx->rsgl.sg, used,
685                                            ctx->iv);
686
687                 err = af_alg_wait_for_completion(
688                                 ctx->enc ?
689                                         crypto_skcipher_encrypt(&ctx->req) :
690                                         crypto_skcipher_decrypt(&ctx->req),
691                                 &ctx->completion);
692
693 free:
694                 af_alg_free_sg(&ctx->rsgl);
695
696                 if (err)
697                         goto unlock;
698
699                 copied += used;
700                 skcipher_pull_sgl(sk, used, 1);
701                 iov_iter_advance(&msg->msg_iter, used);
702         }
703
704         err = 0;
705
706 unlock:
707         skcipher_wmem_wakeup(sk);
708         release_sock(sk);
709
710         return copied ?: err;
711 }
712
713 static int skcipher_recvmsg(struct socket *sock, struct msghdr *msg,
714                             size_t ignored, int flags)
715 {
716         return (msg->msg_iocb && !is_sync_kiocb(msg->msg_iocb)) ?
717                 skcipher_recvmsg_async(sock, msg, flags) :
718                 skcipher_recvmsg_sync(sock, msg, flags);
719 }
720
721 static unsigned int skcipher_poll(struct file *file, struct socket *sock,
722                                   poll_table *wait)
723 {
724         struct sock *sk = sock->sk;
725         struct alg_sock *ask = alg_sk(sk);
726         struct skcipher_ctx *ctx = ask->private;
727         unsigned int mask;
728
729         sock_poll_wait(file, sk_sleep(sk), wait);
730         mask = 0;
731
732         if (ctx->used)
733                 mask |= POLLIN | POLLRDNORM;
734
735         if (skcipher_writable(sk))
736                 mask |= POLLOUT | POLLWRNORM | POLLWRBAND;
737
738         return mask;
739 }
740
741 static struct proto_ops algif_skcipher_ops = {
742         .family         =       PF_ALG,
743
744         .connect        =       sock_no_connect,
745         .socketpair     =       sock_no_socketpair,
746         .getname        =       sock_no_getname,
747         .ioctl          =       sock_no_ioctl,
748         .listen         =       sock_no_listen,
749         .shutdown       =       sock_no_shutdown,
750         .getsockopt     =       sock_no_getsockopt,
751         .mmap           =       sock_no_mmap,
752         .bind           =       sock_no_bind,
753         .accept         =       sock_no_accept,
754         .setsockopt     =       sock_no_setsockopt,
755
756         .release        =       af_alg_release,
757         .sendmsg        =       skcipher_sendmsg,
758         .sendpage       =       skcipher_sendpage,
759         .recvmsg        =       skcipher_recvmsg,
760         .poll           =       skcipher_poll,
761 };
762
763 static int skcipher_check_key(struct socket *sock)
764 {
765         int err = 0;
766         struct sock *psk;
767         struct alg_sock *pask;
768         struct skcipher_tfm *tfm;
769         struct sock *sk = sock->sk;
770         struct alg_sock *ask = alg_sk(sk);
771
772         lock_sock(sk);
773         if (ask->refcnt)
774                 goto unlock_child;
775
776         psk = ask->parent;
777         pask = alg_sk(ask->parent);
778         tfm = pask->private;
779
780         err = -ENOKEY;
781         lock_sock_nested(psk, SINGLE_DEPTH_NESTING);
782         if (!tfm->has_key)
783                 goto unlock;
784
785         if (!pask->refcnt++)
786                 sock_hold(psk);
787
788         ask->refcnt = 1;
789         sock_put(psk);
790
791         err = 0;
792
793 unlock:
794         release_sock(psk);
795 unlock_child:
796         release_sock(sk);
797
798         return err;
799 }
800
801 static int skcipher_sendmsg_nokey(struct socket *sock, struct msghdr *msg,
802                                   size_t size)
803 {
804         int err;
805
806         err = skcipher_check_key(sock);
807         if (err)
808                 return err;
809
810         return skcipher_sendmsg(sock, msg, size);
811 }
812
813 static ssize_t skcipher_sendpage_nokey(struct socket *sock, struct page *page,
814                                        int offset, size_t size, int flags)
815 {
816         int err;
817
818         err = skcipher_check_key(sock);
819         if (err)
820                 return err;
821
822         return skcipher_sendpage(sock, page, offset, size, flags);
823 }
824
825 static int skcipher_recvmsg_nokey(struct socket *sock, struct msghdr *msg,
826                                   size_t ignored, int flags)
827 {
828         int err;
829
830         err = skcipher_check_key(sock);
831         if (err)
832                 return err;
833
834         return skcipher_recvmsg(sock, msg, ignored, flags);
835 }
836
837 static struct proto_ops algif_skcipher_ops_nokey = {
838         .family         =       PF_ALG,
839
840         .connect        =       sock_no_connect,
841         .socketpair     =       sock_no_socketpair,
842         .getname        =       sock_no_getname,
843         .ioctl          =       sock_no_ioctl,
844         .listen         =       sock_no_listen,
845         .shutdown       =       sock_no_shutdown,
846         .getsockopt     =       sock_no_getsockopt,
847         .mmap           =       sock_no_mmap,
848         .bind           =       sock_no_bind,
849         .accept         =       sock_no_accept,
850         .setsockopt     =       sock_no_setsockopt,
851
852         .release        =       af_alg_release,
853         .sendmsg        =       skcipher_sendmsg_nokey,
854         .sendpage       =       skcipher_sendpage_nokey,
855         .recvmsg        =       skcipher_recvmsg_nokey,
856         .poll           =       skcipher_poll,
857 };
858
859 static void *skcipher_bind(const char *name, u32 type, u32 mask)
860 {
861         struct skcipher_tfm *tfm;
862         struct crypto_skcipher *skcipher;
863
864         tfm = kzalloc(sizeof(*tfm), GFP_KERNEL);
865         if (!tfm)
866                 return ERR_PTR(-ENOMEM);
867
868         skcipher = crypto_alloc_skcipher(name, type, mask);
869         if (IS_ERR(skcipher)) {
870                 kfree(tfm);
871                 return ERR_CAST(skcipher);
872         }
873
874         tfm->skcipher = skcipher;
875
876         return tfm;
877 }
878
879 static void skcipher_release(void *private)
880 {
881         struct skcipher_tfm *tfm = private;
882
883         crypto_free_skcipher(tfm->skcipher);
884         kfree(tfm);
885 }
886
887 static int skcipher_setkey(void *private, const u8 *key, unsigned int keylen)
888 {
889         struct skcipher_tfm *tfm = private;
890         int err;
891
892         err = crypto_skcipher_setkey(tfm->skcipher, key, keylen);
893         tfm->has_key = !err;
894
895         return err;
896 }
897
898 static void skcipher_wait(struct sock *sk)
899 {
900         struct alg_sock *ask = alg_sk(sk);
901         struct skcipher_ctx *ctx = ask->private;
902         int ctr = 0;
903
904         while (atomic_read(&ctx->inflight) && ctr++ < 100)
905                 msleep(100);
906 }
907
908 static void skcipher_sock_destruct(struct sock *sk)
909 {
910         struct alg_sock *ask = alg_sk(sk);
911         struct skcipher_ctx *ctx = ask->private;
912         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(&ctx->req);
913
914         if (atomic_read(&ctx->inflight))
915                 skcipher_wait(sk);
916
917         skcipher_free_sgl(sk);
918         sock_kzfree_s(sk, ctx->iv, crypto_skcipher_ivsize(tfm));
919         sock_kfree_s(sk, ctx, ctx->len);
920         af_alg_release_parent(sk);
921 }
922
923 static int skcipher_accept_parent_nokey(void *private, struct sock *sk)
924 {
925         struct skcipher_ctx *ctx;
926         struct alg_sock *ask = alg_sk(sk);
927         struct skcipher_tfm *tfm = private;
928         struct crypto_skcipher *skcipher = tfm->skcipher;
929         unsigned int len = sizeof(*ctx) + crypto_skcipher_reqsize(skcipher);
930
931         ctx = sock_kmalloc(sk, len, GFP_KERNEL);
932         if (!ctx)
933                 return -ENOMEM;
934
935         ctx->iv = sock_kmalloc(sk, crypto_skcipher_ivsize(skcipher),
936                                GFP_KERNEL);
937         if (!ctx->iv) {
938                 sock_kfree_s(sk, ctx, len);
939                 return -ENOMEM;
940         }
941
942         memset(ctx->iv, 0, crypto_skcipher_ivsize(skcipher));
943
944         INIT_LIST_HEAD(&ctx->tsgl);
945         ctx->len = len;
946         ctx->used = 0;
947         ctx->more = 0;
948         ctx->merge = 0;
949         ctx->enc = 0;
950         atomic_set(&ctx->inflight, 0);
951         af_alg_init_completion(&ctx->completion);
952
953         ask->private = ctx;
954
955         skcipher_request_set_tfm(&ctx->req, skcipher);
956         skcipher_request_set_callback(&ctx->req, CRYPTO_TFM_REQ_MAY_SLEEP |
957                                                  CRYPTO_TFM_REQ_MAY_BACKLOG,
958                                       af_alg_complete, &ctx->completion);
959
960         sk->sk_destruct = skcipher_sock_destruct;
961
962         return 0;
963 }
964
965 static int skcipher_accept_parent(void *private, struct sock *sk)
966 {
967         struct skcipher_tfm *tfm = private;
968
969         if (!tfm->has_key && crypto_skcipher_has_setkey(tfm->skcipher))
970                 return -ENOKEY;
971
972         return skcipher_accept_parent_nokey(private, sk);
973 }
974
975 static const struct af_alg_type algif_type_skcipher = {
976         .bind           =       skcipher_bind,
977         .release        =       skcipher_release,
978         .setkey         =       skcipher_setkey,
979         .accept         =       skcipher_accept_parent,
980         .accept_nokey   =       skcipher_accept_parent_nokey,
981         .ops            =       &algif_skcipher_ops,
982         .ops_nokey      =       &algif_skcipher_ops_nokey,
983         .name           =       "skcipher",
984         .owner          =       THIS_MODULE
985 };
986
987 static int __init algif_skcipher_init(void)
988 {
989         return af_alg_register_type(&algif_type_skcipher);
990 }
991
992 static void __exit algif_skcipher_exit(void)
993 {
994         int err = af_alg_unregister_type(&algif_type_skcipher);
995         BUG_ON(err);
996 }
997
998 module_init(algif_skcipher_init);
999 module_exit(algif_skcipher_exit);
1000 MODULE_LICENSE("GPL");