cifs: update receive_encrypted_standard to handle compounded responses
authorRonnie Sahlberg <lsahlber@redhat.com>
Wed, 8 Aug 2018 05:07:45 +0000 (15:07 +1000)
committerSteve French <stfrench@microsoft.com>
Fri, 10 Aug 2018 02:19:45 +0000 (21:19 -0500)
Signed-off-by: Ronnie Sahlberg <lsahlber@redhat.com>
Signed-off-by: Steve French <stfrench@microsoft.com>
Reviewed-by: Paulo Alcantara <palcantara@suse.com>
Reviewed-by: Pavel Shilovsky <pshilov@microsoft.com>
fs/cifs/cifsglob.h
fs/cifs/connect.c
fs/cifs/smb2ops.c
fs/cifs/transport.c

index 41803d3..0c9ab62 100644 (file)
@@ -76,6 +76,9 @@
 #define SMB_ECHO_INTERVAL_MAX 600
 #define SMB_ECHO_INTERVAL_DEFAULT 60
 
+/* maximum number of PDUs in one compound */
+#define MAX_COMPOUND 5
+
 /*
  * Default number of credits to keep available for SMB3.
  * This value is chosen somewhat arbitrarily. The Windows client
@@ -458,7 +461,7 @@ struct smb_version_operations {
                                 struct smb_rqst *, struct smb_rqst *);
        int (*is_transform_hdr)(void *buf);
        int (*receive_transform)(struct TCP_Server_Info *,
-                                struct mid_q_entry **);
+                                struct mid_q_entry **, char **, int *);
        enum securityEnum (*select_sectype)(struct TCP_Server_Info *,
                            enum securityEnum);
        int (*next_header)(char *);
index d9bd10d..c832a8a 100644 (file)
@@ -850,13 +850,14 @@ cifs_handle_standard(struct TCP_Server_Info *server, struct mid_q_entry *mid)
 static int
 cifs_demultiplex_thread(void *p)
 {
-       int length;
+       int i, num_mids, length;
        struct TCP_Server_Info *server = p;
        unsigned int pdu_length;
        unsigned int next_offset;
        char *buf = NULL;
        struct task_struct *task_to_wake = NULL;
-       struct mid_q_entry *mid_entry;
+       struct mid_q_entry *mids[MAX_COMPOUND];
+       char *bufs[MAX_COMPOUND];
 
        current->flags |= PF_MEMALLOC;
        cifs_dbg(FYI, "Demultiplex PID: %d\n", task_pid_nr(current));
@@ -923,58 +924,75 @@ next_pdu:
                                server->pdu_size = next_offset;
                }
 
-               mid_entry = NULL;
+               memset(mids, 0, sizeof(mids));
+               memset(bufs, 0, sizeof(bufs));
+               num_mids = 0;
+
                if (server->ops->is_transform_hdr &&
                    server->ops->receive_transform &&
                    server->ops->is_transform_hdr(buf)) {
                        length = server->ops->receive_transform(server,
-                                                               &mid_entry);
+                                                               mids,
+                                                               bufs,
+                                                               &num_mids);
                } else {
-                       mid_entry = server->ops->find_mid(server, buf);
+                       mids[0] = server->ops->find_mid(server, buf);
+                       bufs[0] = buf;
+                       if (mids[0])
+                               num_mids = 1;
 
-                       if (!mid_entry || !mid_entry->receive)
-                               length = standard_receive3(server, mid_entry);
+                       if (!mids[0] || !mids[0]->receive)
+                               length = standard_receive3(server, mids[0]);
                        else
-                               length = mid_entry->receive(server, mid_entry);
+                               length = mids[0]->receive(server, mids[0]);
                }
 
                if (length < 0) {
-                       if (mid_entry)
-                               cifs_mid_q_entry_release(mid_entry);
+                       for (i = 0; i < num_mids; i++)
+                               if (mids[i])
+                                       cifs_mid_q_entry_release(mids[i]);
                        continue;
                }
 
                if (server->large_buf)
                        buf = server->bigbuf;
 
+
                server->lstrp = jiffies;
-               if (mid_entry != NULL) {
-                       mid_entry->resp_buf_size = server->pdu_size;
-                       if ((mid_entry->mid_flags & MID_WAIT_CANCELLED) &&
-                            mid_entry->mid_state == MID_RESPONSE_RECEIVED &&
-                                       server->ops->handle_cancelled_mid)
-                               server->ops->handle_cancelled_mid(
-                                                       mid_entry->resp_buf,
-                                                       server);
 
-                       if (!mid_entry->multiRsp || mid_entry->multiEnd)
-                               mid_entry->callback(mid_entry);
+               for (i = 0; i < num_mids; i++) {
+                       if (mids[i] != NULL) {
+                               mids[i]->resp_buf_size = server->pdu_size;
+                               if ((mids[i]->mid_flags & MID_WAIT_CANCELLED) &&
+                                   mids[i]->mid_state == MID_RESPONSE_RECEIVED &&
+                                   server->ops->handle_cancelled_mid)
+                                       server->ops->handle_cancelled_mid(
+                                                       mids[i]->resp_buf,
+                                                       server);
 
-                       cifs_mid_q_entry_release(mid_entry);
-               } else if (server->ops->is_oplock_break &&
-                          server->ops->is_oplock_break(buf, server)) {
-                       cifs_dbg(FYI, "Received oplock break\n");
-               } else {
-                       cifs_dbg(VFS, "No task to wake, unknown frame received! NumMids %d\n",
-                                atomic_read(&midCount));
-                       cifs_dump_mem("Received Data is: ", buf,
-                                     HEADER_SIZE(server));
+                               if (!mids[i]->multiRsp || mids[i]->multiEnd)
+                                       mids[i]->callback(mids[i]);
+
+                               cifs_mid_q_entry_release(mids[i]);
+                       } else if (server->ops->is_oplock_break &&
+                                  server->ops->is_oplock_break(bufs[i],
+                                                               server)) {
+                               cifs_dbg(FYI, "Received oplock break\n");
+                       } else {
+                               cifs_dbg(VFS, "No task to wake, unknown frame "
+                                        "received! NumMids %d\n",
+                                        atomic_read(&midCount));
+                               cifs_dump_mem("Received Data is: ", bufs[i],
+                                             HEADER_SIZE(server));
 #ifdef CONFIG_CIFS_DEBUG2
-                       if (server->ops->dump_detail)
-                               server->ops->dump_detail(buf, server);
-                       cifs_dump_mids(server);
+                               if (server->ops->dump_detail)
+                                       server->ops->dump_detail(bufs[i],
+                                                                server);
+                               cifs_dump_mids(server);
 #endif /* CIFS_DEBUG2 */
