ksmbd: fix wrong smbd max read/write size check
authorNamjae Jeon <linkinjeon@kernel.org>
Mon, 16 May 2022 07:23:28 +0000 (16:23 +0900)
committerSteve French <stfrench@microsoft.com>
Sat, 21 May 2022 20:01:43 +0000 (15:01 -0500)
smb-direct max read/write size can be different with smb2 max read/write
size. So smb2_read() can return error by wrong max read/write size check.
This patch use smb_direct_max_read_write_size for this check in
smb-direct read/write().

Signed-off-by: Namjae Jeon <linkinjeon@kernel.org>
Reviewed-by: Hyunchul Lee <hyc.lee@gmail.com>
Signed-off-by: Steve French <stfrench@microsoft.com>
fs/ksmbd/smb2pdu.c
fs/ksmbd/transport_rdma.c
fs/ksmbd/transport_rdma.h

index 6bc30dd..e6f4ccc 100644 (file)
@@ -6183,6 +6183,8 @@ int smb2_read(struct ksmbd_work *work)
        size_t length, mincount;
        ssize_t nbytes = 0, remain_bytes = 0;
        int err = 0;
+       bool is_rdma_channel = false;
+       unsigned int max_read_size = conn->vals->max_read_size;
 
        WORK_BUFFERS(work, req, rsp);
 
@@ -6194,6 +6196,11 @@ int smb2_read(struct ksmbd_work *work)
 
        if (req->Channel == SMB2_CHANNEL_RDMA_V1_INVALIDATE ||
            req->Channel == SMB2_CHANNEL_RDMA_V1) {
+               is_rdma_channel = true;
+               max_read_size = get_smbd_max_read_write_size();
+       }
+
+       if (is_rdma_channel == true) {
                unsigned int ch_offset = le16_to_cpu(req->ReadChannelInfoOffset);
 
                if (ch_offset < offsetof(struct smb2_read_req, Buffer)) {
@@ -6225,9 +6232,9 @@ int smb2_read(struct ksmbd_work *work)
        length = le32_to_cpu(req->Length);
        mincount = le32_to_cpu(req->MinimumCount);
 
-       if (length > conn->vals->max_read_size) {
+       if (length > max_read_size) {
                ksmbd_debug(SMB, "limiting read size to max size(%u)\n",
-                           conn->vals->max_read_size);
+                           max_read_size);
                err = -EINVAL;
                goto out;
        }
@@ -6259,8 +6266,7 @@ int smb2_read(struct ksmbd_work *work)
        ksmbd_debug(SMB, "nbytes %zu, offset %lld mincount %zu\n",
                    nbytes, offset, mincount);
 
-       if (req->Channel == SMB2_CHANNEL_RDMA_V1_INVALIDATE ||
-           req->Channel == SMB2_CHANNEL_RDMA_V1) {
+       if (is_rdma_channel == true) {
                /* write data to the client using rdma channel */
                remain_bytes = smb2_read_rdma_channel(work, req,
                                                      work->aux_payload_buf,
@@ -6421,8 +6427,9 @@ int smb2_write(struct ksmbd_work *work)
        size_t length;
        ssize_t nbytes;
        char *data_buf;
-       bool writethrough = false;
+       bool writethrough = false, is_rdma_channel = false;
        int err = 0;
+       unsigned int max_write_size = work->conn->vals->max_write_size;
 
        WORK_BUFFERS(work, req, rsp);
 
@@ -6431,8 +6438,17 @@ int smb2_write(struct ksmbd_work *work)
                return smb2_write_pipe(work);
        }
 
+       offset = le64_to_cpu(req->Offset);
+       length = le32_to_cpu(req->Length);
+
        if (req->Channel == SMB2_CHANNEL_RDMA_V1 ||
            req->Channel == SMB2_CHANNEL_RDMA_V1_INVALIDATE) {
+               is_rdma_channel = true;
+               max_write_size = get_smbd_max_read_write_size();
+               length = le32_to_cpu(req->RemainingBytes);
+       }
+
+       if (is_rdma_channel == true) {
                unsigned int ch_offset = le16_to_cpu(req->WriteChannelInfoOffset);
 
                if (req->Length != 0 || req->DataOffset != 0 ||
@@ -6467,12 +6483,9 @@ int smb2_write(struct ksmbd_work *work)
                goto out;
        }
 
-       offset = le64_to_cpu(req->Offset);
-       length = le32_to_cpu(req->Length);
-
-       if (length > work->conn->vals->max_write_size) {
+       if (length > max_write_size) {
                ksmbd_debug(SMB, "limiting write size to max size(%u)\n",
-                           work->conn->vals->max_write_size);
+                           max_write_size);
                err = -EINVAL;
                goto out;
        }
@@ -6480,8 +6493,7 @@ int smb2_write(struct ksmbd_work *work)
        if (le32_to_cpu(req->Flags) & SMB2_WRITEFLAG_WRITE_THROUGH)
                writethrough = true;
 
-       if (req->Channel != SMB2_CHANNEL_RDMA_V1 &&
-           req->Channel != SMB2_CHANNEL_RDMA_V1_INVALIDATE) {
+       if (is_rdma_channel == false) {
                if ((u64)le16_to_cpu(req->DataOffset) + length >
                    get_rfc1002_len(work->request_buf)) {
                        pr_err("invalid write data offset %u, smb_len %u\n",
@@ -6507,8 +6519,7 @@ int smb2_write(struct ksmbd_work *work)
                /* read data from the client using rdma channel, and
                 * write the data.
                 */
-               nbytes = smb2_write_rdma_channel(work, req, fp, offset,
-                                                le32_to_cpu(req->RemainingBytes),
+               nbytes = smb2_write_rdma_channel(work, req, fp, offset, length,
                                                 writethrough);
                if (nbytes < 0) {
                        err = (int)nbytes;
index 6d652ff..0741fd1 100644 (file)
@@ -220,6 +220,11 @@ void init_smbd_max_io_size(unsigned int sz)
        smb_direct_max_read_write_size = sz;
 }
 
+unsigned int get_smbd_max_read_write_size(void)
+{
+       return smb_direct_max_read_write_size;
+}
+
 static inline int get_buf_page_count(void *buf, int size)
 {
        return DIV_ROUND_UP((uintptr_t)buf + size, PAGE_SIZE) -
index e7b4e67..77aee4e 100644 (file)
@@ -57,11 +57,13 @@ int ksmbd_rdma_init(void);
 void ksmbd_rdma_destroy(void);
 bool ksmbd_rdma_capable_netdev(struct net_device *netdev);
 void init_smbd_max_io_size(unsigned int sz);
+unsigned int get_smbd_max_read_write_size(void);
 #else
 static inline int ksmbd_rdma_init(void) { return 0; }
 static inline int ksmbd_rdma_destroy(void) { return 0; }
 static inline bool ksmbd_rdma_capable_netdev(struct net_device *netdev) { return false; }
 static inline void init_smbd_max_io_size(unsigned int sz) { }
+static inline unsigned int get_smbd_max_read_write_size(void) { return 0; }
 #endif
 
 #endif /* __KSMBD_TRANSPORT_RDMA_H__ */