io_uring: fix ltimeout unprep
[platform/kernel/linux-starfive.git] / fs / ksmbd / auth.c
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  *   Copyright (C) 2016 Namjae Jeon <linkinjeon@kernel.org>
4  *   Copyright (C) 2018 Samsung Electronics Co., Ltd.
5  */
6
7 #include <linux/kernel.h>
8 #include <linux/fs.h>
9 #include <linux/uaccess.h>
10 #include <linux/backing-dev.h>
11 #include <linux/writeback.h>
12 #include <linux/uio.h>
13 #include <linux/xattr.h>
14 #include <crypto/hash.h>
15 #include <crypto/aead.h>
16 #include <linux/random.h>
17 #include <linux/scatterlist.h>
18
19 #include "auth.h"
20 #include "glob.h"
21
22 #include <linux/fips.h>
23 #include <crypto/des.h>
24
25 #include "server.h"
26 #include "smb_common.h"
27 #include "connection.h"
28 #include "mgmt/user_session.h"
29 #include "mgmt/user_config.h"
30 #include "crypto_ctx.h"
31 #include "transport_ipc.h"
32
33 /*
34  * Fixed format data defining GSS header and fixed string
35  * "not_defined_in_RFC4178@please_ignore".
36  * So sec blob data in neg phase could be generated statically.
37  */
38 static char NEGOTIATE_GSS_HEADER[AUTH_GSS_LENGTH] = {
39 #ifdef CONFIG_SMB_SERVER_KERBEROS5
40         0x60, 0x5e, 0x06, 0x06, 0x2b, 0x06, 0x01, 0x05,
41         0x05, 0x02, 0xa0, 0x54, 0x30, 0x52, 0xa0, 0x24,
42         0x30, 0x22, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86,
43         0xf7, 0x12, 0x01, 0x02, 0x02, 0x06, 0x09, 0x2a,
44         0x86, 0x48, 0x82, 0xf7, 0x12, 0x01, 0x02, 0x02,
45         0x06, 0x0a, 0x2b, 0x06, 0x01, 0x04, 0x01, 0x82,
46         0x37, 0x02, 0x02, 0x0a, 0xa3, 0x2a, 0x30, 0x28,
47         0xa0, 0x26, 0x1b, 0x24, 0x6e, 0x6f, 0x74, 0x5f,
48         0x64, 0x65, 0x66, 0x69, 0x6e, 0x65, 0x64, 0x5f,
49         0x69, 0x6e, 0x5f, 0x52, 0x46, 0x43, 0x34, 0x31,
50         0x37, 0x38, 0x40, 0x70, 0x6c, 0x65, 0x61, 0x73,
51         0x65, 0x5f, 0x69, 0x67, 0x6e, 0x6f, 0x72, 0x65
52 #else
53         0x60, 0x48, 0x06, 0x06, 0x2b, 0x06, 0x01, 0x05,
54         0x05, 0x02, 0xa0, 0x3e, 0x30, 0x3c, 0xa0, 0x0e,
55         0x30, 0x0c, 0x06, 0x0a, 0x2b, 0x06, 0x01, 0x04,
56         0x01, 0x82, 0x37, 0x02, 0x02, 0x0a, 0xa3, 0x2a,
57         0x30, 0x28, 0xa0, 0x26, 0x1b, 0x24, 0x6e, 0x6f,
58         0x74, 0x5f, 0x64, 0x65, 0x66, 0x69, 0x6e, 0x65,
59         0x64, 0x5f, 0x69, 0x6e, 0x5f, 0x52, 0x46, 0x43,
60         0x34, 0x31, 0x37, 0x38, 0x40, 0x70, 0x6c, 0x65,
61         0x61, 0x73, 0x65, 0x5f, 0x69, 0x67, 0x6e, 0x6f,
62         0x72, 0x65
63 #endif
64 };
65
66 void ksmbd_copy_gss_neg_header(void *buf)
67 {
68         memcpy(buf, NEGOTIATE_GSS_HEADER, AUTH_GSS_LENGTH);
69 }
70
71 static void
72 str_to_key(unsigned char *str, unsigned char *key)
73 {
74         int i;
75
76         key[0] = str[0] >> 1;
77         key[1] = ((str[0] & 0x01) << 6) | (str[1] >> 2);
78         key[2] = ((str[1] & 0x03) << 5) | (str[2] >> 3);
79         key[3] = ((str[2] & 0x07) << 4) | (str[3] >> 4);
80         key[4] = ((str[3] & 0x0F) << 3) | (str[4] >> 5);
81         key[5] = ((str[4] & 0x1F) << 2) | (str[5] >> 6);
82         key[6] = ((str[5] & 0x3F) << 1) | (str[6] >> 7);
83         key[7] = str[6] & 0x7F;
84         for (i = 0; i < 8; i++)
85                 key[i] = (key[i] << 1);
86 }
87
88 static int
89 smbhash(unsigned char *out, const unsigned char *in, unsigned char *key)
90 {
91         unsigned char key2[8];
92         struct des_ctx ctx;
93
94         if (fips_enabled) {
95                 ksmbd_debug(AUTH, "FIPS compliance enabled: DES not permitted\n");
96                 return -ENOENT;
97         }
98
99         str_to_key(key, key2);
100         des_expand_key(&ctx, key2, DES_KEY_SIZE);
101         des_encrypt(&ctx, out, in);
102         memzero_explicit(&ctx, sizeof(ctx));
103         return 0;
104 }
105
106 static int ksmbd_enc_p24(unsigned char *p21, const unsigned char *c8, unsigned char *p24)
107 {
108         int rc;
109
110         rc = smbhash(p24, c8, p21);
111         if (rc)
112                 return rc;
113         rc = smbhash(p24 + 8, c8, p21 + 7);
114         if (rc)
115                 return rc;
116         return smbhash(p24 + 16, c8, p21 + 14);
117 }
118
119 /* produce a md4 message digest from data of length n bytes */
120 static int ksmbd_enc_md4(unsigned char *md4_hash, unsigned char *link_str,
121                          int link_len)
122 {
123         int rc;
124         struct ksmbd_crypto_ctx *ctx;
125
126         ctx = ksmbd_crypto_ctx_find_md4();
127         if (!ctx) {
128                 ksmbd_debug(AUTH, "Crypto md4 allocation error\n");
129                 return -ENOMEM;
130         }
131
132         rc = crypto_shash_init(CRYPTO_MD4(ctx));
133         if (rc) {
134                 ksmbd_debug(AUTH, "Could not init md4 shash\n");
135                 goto out;
136         }
137
138         rc = crypto_shash_update(CRYPTO_MD4(ctx), link_str, link_len);
139         if (rc) {
140                 ksmbd_debug(AUTH, "Could not update with link_str\n");
141                 goto out;
142         }
143
144         rc = crypto_shash_final(CRYPTO_MD4(ctx), md4_hash);
145         if (rc)
146                 ksmbd_debug(AUTH, "Could not generate md4 hash\n");
147 out:
148         ksmbd_release_crypto_ctx(ctx);
149         return rc;
150 }
151
152 static int ksmbd_enc_update_sess_key(unsigned char *md5_hash, char *nonce,
153                                      char *server_challenge, int len)
154 {
155         int rc;
156         struct ksmbd_crypto_ctx *ctx;
157
158         ctx = ksmbd_crypto_ctx_find_md5();
159         if (!ctx) {
160                 ksmbd_debug(AUTH, "Crypto md5 allocation error\n");
161                 return -ENOMEM;
162         }
163
164         rc = crypto_shash_init(CRYPTO_MD5(ctx));
165         if (rc) {
166                 ksmbd_debug(AUTH, "Could not init md5 shash\n");
167                 goto out;
168         }
169
170         rc = crypto_shash_update(CRYPTO_MD5(ctx), server_challenge, len);
171         if (rc) {
172                 ksmbd_debug(AUTH, "Could not update with challenge\n");
173                 goto out;
174         }
175
176         rc = crypto_shash_update(CRYPTO_MD5(ctx), nonce, len);
177         if (rc) {
178                 ksmbd_debug(AUTH, "Could not update with nonce\n");
179                 goto out;
180         }
181
182         rc = crypto_shash_final(CRYPTO_MD5(ctx), md5_hash);
183         if (rc)
184                 ksmbd_debug(AUTH, "Could not generate md5 hash\n");
185 out:
186         ksmbd_release_crypto_ctx(ctx);
187         return rc;
188 }
189
190 /**
191  * ksmbd_gen_sess_key() - function to generate session key
192  * @sess:       session of connection
193  * @hash:       source hash value to be used for find session key
194  * @hmac:       source hmac value to be used for finding session key
195  *
196  */
197 static int ksmbd_gen_sess_key(struct ksmbd_session *sess, char *hash,
198                               char *hmac)
199 {
200         struct ksmbd_crypto_ctx *ctx;
201         int rc;
202
203         ctx = ksmbd_crypto_ctx_find_hmacmd5();
204         if (!ctx) {
205                 ksmbd_debug(AUTH, "could not crypto alloc hmacmd5\n");
206                 return -ENOMEM;
207         }
208
209         rc = crypto_shash_setkey(CRYPTO_HMACMD5_TFM(ctx),
210                                  hash,
211                                  CIFS_HMAC_MD5_HASH_SIZE);
212         if (rc) {
213                 ksmbd_debug(AUTH, "hmacmd5 set key fail error %d\n", rc);
214                 goto out;
215         }
216
217         rc = crypto_shash_init(CRYPTO_HMACMD5(ctx));
218         if (rc) {
219                 ksmbd_debug(AUTH, "could not init hmacmd5 error %d\n", rc);
220                 goto out;
221         }
222
223         rc = crypto_shash_update(CRYPTO_HMACMD5(ctx),
224                                  hmac,
225                                  SMB2_NTLMV2_SESSKEY_SIZE);
226         if (rc) {
227                 ksmbd_debug(AUTH, "Could not update with response error %d\n", rc);
228                 goto out;
229         }
230
231         rc = crypto_shash_final(CRYPTO_HMACMD5(ctx), sess->sess_key);
232         if (rc) {
233                 ksmbd_debug(AUTH, "Could not generate hmacmd5 hash error %d\n", rc);
234                 goto out;
235         }
236
237 out:
238         ksmbd_release_crypto_ctx(ctx);
239         return rc;
240 }
241
242 static int calc_ntlmv2_hash(struct ksmbd_session *sess, char *ntlmv2_hash,
243                             char *dname)
244 {
245         int ret, len, conv_len;
246         wchar_t *domain = NULL;
247         __le16 *uniname = NULL;
248         struct ksmbd_crypto_ctx *ctx;
249
250         ctx = ksmbd_crypto_ctx_find_hmacmd5();
251         if (!ctx) {
252                 ksmbd_debug(AUTH, "can't generate ntlmv2 hash\n");
253                 return -ENOMEM;
254         }
255
256         ret = crypto_shash_setkey(CRYPTO_HMACMD5_TFM(ctx),
257                                   user_passkey(sess->user),
258                                   CIFS_ENCPWD_SIZE);
259         if (ret) {
260                 ksmbd_debug(AUTH, "Could not set NT Hash as a key\n");
261                 goto out;
262         }
263
264         ret = crypto_shash_init(CRYPTO_HMACMD5(ctx));
265         if (ret) {
266                 ksmbd_debug(AUTH, "could not init hmacmd5\n");
267                 goto out;
268         }
269
270         /* convert user_name to unicode */
271         len = strlen(user_name(sess->user));
272         uniname = kzalloc(2 + UNICODE_LEN(len), GFP_KERNEL);
273         if (!uniname) {
274                 ret = -ENOMEM;
275                 goto out;
276         }
277
278         conv_len = smb_strtoUTF16(uniname, user_name(sess->user), len,
279                                   sess->conn->local_nls);
280         if (conv_len < 0 || conv_len > len) {
281                 ret = -EINVAL;
282                 goto out;
283         }
284         UniStrupr(uniname);
285
286         ret = crypto_shash_update(CRYPTO_HMACMD5(ctx),
287                                   (char *)uniname,
288                                   UNICODE_LEN(conv_len));
289         if (ret) {
290                 ksmbd_debug(AUTH, "Could not update with user\n");
291                 goto out;
292         }
293
294         /* Convert domain name or conn name to unicode and uppercase */
295         len = strlen(dname);
296         domain = kzalloc(2 + UNICODE_LEN(len), GFP_KERNEL);
297         if (!domain) {
298                 ret = -ENOMEM;
299                 goto out;
300         }
301
302         conv_len = smb_strtoUTF16((__le16 *)domain, dname, len,
303                                   sess->conn->local_nls);
304         if (conv_len < 0 || conv_len > len) {
305                 ret = -EINVAL;
306                 goto out;
307         }
308
309         ret = crypto_shash_update(CRYPTO_HMACMD5(ctx),
310                                   (char *)domain,
311                                   UNICODE_LEN(conv_len));
312         if (ret) {
313                 ksmbd_debug(AUTH, "Could not update with domain\n");
314                 goto out;
315         }
316
317         ret = crypto_shash_final(CRYPTO_HMACMD5(ctx), ntlmv2_hash);
318         if (ret)
319                 ksmbd_debug(AUTH, "Could not generate md5 hash\n");
320 out:
321         kfree(uniname);
322         kfree(domain);
323         ksmbd_release_crypto_ctx(ctx);
324         return ret;
325 }
326
327 /**
328  * ksmbd_auth_ntlm() - NTLM authentication handler
329  * @sess:       session of connection
330  * @pw_buf:     NTLM challenge response
331  * @passkey:    user password
332  *
333  * Return:      0 on success, error number on error
334  */
335 int ksmbd_auth_ntlm(struct ksmbd_session *sess, char *pw_buf)
336 {
337         int rc;
338         unsigned char p21[21];
339         char key[CIFS_AUTH_RESP_SIZE];
340
341         memset(p21, '\0', 21);
342         memcpy(p21, user_passkey(sess->user), CIFS_NTHASH_SIZE);
343         rc = ksmbd_enc_p24(p21, sess->ntlmssp.cryptkey, key);
344         if (rc) {
345                 pr_err("password processing failed\n");
346                 return rc;
347         }
348
349         ksmbd_enc_md4(sess->sess_key, user_passkey(sess->user),
350                       CIFS_SMB1_SESSKEY_SIZE);
351         memcpy(sess->sess_key + CIFS_SMB1_SESSKEY_SIZE, key,
352                CIFS_AUTH_RESP_SIZE);
353         sess->sequence_number = 1;
354
355         if (strncmp(pw_buf, key, CIFS_AUTH_RESP_SIZE) != 0) {
356                 ksmbd_debug(AUTH, "ntlmv1 authentication failed\n");
357                 return -EINVAL;
358         }
359
360         ksmbd_debug(AUTH, "ntlmv1 authentication pass\n");
361         return 0;
362 }
363
364 /**
365  * ksmbd_auth_ntlmv2() - NTLMv2 authentication handler
366  * @sess:       session of connection
367  * @ntlmv2:             NTLMv2 challenge response
368  * @blen:               NTLMv2 blob length
369  * @domain_name:        domain name
370  *
371  * Return:      0 on success, error number on error
372  */
373 int ksmbd_auth_ntlmv2(struct ksmbd_session *sess, struct ntlmv2_resp *ntlmv2,
374                       int blen, char *domain_name)
375 {
376         char ntlmv2_hash[CIFS_ENCPWD_SIZE];
377         char ntlmv2_rsp[CIFS_HMAC_MD5_HASH_SIZE];
378         struct ksmbd_crypto_ctx *ctx;
379         char *construct = NULL;
380         int rc, len;
381
382         ctx = ksmbd_crypto_ctx_find_hmacmd5();
383         if (!ctx) {
384                 ksmbd_debug(AUTH, "could not crypto alloc hmacmd5\n");
385                 return -ENOMEM;
386         }
387
388         rc = calc_ntlmv2_hash(sess, ntlmv2_hash, domain_name);
389         if (rc) {
390                 ksmbd_debug(AUTH, "could not get v2 hash rc %d\n", rc);
391                 goto out;
392         }
393
394         rc = crypto_shash_setkey(CRYPTO_HMACMD5_TFM(ctx),
395                                  ntlmv2_hash,
396                                  CIFS_HMAC_MD5_HASH_SIZE);
397         if (rc) {
398                 ksmbd_debug(AUTH, "Could not set NTLMV2 Hash as a key\n");
399                 goto out;
400         }
401
402         rc = crypto_shash_init(CRYPTO_HMACMD5(ctx));
403         if (rc) {
404                 ksmbd_debug(AUTH, "Could not init hmacmd5\n");
405                 goto out;
406         }
407
408         len = CIFS_CRYPTO_KEY_SIZE + blen;
409         construct = kzalloc(len, GFP_KERNEL);
410         if (!construct) {
411                 rc = -ENOMEM;
412                 goto out;
413         }
414
415         memcpy(construct, sess->ntlmssp.cryptkey, CIFS_CRYPTO_KEY_SIZE);
416         memcpy(construct + CIFS_CRYPTO_KEY_SIZE, &ntlmv2->blob_signature, blen);
417
418         rc = crypto_shash_update(CRYPTO_HMACMD5(ctx), construct, len);
419         if (rc) {
420                 ksmbd_debug(AUTH, "Could not update with response\n");
421                 goto out;
422         }
423
424         rc = crypto_shash_final(CRYPTO_HMACMD5(ctx), ntlmv2_rsp);
425         if (rc) {
426                 ksmbd_debug(AUTH, "Could not generate md5 hash\n");
427                 goto out;
428         }
429
430         rc = ksmbd_gen_sess_key(sess, ntlmv2_hash, ntlmv2_rsp);
431         if (rc) {
432                 ksmbd_debug(AUTH, "Could not generate sess key\n");
433                 goto out;
434         }
435
436         if (memcmp(ntlmv2->ntlmv2_hash, ntlmv2_rsp, CIFS_HMAC_MD5_HASH_SIZE) != 0)
437                 rc = -EINVAL;
438 out:
439         ksmbd_release_crypto_ctx(ctx);
440         kfree(construct);
441         return rc;
442 }
443
444 /**
445  * __ksmbd_auth_ntlmv2() - NTLM2(extended security) authentication handler
446  * @sess:       session of connection
447  * @client_nonce:       client nonce from LM response.
448  * @ntlm_resp:          ntlm response data from client.
449  *
450  * Return:      0 on success, error number on error
451  */
452 static int __ksmbd_auth_ntlmv2(struct ksmbd_session *sess, char *client_nonce,
453                                char *ntlm_resp)
454 {
455         char sess_key[CIFS_SMB1_SESSKEY_SIZE] = {0};
456         int rc;
457         unsigned char p21[21];
458         char key[CIFS_AUTH_RESP_SIZE];
459
460         rc = ksmbd_enc_update_sess_key(sess_key,
461                                        client_nonce,
462                                        (char *)sess->ntlmssp.cryptkey, 8);
463         if (rc) {
464                 pr_err("password processing failed\n");
465                 goto out;
466         }
467
468         memset(p21, '\0', 21);
469         memcpy(p21, user_passkey(sess->user), CIFS_NTHASH_SIZE);
470         rc = ksmbd_enc_p24(p21, sess_key, key);
471         if (rc) {
472                 pr_err("password processing failed\n");
473                 goto out;
474         }
475
476         if (memcmp(ntlm_resp, key, CIFS_AUTH_RESP_SIZE) != 0)
477                 rc = -EINVAL;
478 out:
479         return rc;
480 }
481
482 /**
483  * ksmbd_decode_ntlmssp_auth_blob() - helper function to construct
484  * authenticate blob
485  * @authblob:   authenticate blob source pointer
486  * @usr:        user details
487  * @sess:       session of connection
488  *
489  * Return:      0 on success, error number on error
490  */
491 int ksmbd_decode_ntlmssp_auth_blob(struct authenticate_message *authblob,
492                                    int blob_len, struct ksmbd_session *sess)
493 {
494         char *domain_name;
495         unsigned int lm_off, nt_off;
496         unsigned short nt_len;
497         int ret;
498
499         if (blob_len < sizeof(struct authenticate_message)) {
500                 ksmbd_debug(AUTH, "negotiate blob len %d too small\n",
501                             blob_len);
502                 return -EINVAL;
503         }
504
505         if (memcmp(authblob->Signature, "NTLMSSP", 8)) {
506                 ksmbd_debug(AUTH, "blob signature incorrect %s\n",
507                             authblob->Signature);
508                 return -EINVAL;
509         }
510
511         lm_off = le32_to_cpu(authblob->LmChallengeResponse.BufferOffset);
512         nt_off = le32_to_cpu(authblob->NtChallengeResponse.BufferOffset);
513         nt_len = le16_to_cpu(authblob->NtChallengeResponse.Length);
514
515         /* process NTLM authentication */
516         if (nt_len == CIFS_AUTH_RESP_SIZE) {
517                 if (le32_to_cpu(authblob->NegotiateFlags) &
518                     NTLMSSP_NEGOTIATE_EXTENDED_SEC)
519                         return __ksmbd_auth_ntlmv2(sess, (char *)authblob +
520                                 lm_off, (char *)authblob + nt_off);
521                 else
522                         return ksmbd_auth_ntlm(sess, (char *)authblob +
523                                 nt_off);
524         }
525
526         /* TODO : use domain name that imported from configuration file */
527         domain_name = smb_strndup_from_utf16((const char *)authblob +
528                         le32_to_cpu(authblob->DomainName.BufferOffset),
529                         le16_to_cpu(authblob->DomainName.Length), true,
530                         sess->conn->local_nls);
531         if (IS_ERR(domain_name))
532                 return PTR_ERR(domain_name);
533
534         /* process NTLMv2 authentication */
535         ksmbd_debug(AUTH, "decode_ntlmssp_authenticate_blob dname%s\n",
536                     domain_name);
537         ret = ksmbd_auth_ntlmv2(sess, (struct ntlmv2_resp *)((char *)authblob + nt_off),
538                                 nt_len - CIFS_ENCPWD_SIZE,
539                                 domain_name);
540         kfree(domain_name);
541         return ret;
542 }
543
544 /**
545  * ksmbd_decode_ntlmssp_neg_blob() - helper function to construct
546  * negotiate blob
547  * @negblob: negotiate blob source pointer
548  * @rsp:     response header pointer to be updated
549  * @sess:    session of connection
550  *
551  */
552 int ksmbd_decode_ntlmssp_neg_blob(struct negotiate_message *negblob,
553                                   int blob_len, struct ksmbd_session *sess)
554 {
555         if (blob_len < sizeof(struct negotiate_message)) {
556                 ksmbd_debug(AUTH, "negotiate blob len %d too small\n",
557                             blob_len);
558                 return -EINVAL;
559         }
560
561         if (memcmp(negblob->Signature, "NTLMSSP", 8)) {
562                 ksmbd_debug(AUTH, "blob signature incorrect %s\n",
563                             negblob->Signature);
564                 return -EINVAL;
565         }
566
567         sess->ntlmssp.client_flags = le32_to_cpu(negblob->NegotiateFlags);
568         return 0;
569 }
570
571 /**
572  * ksmbd_build_ntlmssp_challenge_blob() - helper function to construct
573  * challenge blob
574  * @chgblob: challenge blob source pointer to initialize
575  * @rsp:     response header pointer to be updated
576  * @sess:    session of connection
577  *
578  */
579 unsigned int
580 ksmbd_build_ntlmssp_challenge_blob(struct challenge_message *chgblob,
581                                    struct ksmbd_session *sess)
582 {
583         struct target_info *tinfo;
584         wchar_t *name;
585         __u8 *target_name;
586         unsigned int flags, blob_off, blob_len, type, target_info_len = 0;
587         int len, uni_len, conv_len;
588         int cflags = sess->ntlmssp.client_flags;
589
590         memcpy(chgblob->Signature, NTLMSSP_SIGNATURE, 8);
591         chgblob->MessageType = NtLmChallenge;
592
593         flags = NTLMSSP_NEGOTIATE_UNICODE |
594                 NTLMSSP_NEGOTIATE_NTLM | NTLMSSP_TARGET_TYPE_SERVER |
595                 NTLMSSP_NEGOTIATE_TARGET_INFO;
596
597         if (cflags & NTLMSSP_NEGOTIATE_SIGN) {
598                 flags |= NTLMSSP_NEGOTIATE_SIGN;
599                 flags |= cflags & (NTLMSSP_NEGOTIATE_128 |
600                                    NTLMSSP_NEGOTIATE_56);
601         }
602
603         if (cflags & NTLMSSP_NEGOTIATE_ALWAYS_SIGN)
604                 flags |= NTLMSSP_NEGOTIATE_ALWAYS_SIGN;
605
606         if (cflags & NTLMSSP_REQUEST_TARGET)
607                 flags |= NTLMSSP_REQUEST_TARGET;
608
609         if (sess->conn->use_spnego &&
610             (cflags & NTLMSSP_NEGOTIATE_EXTENDED_SEC))
611                 flags |= NTLMSSP_NEGOTIATE_EXTENDED_SEC;
612
613         chgblob->NegotiateFlags = cpu_to_le32(flags);
614         len = strlen(ksmbd_netbios_name());
615         name = kmalloc(2 + UNICODE_LEN(len), GFP_KERNEL);
616         if (!name)
617                 return -ENOMEM;
618
619         conv_len = smb_strtoUTF16((__le16 *)name, ksmbd_netbios_name(), len,
620                                   sess->conn->local_nls);
621         if (conv_len < 0 || conv_len > len) {
622                 kfree(name);
623                 return -EINVAL;
624         }
625
626         uni_len = UNICODE_LEN(conv_len);
627
628         blob_off = sizeof(struct challenge_message);
629         blob_len = blob_off + uni_len;
630
631         chgblob->TargetName.Length = cpu_to_le16(uni_len);
632         chgblob->TargetName.MaximumLength = cpu_to_le16(uni_len);
633         chgblob->TargetName.BufferOffset = cpu_to_le32(blob_off);
634
635         /* Initialize random conn challenge */
636         get_random_bytes(sess->ntlmssp.cryptkey, sizeof(__u64));
637         memcpy(chgblob->Challenge, sess->ntlmssp.cryptkey,
638                CIFS_CRYPTO_KEY_SIZE);
639
640         /* Add Target Information to security buffer */
641         chgblob->TargetInfoArray.BufferOffset = cpu_to_le32(blob_len);
642
643         target_name = (__u8 *)chgblob + blob_off;
644         memcpy(target_name, name, uni_len);
645         tinfo = (struct target_info *)(target_name + uni_len);
646
647         chgblob->TargetInfoArray.Length = 0;
648         /* Add target info list for NetBIOS/DNS settings */
649         for (type = NTLMSSP_AV_NB_COMPUTER_NAME;
650              type <= NTLMSSP_AV_DNS_DOMAIN_NAME; type++) {
651                 tinfo->Type = cpu_to_le16(type);
652                 tinfo->Length = cpu_to_le16(uni_len);
653                 memcpy(tinfo->Content, name, uni_len);
654                 tinfo = (struct target_info *)((char *)tinfo + 4 + uni_len);
655                 target_info_len += 4 + uni_len;
656         }
657
658         /* Add terminator subblock */
659         tinfo->Type = 0;
660         tinfo->Length = 0;
661         target_info_len += 4;
662
663         chgblob->TargetInfoArray.Length = cpu_to_le16(target_info_len);
664         chgblob->TargetInfoArray.MaximumLength = cpu_to_le16(target_info_len);
665         blob_len += target_info_len;
666         kfree(name);
667         ksmbd_debug(AUTH, "NTLMSSP SecurityBufferLength %d\n", blob_len);
668         return blob_len;
669 }
670
671 #ifdef CONFIG_SMB_SERVER_KERBEROS5
672 int ksmbd_krb5_authenticate(struct ksmbd_session *sess, char *in_blob,
673                             int in_len, char *out_blob, int *out_len)
674 {
675         struct ksmbd_spnego_authen_response *resp;
676         struct ksmbd_user *user = NULL;
677         int retval;
678
679         resp = ksmbd_ipc_spnego_authen_request(in_blob, in_len);
680         if (!resp) {
681                 ksmbd_debug(AUTH, "SPNEGO_AUTHEN_REQUEST failure\n");
682                 return -EINVAL;
683         }
684
685         if (!(resp->login_response.status & KSMBD_USER_FLAG_OK)) {
686                 ksmbd_debug(AUTH, "krb5 authentication failure\n");
687                 retval = -EPERM;
688                 goto out;
689         }
690
691         if (*out_len <= resp->spnego_blob_len) {
692                 ksmbd_debug(AUTH, "buf len %d, but blob len %d\n",
693                             *out_len, resp->spnego_blob_len);
694                 retval = -EINVAL;
695                 goto out;
696         }
697
698         if (resp->session_key_len > sizeof(sess->sess_key)) {
699                 ksmbd_debug(AUTH, "session key is too long\n");
700                 retval = -EINVAL;
701                 goto out;
702         }
703
704         user = ksmbd_alloc_user(&resp->login_response);
705         if (!user) {
706                 ksmbd_debug(AUTH, "login failure\n");
707                 retval = -ENOMEM;
708                 goto out;
709         }
710         sess->user = user;
711
712         memcpy(sess->sess_key, resp->payload, resp->session_key_len);
713         memcpy(out_blob, resp->payload + resp->session_key_len,
714                resp->spnego_blob_len);
715         *out_len = resp->spnego_blob_len;
716         retval = 0;
717 out:
718         kvfree(resp);
719         return retval;
720 }
721 #else
722 int ksmbd_krb5_authenticate(struct ksmbd_session *sess, char *in_blob,
723                             int in_len, char *out_blob, int *out_len)
724 {
725         return -EOPNOTSUPP;
726 }
727 #endif
728
729 /**
730  * ksmbd_sign_smb2_pdu() - function to generate packet signing
731  * @conn:       connection
732  * @key:        signing key
733  * @iov:        buffer iov array
734  * @n_vec:      number of iovecs
735  * @sig:        signature value generated for client request packet
736  *
737  */
738 int ksmbd_sign_smb2_pdu(struct ksmbd_conn *conn, char *key, struct kvec *iov,
739                         int n_vec, char *sig)
740 {
741         struct ksmbd_crypto_ctx *ctx;
742         int rc, i;
743
744         ctx = ksmbd_crypto_ctx_find_hmacsha256();
745         if (!ctx) {
746                 ksmbd_debug(AUTH, "could not crypto alloc hmacmd5\n");
747                 return -ENOMEM;
748         }
749
750         rc = crypto_shash_setkey(CRYPTO_HMACSHA256_TFM(ctx),
751                                  key,
752                                  SMB2_NTLMV2_SESSKEY_SIZE);
753         if (rc)
754                 goto out;
755
756         rc = crypto_shash_init(CRYPTO_HMACSHA256(ctx));
757         if (rc) {
758                 ksmbd_debug(AUTH, "hmacsha256 init error %d\n", rc);
759                 goto out;
760         }
761
762         for (i = 0; i < n_vec; i++) {
763                 rc = crypto_shash_update(CRYPTO_HMACSHA256(ctx),
764                                          iov[i].iov_base,
765                                          iov[i].iov_len);
766                 if (rc) {
767                         ksmbd_debug(AUTH, "hmacsha256 update error %d\n", rc);
768                         goto out;
769                 }
770         }
771
772         rc = crypto_shash_final(CRYPTO_HMACSHA256(ctx), sig);
773         if (rc)
774                 ksmbd_debug(AUTH, "hmacsha256 generation error %d\n", rc);
775 out:
776         ksmbd_release_crypto_ctx(ctx);
777         return rc;
778 }
779
780 /**
781  * ksmbd_sign_smb3_pdu() - function to generate packet signing
782  * @conn:       connection
783  * @key:        signing key
784  * @iov:        buffer iov array
785  * @n_vec:      number of iovecs
786  * @sig:        signature value generated for client request packet
787  *
788  */
789 int ksmbd_sign_smb3_pdu(struct ksmbd_conn *conn, char *key, struct kvec *iov,
790                         int n_vec, char *sig)
791 {
792         struct ksmbd_crypto_ctx *ctx;
793         int rc, i;
794
795         ctx = ksmbd_crypto_ctx_find_cmacaes();
796         if (!ctx) {
797                 ksmbd_debug(AUTH, "could not crypto alloc cmac\n");
798                 return -ENOMEM;
799         }
800
801         rc = crypto_shash_setkey(CRYPTO_CMACAES_TFM(ctx),
802                                  key,
803                                  SMB2_CMACAES_SIZE);
804         if (rc)
805                 goto out;
806
807         rc = crypto_shash_init(CRYPTO_CMACAES(ctx));
808         if (rc) {
809                 ksmbd_debug(AUTH, "cmaces init error %d\n", rc);
810                 goto out;
811         }
812
813         for (i = 0; i < n_vec; i++) {
814                 rc = crypto_shash_update(CRYPTO_CMACAES(ctx),
815                                          iov[i].iov_base,
816                                          iov[i].iov_len);
817                 if (rc) {
818                         ksmbd_debug(AUTH, "cmaces update error %d\n", rc);
819                         goto out;
820                 }
821         }
822
823         rc = crypto_shash_final(CRYPTO_CMACAES(ctx), sig);
824         if (rc)
825                 ksmbd_debug(AUTH, "cmaces generation error %d\n", rc);
826 out:
827         ksmbd_release_crypto_ctx(ctx);
828         return rc;
829 }
830
831 struct derivation {
832         struct kvec label;
833         struct kvec context;
834         bool binding;
835 };
836
837 static int generate_key(struct ksmbd_session *sess, struct kvec label,
838                         struct kvec context, __u8 *key, unsigned int key_size)
839 {
840         unsigned char zero = 0x0;
841         __u8 i[4] = {0, 0, 0, 1};
842         __u8 L128[4] = {0, 0, 0, 128};
843         __u8 L256[4] = {0, 0, 1, 0};
844         int rc;
845         unsigned char prfhash[SMB2_HMACSHA256_SIZE];
846         unsigned char *hashptr = prfhash;
847         struct ksmbd_crypto_ctx *ctx;
848
849         memset(prfhash, 0x0, SMB2_HMACSHA256_SIZE);
850         memset(key, 0x0, key_size);
851
852         ctx = ksmbd_crypto_ctx_find_hmacsha256();
853         if (!ctx) {
854                 ksmbd_debug(AUTH, "could not crypto alloc hmacmd5\n");
855                 return -ENOMEM;
856         }
857
858         rc = crypto_shash_setkey(CRYPTO_HMACSHA256_TFM(ctx),
859                                  sess->sess_key,
860                                  SMB2_NTLMV2_SESSKEY_SIZE);
861         if (rc)
862                 goto smb3signkey_ret;
863
864         rc = crypto_shash_init(CRYPTO_HMACSHA256(ctx));
865         if (rc) {
866                 ksmbd_debug(AUTH, "hmacsha256 init error %d\n", rc);
867                 goto smb3signkey_ret;
868         }
869
870         rc = crypto_shash_update(CRYPTO_HMACSHA256(ctx), i, 4);
871         if (rc) {
872                 ksmbd_debug(AUTH, "could not update with n\n");
873                 goto smb3signkey_ret;
874         }
875
876         rc = crypto_shash_update(CRYPTO_HMACSHA256(ctx),
877                                  label.iov_base,
878                                  label.iov_len);
879         if (rc) {
880                 ksmbd_debug(AUTH, "could not update with label\n");
881                 goto smb3signkey_ret;
882         }
883
884         rc = crypto_shash_update(CRYPTO_HMACSHA256(ctx), &zero, 1);
885         if (rc) {
886                 ksmbd_debug(AUTH, "could not update with zero\n");
887                 goto smb3signkey_ret;
888         }
889
890         rc = crypto_shash_update(CRYPTO_HMACSHA256(ctx),
891                                  context.iov_base,
892                                  context.iov_len);
893         if (rc) {
894                 ksmbd_debug(AUTH, "could not update with context\n");
895                 goto smb3signkey_ret;
896         }
897
898         if (sess->conn->cipher_type == SMB2_ENCRYPTION_AES256_CCM ||
899             sess->conn->cipher_type == SMB2_ENCRYPTION_AES256_GCM)
900                 rc = crypto_shash_update(CRYPTO_HMACSHA256(ctx), L256, 4);
901         else
902                 rc = crypto_shash_update(CRYPTO_HMACSHA256(ctx), L128, 4);
903         if (rc) {
904                 ksmbd_debug(AUTH, "could not update with L\n");
905                 goto smb3signkey_ret;
906         }
907
908         rc = crypto_shash_final(CRYPTO_HMACSHA256(ctx), hashptr);
909         if (rc) {
910                 ksmbd_debug(AUTH, "Could not generate hmacmd5 hash error %d\n",
911                             rc);
912                 goto smb3signkey_ret;
913         }
914
915         memcpy(key, hashptr, key_size);
916
917 smb3signkey_ret:
918         ksmbd_release_crypto_ctx(ctx);
919         return rc;
920 }
921
922 static int generate_smb3signingkey(struct ksmbd_session *sess,
923                                    struct ksmbd_conn *conn,
924                                    const struct derivation *signing)
925 {
926         int rc;
927         struct channel *chann;
928         char *key;
929
930         chann = lookup_chann_list(sess, conn);
931         if (!chann)
932                 return 0;
933
934         if (sess->conn->dialect >= SMB30_PROT_ID && signing->binding)
935                 key = chann->smb3signingkey;
936         else
937                 key = sess->smb3signingkey;
938
939         rc = generate_key(sess, signing->label, signing->context, key,
940                           SMB3_SIGN_KEY_SIZE);
941         if (rc)
942                 return rc;
943
944         if (!(sess->conn->dialect >= SMB30_PROT_ID && signing->binding))
945                 memcpy(chann->smb3signingkey, key, SMB3_SIGN_KEY_SIZE);
946
947         ksmbd_debug(AUTH, "dumping generated AES signing keys\n");
948         ksmbd_debug(AUTH, "Session Id    %llu\n", sess->id);
949         ksmbd_debug(AUTH, "Session Key   %*ph\n",
950                     SMB2_NTLMV2_SESSKEY_SIZE, sess->sess_key);
951         ksmbd_debug(AUTH, "Signing Key   %*ph\n",
952                     SMB3_SIGN_KEY_SIZE, key);
953         return 0;
954 }
955
956 int ksmbd_gen_smb30_signingkey(struct ksmbd_session *sess,
957                                struct ksmbd_conn *conn)
958 {
959         struct derivation d;
960
961         d.label.iov_base = "SMB2AESCMAC";
962         d.label.iov_len = 12;
963         d.context.iov_base = "SmbSign";
964         d.context.iov_len = 8;
965         d.binding = conn->binding;
966
967         return generate_smb3signingkey(sess, conn, &d);
968 }
969
970 int ksmbd_gen_smb311_signingkey(struct ksmbd_session *sess,
971                                 struct ksmbd_conn *conn)
972 {
973         struct derivation d;
974
975         d.label.iov_base = "SMBSigningKey";
976         d.label.iov_len = 14;
977         if (conn->binding) {
978                 struct preauth_session *preauth_sess;
979
980                 preauth_sess = ksmbd_preauth_session_lookup(conn, sess->id);
981                 if (!preauth_sess)
982                         return -ENOENT;
983                 d.context.iov_base = preauth_sess->Preauth_HashValue;
984         } else {
985                 d.context.iov_base = sess->Preauth_HashValue;
986         }
987         d.context.iov_len = 64;
988         d.binding = conn->binding;
989
990         return generate_smb3signingkey(sess, conn, &d);
991 }
992
993 struct derivation_twin {
994         struct derivation encryption;
995         struct derivation decryption;
996 };
997
998 static int generate_smb3encryptionkey(struct ksmbd_session *sess,
999                                       const struct derivation_twin *ptwin)
1000 {
1001         int rc;
1002
1003         rc = generate_key(sess, ptwin->encryption.label,
1004                           ptwin->encryption.context, sess->smb3encryptionkey,
1005                           SMB3_ENC_DEC_KEY_SIZE);
1006         if (rc)
1007                 return rc;
1008
1009         rc = generate_key(sess, ptwin->decryption.label,
1010                           ptwin->decryption.context,
1011                           sess->smb3decryptionkey, SMB3_ENC_DEC_KEY_SIZE);
1012         if (rc)
1013                 return rc;
1014
1015         ksmbd_debug(AUTH, "dumping generated AES encryption keys\n");
1016         ksmbd_debug(AUTH, "Cipher type   %d\n", sess->conn->cipher_type);
1017         ksmbd_debug(AUTH, "Session Id    %llu\n", sess->id);
1018         ksmbd_debug(AUTH, "Session Key   %*ph\n",
1019                     SMB2_NTLMV2_SESSKEY_SIZE, sess->sess_key);
1020         if (sess->conn->cipher_type == SMB2_ENCRYPTION_AES256_CCM ||
1021             sess->conn->cipher_type == SMB2_ENCRYPTION_AES256_GCM) {
1022                 ksmbd_debug(AUTH, "ServerIn Key  %*ph\n",
1023                             SMB3_GCM256_CRYPTKEY_SIZE, sess->smb3encryptionkey);
1024                 ksmbd_debug(AUTH, "ServerOut Key %*ph\n",
1025                             SMB3_GCM256_CRYPTKEY_SIZE, sess->smb3decryptionkey);
1026         } else {
1027                 ksmbd_debug(AUTH, "ServerIn Key  %*ph\n",
1028                             SMB3_GCM128_CRYPTKEY_SIZE, sess->smb3encryptionkey);
1029                 ksmbd_debug(AUTH, "ServerOut Key %*ph\n",
1030                             SMB3_GCM128_CRYPTKEY_SIZE, sess->smb3decryptionkey);
1031         }
1032         return 0;
1033 }
1034
1035 int ksmbd_gen_smb30_encryptionkey(struct ksmbd_session *sess)
1036 {
1037         struct derivation_twin twin;
1038         struct derivation *d;
1039
1040         d = &twin.encryption;
1041         d->label.iov_base = "SMB2AESCCM";
1042         d->label.iov_len = 11;
1043         d->context.iov_base = "ServerOut";
1044         d->context.iov_len = 10;
1045
1046         d = &twin.decryption;
1047         d->label.iov_base = "SMB2AESCCM";
1048         d->label.iov_len = 11;
1049         d->context.iov_base = "ServerIn ";
1050         d->context.iov_len = 10;
1051
1052         return generate_smb3encryptionkey(sess, &twin);
1053 }
1054
1055 int ksmbd_gen_smb311_encryptionkey(struct ksmbd_session *sess)
1056 {
1057         struct derivation_twin twin;
1058         struct derivation *d;
1059
1060         d = &twin.encryption;
1061         d->label.iov_base = "SMBS2CCipherKey";
1062         d->label.iov_len = 16;
1063         d->context.iov_base = sess->Preauth_HashValue;
1064         d->context.iov_len = 64;
1065
1066         d = &twin.decryption;
1067         d->label.iov_base = "SMBC2SCipherKey";
1068         d->label.iov_len = 16;
1069         d->context.iov_base = sess->Preauth_HashValue;
1070         d->context.iov_len = 64;
1071
1072         return generate_smb3encryptionkey(sess, &twin);
1073 }
1074
1075 int ksmbd_gen_preauth_integrity_hash(struct ksmbd_conn *conn, char *buf,
1076                                      __u8 *pi_hash)
1077 {
1078         int rc;
1079         struct smb2_hdr *rcv_hdr = (struct smb2_hdr *)buf;
1080         char *all_bytes_msg = (char *)&rcv_hdr->ProtocolId;
1081         int msg_size = be32_to_cpu(rcv_hdr->smb2_buf_length);
1082         struct ksmbd_crypto_ctx *ctx = NULL;
1083
1084         if (conn->preauth_info->Preauth_HashId !=
1085             SMB2_PREAUTH_INTEGRITY_SHA512)
1086                 return -EINVAL;
1087
1088         ctx = ksmbd_crypto_ctx_find_sha512();
1089         if (!ctx) {
1090                 ksmbd_debug(AUTH, "could not alloc sha512\n");
1091                 return -ENOMEM;
1092         }
1093
1094         rc = crypto_shash_init(CRYPTO_SHA512(ctx));
1095         if (rc) {
1096                 ksmbd_debug(AUTH, "could not init shashn");
1097                 goto out;
1098         }
1099
1100         rc = crypto_shash_update(CRYPTO_SHA512(ctx), pi_hash, 64);
1101         if (rc) {
1102                 ksmbd_debug(AUTH, "could not update with n\n");
1103                 goto out;
1104         }
1105
1106         rc = crypto_shash_update(CRYPTO_SHA512(ctx), all_bytes_msg, msg_size);
1107         if (rc) {
1108                 ksmbd_debug(AUTH, "could not update with n\n");
1109                 goto out;
1110         }
1111
1112         rc = crypto_shash_final(CRYPTO_SHA512(ctx), pi_hash);
1113         if (rc) {
1114                 ksmbd_debug(AUTH, "Could not generate hash err : %d\n", rc);
1115                 goto out;
1116         }
1117 out:
1118         ksmbd_release_crypto_ctx(ctx);
1119         return rc;
1120 }
1121
1122 int ksmbd_gen_sd_hash(struct ksmbd_conn *conn, char *sd_buf, int len,
1123                       __u8 *pi_hash)
1124 {
1125         int rc;
1126         struct ksmbd_crypto_ctx *ctx = NULL;
1127
1128         ctx = ksmbd_crypto_ctx_find_sha256();
1129         if (!ctx) {
1130                 ksmbd_debug(AUTH, "could not alloc sha256\n");
1131                 return -ENOMEM;
1132         }
1133
1134         rc = crypto_shash_init(CRYPTO_SHA256(ctx));
1135         if (rc) {
1136                 ksmbd_debug(AUTH, "could not init shashn");
1137                 goto out;
1138         }
1139
1140         rc = crypto_shash_update(CRYPTO_SHA256(ctx), sd_buf, len);
1141         if (rc) {
1142                 ksmbd_debug(AUTH, "could not update with n\n");
1143                 goto out;
1144         }
1145
1146         rc = crypto_shash_final(CRYPTO_SHA256(ctx), pi_hash);
1147         if (rc) {
1148                 ksmbd_debug(AUTH, "Could not generate hash err : %d\n", rc);
1149                 goto out;
1150         }
1151 out:
1152         ksmbd_release_crypto_ctx(ctx);
1153         return rc;
1154 }
1155
1156 static int ksmbd_get_encryption_key(struct ksmbd_conn *conn, __u64 ses_id,
1157                                     int enc, u8 *key)
1158 {
1159         struct ksmbd_session *sess;
1160         u8 *ses_enc_key;
1161
1162         sess = ksmbd_session_lookup_all(conn, ses_id);
1163         if (!sess)
1164                 return -EINVAL;
1165
1166         ses_enc_key = enc ? sess->smb3encryptionkey :
1167                 sess->smb3decryptionkey;
1168         memcpy(key, ses_enc_key, SMB3_ENC_DEC_KEY_SIZE);
1169
1170         return 0;
1171 }
1172
1173 static inline void smb2_sg_set_buf(struct scatterlist *sg, const void *buf,
1174                                    unsigned int buflen)
1175 {
1176         void *addr;
1177
1178         if (is_vmalloc_addr(buf))
1179                 addr = vmalloc_to_page(buf);
1180         else
1181                 addr = virt_to_page(buf);
1182         sg_set_page(sg, addr, buflen, offset_in_page(buf));
1183 }
1184
1185 static struct scatterlist *ksmbd_init_sg(struct kvec *iov, unsigned int nvec,
1186                                          u8 *sign)
1187 {
1188         struct scatterlist *sg;
1189         unsigned int assoc_data_len = sizeof(struct smb2_transform_hdr) - 24;
1190         int i, nr_entries[3] = {0}, total_entries = 0, sg_idx = 0;
1191
1192         if (!nvec)
1193                 return NULL;
1194
1195         for (i = 0; i < nvec - 1; i++) {
1196                 unsigned long kaddr = (unsigned long)iov[i + 1].iov_base;
1197
1198                 if (is_vmalloc_addr(iov[i + 1].iov_base)) {
1199                         nr_entries[i] = ((kaddr + iov[i + 1].iov_len +
1200                                         PAGE_SIZE - 1) >> PAGE_SHIFT) -
1201                                 (kaddr >> PAGE_SHIFT);
1202                 } else {
1203                         nr_entries[i]++;
1204                 }
1205                 total_entries += nr_entries[i];
1206         }
1207
1208         /* Add two entries for transform header and signature */
1209         total_entries += 2;
1210
1211         sg = kmalloc_array(total_entries, sizeof(struct scatterlist), GFP_KERNEL);
1212         if (!sg)
1213                 return NULL;
1214
1215         sg_init_table(sg, total_entries);
1216         smb2_sg_set_buf(&sg[sg_idx++], iov[0].iov_base + 24, assoc_data_len);
1217         for (i = 0; i < nvec - 1; i++) {
1218                 void *data = iov[i + 1].iov_base;
1219                 int len = iov[i + 1].iov_len;
1220
1221                 if (is_vmalloc_addr(data)) {
1222                         int j, offset = offset_in_page(data);
1223
1224                         for (j = 0; j < nr_entries[i]; j++) {
1225                                 unsigned int bytes = PAGE_SIZE - offset;
1226
1227                                 if (!len)
1228                                         break;
1229
1230                                 if (bytes > len)
1231                                         bytes = len;
1232
1233                                 sg_set_page(&sg[sg_idx++],
1234                                             vmalloc_to_page(data), bytes,
1235                                             offset_in_page(data));
1236
1237                                 data += bytes;
1238                                 len -= bytes;
1239                                 offset = 0;
1240                         }
1241                 } else {
1242                         sg_set_page(&sg[sg_idx++], virt_to_page(data), len,
1243                                     offset_in_page(data));
1244                 }
1245         }
1246         smb2_sg_set_buf(&sg[sg_idx], sign, SMB2_SIGNATURE_SIZE);
1247         return sg;
1248 }
1249
1250 int ksmbd_crypt_message(struct ksmbd_conn *conn, struct kvec *iov,
1251                         unsigned int nvec, int enc)
1252 {
1253         struct smb2_transform_hdr *tr_hdr =
1254                 (struct smb2_transform_hdr *)iov[0].iov_base;
1255         unsigned int assoc_data_len = sizeof(struct smb2_transform_hdr) - 24;
1256         int rc;
1257         struct scatterlist *sg;
1258         u8 sign[SMB2_SIGNATURE_SIZE] = {};
1259         u8 key[SMB3_ENC_DEC_KEY_SIZE];
1260         struct aead_request *req;
1261         char *iv;
1262         unsigned int iv_len;
1263         struct crypto_aead *tfm;
1264         unsigned int crypt_len = le32_to_cpu(tr_hdr->OriginalMessageSize);
1265         struct ksmbd_crypto_ctx *ctx;
1266
1267         rc = ksmbd_get_encryption_key(conn,
1268                                       le64_to_cpu(tr_hdr->SessionId),
1269                                       enc,
1270                                       key);
1271         if (rc) {
1272                 pr_err("Could not get %scryption key\n", enc ? "en" : "de");
1273                 return rc;
1274         }
1275
1276         if (conn->cipher_type == SMB2_ENCRYPTION_AES128_GCM ||
1277             conn->cipher_type == SMB2_ENCRYPTION_AES256_GCM)
1278                 ctx = ksmbd_crypto_ctx_find_gcm();
1279         else
1280                 ctx = ksmbd_crypto_ctx_find_ccm();
1281         if (!ctx) {
1282                 pr_err("crypto alloc failed\n");
1283                 return -ENOMEM;
1284         }
1285
1286         if (conn->cipher_type == SMB2_ENCRYPTION_AES128_GCM ||
1287             conn->cipher_type == SMB2_ENCRYPTION_AES256_GCM)
1288                 tfm = CRYPTO_GCM(ctx);
1289         else
1290                 tfm = CRYPTO_CCM(ctx);
1291
1292         if (conn->cipher_type == SMB2_ENCRYPTION_AES256_CCM ||
1293             conn->cipher_type == SMB2_ENCRYPTION_AES256_GCM)
1294                 rc = crypto_aead_setkey(tfm, key, SMB3_GCM256_CRYPTKEY_SIZE);
1295         else
1296                 rc = crypto_aead_setkey(tfm, key, SMB3_GCM128_CRYPTKEY_SIZE);
1297         if (rc) {
1298                 pr_err("Failed to set aead key %d\n", rc);
1299                 goto free_ctx;
1300         }
1301
1302         rc = crypto_aead_setauthsize(tfm, SMB2_SIGNATURE_SIZE);
1303         if (rc) {
1304                 pr_err("Failed to set authsize %d\n", rc);
1305                 goto free_ctx;
1306         }
1307
1308         req = aead_request_alloc(tfm, GFP_KERNEL);
1309         if (!req) {
1310                 rc = -ENOMEM;
1311                 goto free_ctx;
1312         }
1313
1314         if (!enc) {
1315                 memcpy(sign, &tr_hdr->Signature, SMB2_SIGNATURE_SIZE);
1316                 crypt_len += SMB2_SIGNATURE_SIZE;
1317         }
1318
1319         sg = ksmbd_init_sg(iov, nvec, sign);
1320         if (!sg) {
1321                 pr_err("Failed to init sg\n");
1322                 rc = -ENOMEM;
1323                 goto free_req;
1324         }
1325
1326         iv_len = crypto_aead_ivsize(tfm);
1327         iv = kzalloc(iv_len, GFP_KERNEL);
1328         if (!iv) {
1329                 rc = -ENOMEM;
1330                 goto free_sg;
1331         }
1332
1333         if (conn->cipher_type == SMB2_ENCRYPTION_AES128_GCM ||
1334             conn->cipher_type == SMB2_ENCRYPTION_AES256_GCM) {
1335                 memcpy(iv, (char *)tr_hdr->Nonce, SMB3_AES_GCM_NONCE);
1336         } else {
1337                 iv[0] = 3;
1338                 memcpy(iv + 1, (char *)tr_hdr->Nonce, SMB3_AES_CCM_NONCE);
1339         }
1340
1341         aead_request_set_crypt(req, sg, sg, crypt_len, iv);
1342         aead_request_set_ad(req, assoc_data_len);
1343         aead_request_set_callback(req, CRYPTO_TFM_REQ_MAY_SLEEP, NULL, NULL);
1344
1345         if (enc)
1346                 rc = crypto_aead_encrypt(req);
1347         else
1348                 rc = crypto_aead_decrypt(req);
1349         if (rc)
1350                 goto free_iv;
1351
1352         if (enc)
1353                 memcpy(&tr_hdr->Signature, sign, SMB2_SIGNATURE_SIZE);
1354
1355 free_iv:
1356         kfree(iv);
1357 free_sg:
1358         kfree(sg);
1359 free_req:
1360         kfree(req);
1361 free_ctx:
1362         ksmbd_release_crypto_ctx(ctx);
1363         return rc;
1364 }