Merge tag 'xfs-5.16-merge-5' of git://git.kernel.org/pub/scm/fs/xfs/xfs-linux
[platform/kernel/linux-starfive.git] / fs / cifs / smb2transport.c
1 // SPDX-License-Identifier: LGPL-2.1
2 /*
3  *
4  *   Copyright (C) International Business Machines  Corp., 2002, 2011
5  *                 Etersoft, 2012
6  *   Author(s): Steve French (sfrench@us.ibm.com)
7  *              Jeremy Allison (jra@samba.org) 2006
8  *              Pavel Shilovsky (pshilovsky@samba.org) 2012
9  *
10  */
11
12 #include <linux/fs.h>
13 #include <linux/list.h>
14 #include <linux/wait.h>
15 #include <linux/net.h>
16 #include <linux/delay.h>
17 #include <linux/uaccess.h>
18 #include <asm/processor.h>
19 #include <linux/mempool.h>
20 #include <linux/highmem.h>
21 #include <crypto/aead.h>
22 #include "cifsglob.h"
23 #include "cifsproto.h"
24 #include "smb2proto.h"
25 #include "cifs_debug.h"
26 #include "smb2status.h"
27 #include "smb2glob.h"
28
29 static int
30 smb3_crypto_shash_allocate(struct TCP_Server_Info *server)
31 {
32         struct cifs_secmech *p = &server->secmech;
33         int rc;
34
35         rc = cifs_alloc_hash("hmac(sha256)",
36                              &p->hmacsha256,
37                              &p->sdeschmacsha256);
38         if (rc)
39                 goto err;
40
41         rc = cifs_alloc_hash("cmac(aes)", &p->cmacaes, &p->sdesccmacaes);
42         if (rc)
43                 goto err;
44
45         return 0;
46 err:
47         cifs_free_hash(&p->hmacsha256, &p->sdeschmacsha256);
48         return rc;
49 }
50
51 int
52 smb311_crypto_shash_allocate(struct TCP_Server_Info *server)
53 {
54         struct cifs_secmech *p = &server->secmech;
55         int rc = 0;
56
57         rc = cifs_alloc_hash("hmac(sha256)",
58                              &p->hmacsha256,
59                              &p->sdeschmacsha256);
60         if (rc)
61                 return rc;
62
63         rc = cifs_alloc_hash("cmac(aes)", &p->cmacaes, &p->sdesccmacaes);
64         if (rc)
65                 goto err;
66
67         rc = cifs_alloc_hash("sha512", &p->sha512, &p->sdescsha512);
68         if (rc)
69                 goto err;
70
71         return 0;
72
73 err:
74         cifs_free_hash(&p->cmacaes, &p->sdesccmacaes);
75         cifs_free_hash(&p->hmacsha256, &p->sdeschmacsha256);
76         return rc;
77 }
78
79
80 static
81 int smb2_get_sign_key(__u64 ses_id, struct TCP_Server_Info *server, u8 *key)
82 {
83         struct cifs_chan *chan;
84         struct cifs_ses *ses = NULL;
85         struct TCP_Server_Info *it = NULL;
86         int i;
87         int rc = 0;
88
89         spin_lock(&cifs_tcp_ses_lock);
90
91         list_for_each_entry(it, &cifs_tcp_ses_list, tcp_ses_list) {
92                 list_for_each_entry(ses, &it->smb_ses_list, smb_ses_list) {
93                         if (ses->Suid == ses_id)
94                                 goto found;
95                 }
96         }
97         cifs_server_dbg(VFS, "%s: Could not find session 0x%llx\n",
98                         __func__, ses_id);
99         rc = -ENOENT;
100         goto out;
101
102 found:
103         if (ses->binding) {
104                 /*
105                  * If we are in the process of binding a new channel
106                  * to an existing session, use the master connection
107                  * session key
108                  */
109                 memcpy(key, ses->smb3signingkey, SMB3_SIGN_KEY_SIZE);
110                 goto out;
111         }
112
113         /*
114          * Otherwise, use the channel key.
115          */
116
117         for (i = 0; i < ses->chan_count; i++) {
118                 chan = ses->chans + i;
119                 if (chan->server == server) {
120                         memcpy(key, chan->signkey, SMB3_SIGN_KEY_SIZE);
121                         goto out;
122                 }
123         }
124
125         cifs_dbg(VFS,
126                  "%s: Could not find channel signing key for session 0x%llx\n",
127                  __func__, ses_id);
128         rc = -ENOENT;
129
130 out:
131         spin_unlock(&cifs_tcp_ses_lock);
132         return rc;
133 }
134
135 static struct cifs_ses *
136 smb2_find_smb_ses_unlocked(struct TCP_Server_Info *server, __u64 ses_id)
137 {
138         struct cifs_ses *ses;
139
140         list_for_each_entry(ses, &server->smb_ses_list, smb_ses_list) {
141                 if (ses->Suid != ses_id)
142                         continue;
143                 ++ses->ses_count;
144                 return ses;
145         }
146
147         return NULL;
148 }
149
150 struct cifs_ses *
151 smb2_find_smb_ses(struct TCP_Server_Info *server, __u64 ses_id)
152 {
153         struct cifs_ses *ses;
154
155         spin_lock(&cifs_tcp_ses_lock);
156         ses = smb2_find_smb_ses_unlocked(server, ses_id);
157         spin_unlock(&cifs_tcp_ses_lock);
158
159         return ses;
160 }
161
162 static struct cifs_tcon *
163 smb2_find_smb_sess_tcon_unlocked(struct cifs_ses *ses, __u32  tid)
164 {
165         struct cifs_tcon *tcon;
166
167         list_for_each_entry(tcon, &ses->tcon_list, tcon_list) {
168                 if (tcon->tid != tid)
169                         continue;
170                 ++tcon->tc_count;
171                 return tcon;
172         }
173
174         return NULL;
175 }
176
177 /*
178  * Obtain tcon corresponding to the tid in the given
179  * cifs_ses
180  */
181
182 struct cifs_tcon *
183 smb2_find_smb_tcon(struct TCP_Server_Info *server, __u64 ses_id, __u32  tid)
184 {
185         struct cifs_ses *ses;
186         struct cifs_tcon *tcon;
187
188         spin_lock(&cifs_tcp_ses_lock);
189         ses = smb2_find_smb_ses_unlocked(server, ses_id);
190         if (!ses) {
191                 spin_unlock(&cifs_tcp_ses_lock);
192                 return NULL;
193         }
194         tcon = smb2_find_smb_sess_tcon_unlocked(ses, tid);
195         if (!tcon) {
196                 cifs_put_smb_ses(ses);
197                 spin_unlock(&cifs_tcp_ses_lock);
198                 return NULL;
199         }
200         spin_unlock(&cifs_tcp_ses_lock);
201         /* tcon already has a ref to ses, so we don't need ses anymore */
202         cifs_put_smb_ses(ses);
203
204         return tcon;
205 }
206
207 int
208 smb2_calc_signature(struct smb_rqst *rqst, struct TCP_Server_Info *server,
209                         bool allocate_crypto)
210 {
211         int rc;
212         unsigned char smb2_signature[SMB2_HMACSHA256_SIZE];
213         unsigned char *sigptr = smb2_signature;
214         struct kvec *iov = rqst->rq_iov;
215         struct smb2_hdr *shdr = (struct smb2_hdr *)iov[0].iov_base;
216         struct cifs_ses *ses;
217         struct shash_desc *shash;
218         struct crypto_shash *hash;
219         struct sdesc *sdesc = NULL;
220         struct smb_rqst drqst;
221
222         ses = smb2_find_smb_ses(server, le64_to_cpu(shdr->SessionId));
223         if (!ses) {
224                 cifs_server_dbg(VFS, "%s: Could not find session\n", __func__);
225                 return 0;
226         }
227
228         memset(smb2_signature, 0x0, SMB2_HMACSHA256_SIZE);
229         memset(shdr->Signature, 0x0, SMB2_SIGNATURE_SIZE);
230
231         if (allocate_crypto) {
232                 rc = cifs_alloc_hash("hmac(sha256)", &hash, &sdesc);
233                 if (rc) {
234                         cifs_server_dbg(VFS,
235                                         "%s: sha256 alloc failed\n", __func__);
236                         goto out;
237                 }
238                 shash = &sdesc->shash;
239         } else {
240                 hash = server->secmech.hmacsha256;
241                 shash = &server->secmech.sdeschmacsha256->shash;
242         }
243
244         rc = crypto_shash_setkey(hash, ses->auth_key.response,
245                         SMB2_NTLMV2_SESSKEY_SIZE);
246         if (rc) {
247                 cifs_server_dbg(VFS,
248                                 "%s: Could not update with response\n",
249                                 __func__);
250                 goto out;
251         }
252
253         rc = crypto_shash_init(shash);
254         if (rc) {
255                 cifs_server_dbg(VFS, "%s: Could not init sha256", __func__);
256                 goto out;
257         }
258
259         /*
260          * For SMB2+, __cifs_calc_signature() expects to sign only the actual
261          * data, that is, iov[0] should not contain a rfc1002 length.
262          *
263          * Sign the rfc1002 length prior to passing the data (iov[1-N]) down to
264          * __cifs_calc_signature().
265          */
266         drqst = *rqst;
267         if (drqst.rq_nvec >= 2 && iov[0].iov_len == 4) {
268                 rc = crypto_shash_update(shash, iov[0].iov_base,
269                                          iov[0].iov_len);
270                 if (rc) {
271                         cifs_server_dbg(VFS,
272                                         "%s: Could not update with payload\n",
273                                         __func__);
274                         goto out;
275                 }
276                 drqst.rq_iov++;
277                 drqst.rq_nvec--;
278         }
279
280         rc = __cifs_calc_signature(&drqst, server, sigptr, shash);
281         if (!rc)
282                 memcpy(shdr->Signature, sigptr, SMB2_SIGNATURE_SIZE);
283
284 out:
285         if (allocate_crypto)
286                 cifs_free_hash(&hash, &sdesc);
287         if (ses)
288                 cifs_put_smb_ses(ses);
289         return rc;
290 }
291
292 static int generate_key(struct cifs_ses *ses, struct kvec label,
293                         struct kvec context, __u8 *key, unsigned int key_size)
294 {
295         unsigned char zero = 0x0;
296         __u8 i[4] = {0, 0, 0, 1};
297         __u8 L128[4] = {0, 0, 0, 128};
298         __u8 L256[4] = {0, 0, 1, 0};
299         int rc = 0;
300         unsigned char prfhash[SMB2_HMACSHA256_SIZE];
301         unsigned char *hashptr = prfhash;
302         struct TCP_Server_Info *server = ses->server;
303
304         memset(prfhash, 0x0, SMB2_HMACSHA256_SIZE);
305         memset(key, 0x0, key_size);
306
307         rc = smb3_crypto_shash_allocate(server);
308         if (rc) {
309                 cifs_server_dbg(VFS, "%s: crypto alloc failed\n", __func__);
310                 goto smb3signkey_ret;
311         }
312
313         rc = crypto_shash_setkey(server->secmech.hmacsha256,
314                 ses->auth_key.response, SMB2_NTLMV2_SESSKEY_SIZE);
315         if (rc) {
316                 cifs_server_dbg(VFS, "%s: Could not set with session key\n", __func__);
317                 goto smb3signkey_ret;
318         }
319
320         rc = crypto_shash_init(&server->secmech.sdeschmacsha256->shash);
321         if (rc) {
322                 cifs_server_dbg(VFS, "%s: Could not init sign hmac\n", __func__);
323                 goto smb3signkey_ret;
324         }
325
326         rc = crypto_shash_update(&server->secmech.sdeschmacsha256->shash,
327                                 i, 4);
328         if (rc) {
329                 cifs_server_dbg(VFS, "%s: Could not update with n\n", __func__);
330                 goto smb3signkey_ret;
331         }
332
333         rc = crypto_shash_update(&server->secmech.sdeschmacsha256->shash,
334                                 label.iov_base, label.iov_len);
335         if (rc) {
336                 cifs_server_dbg(VFS, "%s: Could not update with label\n", __func__);
337                 goto smb3signkey_ret;
338         }
339
340         rc = crypto_shash_update(&server->secmech.sdeschmacsha256->shash,
341                                 &zero, 1);
342         if (rc) {
343                 cifs_server_dbg(VFS, "%s: Could not update with zero\n", __func__);
344                 goto smb3signkey_ret;
345         }
346
347         rc = crypto_shash_update(&server->secmech.sdeschmacsha256->shash,
348                                 context.iov_base, context.iov_len);
349         if (rc) {
350                 cifs_server_dbg(VFS, "%s: Could not update with context\n", __func__);
351                 goto smb3signkey_ret;
352         }
353
354         if ((server->cipher_type == SMB2_ENCRYPTION_AES256_CCM) ||
355                 (server->cipher_type == SMB2_ENCRYPTION_AES256_GCM)) {
356                 rc = crypto_shash_update(&server->secmech.sdeschmacsha256->shash,
357                                 L256, 4);
358         } else {
359                 rc = crypto_shash_update(&server->secmech.sdeschmacsha256->shash,
360                                 L128, 4);
361         }
362         if (rc) {
363                 cifs_server_dbg(VFS, "%s: Could not update with L\n", __func__);
364                 goto smb3signkey_ret;
365         }
366
367         rc = crypto_shash_final(&server->secmech.sdeschmacsha256->shash,
368                                 hashptr);
369         if (rc) {
370                 cifs_server_dbg(VFS, "%s: Could not generate sha256 hash\n", __func__);
371                 goto smb3signkey_ret;
372         }
373
374         memcpy(key, hashptr, key_size);
375
376 smb3signkey_ret:
377         return rc;
378 }
379
380 struct derivation {
381         struct kvec label;
382         struct kvec context;
383 };
384
385 struct derivation_triplet {
386         struct derivation signing;
387         struct derivation encryption;
388         struct derivation decryption;
389 };
390
391 static int
392 generate_smb3signingkey(struct cifs_ses *ses,
393                         const struct derivation_triplet *ptriplet)
394 {
395         int rc;
396 #ifdef CONFIG_CIFS_DEBUG_DUMP_KEYS
397         struct TCP_Server_Info *server = ses->server;
398 #endif
399
400         /*
401          * All channels use the same encryption/decryption keys but
402          * they have their own signing key.
403          *
404          * When we generate the keys, check if it is for a new channel
405          * (binding) in which case we only need to generate a signing
406          * key and store it in the channel as to not overwrite the
407          * master connection signing key stored in the session
408          */
409
410         if (ses->binding) {
411                 rc = generate_key(ses, ptriplet->signing.label,
412                                   ptriplet->signing.context,
413                                   cifs_ses_binding_channel(ses)->signkey,
414                                   SMB3_SIGN_KEY_SIZE);
415                 if (rc)
416                         return rc;
417         } else {
418                 rc = generate_key(ses, ptriplet->signing.label,
419                                   ptriplet->signing.context,
420                                   ses->smb3signingkey,
421                                   SMB3_SIGN_KEY_SIZE);
422                 if (rc)
423                         return rc;
424
425                 memcpy(ses->chans[0].signkey, ses->smb3signingkey,
426                        SMB3_SIGN_KEY_SIZE);
427
428                 rc = generate_key(ses, ptriplet->encryption.label,
429                                   ptriplet->encryption.context,
430                                   ses->smb3encryptionkey,
431                                   SMB3_ENC_DEC_KEY_SIZE);
432                 rc = generate_key(ses, ptriplet->decryption.label,
433                                   ptriplet->decryption.context,
434                                   ses->smb3decryptionkey,
435                                   SMB3_ENC_DEC_KEY_SIZE);
436                 if (rc)
437                         return rc;
438         }
439
440         if (rc)
441                 return rc;
442
443 #ifdef CONFIG_CIFS_DEBUG_DUMP_KEYS
444         cifs_dbg(VFS, "%s: dumping generated AES session keys\n", __func__);
445         /*
446          * The session id is opaque in terms of endianness, so we can't
447          * print it as a long long. we dump it as we got it on the wire
448          */
449         cifs_dbg(VFS, "Session Id    %*ph\n", (int)sizeof(ses->Suid),
450                         &ses->Suid);
451         cifs_dbg(VFS, "Cipher type   %d\n", server->cipher_type);
452         cifs_dbg(VFS, "Session Key   %*ph\n",
453                  SMB2_NTLMV2_SESSKEY_SIZE, ses->auth_key.response);
454         cifs_dbg(VFS, "Signing Key   %*ph\n",
455                  SMB3_SIGN_KEY_SIZE, ses->smb3signingkey);
456         if ((server->cipher_type == SMB2_ENCRYPTION_AES256_CCM) ||
457                 (server->cipher_type == SMB2_ENCRYPTION_AES256_GCM)) {
458                 cifs_dbg(VFS, "ServerIn Key  %*ph\n",
459                                 SMB3_GCM256_CRYPTKEY_SIZE, ses->smb3encryptionkey);
460                 cifs_dbg(VFS, "ServerOut Key %*ph\n",
461                                 SMB3_GCM256_CRYPTKEY_SIZE, ses->smb3decryptionkey);
462         } else {
463                 cifs_dbg(VFS, "ServerIn Key  %*ph\n",
464                                 SMB3_GCM128_CRYPTKEY_SIZE, ses->smb3encryptionkey);
465                 cifs_dbg(VFS, "ServerOut Key %*ph\n",
466                                 SMB3_GCM128_CRYPTKEY_SIZE, ses->smb3decryptionkey);
467         }
468 #endif
469         return rc;
470 }
471
472 int
473 generate_smb30signingkey(struct cifs_ses *ses)
474
475 {
476         struct derivation_triplet triplet;
477         struct derivation *d;
478
479         d = &triplet.signing;
480         d->label.iov_base = "SMB2AESCMAC";
481         d->label.iov_len = 12;
482         d->context.iov_base = "SmbSign";
483         d->context.iov_len = 8;
484
485         d = &triplet.encryption;
486         d->label.iov_base = "SMB2AESCCM";
487         d->label.iov_len = 11;
488         d->context.iov_base = "ServerIn ";
489         d->context.iov_len = 10;
490
491         d = &triplet.decryption;
492         d->label.iov_base = "SMB2AESCCM";
493         d->label.iov_len = 11;
494         d->context.iov_base = "ServerOut";
495         d->context.iov_len = 10;
496
497         return generate_smb3signingkey(ses, &triplet);
498 }
499
500 int
501 generate_smb311signingkey(struct cifs_ses *ses)
502
503 {
504         struct derivation_triplet triplet;
505         struct derivation *d;
506
507         d = &triplet.signing;
508         d->label.iov_base = "SMBSigningKey";
509         d->label.iov_len = 14;
510         d->context.iov_base = ses->preauth_sha_hash;
511         d->context.iov_len = 64;
512
513         d = &triplet.encryption;
514         d->label.iov_base = "SMBC2SCipherKey";
515         d->label.iov_len = 16;
516         d->context.iov_base = ses->preauth_sha_hash;
517         d->context.iov_len = 64;
518
519         d = &triplet.decryption;
520         d->label.iov_base = "SMBS2CCipherKey";
521         d->label.iov_len = 16;
522         d->context.iov_base = ses->preauth_sha_hash;
523         d->context.iov_len = 64;
524
525         return generate_smb3signingkey(ses, &triplet);
526 }
527
528 int
529 smb3_calc_signature(struct smb_rqst *rqst, struct TCP_Server_Info *server,
530                         bool allocate_crypto)
531 {
532         int rc;
533         unsigned char smb3_signature[SMB2_CMACAES_SIZE];
534         unsigned char *sigptr = smb3_signature;
535         struct kvec *iov = rqst->rq_iov;
536         struct smb2_hdr *shdr = (struct smb2_hdr *)iov[0].iov_base;
537         struct shash_desc *shash;
538         struct crypto_shash *hash;
539         struct sdesc *sdesc = NULL;
540         struct smb_rqst drqst;
541         u8 key[SMB3_SIGN_KEY_SIZE];
542
543         rc = smb2_get_sign_key(le64_to_cpu(shdr->SessionId), server, key);
544         if (rc)
545                 return 0;
546
547         if (allocate_crypto) {
548                 rc = cifs_alloc_hash("cmac(aes)", &hash, &sdesc);
549                 if (rc)
550                         return rc;
551
552                 shash = &sdesc->shash;
553         } else {
554                 hash = server->secmech.cmacaes;
555                 shash = &server->secmech.sdesccmacaes->shash;
556         }
557
558         memset(smb3_signature, 0x0, SMB2_CMACAES_SIZE);
559         memset(shdr->Signature, 0x0, SMB2_SIGNATURE_SIZE);
560
561         rc = crypto_shash_setkey(hash, key, SMB2_CMACAES_SIZE);
562         if (rc) {
563                 cifs_server_dbg(VFS, "%s: Could not set key for cmac aes\n", __func__);
564                 goto out;
565         }
566
567         /*
568          * we already allocate sdesccmacaes when we init smb3 signing key,
569          * so unlike smb2 case we do not have to check here if secmech are
570          * initialized
571          */
572         rc = crypto_shash_init(shash);
573         if (rc) {
574                 cifs_server_dbg(VFS, "%s: Could not init cmac aes\n", __func__);
575                 goto out;
576         }
577
578         /*
579          * For SMB2+, __cifs_calc_signature() expects to sign only the actual
580          * data, that is, iov[0] should not contain a rfc1002 length.
581          *
582          * Sign the rfc1002 length prior to passing the data (iov[1-N]) down to
583          * __cifs_calc_signature().
584          */
585         drqst = *rqst;
586         if (drqst.rq_nvec >= 2 && iov[0].iov_len == 4) {
587                 rc = crypto_shash_update(shash, iov[0].iov_base,
588                                          iov[0].iov_len);
589                 if (rc) {
590                         cifs_server_dbg(VFS, "%s: Could not update with payload\n",
591                                  __func__);
592                         goto out;
593                 }
594                 drqst.rq_iov++;
595                 drqst.rq_nvec--;
596         }
597
598         rc = __cifs_calc_signature(&drqst, server, sigptr, shash);
599         if (!rc)
600                 memcpy(shdr->Signature, sigptr, SMB2_SIGNATURE_SIZE);
601
602 out:
603         if (allocate_crypto)
604                 cifs_free_hash(&hash, &sdesc);
605         return rc;
606 }
607
608 /* must be called with server->srv_mutex held */
609 static int
610 smb2_sign_rqst(struct smb_rqst *rqst, struct TCP_Server_Info *server)
611 {
612         int rc = 0;
613         struct smb2_hdr *shdr;
614         struct smb2_sess_setup_req *ssr;
615         bool is_binding;
616         bool is_signed;
617
618         shdr = (struct smb2_hdr *)rqst->rq_iov[0].iov_base;
619         ssr = (struct smb2_sess_setup_req *)shdr;
620
621         is_binding = shdr->Command == SMB2_SESSION_SETUP &&
622                 (ssr->Flags & SMB2_SESSION_REQ_FLAG_BINDING);
623         is_signed = shdr->Flags & SMB2_FLAGS_SIGNED;
624
625         if (!is_signed)
626                 return 0;
627         if (server->tcpStatus == CifsNeedNegotiate)
628                 return 0;
629         if (!is_binding && !server->session_estab) {
630                 strncpy(shdr->Signature, "BSRSPYL", 8);
631                 return 0;
632         }
633
634         rc = server->ops->calc_signature(rqst, server, false);
635
636         return rc;
637 }
638
639 int
640 smb2_verify_signature(struct smb_rqst *rqst, struct TCP_Server_Info *server)
641 {
642         unsigned int rc;
643         char server_response_sig[SMB2_SIGNATURE_SIZE];
644         struct smb2_hdr *shdr =
645                         (struct smb2_hdr *)rqst->rq_iov[0].iov_base;
646
647         if ((shdr->Command == SMB2_NEGOTIATE) ||
648             (shdr->Command == SMB2_SESSION_SETUP) ||
649             (shdr->Command == SMB2_OPLOCK_BREAK) ||
650             server->ignore_signature ||
651             (!server->session_estab))
652                 return 0;
653
654         /*
655          * BB what if signatures are supposed to be on for session but
656          * server does not send one? BB
657          */
658
659         /* Do not need to verify session setups with signature "BSRSPYL " */
660         if (memcmp(shdr->Signature, "BSRSPYL ", 8) == 0)
661                 cifs_dbg(FYI, "dummy signature received for smb command 0x%x\n",
662                          shdr->Command);
663
664         /*
665          * Save off the origiginal signature so we can modify the smb and check
666          * our calculated signature against what the server sent.
667          */
668         memcpy(server_response_sig, shdr->Signature, SMB2_SIGNATURE_SIZE);
669
670         memset(shdr->Signature, 0, SMB2_SIGNATURE_SIZE);
671
672         rc = server->ops->calc_signature(rqst, server, true);
673
674         if (rc)
675                 return rc;
676
677         if (memcmp(server_response_sig, shdr->Signature, SMB2_SIGNATURE_SIZE)) {
678                 cifs_dbg(VFS, "sign fail cmd 0x%x message id 0x%llx\n",
679                         shdr->Command, shdr->MessageId);
680                 return -EACCES;
681         } else
682                 return 0;
683 }
684
685 /*
686  * Set message id for the request. Should be called after wait_for_free_request
687  * and when srv_mutex is held.
688  */
689 static inline void
690 smb2_seq_num_into_buf(struct TCP_Server_Info *server,
691                       struct smb2_hdr *shdr)
692 {
693         unsigned int i, num = le16_to_cpu(shdr->CreditCharge);
694
695         shdr->MessageId = get_next_mid64(server);
696         /* skip message numbers according to CreditCharge field */
697         for (i = 1; i < num; i++)
698                 get_next_mid(server);
699 }
700
701 static struct mid_q_entry *
702 smb2_mid_entry_alloc(const struct smb2_hdr *shdr,
703                      struct TCP_Server_Info *server)
704 {
705         struct mid_q_entry *temp;
706         unsigned int credits = le16_to_cpu(shdr->CreditCharge);
707
708         if (server == NULL) {
709                 cifs_dbg(VFS, "Null TCP session in smb2_mid_entry_alloc\n");
710                 return NULL;
711         }
712
713         temp = mempool_alloc(cifs_mid_poolp, GFP_NOFS);
714         memset(temp, 0, sizeof(struct mid_q_entry));
715         kref_init(&temp->refcount);
716         temp->mid = le64_to_cpu(shdr->MessageId);
717         temp->credits = credits > 0 ? credits : 1;
718         temp->pid = current->pid;
719         temp->command = shdr->Command; /* Always LE */
720         temp->when_alloc = jiffies;
721         temp->server = server;
722
723         /*
724          * The default is for the mid to be synchronous, so the
725          * default callback just wakes up the current task.
726          */
727         get_task_struct(current);
728         temp->creator = current;
729         temp->callback = cifs_wake_up_task;
730         temp->callback_data = current;
731
732         atomic_inc(&midCount);
733         temp->mid_state = MID_REQUEST_ALLOCATED;
734         trace_smb3_cmd_enter(le32_to_cpu(shdr->Id.SyncId.TreeId),
735                              le64_to_cpu(shdr->SessionId),
736                              le16_to_cpu(shdr->Command), temp->mid);
737         return temp;
738 }
739
740 static int
741 smb2_get_mid_entry(struct cifs_ses *ses, struct TCP_Server_Info *server,
742                    struct smb2_hdr *shdr, struct mid_q_entry **mid)
743 {
744         if (server->tcpStatus == CifsExiting)
745                 return -ENOENT;
746
747         if (server->tcpStatus == CifsNeedReconnect) {
748                 cifs_dbg(FYI, "tcp session dead - return to caller to retry\n");
749                 return -EAGAIN;
750         }
751
752         if (server->tcpStatus == CifsNeedNegotiate &&
753            shdr->Command != SMB2_NEGOTIATE)
754                 return -EAGAIN;
755
756         if (ses->status == CifsNew) {
757                 if ((shdr->Command != SMB2_SESSION_SETUP) &&
758                     (shdr->Command != SMB2_NEGOTIATE))
759                         return -EAGAIN;
760                 /* else ok - we are setting up session */
761         }
762
763         if (ses->status == CifsExiting) {
764                 if (shdr->Command != SMB2_LOGOFF)
765                         return -EAGAIN;
766                 /* else ok - we are shutting down the session */
767         }
768
769         *mid = smb2_mid_entry_alloc(shdr, server);
770         if (*mid == NULL)
771                 return -ENOMEM;
772         spin_lock(&GlobalMid_Lock);
773         list_add_tail(&(*mid)->qhead, &server->pending_mid_q);
774         spin_unlock(&GlobalMid_Lock);
775
776         return 0;
777 }
778
779 int
780 smb2_check_receive(struct mid_q_entry *mid, struct TCP_Server_Info *server,
781                    bool log_error)
782 {
783         unsigned int len = mid->resp_buf_size;
784         struct kvec iov[1];
785         struct smb_rqst rqst = { .rq_iov = iov,
786                                  .rq_nvec = 1 };
787
788         iov[0].iov_base = (char *)mid->resp_buf;
789         iov[0].iov_len = len;
790
791         dump_smb(mid->resp_buf, min_t(u32, 80, len));
792         /* convert the length into a more usable form */
793         if (len > 24 && server->sign && !mid->decrypted) {
794                 int rc;
795
796                 rc = smb2_verify_signature(&rqst, server);
797                 if (rc)
798                         cifs_server_dbg(VFS, "SMB signature verification returned error = %d\n",
799                                  rc);
800         }
801
802         return map_smb2_to_linux_error(mid->resp_buf, log_error);
803 }
804
805 struct mid_q_entry *
806 smb2_setup_request(struct cifs_ses *ses, struct TCP_Server_Info *server,
807                    struct smb_rqst *rqst)
808 {
809         int rc;
810         struct smb2_hdr *shdr =
811                         (struct smb2_hdr *)rqst->rq_iov[0].iov_base;
812         struct mid_q_entry *mid;
813
814         smb2_seq_num_into_buf(server, shdr);
815
816         rc = smb2_get_mid_entry(ses, server, shdr, &mid);
817         if (rc) {
818                 revert_current_mid_from_hdr(server, shdr);
819                 return ERR_PTR(rc);
820         }
821
822         rc = smb2_sign_rqst(rqst, server);
823         if (rc) {
824                 revert_current_mid_from_hdr(server, shdr);
825                 cifs_delete_mid(mid);
826                 return ERR_PTR(rc);
827         }
828
829         return mid;
830 }
831
832 struct mid_q_entry *
833 smb2_setup_async_request(struct TCP_Server_Info *server, struct smb_rqst *rqst)
834 {
835         int rc;
836         struct smb2_hdr *shdr =
837                         (struct smb2_hdr *)rqst->rq_iov[0].iov_base;
838         struct mid_q_entry *mid;
839
840         if (server->tcpStatus == CifsNeedNegotiate &&
841            shdr->Command != SMB2_NEGOTIATE)
842                 return ERR_PTR(-EAGAIN);
843
844         smb2_seq_num_into_buf(server, shdr);
845
846         mid = smb2_mid_entry_alloc(shdr, server);
847         if (mid == NULL) {
848                 revert_current_mid_from_hdr(server, shdr);
849                 return ERR_PTR(-ENOMEM);
850         }
851
852         rc = smb2_sign_rqst(rqst, server);
853         if (rc) {
854                 revert_current_mid_from_hdr(server, shdr);
855                 DeleteMidQEntry(mid);
856                 return ERR_PTR(rc);
857         }
858
859         return mid;
860 }
861
862 int
863 smb3_crypto_aead_allocate(struct TCP_Server_Info *server)
864 {
865         struct crypto_aead *tfm;
866
867         if (!server->secmech.ccmaesencrypt) {
868                 if ((server->cipher_type == SMB2_ENCRYPTION_AES128_GCM) ||
869                     (server->cipher_type == SMB2_ENCRYPTION_AES256_GCM))
870                         tfm = crypto_alloc_aead("gcm(aes)", 0, 0);
871                 else
872                         tfm = crypto_alloc_aead("ccm(aes)", 0, 0);
873                 if (IS_ERR(tfm)) {
874                         cifs_server_dbg(VFS, "%s: Failed alloc encrypt aead\n",
875                                  __func__);
876                         return PTR_ERR(tfm);
877                 }
878                 server->secmech.ccmaesencrypt = tfm;
879         }
880
881         if (!server->secmech.ccmaesdecrypt) {
882                 if ((server->cipher_type == SMB2_ENCRYPTION_AES128_GCM) ||
883                     (server->cipher_type == SMB2_ENCRYPTION_AES256_GCM))
884                         tfm = crypto_alloc_aead("gcm(aes)", 0, 0);
885                 else
886                         tfm = crypto_alloc_aead("ccm(aes)", 0, 0);
887                 if (IS_ERR(tfm)) {
888                         crypto_free_aead(server->secmech.ccmaesencrypt);
889                         server->secmech.ccmaesencrypt = NULL;
890                         cifs_server_dbg(VFS, "%s: Failed to alloc decrypt aead\n",
891                                  __func__);
892                         return PTR_ERR(tfm);
893                 }
894                 server->secmech.ccmaesdecrypt = tfm;
895         }
896
897         return 0;
898 }