+                       }
                }
+
                if (pdu_length > server->pdu_size) {
                        if (!allocate_buffers(server))
                                continue;
index ebc13eb..d237150 100644 (file)
@@ -2942,13 +2942,20 @@ discard_data:
 
 static int
 receive_encrypted_standard(struct TCP_Server_Info *server,
-                          struct mid_q_entry **mid)
+                          struct mid_q_entry **mids, char **bufs,
+                          int *num_mids)
 {
-       int length;
+       int ret, length;
        char *buf = server->smallbuf;
+       char *tmpbuf;
+       struct smb2_sync_hdr *shdr;
        unsigned int pdu_length = server->pdu_size;
        unsigned int buf_size;
        struct mid_q_entry *mid_entry;
+       int next_is_large;
+       char *next_buffer = NULL;
+
+       *num_mids = 0;
 
        /* switch to large buffer if too big for a small one */
        if (pdu_length > MAX_CIFS_SMALL_BUFFER_SIZE) {
@@ -2969,24 +2976,61 @@ receive_encrypted_standard(struct TCP_Server_Info *server,
        if (length)
                return length;
 
+       next_is_large = server->large_buf;
+ one_more:
+       shdr = (struct smb2_sync_hdr *)buf;
+       if (shdr->NextCommand) {
+               if (next_is_large) {
+                       tmpbuf = server->bigbuf;
+                       next_buffer = (char *)cifs_buf_get();
+               } else {
+                       tmpbuf = server->smallbuf;
+                       next_buffer = (char *)cifs_small_buf_get();
+               }
+               memcpy(next_buffer,
+                      tmpbuf + le32_to_cpu(shdr->NextCommand),
+                      pdu_length - le32_to_cpu(shdr->NextCommand));
+       }
+
        mid_entry = smb2_find_mid(server, buf);
        if (mid_entry == NULL)
                cifs_dbg(FYI, "mid not found\n");
        else {
                cifs_dbg(FYI, "mid found\n");
                mid_entry->decrypted = true;
+               mid_entry->resp_buf_size = server->pdu_size;
        }
 
-       *mid = mid_entry;
+       if (*num_mids >= MAX_COMPOUND) {
+               cifs_dbg(VFS, "too many PDUs in compound\n");
+               return -1;
+       }
+       bufs[*num_mids] = buf;
+       mids[(*num_mids)++] = mid_entry;
 
        if (mid_entry && mid_entry->handle)
-               return mid_entry->handle(server, mid_entry);
+               ret = mid_entry->handle(server, mid_entry);
+       else
+               ret = cifs_handle_standard(server, mid_entry);
+
+       if (ret == 0 && shdr->NextCommand) {
+               pdu_length -= le32_to_cpu(shdr->NextCommand);
+               server->large_buf = next_is_large;
+               if (next_is_large)
+                       server->bigbuf = next_buffer;
+               else
+                       server->smallbuf = next_buffer;
+
+               buf += le32_to_cpu(shdr->NextCommand);
+               goto one_more;
+       }
 
-       return cifs_handle_standard(server, mid_entry);
+       return ret;
 }
 
 static int
-smb3_receive_transform(struct TCP_Server_Info *server, struct mid_q_entry **mid)
+smb3_receive_transform(struct TCP_Server_Info *server,
+                      struct mid_q_entry **mids, char **bufs, int *num_mids)
 {
        char *buf = server->smallbuf;
        unsigned int pdu_length = server->pdu_size;
@@ -3009,10 +3053,11 @@ smb3_receive_transform(struct TCP_Server_Info *server, struct mid_q_entry **mid)
                return -ECONNABORTED;
        }
 
+       /* TODO: add support for compounds containing READ. */
        if (pdu_length > CIFSMaxBufSize + MAX_HEADER_SIZE(server))
-               return receive_encrypted_read(server, mid);
+               return receive_encrypted_read(server, &mids[0]);
 
-       return receive_encrypted_standard(server, mid);
+       return receive_encrypted_standard(server, mids, bufs, num_mids);
 }
 
 int
index c53c090..78f96fa 100644 (file)
@@ -383,8 +383,6 @@ smbd_done:
        return rc;
 }
 
-#define MAX_COMPOUND 5
-
 static int
 smb_send_rqst(struct TCP_Server_Info *server, int num_rqst,
              struct smb_rqst *rqst, int flags)