cifs: update init_sg, crypt_message to take an array of rqst
authorRonnie Sahlberg <lsahlber@redhat.com>
Tue, 31 Jul 2018 23:26:11 +0000 (09:26 +1000)
committerSteve French <stfrench@microsoft.com>
Tue, 7 Aug 2018 19:21:18 +0000 (14:21 -0500)
These are used for SMB3 encryption and compounded requests.
Update these functions and the other functions related to SMB3 encryption to
take an array of requests.

Signed-off-by: Ronnie Sahlberg <lsahlber@redhat.com>
Signed-off-by: Steve French <stfrench@microsoft.com>
Reviewed-by: Pavel Shilovsky <pshilov@microsoft.com>
fs/cifs/cifsglob.h
fs/cifs/smb2ops.c
fs/cifs/transport.c

index 4a3a737..0553929 100644 (file)
@@ -454,10 +454,8 @@ struct smb_version_operations {
        long (*fallocate)(struct file *, struct cifs_tcon *, int, loff_t,
                          loff_t);
        /* init transform request - used for encryption for now */
-       int (*init_transform_rq)(struct TCP_Server_Info *, struct smb_rqst *,
-                                struct smb_rqst *);
-       /* free transform request */
-       void (*free_transform_rq)(struct smb_rqst *);
+       int (*init_transform_rq)(struct TCP_Server_Info *, int num_rqst,
+                                struct smb_rqst *, struct smb_rqst *);
        int (*is_transform_hdr)(void *buf);
        int (*receive_transform)(struct TCP_Server_Info *,
                                 struct mid_q_entry **);
@@ -1023,6 +1021,7 @@ struct tcon_link {
 };
 
 extern struct tcon_link *cifs_sb_tlink(struct cifs_sb_info *cifs_sb);
+extern void smb3_free_compound_rqst(int num_rqst, struct smb_rqst *rqst);
 
 static inline struct cifs_tcon *
 tlink_tcon(struct tcon_link *tlink)
index 85e8480..ebc13eb 100644 (file)
@@ -2378,35 +2378,51 @@ static inline void smb2_sg_set_buf(struct scatterlist *sg, const void *buf,
        sg_set_page(sg, virt_to_page(buf), buflen, offset_in_page(buf));
 }
 
-/* Assumes:
- * rqst->rq_iov[0]  is transform header
- * rqst->rq_iov[1+] data to be encrypted/decrypted
+/* Assumes the first rqst has a transform header as the first iov.
+ * I.e.
+ * rqst[0].rq_iov[0]  is transform header
+ * rqst[0].rq_iov[1+] data to be encrypted/decrypted
+ * rqst[1+].rq_iov[0+] data to be encrypted/decrypted
  */
 static struct scatterlist *
-init_sg(struct smb_rqst *rqst, u8 *sign)
+init_sg(int num_rqst, struct smb_rqst *rqst, u8 *sign)
 {
-       unsigned int sg_len = rqst->rq_nvec + rqst->rq_npages + 1;
-       unsigned int assoc_data_len = sizeof(struct smb2_transform_hdr) - 20;
+       unsigned int sg_len;
        struct scatterlist *sg;
        unsigned int i;
        unsigned int j;
+       unsigned int idx = 0;
+       int skip;
+
+       sg_len = 1;
+       for (i = 0; i < num_rqst; i++)
+               sg_len += rqst[i].rq_nvec + rqst[i].rq_npages;
 
        sg = kmalloc_array(sg_len, sizeof(struct scatterlist), GFP_KERNEL);
        if (!sg)
                return NULL;
 
        sg_init_table(sg, sg_len);
-       smb2_sg_set_buf(&sg[0], rqst->rq_iov[0].iov_base + 20, assoc_data_len);
-       for (i = 1; i < rqst->rq_nvec; i++)
-               smb2_sg_set_buf(&sg[i], rqst->rq_iov[i].iov_base,
-                                               rqst->rq_iov[i].iov_len);
-       for (j = 0; i < sg_len - 1; i++, j++) {
-               unsigned int len, offset;
+       for (i = 0; i < num_rqst; i++) {
+               for (j = 0; j < rqst[i].rq_nvec; j++) {
+                       /*
+                        * The first rqst has a transform header where the
+                        * first 20 bytes are not part of the encrypted blob
+                        */
+                       skip = (i == 0) && (j == 0) ? 20 : 0;
+                       smb2_sg_set_buf(&sg[idx++],
+                                       rqst[i].rq_iov[j].iov_base + skip,
+                                       rqst[i].rq_iov[j].iov_len - skip);
+               }
+
+               for (j = 0; j < rqst[i].rq_npages; j++) {
+                       unsigned int len, offset;
 
-               rqst_page_get_length(rqst, j, &len, &offset);
-               sg_set_page(&sg[i], rqst->rq_pages[j], len, offset);
+                       rqst_page_get_length(&rqst[i], j, &len, &offset);
+                       sg_set_page(&sg[idx++], rqst[i].rq_pages[j], len, offset);
+               }
        }
-       smb2_sg_set_buf(&sg[sg_len - 1], sign, SMB2_SIGNATURE_SIZE);
+       smb2_sg_set_buf(&sg[idx], sign, SMB2_SIGNATURE_SIZE);
        return sg;
 }
 
@@ -2438,10 +2454,11 @@ smb2_get_enc_key(struct TCP_Server_Info *server, __u64 ses_id, int enc, u8 *key)
  * untouched.
  */
 static int
-crypt_message(struct TCP_Server_Info *server, struct smb_rqst *rqst, int enc)
+crypt_message(struct TCP_Server_Info *server, int num_rqst,
+             struct smb_rqst *rqst, int enc)
 {
        struct smb2_transform_hdr *tr_hdr =
-                       (struct smb2_transform_hdr *)rqst->rq_iov[0].iov_base;
+               (struct smb2_transform_hdr *)rqst[0].rq_iov[0].iov_base;
        unsigned int assoc_data_len = sizeof(struct smb2_transform_hdr) - 20;
        int rc = 0;
        struct scatterlist *sg;
@@ -2492,7 +2509,7 @@ crypt_message(struct TCP_Server_Info *server, struct smb_rqst *rqst, int enc)
                crypt_len += SMB2_SIGNATURE_SIZE;
        }
 
-       sg = init_sg(rqst, sign);
+       sg = init_sg(num_rqst, rqst, sign);
        if (!sg) {
                cifs_dbg(VFS, "%s: Failed to init sg", __func__);
                rc = -ENOMEM;
@@ -2529,103 +2546,98 @@ free_req:
        return rc;
 }
 
+void
+smb3_free_compound_rqst(int num_rqst, struct smb_rqst *rqst)
+{
+       int i, j;
+
+       for (i = 0; i < num_rqst; i++) {
+               if (rqst[i].rq_pages) {
+                       for (j = rqst[i].rq_npages - 1; j >= 0; j--)
+                               put_page(rqst[i].rq_pages[j]);
+                       kfree(rqst[i].rq_pages);
+               }
+       }
+}
+
+/*
+ * This function will initialize new_rq and encrypt the content.
+ * The first entry, new_rq[0], only contains a single iov which contains
+ * a smb2_transform_hdr and is pre-allocated by the caller.
+ * This function then populates new_rq[1+] with the content from olq_rq[0+].
+ *
+ * The end result is an array of smb_rqst structures where the first structure
+ * only contains a single iov for the transform header which we then can pass
+ * to crypt_message().
+ *
+ * new_rq[0].rq_iov[0] :  smb2_transform_hdr pre-allocated by the caller
+ * new_rq[1+].rq_iov[*] == old_rq[0+].rq_iov[*] : SMB2/3 requests
+ */
 static int
-smb3_init_transform_rq(struct TCP_Server_Info *server, struct smb_rqst *new_rq,
-                      struct smb_rqst *old_rq)
+smb3_init_transform_rq(struct TCP_Server_Info *server, int num_rqst,
+                      struct smb_rqst *new_rq, struct smb_rqst *old_rq)
 {
-       struct kvec *iov;
        struct page **pages;
-       struct smb2_transform_hdr *tr_hdr;
-       unsigned int npages = old_rq->rq_npages;
-       unsigned int orig_len;
-       int i;
+       struct smb2_transform_hdr *tr_hdr = new_rq[0].rq_iov[0].iov_base;
+       unsigned int npages;
+       unsigned int orig_len = 0;
+       int i, j;
        int rc = -ENOMEM;
 
-       pages = kmalloc_array(npages, sizeof(struct page *), GFP_KERNEL);
-       if (!pages)
-               return rc;
-
-       new_rq->rq_pages = pages;
-       new_rq->rq_offset = old_rq->rq_offset;
-       new_rq->rq_npages = old_rq->rq_npages;
-       new_rq->rq_pagesz = old_rq->rq_pagesz;
-       new_rq->rq_tailsz = old_rq->rq_tailsz;
-
-       for (i = 0; i < npages; i++) {
-               pages[i] = alloc_page(GFP_KERNEL|__GFP_HIGHMEM);
-               if (!pages[i])
-                       goto err_free_pages;
-       }
-
-       iov = kmalloc_array(old_rq->rq_nvec + 1, sizeof(struct kvec),
-                           GFP_KERNEL);
-       if (!iov)
-               goto err_free_pages;
+       for (i = 1; i < num_rqst; i++) {
+               npages = old_rq[i - 1].rq_npages;
+               pages = kmalloc_array(npages, sizeof(struct page *),
+                                     GFP_KERNEL);
+               if (!pages)
+                       goto err_free;
+
+               new_rq[i].rq_pages = pages;
+               new_rq[i].rq_npages = npages;
+               new_rq[i].rq_offset = old_rq[i - 1].rq_offset;
+               new_rq[i].rq_pagesz = old_rq[i - 1].rq_pagesz;
+               new_rq[i].rq_tailsz = old_rq[i - 1].rq_tailsz;
+               new_rq[i].rq_iov = old_rq[i - 1].rq_iov;
+               new_rq[i].rq_nvec = old_rq[i - 1].rq_nvec;
+
+               orig_len += smb_rqst_len(server, &old_rq[i - 1]);
+
+               for (j = 0; j < npages; j++) {
+                       pages[j] = alloc_page(GFP_KERNEL|__GFP_HIGHMEM);
+                       if (!pages[j])
+                               goto err_free;
+               }
 
-       /* copy all iovs from the old */
-       memcpy(&iov[1], &old_rq->rq_iov[0],
-                               sizeof(struct kvec) * old_rq->rq_nvec);
+               /* copy pages form the old */
+               for (j = 0; j < npages; j++) {
+                       char *dst, *src;
+                       unsigned int offset, len;
 
-       new_rq->rq_iov = iov;
-       new_rq->rq_nvec = old_rq->rq_nvec + 1;
+                       rqst_page_get_length(&new_rq[i], j, &len, &offset);
 
-       tr_hdr = kmalloc(sizeof(struct smb2_transform_hdr), GFP_KERNEL);
-       if (!tr_hdr)
-               goto err_free_iov;
+                       dst = (char *) kmap(new_rq[i].rq_pages[j]) + offset;
+                       src = (char *) kmap(old_rq[i - 1].rq_pages[j]) + offset;
 
-       orig_len = smb_rqst_len(server, old_rq);
+                       memcpy(dst, src, len);
+                       kunmap(new_rq[i].rq_pages[j]);
+                       kunmap(old_rq[i - 1].rq_pages[j]);
+               }
+       }
 
-       /* fill the 2nd iov with a transform header */
+       /* fill the 1st iov with a transform header */
        fill_transform_hdr(tr_hdr, orig_len, old_rq);
-       new_rq->rq_iov[0].iov_base = tr_hdr;
-       new_rq->rq_iov[0].iov_len = sizeof(struct smb2_transform_hdr);
-
-       /* copy pages form the old */
-       for (i = 0; i < npages; i++) {
-               char *dst, *src;
-               unsigned int offset, len;
-
-               rqst_page_get_length(new_rq, i, &len, &offset);
 
-               dst = (char *) kmap(new_rq->rq_pages[i]) + offset;
-               src = (char *) kmap(old_rq->rq_pages[i]) + offset;
-
-               memcpy(dst, src, len);
-               kunmap(new_rq->rq_pages[i]);
-               kunmap(old_rq->rq_pages[i]);
-       }
-
-       rc = crypt_message(server, new_rq, 1);
+       rc = crypt_message(server, num_rqst, new_rq, 1);
        cifs_dbg(FYI, "encrypt message returned %d", rc);
        if (rc)
-               goto err_free_tr_hdr;
+               goto err_free;
 
        return rc;
 
-err_free_tr_hdr:
-       kfree(tr_hdr);
-err_free_iov:
-       kfree(iov);
-err_free_pages:
-       for (i = i - 1; i >= 0; i--)
-               put_page(pages[i]);
-       kfree(pages);
+err_free:
+       smb3_free_compound_rqst(num_rqst - 1, &new_rq[1]);
        return rc;
 }
 
-static void
-smb3_free_transform_rq(struct smb_rqst *rqst)
-{
-       int i = rqst->rq_npages - 1;
-
-       for (; i >= 0; i--)
-               put_page(rqst->rq_pages[i]);
-       kfree(rqst->rq_pages);
-       /* free transform header */
-       kfree(rqst->rq_iov[0].iov_base);
-       kfree(rqst->rq_iov);
-}
-
 static int
 smb3_is_transform_hdr(void *buf)
 {
@@ -2655,7 +2667,7 @@ decrypt_raw_data(struct TCP_Server_Info *server, char *buf,
        rqst.rq_pagesz = PAGE_SIZE;
        rqst.rq_tailsz = (page_data_size % PAGE_SIZE) ? : PAGE_SIZE;
 
-       rc = crypt_message(server, &rqst, 0);
+       rc = crypt_message(server, 1, &rqst, 0);
        cifs_dbg(FYI, "decrypt message returned %d\n", rc);
 
        if (rc)
@@ -3302,7 +3314,6 @@ struct smb_version_operations smb30_operations = {
        .fallocate = smb3_fallocate,
        .enum_snapshots = smb3_enum_snapshots,
        .init_transform_rq = smb3_init_transform_rq,
-       .free_transform_rq = smb3_free_transform_rq,
        .is_transform_hdr = smb3_is_transform_hdr,
        .receive_transform = smb3_receive_transform,
        .get_dfs_refer = smb2_get_dfs_refer,
@@ -3408,7 +3419,6 @@ struct smb_version_operations smb311_operations = {
        .fallocate = smb3_fallocate,
        .enum_snapshots = smb3_enum_snapshots,
        .init_transform_rq = smb3_init_transform_rq,
-       .free_transform_rq = smb3_free_transform_rq,
        .is_transform_hdr = smb3_is_transform_hdr,
        .receive_transform = smb3_receive_transform,
        .get_dfs_refer = smb2_get_dfs_refer,
index 357d253..8039c93 100644 (file)
@@ -374,27 +374,40 @@ smbd_done:
        return rc;
 }
 
+#define MAX_COMPOUND 2
+
 static int
 smb_send_rqst(struct TCP_Server_Info *server, struct smb_rqst *rqst, int flags)
 {
-       struct smb_rqst cur_rqst;
+       struct kvec iov;
+       struct smb2_transform_hdr tr_hdr;
+       struct smb_rqst cur_rqst[MAX_COMPOUND];
        int rc;
 
        if (!(flags & CIFS_TRANSFORM_REQ))
                return __smb_send_rqst(server, 1, rqst);
 
-       if (!server->ops->init_transform_rq ||
-           !server->ops->free_transform_rq) {
-               cifs_dbg(VFS, "Encryption requested but transform callbacks are missed\n");
+       memset(&cur_rqst[0], 0, sizeof(cur_rqst));
+       memset(&iov, 0, sizeof(iov));
+       memset(&tr_hdr, 0, sizeof(tr_hdr));
+
+       iov.iov_base = &tr_hdr;
+       iov.iov_len = sizeof(tr_hdr);
+       cur_rqst[0].rq_iov = &iov;
+       cur_rqst[0].rq_nvec = 1;
+
+       if (!server->ops->init_transform_rq) {
+               cifs_dbg(VFS, "Encryption requested but transform callback "
+                        "is missing\n");
                return -EIO;
        }
 
-       rc = server->ops->init_transform_rq(server, &cur_rqst, rqst);
+       rc = server->ops->init_transform_rq(server, 2, &cur_rqst[0], rqst);
        if (rc)
                return rc;
 
-       rc = __smb_send_rqst(server, 1, &cur_rqst);
-       server->ops->free_transform_rq(&cur_rqst);
+       rc = __smb_send_rqst(server, 2, &cur_rqst[0]);
+       smb3_free_compound_rqst(1, &cur_rqst[1]);
        return rc;
 }