ocfs2: fix data corruption after conversion from inline format
[platform/kernel/linux-starfive.git] / fs / ksmbd / transport_rdma.c
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  *   Copyright (C) 2017, Microsoft Corporation.
4  *   Copyright (C) 2018, LG Electronics.
5  *
6  *   Author(s): Long Li <longli@microsoft.com>,
7  *              Hyunchul Lee <hyc.lee@gmail.com>
8  *
9  *   This program is free software;  you can redistribute it and/or modify
10  *   it under the terms of the GNU General Public License as published by
11  *   the Free Software Foundation; either version 2 of the License, or
12  *   (at your option) any later version.
13  *
14  *   This program is distributed in the hope that it will be useful,
15  *   but WITHOUT ANY WARRANTY;  without even the implied warranty of
16  *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See
17  *   the GNU General Public License for more details.
18  */
19
20 #define SUBMOD_NAME     "smb_direct"
21
22 #include <linux/kthread.h>
23 #include <linux/list.h>
24 #include <linux/mempool.h>
25 #include <linux/highmem.h>
26 #include <linux/scatterlist.h>
27 #include <rdma/ib_verbs.h>
28 #include <rdma/rdma_cm.h>
29 #include <rdma/rw.h>
30
31 #include "glob.h"
32 #include "connection.h"
33 #include "smb_common.h"
34 #include "smbstatus.h"
35 #include "transport_rdma.h"
36
37 #define SMB_DIRECT_PORT 5445
38
39 #define SMB_DIRECT_VERSION_LE           cpu_to_le16(0x0100)
40
41 /* SMB_DIRECT negotiation timeout in seconds */
42 #define SMB_DIRECT_NEGOTIATE_TIMEOUT            120
43
44 #define SMB_DIRECT_MAX_SEND_SGES                8
45 #define SMB_DIRECT_MAX_RECV_SGES                1
46
47 /*
48  * Default maximum number of RDMA read/write outstanding on this connection
49  * This value is possibly decreased during QP creation on hardware limit
50  */
51 #define SMB_DIRECT_CM_INITIATOR_DEPTH           8
52
53 /* Maximum number of retries on data transfer operations */
54 #define SMB_DIRECT_CM_RETRY                     6
55 /* No need to retry on Receiver Not Ready since SMB_DIRECT manages credits */
56 #define SMB_DIRECT_CM_RNR_RETRY         0
57
58 /*
59  * User configurable initial values per SMB_DIRECT transport connection
60  * as defined in [MS-SMBD] 3.1.1.1
61  * Those may change after a SMB_DIRECT negotiation
62  */
63 /* The local peer's maximum number of credits to grant to the peer */
64 static int smb_direct_receive_credit_max = 255;
65
66 /* The remote peer's credit request of local peer */
67 static int smb_direct_send_credit_target = 255;
68
69 /* The maximum single message size can be sent to remote peer */
70 static int smb_direct_max_send_size = 8192;
71
72 /*  The maximum fragmented upper-layer payload receive size supported */
73 static int smb_direct_max_fragmented_recv_size = 1024 * 1024;
74
75 /*  The maximum single-message size which can be received */
76 static int smb_direct_max_receive_size = 8192;
77
78 static int smb_direct_max_read_write_size = 1024 * 1024;
79
80 static int smb_direct_max_outstanding_rw_ops = 8;
81
82 static struct smb_direct_listener {
83         struct rdma_cm_id       *cm_id;
84 } smb_direct_listener;
85
86 static struct workqueue_struct *smb_direct_wq;
87
88 enum smb_direct_status {
89         SMB_DIRECT_CS_NEW = 0,
90         SMB_DIRECT_CS_CONNECTED,
91         SMB_DIRECT_CS_DISCONNECTING,
92         SMB_DIRECT_CS_DISCONNECTED,
93 };
94
95 struct smb_direct_transport {
96         struct ksmbd_transport  transport;
97
98         enum smb_direct_status  status;
99         bool                    full_packet_received;
100         wait_queue_head_t       wait_status;
101
102         struct rdma_cm_id       *cm_id;
103         struct ib_cq            *send_cq;
104         struct ib_cq            *recv_cq;
105         struct ib_pd            *pd;
106         struct ib_qp            *qp;
107
108         int                     max_send_size;
109         int                     max_recv_size;
110         int                     max_fragmented_send_size;
111         int                     max_fragmented_recv_size;
112         int                     max_rdma_rw_size;
113
114         spinlock_t              reassembly_queue_lock;
115         struct list_head        reassembly_queue;
116         int                     reassembly_data_length;
117         int                     reassembly_queue_length;
118         int                     first_entry_offset;
119         wait_queue_head_t       wait_reassembly_queue;
120
121         spinlock_t              receive_credit_lock;
122         int                     recv_credits;
123         int                     count_avail_recvmsg;
124         int                     recv_credit_max;
125         int                     recv_credit_target;
126
127         spinlock_t              recvmsg_queue_lock;
128         struct list_head        recvmsg_queue;
129
130         spinlock_t              empty_recvmsg_queue_lock;
131         struct list_head        empty_recvmsg_queue;
132
133         int                     send_credit_target;
134         atomic_t                send_credits;
135         spinlock_t              lock_new_recv_credits;
136         int                     new_recv_credits;
137         atomic_t                rw_avail_ops;
138
139         wait_queue_head_t       wait_send_credits;
140         wait_queue_head_t       wait_rw_avail_ops;
141
142         mempool_t               *sendmsg_mempool;
143         struct kmem_cache       *sendmsg_cache;
144         mempool_t               *recvmsg_mempool;
145         struct kmem_cache       *recvmsg_cache;
146
147         wait_queue_head_t       wait_send_payload_pending;
148         atomic_t                send_payload_pending;
149         wait_queue_head_t       wait_send_pending;
150         atomic_t                send_pending;
151
152         struct delayed_work     post_recv_credits_work;
153         struct work_struct      send_immediate_work;
154         struct work_struct      disconnect_work;
155
156         bool                    negotiation_requested;
157 };
158
159 #define KSMBD_TRANS(t) ((struct ksmbd_transport *)&((t)->transport))
160
161 enum {
162         SMB_DIRECT_MSG_NEGOTIATE_REQ = 0,
163         SMB_DIRECT_MSG_DATA_TRANSFER
164 };
165
166 static struct ksmbd_transport_ops ksmbd_smb_direct_transport_ops;
167
168 struct smb_direct_send_ctx {
169         struct list_head        msg_list;
170         int                     wr_cnt;
171         bool                    need_invalidate_rkey;
172         unsigned int            remote_key;
173 };
174
175 struct smb_direct_sendmsg {
176         struct smb_direct_transport     *transport;
177         struct ib_send_wr       wr;
178         struct list_head        list;
179         int                     num_sge;
180         struct ib_sge           sge[SMB_DIRECT_MAX_SEND_SGES];
181         struct ib_cqe           cqe;
182         u8                      packet[];
183 };
184
185 struct smb_direct_recvmsg {
186         struct smb_direct_transport     *transport;
187         struct list_head        list;
188         int                     type;
189         struct ib_sge           sge;
190         struct ib_cqe           cqe;
191         bool                    first_segment;
192         u8                      packet[];
193 };
194
195 struct smb_direct_rdma_rw_msg {
196         struct smb_direct_transport     *t;
197         struct ib_cqe           cqe;
198         struct completion       *completion;
199         struct rdma_rw_ctx      rw_ctx;
200         struct sg_table         sgt;
201         struct scatterlist      sg_list[0];
202 };
203
204 static inline int get_buf_page_count(void *buf, int size)
205 {
206         return DIV_ROUND_UP((uintptr_t)buf + size, PAGE_SIZE) -
207                 (uintptr_t)buf / PAGE_SIZE;
208 }
209
210 static void smb_direct_destroy_pools(struct smb_direct_transport *transport);
211 static void smb_direct_post_recv_credits(struct work_struct *work);
212 static int smb_direct_post_send_data(struct smb_direct_transport *t,
213                                      struct smb_direct_send_ctx *send_ctx,
214                                      struct kvec *iov, int niov,
215                                      int remaining_data_length);
216
217 static inline struct smb_direct_transport *
218 smb_trans_direct_transfort(struct ksmbd_transport *t)
219 {
220         return container_of(t, struct smb_direct_transport, transport);
221 }
222
223 static inline void
224 *smb_direct_recvmsg_payload(struct smb_direct_recvmsg *recvmsg)
225 {
226         return (void *)recvmsg->packet;
227 }
228
229 static inline bool is_receive_credit_post_required(int receive_credits,
230                                                    int avail_recvmsg_count)
231 {
232         return receive_credits <= (smb_direct_receive_credit_max >> 3) &&
233                 avail_recvmsg_count >= (receive_credits >> 2);
234 }
235
236 static struct
237 smb_direct_recvmsg *get_free_recvmsg(struct smb_direct_transport *t)
238 {
239         struct smb_direct_recvmsg *recvmsg = NULL;
240
241         spin_lock(&t->recvmsg_queue_lock);
242         if (!list_empty(&t->recvmsg_queue)) {
243                 recvmsg = list_first_entry(&t->recvmsg_queue,
244                                            struct smb_direct_recvmsg,
245                                            list);
246                 list_del(&recvmsg->list);
247         }
248         spin_unlock(&t->recvmsg_queue_lock);
249         return recvmsg;
250 }
251
252 static void put_recvmsg(struct smb_direct_transport *t,
253                         struct smb_direct_recvmsg *recvmsg)
254 {
255         ib_dma_unmap_single(t->cm_id->device, recvmsg->sge.addr,
256                             recvmsg->sge.length, DMA_FROM_DEVICE);
257
258         spin_lock(&t->recvmsg_queue_lock);
259         list_add(&recvmsg->list, &t->recvmsg_queue);
260         spin_unlock(&t->recvmsg_queue_lock);
261 }
262
263 static struct
264 smb_direct_recvmsg *get_empty_recvmsg(struct smb_direct_transport *t)
265 {
266         struct smb_direct_recvmsg *recvmsg = NULL;
267
268         spin_lock(&t->empty_recvmsg_queue_lock);
269         if (!list_empty(&t->empty_recvmsg_queue)) {
270                 recvmsg = list_first_entry(&t->empty_recvmsg_queue,
271                                            struct smb_direct_recvmsg, list);
272                 list_del(&recvmsg->list);
273         }
274         spin_unlock(&t->empty_recvmsg_queue_lock);
275         return recvmsg;
276 }
277
278 static void put_empty_recvmsg(struct smb_direct_transport *t,
279                               struct smb_direct_recvmsg *recvmsg)
280 {
281         ib_dma_unmap_single(t->cm_id->device, recvmsg->sge.addr,
282                             recvmsg->sge.length, DMA_FROM_DEVICE);
283
284         spin_lock(&t->empty_recvmsg_queue_lock);
285         list_add_tail(&recvmsg->list, &t->empty_recvmsg_queue);
286         spin_unlock(&t->empty_recvmsg_queue_lock);
287 }
288
289 static void enqueue_reassembly(struct smb_direct_transport *t,
290                                struct smb_direct_recvmsg *recvmsg,
291                                int data_length)
292 {
293         spin_lock(&t->reassembly_queue_lock);
294         list_add_tail(&recvmsg->list, &t->reassembly_queue);
295         t->reassembly_queue_length++;
296         /*
297          * Make sure reassembly_data_length is updated after list and
298          * reassembly_queue_length are updated. On the dequeue side
299          * reassembly_data_length is checked without a lock to determine
300          * if reassembly_queue_length and list is up to date
301          */
302         virt_wmb();
303         t->reassembly_data_length += data_length;
304         spin_unlock(&t->reassembly_queue_lock);
305 }
306
307 static struct smb_direct_recvmsg *get_first_reassembly(struct smb_direct_transport *t)
308 {
309         if (!list_empty(&t->reassembly_queue))
310                 return list_first_entry(&t->reassembly_queue,
311                                 struct smb_direct_recvmsg, list);
312         else
313                 return NULL;
314 }
315
316 static void smb_direct_disconnect_rdma_work(struct work_struct *work)
317 {
318         struct smb_direct_transport *t =
319                 container_of(work, struct smb_direct_transport,
320                              disconnect_work);
321
322         if (t->status == SMB_DIRECT_CS_CONNECTED) {
323                 t->status = SMB_DIRECT_CS_DISCONNECTING;
324                 rdma_disconnect(t->cm_id);
325         }
326 }
327
328 static void
329 smb_direct_disconnect_rdma_connection(struct smb_direct_transport *t)
330 {
331         if (t->status == SMB_DIRECT_CS_CONNECTED)
332                 queue_work(smb_direct_wq, &t->disconnect_work);
333 }
334
335 static void smb_direct_send_immediate_work(struct work_struct *work)
336 {
337         struct smb_direct_transport *t = container_of(work,
338                         struct smb_direct_transport, send_immediate_work);
339
340         if (t->status != SMB_DIRECT_CS_CONNECTED)
341                 return;
342
343         smb_direct_post_send_data(t, NULL, NULL, 0, 0);
344 }
345
346 static struct smb_direct_transport *alloc_transport(struct rdma_cm_id *cm_id)
347 {
348         struct smb_direct_transport *t;
349         struct ksmbd_conn *conn;
350
351         t = kzalloc(sizeof(*t), GFP_KERNEL);
352         if (!t)
353                 return NULL;
354
355         t->cm_id = cm_id;
356         cm_id->context = t;
357
358         t->status = SMB_DIRECT_CS_NEW;
359         init_waitqueue_head(&t->wait_status);
360
361         spin_lock_init(&t->reassembly_queue_lock);
362         INIT_LIST_HEAD(&t->reassembly_queue);
363         t->reassembly_data_length = 0;
364         t->reassembly_queue_length = 0;
365         init_waitqueue_head(&t->wait_reassembly_queue);
366         init_waitqueue_head(&t->wait_send_credits);
367         init_waitqueue_head(&t->wait_rw_avail_ops);
368
369         spin_lock_init(&t->receive_credit_lock);
370         spin_lock_init(&t->recvmsg_queue_lock);
371         INIT_LIST_HEAD(&t->recvmsg_queue);
372
373         spin_lock_init(&t->empty_recvmsg_queue_lock);
374         INIT_LIST_HEAD(&t->empty_recvmsg_queue);
375
376         init_waitqueue_head(&t->wait_send_payload_pending);
377         atomic_set(&t->send_payload_pending, 0);
378         init_waitqueue_head(&t->wait_send_pending);
379         atomic_set(&t->send_pending, 0);
380
381         spin_lock_init(&t->lock_new_recv_credits);
382
383         INIT_DELAYED_WORK(&t->post_recv_credits_work,
384                           smb_direct_post_recv_credits);
385         INIT_WORK(&t->send_immediate_work, smb_direct_send_immediate_work);
386         INIT_WORK(&t->disconnect_work, smb_direct_disconnect_rdma_work);
387
388         conn = ksmbd_conn_alloc();
389         if (!conn)
390                 goto err;
391         conn->transport = KSMBD_TRANS(t);
392         KSMBD_TRANS(t)->conn = conn;
393         KSMBD_TRANS(t)->ops = &ksmbd_smb_direct_transport_ops;
394         return t;
395 err:
396         kfree(t);
397         return NULL;
398 }
399
400 static void free_transport(struct smb_direct_transport *t)
401 {
402         struct smb_direct_recvmsg *recvmsg;
403
404         wake_up_interruptible(&t->wait_send_credits);
405
406         ksmbd_debug(RDMA, "wait for all send posted to IB to finish\n");
407         wait_event(t->wait_send_payload_pending,
408                    atomic_read(&t->send_payload_pending) == 0);
409         wait_event(t->wait_send_pending,
410                    atomic_read(&t->send_pending) == 0);
411
412         cancel_work_sync(&t->disconnect_work);
413         cancel_delayed_work_sync(&t->post_recv_credits_work);
414         cancel_work_sync(&t->send_immediate_work);
415
416         if (t->qp) {
417                 ib_drain_qp(t->qp);
418                 ib_destroy_qp(t->qp);
419         }
420
421         ksmbd_debug(RDMA, "drain the reassembly queue\n");
422         do {
423                 spin_lock(&t->reassembly_queue_lock);
424                 recvmsg = get_first_reassembly(t);
425                 if (recvmsg) {
426                         list_del(&recvmsg->list);
427                         spin_unlock(&t->reassembly_queue_lock);
428                         put_recvmsg(t, recvmsg);
429                 } else {
430                         spin_unlock(&t->reassembly_queue_lock);
431                 }
432         } while (recvmsg);
433         t->reassembly_data_length = 0;
434
435         if (t->send_cq)
436                 ib_free_cq(t->send_cq);
437         if (t->recv_cq)
438                 ib_free_cq(t->recv_cq);
439         if (t->pd)
440                 ib_dealloc_pd(t->pd);
441         if (t->cm_id)
442                 rdma_destroy_id(t->cm_id);
443
444         smb_direct_destroy_pools(t);
445         ksmbd_conn_free(KSMBD_TRANS(t)->conn);
446         kfree(t);
447 }
448
449 static struct smb_direct_sendmsg
450 *smb_direct_alloc_sendmsg(struct smb_direct_transport *t)
451 {
452         struct smb_direct_sendmsg *msg;
453
454         msg = mempool_alloc(t->sendmsg_mempool, GFP_KERNEL);
455         if (!msg)
456                 return ERR_PTR(-ENOMEM);
457         msg->transport = t;
458         INIT_LIST_HEAD(&msg->list);
459         msg->num_sge = 0;
460         return msg;
461 }
462
463 static void smb_direct_free_sendmsg(struct smb_direct_transport *t,
464                                     struct smb_direct_sendmsg *msg)
465 {
466         int i;
467
468         if (msg->num_sge > 0) {
469                 ib_dma_unmap_single(t->cm_id->device,
470                                     msg->sge[0].addr, msg->sge[0].length,
471                                     DMA_TO_DEVICE);
472                 for (i = 1; i < msg->num_sge; i++)
473                         ib_dma_unmap_page(t->cm_id->device,
474                                           msg->sge[i].addr, msg->sge[i].length,
475                                           DMA_TO_DEVICE);
476         }
477         mempool_free(msg, t->sendmsg_mempool);
478 }
479
480 static int smb_direct_check_recvmsg(struct smb_direct_recvmsg *recvmsg)
481 {
482         switch (recvmsg->type) {
483         case SMB_DIRECT_MSG_DATA_TRANSFER: {
484                 struct smb_direct_data_transfer *req =
485                         (struct smb_direct_data_transfer *)recvmsg->packet;
486                 struct smb2_hdr *hdr = (struct smb2_hdr *)(recvmsg->packet
487                                 + le32_to_cpu(req->data_offset) - 4);
488                 ksmbd_debug(RDMA,
489                             "CreditGranted: %u, CreditRequested: %u, DataLength: %u, RemainingDataLength: %u, SMB: %x, Command: %u\n",
490                             le16_to_cpu(req->credits_granted),
491                             le16_to_cpu(req->credits_requested),
492                             req->data_length, req->remaining_data_length,
493                             hdr->ProtocolId, hdr->Command);
494                 break;
495         }
496         case SMB_DIRECT_MSG_NEGOTIATE_REQ: {
497                 struct smb_direct_negotiate_req *req =
498                         (struct smb_direct_negotiate_req *)recvmsg->packet;
499                 ksmbd_debug(RDMA,
500                             "MinVersion: %u, MaxVersion: %u, CreditRequested: %u, MaxSendSize: %u, MaxRecvSize: %u, MaxFragmentedSize: %u\n",
501                             le16_to_cpu(req->min_version),
502                             le16_to_cpu(req->max_version),
503                             le16_to_cpu(req->credits_requested),
504                             le32_to_cpu(req->preferred_send_size),
505                             le32_to_cpu(req->max_receive_size),
506                             le32_to_cpu(req->max_fragmented_size));
507                 if (le16_to_cpu(req->min_version) > 0x0100 ||
508                     le16_to_cpu(req->max_version) < 0x0100)
509                         return -EOPNOTSUPP;
510                 if (le16_to_cpu(req->credits_requested) <= 0 ||
511                     le32_to_cpu(req->max_receive_size) <= 128 ||
512                     le32_to_cpu(req->max_fragmented_size) <=
513                                         128 * 1024)
514                         return -ECONNABORTED;
515
516                 break;
517         }
518         default:
519                 return -EINVAL;
520         }
521         return 0;
522 }
523
524 static void recv_done(struct ib_cq *cq, struct ib_wc *wc)
525 {
526         struct smb_direct_recvmsg *recvmsg;
527         struct smb_direct_transport *t;
528
529         recvmsg = container_of(wc->wr_cqe, struct smb_direct_recvmsg, cqe);
530         t = recvmsg->transport;
531
532         if (wc->status != IB_WC_SUCCESS || wc->opcode != IB_WC_RECV) {
533                 if (wc->status != IB_WC_WR_FLUSH_ERR) {
534                         pr_err("Recv error. status='%s (%d)' opcode=%d\n",
535                                ib_wc_status_msg(wc->status), wc->status,
536                                wc->opcode);
537                         smb_direct_disconnect_rdma_connection(t);
538                 }
539                 put_empty_recvmsg(t, recvmsg);
540                 return;
541         }
542
543         ksmbd_debug(RDMA, "Recv completed. status='%s (%d)', opcode=%d\n",
544                     ib_wc_status_msg(wc->status), wc->status,
545                     wc->opcode);
546
547         ib_dma_sync_single_for_cpu(wc->qp->device, recvmsg->sge.addr,
548                                    recvmsg->sge.length, DMA_FROM_DEVICE);
549
550         switch (recvmsg->type) {
551         case SMB_DIRECT_MSG_NEGOTIATE_REQ:
552                 t->negotiation_requested = true;
553                 t->full_packet_received = true;
554                 wake_up_interruptible(&t->wait_status);
555                 break;
556         case SMB_DIRECT_MSG_DATA_TRANSFER: {
557                 struct smb_direct_data_transfer *data_transfer =
558                         (struct smb_direct_data_transfer *)recvmsg->packet;
559                 int data_length = le32_to_cpu(data_transfer->data_length);
560                 int avail_recvmsg_count, receive_credits;
561
562                 if (data_length) {
563                         if (t->full_packet_received)
564                                 recvmsg->first_segment = true;
565
566                         if (le32_to_cpu(data_transfer->remaining_data_length))
567                                 t->full_packet_received = false;
568                         else
569                                 t->full_packet_received = true;
570
571                         enqueue_reassembly(t, recvmsg, data_length);
572                         wake_up_interruptible(&t->wait_reassembly_queue);
573
574                         spin_lock(&t->receive_credit_lock);
575                         receive_credits = --(t->recv_credits);
576                         avail_recvmsg_count = t->count_avail_recvmsg;
577                         spin_unlock(&t->receive_credit_lock);
578                 } else {
579                         put_empty_recvmsg(t, recvmsg);
580
581                         spin_lock(&t->receive_credit_lock);
582                         receive_credits = --(t->recv_credits);
583                         avail_recvmsg_count = ++(t->count_avail_recvmsg);
584                         spin_unlock(&t->receive_credit_lock);
585                 }
586
587                 t->recv_credit_target =
588                                 le16_to_cpu(data_transfer->credits_requested);
589                 atomic_add(le16_to_cpu(data_transfer->credits_granted),
590                            &t->send_credits);
591
592                 if (le16_to_cpu(data_transfer->flags) &
593                     SMB_DIRECT_RESPONSE_REQUESTED)
594                         queue_work(smb_direct_wq, &t->send_immediate_work);
595
596                 if (atomic_read(&t->send_credits) > 0)
597                         wake_up_interruptible(&t->wait_send_credits);
598
599                 if (is_receive_credit_post_required(receive_credits, avail_recvmsg_count))
600                         mod_delayed_work(smb_direct_wq,
601                                          &t->post_recv_credits_work, 0);
602                 break;
603         }
604         default:
605                 break;
606         }
607 }
608
609 static int smb_direct_post_recv(struct smb_direct_transport *t,
610                                 struct smb_direct_recvmsg *recvmsg)
611 {
612         struct ib_recv_wr wr;
613         int ret;
614
615         recvmsg->sge.addr = ib_dma_map_single(t->cm_id->device,
616                                               recvmsg->packet, t->max_recv_size,
617                                               DMA_FROM_DEVICE);
618         ret = ib_dma_mapping_error(t->cm_id->device, recvmsg->sge.addr);
619         if (ret)
620                 return ret;
621         recvmsg->sge.length = t->max_recv_size;
622         recvmsg->sge.lkey = t->pd->local_dma_lkey;
623         recvmsg->cqe.done = recv_done;
624
625         wr.wr_cqe = &recvmsg->cqe;
626         wr.next = NULL;
627         wr.sg_list = &recvmsg->sge;
628         wr.num_sge = 1;
629
630         ret = ib_post_recv(t->qp, &wr, NULL);
631         if (ret) {
632                 pr_err("Can't post recv: %d\n", ret);
633                 ib_dma_unmap_single(t->cm_id->device,
634                                     recvmsg->sge.addr, recvmsg->sge.length,
635                                     DMA_FROM_DEVICE);
636                 smb_direct_disconnect_rdma_connection(t);
637                 return ret;
638         }
639         return ret;
640 }
641
642 static int smb_direct_read(struct ksmbd_transport *t, char *buf,
643                            unsigned int size)
644 {
645         struct smb_direct_recvmsg *recvmsg;
646         struct smb_direct_data_transfer *data_transfer;
647         int to_copy, to_read, data_read, offset;
648         u32 data_length, remaining_data_length, data_offset;
649         int rc;
650         struct smb_direct_transport *st = smb_trans_direct_transfort(t);
651
652 again:
653         if (st->status != SMB_DIRECT_CS_CONNECTED) {
654                 pr_err("disconnected\n");
655                 return -ENOTCONN;
656         }
657
658         /*
659          * No need to hold the reassembly queue lock all the time as we are
660          * the only one reading from the front of the queue. The transport
661          * may add more entries to the back of the queue at the same time
662          */
663         if (st->reassembly_data_length >= size) {
664                 int queue_length;
665                 int queue_removed = 0;
666
667                 /*
668                  * Need to make sure reassembly_data_length is read before
669                  * reading reassembly_queue_length and calling
670                  * get_first_reassembly. This call is lock free
671                  * as we never read at the end of the queue which are being
672                  * updated in SOFTIRQ as more data is received
673                  */
674                 virt_rmb();
675                 queue_length = st->reassembly_queue_length;
676                 data_read = 0;
677                 to_read = size;
678                 offset = st->first_entry_offset;
679                 while (data_read < size) {
680                         recvmsg = get_first_reassembly(st);
681                         data_transfer = smb_direct_recvmsg_payload(recvmsg);
682                         data_length = le32_to_cpu(data_transfer->data_length);
683                         remaining_data_length =
684                                 le32_to_cpu(data_transfer->remaining_data_length);
685                         data_offset = le32_to_cpu(data_transfer->data_offset);
686
687                         /*
688                          * The upper layer expects RFC1002 length at the
689                          * beginning of the payload. Return it to indicate
690                          * the total length of the packet. This minimize the
691                          * change to upper layer packet processing logic. This
692                          * will be eventually remove when an intermediate
693                          * transport layer is added
694                          */
695                         if (recvmsg->first_segment && size == 4) {
696                                 unsigned int rfc1002_len =
697                                         data_length + remaining_data_length;
698                                 *((__be32 *)buf) = cpu_to_be32(rfc1002_len);
699                                 data_read = 4;
700                                 recvmsg->first_segment = false;
701                                 ksmbd_debug(RDMA,
702                                             "returning rfc1002 length %d\n",
703                                             rfc1002_len);
704                                 goto read_rfc1002_done;
705                         }
706
707                         to_copy = min_t(int, data_length - offset, to_read);
708                         memcpy(buf + data_read, (char *)data_transfer + data_offset + offset,
709                                to_copy);
710
711                         /* move on to the next buffer? */
712                         if (to_copy == data_length - offset) {
713                                 queue_length--;
714                                 /*
715                                  * No need to lock if we are not at the
716                                  * end of the queue
717                                  */
718                                 if (queue_length) {
719                                         list_del(&recvmsg->list);
720                                 } else {
721                                         spin_lock_irq(&st->reassembly_queue_lock);
722                                         list_del(&recvmsg->list);
723                                         spin_unlock_irq(&st->reassembly_queue_lock);
724                                 }
725                                 queue_removed++;
726                                 put_recvmsg(st, recvmsg);
727                                 offset = 0;
728                         } else {
729                                 offset += to_copy;
730                         }
731
732                         to_read -= to_copy;
733                         data_read += to_copy;
734                 }
735
736                 spin_lock_irq(&st->reassembly_queue_lock);
737                 st->reassembly_data_length -= data_read;
738                 st->reassembly_queue_length -= queue_removed;
739                 spin_unlock_irq(&st->reassembly_queue_lock);
740
741                 spin_lock(&st->receive_credit_lock);
742                 st->count_avail_recvmsg += queue_removed;
743                 if (is_receive_credit_post_required(st->recv_credits, st->count_avail_recvmsg)) {
744                         spin_unlock(&st->receive_credit_lock);
745                         mod_delayed_work(smb_direct_wq,
746                                          &st->post_recv_credits_work, 0);
747                 } else {
748                         spin_unlock(&st->receive_credit_lock);
749                 }
750
751                 st->first_entry_offset = offset;
752                 ksmbd_debug(RDMA,
753                             "returning to thread data_read=%d reassembly_data_length=%d first_entry_offset=%d\n",
754                             data_read, st->reassembly_data_length,
755                             st->first_entry_offset);
756 read_rfc1002_done:
757                 return data_read;
758         }
759
760         ksmbd_debug(RDMA, "wait_event on more data\n");
761         rc = wait_event_interruptible(st->wait_reassembly_queue,
762                                       st->reassembly_data_length >= size ||
763                                        st->status != SMB_DIRECT_CS_CONNECTED);
764         if (rc)
765                 return -EINTR;
766
767         goto again;
768 }
769
770 static void smb_direct_post_recv_credits(struct work_struct *work)
771 {
772         struct smb_direct_transport *t = container_of(work,
773                 struct smb_direct_transport, post_recv_credits_work.work);
774         struct smb_direct_recvmsg *recvmsg;
775         int receive_credits, credits = 0;
776         int ret;
777         int use_free = 1;
778
779         spin_lock(&t->receive_credit_lock);
780         receive_credits = t->recv_credits;
781         spin_unlock(&t->receive_credit_lock);
782
783         if (receive_credits < t->recv_credit_target) {
784                 while (true) {
785                         if (use_free)
786                                 recvmsg = get_free_recvmsg(t);
787                         else
788                                 recvmsg = get_empty_recvmsg(t);
789                         if (!recvmsg) {
790                                 if (use_free) {
791                                         use_free = 0;
792                                         continue;
793                                 } else {
794                                         break;
795                                 }
796                         }
797
798                         recvmsg->type = SMB_DIRECT_MSG_DATA_TRANSFER;
799                         recvmsg->first_segment = false;
800
801                         ret = smb_direct_post_recv(t, recvmsg);
802                         if (ret) {
803                                 pr_err("Can't post recv: %d\n", ret);
804                                 put_recvmsg(t, recvmsg);
805                                 break;
806                         }
807                         credits++;
808                 }
809         }
810
811         spin_lock(&t->receive_credit_lock);
812         t->recv_credits += credits;
813         t->count_avail_recvmsg -= credits;
814         spin_unlock(&t->receive_credit_lock);
815
816         spin_lock(&t->lock_new_recv_credits);
817         t->new_recv_credits += credits;
818         spin_unlock(&t->lock_new_recv_credits);
819
820         if (credits)
821                 queue_work(smb_direct_wq, &t->send_immediate_work);
822 }
823
824 static void send_done(struct ib_cq *cq, struct ib_wc *wc)
825 {
826         struct smb_direct_sendmsg *sendmsg, *sibling;
827         struct smb_direct_transport *t;
828         struct list_head *pos, *prev, *end;
829
830         sendmsg = container_of(wc->wr_cqe, struct smb_direct_sendmsg, cqe);
831         t = sendmsg->transport;
832
833         ksmbd_debug(RDMA, "Send completed. status='%s (%d)', opcode=%d\n",
834                     ib_wc_status_msg(wc->status), wc->status,
835                     wc->opcode);
836
837         if (wc->status != IB_WC_SUCCESS || wc->opcode != IB_WC_SEND) {
838                 pr_err("Send error. status='%s (%d)', opcode=%d\n",
839                        ib_wc_status_msg(wc->status), wc->status,
840                        wc->opcode);
841                 smb_direct_disconnect_rdma_connection(t);
842         }
843
844         if (sendmsg->num_sge > 1) {
845                 if (atomic_dec_and_test(&t->send_payload_pending))
846                         wake_up(&t->wait_send_payload_pending);
847         } else {
848                 if (atomic_dec_and_test(&t->send_pending))
849                         wake_up(&t->wait_send_pending);
850         }
851
852         /* iterate and free the list of messages in reverse. the list's head
853          * is invalid.
854          */
855         for (pos = &sendmsg->list, prev = pos->prev, end = sendmsg->list.next;
856              prev != end; pos = prev, prev = prev->prev) {
857                 sibling = container_of(pos, struct smb_direct_sendmsg, list);
858                 smb_direct_free_sendmsg(t, sibling);
859         }
860
861         sibling = container_of(pos, struct smb_direct_sendmsg, list);
862         smb_direct_free_sendmsg(t, sibling);
863 }
864
865 static int manage_credits_prior_sending(struct smb_direct_transport *t)
866 {
867         int new_credits;
868
869         spin_lock(&t->lock_new_recv_credits);
870         new_credits = t->new_recv_credits;
871         t->new_recv_credits = 0;
872         spin_unlock(&t->lock_new_recv_credits);
873
874         return new_credits;
875 }
876
877 static int smb_direct_post_send(struct smb_direct_transport *t,
878                                 struct ib_send_wr *wr)
879 {
880         int ret;
881
882         if (wr->num_sge > 1)
883                 atomic_inc(&t->send_payload_pending);
884         else
885                 atomic_inc(&t->send_pending);
886
887         ret = ib_post_send(t->qp, wr, NULL);
888         if (ret) {
889                 pr_err("failed to post send: %d\n", ret);
890                 if (wr->num_sge > 1) {
891                         if (atomic_dec_and_test(&t->send_payload_pending))
892                                 wake_up(&t->wait_send_payload_pending);
893                 } else {
894                         if (atomic_dec_and_test(&t->send_pending))
895                                 wake_up(&t->wait_send_pending);
896                 }
897                 smb_direct_disconnect_rdma_connection(t);
898         }
899         return ret;
900 }
901
902 static void smb_direct_send_ctx_init(struct smb_direct_transport *t,
903                                      struct smb_direct_send_ctx *send_ctx,
904                                      bool need_invalidate_rkey,
905                                      unsigned int remote_key)
906 {
907         INIT_LIST_HEAD(&send_ctx->msg_list);
908         send_ctx->wr_cnt = 0;
909         send_ctx->need_invalidate_rkey = need_invalidate_rkey;
910         send_ctx->remote_key = remote_key;
911 }
912
913 static int smb_direct_flush_send_list(struct smb_direct_transport *t,
914                                       struct smb_direct_send_ctx *send_ctx,
915                                       bool is_last)
916 {
917         struct smb_direct_sendmsg *first, *last;
918         int ret;
919
920         if (list_empty(&send_ctx->msg_list))
921                 return 0;
922
923         first = list_first_entry(&send_ctx->msg_list,
924                                  struct smb_direct_sendmsg,
925                                  list);
926         last = list_last_entry(&send_ctx->msg_list,
927                                struct smb_direct_sendmsg,
928                                list);
929
930         last->wr.send_flags = IB_SEND_SIGNALED;
931         last->wr.wr_cqe = &last->cqe;
932         if (is_last && send_ctx->need_invalidate_rkey) {
933                 last->wr.opcode = IB_WR_SEND_WITH_INV;
934                 last->wr.ex.invalidate_rkey = send_ctx->remote_key;
935         }
936
937         ret = smb_direct_post_send(t, &first->wr);
938         if (!ret) {
939                 smb_direct_send_ctx_init(t, send_ctx,
940                                          send_ctx->need_invalidate_rkey,
941                                          send_ctx->remote_key);
942         } else {
943                 atomic_add(send_ctx->wr_cnt, &t->send_credits);
944                 wake_up(&t->wait_send_credits);
945                 list_for_each_entry_safe(first, last, &send_ctx->msg_list,
946                                          list) {
947                         smb_direct_free_sendmsg(t, first);
948                 }
949         }
950         return ret;
951 }
952
953 static int wait_for_credits(struct smb_direct_transport *t,
954                             wait_queue_head_t *waitq, atomic_t *credits)
955 {
956         int ret;
957
958         do {
959                 if (atomic_dec_return(credits) >= 0)
960                         return 0;
961
962                 atomic_inc(credits);
963                 ret = wait_event_interruptible(*waitq,
964                                                atomic_read(credits) > 0 ||
965                                                 t->status != SMB_DIRECT_CS_CONNECTED);
966
967                 if (t->status != SMB_DIRECT_CS_CONNECTED)
968                         return -ENOTCONN;
969                 else if (ret < 0)
970                         return ret;
971         } while (true);
972 }
973
974 static int wait_for_send_credits(struct smb_direct_transport *t,
975                                  struct smb_direct_send_ctx *send_ctx)
976 {
977         int ret;
978
979         if (send_ctx &&
980             (send_ctx->wr_cnt >= 16 || atomic_read(&t->send_credits) <= 1)) {
981                 ret = smb_direct_flush_send_list(t, send_ctx, false);
982                 if (ret)
983                         return ret;
984         }
985
986         return wait_for_credits(t, &t->wait_send_credits, &t->send_credits);
987 }
988
989 static int smb_direct_create_header(struct smb_direct_transport *t,
990                                     int size, int remaining_data_length,
991                                     struct smb_direct_sendmsg **sendmsg_out)
992 {
993         struct smb_direct_sendmsg *sendmsg;
994         struct smb_direct_data_transfer *packet;
995         int header_length;
996         int ret;
997
998         sendmsg = smb_direct_alloc_sendmsg(t);
999         if (IS_ERR(sendmsg))
1000                 return PTR_ERR(sendmsg);
1001
1002         /* Fill in the packet header */
1003         packet = (struct smb_direct_data_transfer *)sendmsg->packet;
1004         packet->credits_requested = cpu_to_le16(t->send_credit_target);
1005         packet->credits_granted = cpu_to_le16(manage_credits_prior_sending(t));
1006
1007         packet->flags = 0;
1008         packet->reserved = 0;
1009         if (!size)
1010                 packet->data_offset = 0;
1011         else
1012                 packet->data_offset = cpu_to_le32(24);
1013         packet->data_length = cpu_to_le32(size);
1014         packet->remaining_data_length = cpu_to_le32(remaining_data_length);
1015         packet->padding = 0;
1016
1017         ksmbd_debug(RDMA,
1018                     "credits_requested=%d credits_granted=%d data_offset=%d data_length=%d remaining_data_length=%d\n",
1019                     le16_to_cpu(packet->credits_requested),
1020                     le16_to_cpu(packet->credits_granted),
1021                     le32_to_cpu(packet->data_offset),
1022                     le32_to_cpu(packet->data_length),
1023                     le32_to_cpu(packet->remaining_data_length));
1024
1025         /* Map the packet to DMA */
1026         header_length = sizeof(struct smb_direct_data_transfer);
1027         /* If this is a packet without payload, don't send padding */
1028         if (!size)
1029                 header_length =
1030                         offsetof(struct smb_direct_data_transfer, padding);
1031
1032         sendmsg->sge[0].addr = ib_dma_map_single(t->cm_id->device,
1033                                                  (void *)packet,
1034                                                  header_length,
1035                                                  DMA_TO_DEVICE);
1036         ret = ib_dma_mapping_error(t->cm_id->device, sendmsg->sge[0].addr);
1037         if (ret) {
1038                 smb_direct_free_sendmsg(t, sendmsg);
1039                 return ret;
1040         }
1041
1042         sendmsg->num_sge = 1;
1043         sendmsg->sge[0].length = header_length;
1044         sendmsg->sge[0].lkey = t->pd->local_dma_lkey;
1045
1046         *sendmsg_out = sendmsg;
1047         return 0;
1048 }
1049
1050 static int get_sg_list(void *buf, int size, struct scatterlist *sg_list, int nentries)
1051 {
1052         bool high = is_vmalloc_addr(buf);
1053         struct page *page;
1054         int offset, len;
1055         int i = 0;
1056
1057         if (nentries < get_buf_page_count(buf, size))
1058                 return -EINVAL;
1059
1060         offset = offset_in_page(buf);
1061         buf -= offset;
1062         while (size > 0) {
1063                 len = min_t(int, PAGE_SIZE - offset, size);
1064                 if (high)
1065                         page = vmalloc_to_page(buf);
1066                 else
1067                         page = kmap_to_page(buf);
1068
1069                 if (!sg_list)
1070                         return -EINVAL;
1071                 sg_set_page(sg_list, page, len, offset);
1072                 sg_list = sg_next(sg_list);
1073
1074                 buf += PAGE_SIZE;
1075                 size -= len;
1076                 offset = 0;
1077                 i++;
1078         }
1079         return i;
1080 }
1081
1082 static int get_mapped_sg_list(struct ib_device *device, void *buf, int size,
1083                               struct scatterlist *sg_list, int nentries,
1084                               enum dma_data_direction dir)
1085 {
1086         int npages;
1087
1088         npages = get_sg_list(buf, size, sg_list, nentries);
1089         if (npages <= 0)
1090                 return -EINVAL;
1091         return ib_dma_map_sg(device, sg_list, npages, dir);
1092 }
1093
1094 static int post_sendmsg(struct smb_direct_transport *t,
1095                         struct smb_direct_send_ctx *send_ctx,
1096                         struct smb_direct_sendmsg *msg)
1097 {
1098         int i;
1099
1100         for (i = 0; i < msg->num_sge; i++)
1101                 ib_dma_sync_single_for_device(t->cm_id->device,
1102                                               msg->sge[i].addr, msg->sge[i].length,
1103                                               DMA_TO_DEVICE);
1104
1105         msg->cqe.done = send_done;
1106         msg->wr.opcode = IB_WR_SEND;
1107         msg->wr.sg_list = &msg->sge[0];
1108         msg->wr.num_sge = msg->num_sge;
1109         msg->wr.next = NULL;
1110
1111         if (send_ctx) {
1112                 msg->wr.wr_cqe = NULL;
1113                 msg->wr.send_flags = 0;
1114                 if (!list_empty(&send_ctx->msg_list)) {
1115                         struct smb_direct_sendmsg *last;
1116
1117                         last = list_last_entry(&send_ctx->msg_list,
1118                                                struct smb_direct_sendmsg,
1119                                                list);
1120                         last->wr.next = &msg->wr;
1121                 }
1122                 list_add_tail(&msg->list, &send_ctx->msg_list);
1123                 send_ctx->wr_cnt++;
1124                 return 0;
1125         }
1126
1127         msg->wr.wr_cqe = &msg->cqe;
1128         msg->wr.send_flags = IB_SEND_SIGNALED;
1129         return smb_direct_post_send(t, &msg->wr);
1130 }
1131
1132 static int smb_direct_post_send_data(struct smb_direct_transport *t,
1133                                      struct smb_direct_send_ctx *send_ctx,
1134                                      struct kvec *iov, int niov,
1135                                      int remaining_data_length)
1136 {
1137         int i, j, ret;
1138         struct smb_direct_sendmsg *msg;
1139         int data_length;
1140         struct scatterlist sg[SMB_DIRECT_MAX_SEND_SGES - 1];
1141
1142         ret = wait_for_send_credits(t, send_ctx);
1143         if (ret)
1144                 return ret;
1145
1146         data_length = 0;
1147         for (i = 0; i < niov; i++)
1148                 data_length += iov[i].iov_len;
1149
1150         ret = smb_direct_create_header(t, data_length, remaining_data_length,
1151                                        &msg);
1152         if (ret) {
1153                 atomic_inc(&t->send_credits);
1154                 return ret;
1155         }
1156
1157         for (i = 0; i < niov; i++) {
1158                 struct ib_sge *sge;
1159                 int sg_cnt;
1160
1161                 sg_init_table(sg, SMB_DIRECT_MAX_SEND_SGES - 1);
1162                 sg_cnt = get_mapped_sg_list(t->cm_id->device,
1163                                             iov[i].iov_base, iov[i].iov_len,
1164                                             sg, SMB_DIRECT_MAX_SEND_SGES - 1,
1165                                             DMA_TO_DEVICE);
1166                 if (sg_cnt <= 0) {
1167                         pr_err("failed to map buffer\n");
1168                         ret = -ENOMEM;
1169                         goto err;
1170                 } else if (sg_cnt + msg->num_sge > SMB_DIRECT_MAX_SEND_SGES) {
1171                         pr_err("buffer not fitted into sges\n");
1172                         ret = -E2BIG;
1173                         ib_dma_unmap_sg(t->cm_id->device, sg, sg_cnt,
1174                                         DMA_TO_DEVICE);
1175                         goto err;
1176                 }
1177
1178                 for (j = 0; j < sg_cnt; j++) {
1179                         sge = &msg->sge[msg->num_sge];
1180                         sge->addr = sg_dma_address(&sg[j]);
1181                         sge->length = sg_dma_len(&sg[j]);
1182                         sge->lkey  = t->pd->local_dma_lkey;
1183                         msg->num_sge++;
1184                 }
1185         }
1186
1187         ret = post_sendmsg(t, send_ctx, msg);
1188         if (ret)
1189                 goto err;
1190         return 0;
1191 err:
1192         smb_direct_free_sendmsg(t, msg);
1193         atomic_inc(&t->send_credits);
1194         return ret;
1195 }
1196
1197 static int smb_direct_writev(struct ksmbd_transport *t,
1198                              struct kvec *iov, int niovs, int buflen,
1199                              bool need_invalidate, unsigned int remote_key)
1200 {
1201         struct smb_direct_transport *st = smb_trans_direct_transfort(t);
1202         int remaining_data_length;
1203         int start, i, j;
1204         int max_iov_size = st->max_send_size -
1205                         sizeof(struct smb_direct_data_transfer);
1206         int ret;
1207         struct kvec vec;
1208         struct smb_direct_send_ctx send_ctx;
1209
1210         if (st->status != SMB_DIRECT_CS_CONNECTED)
1211                 return -ENOTCONN;
1212
1213         //FIXME: skip RFC1002 header..
1214         buflen -= 4;
1215         iov[0].iov_base += 4;
1216         iov[0].iov_len -= 4;
1217
1218         remaining_data_length = buflen;
1219         ksmbd_debug(RDMA, "Sending smb (RDMA): smb_len=%u\n", buflen);
1220
1221         smb_direct_send_ctx_init(st, &send_ctx, need_invalidate, remote_key);
1222         start = i = 0;
1223         buflen = 0;
1224         while (true) {
1225                 buflen += iov[i].iov_len;
1226                 if (buflen > max_iov_size) {
1227                         if (i > start) {
1228                                 remaining_data_length -=
1229                                         (buflen - iov[i].iov_len);
1230                                 ret = smb_direct_post_send_data(st, &send_ctx,
1231                                                                 &iov[start], i - start,
1232                                                                 remaining_data_length);
1233                                 if (ret)
1234                                         goto done;
1235                         } else {
1236                                 /* iov[start] is too big, break it */
1237                                 int nvec  = (buflen + max_iov_size - 1) /
1238                                                 max_iov_size;
1239
1240                                 for (j = 0; j < nvec; j++) {
1241                                         vec.iov_base =
1242                                                 (char *)iov[start].iov_base +
1243                                                 j * max_iov_size;
1244                                         vec.iov_len =
1245                                                 min_t(int, max_iov_size,
1246                                                       buflen - max_iov_size * j);
1247                                         remaining_data_length -= vec.iov_len;
1248                                         ret = smb_direct_post_send_data(st, &send_ctx, &vec, 1,
1249                                                                         remaining_data_length);
1250                                         if (ret)
1251                                                 goto done;
1252                                 }
1253                                 i++;
1254                                 if (i == niovs)
1255                                         break;
1256                         }
1257                         start = i;
1258                         buflen = 0;
1259                 } else {
1260                         i++;
1261                         if (i == niovs) {
1262                                 /* send out all remaining vecs */
1263                                 remaining_data_length -= buflen;
1264                                 ret = smb_direct_post_send_data(st, &send_ctx,
1265                                                                 &iov[start], i - start,
1266                                                                 remaining_data_length);
1267                                 if (ret)
1268                                         goto done;
1269                                 break;
1270                         }
1271                 }
1272         }
1273
1274 done:
1275         ret = smb_direct_flush_send_list(st, &send_ctx, true);
1276
1277         /*
1278          * As an optimization, we don't wait for individual I/O to finish
1279          * before sending the next one.
1280          * Send them all and wait for pending send count to get to 0
1281          * that means all the I/Os have been out and we are good to return
1282          */
1283
1284         wait_event(st->wait_send_payload_pending,
1285                    atomic_read(&st->send_payload_pending) == 0);
1286         return ret;
1287 }
1288
1289 static void read_write_done(struct ib_cq *cq, struct ib_wc *wc,
1290                             enum dma_data_direction dir)
1291 {
1292         struct smb_direct_rdma_rw_msg *msg = container_of(wc->wr_cqe,
1293                                                           struct smb_direct_rdma_rw_msg, cqe);
1294         struct smb_direct_transport *t = msg->t;
1295
1296         if (wc->status != IB_WC_SUCCESS) {
1297                 pr_err("read/write error. opcode = %d, status = %s(%d)\n",
1298                        wc->opcode, ib_wc_status_msg(wc->status), wc->status);
1299                 smb_direct_disconnect_rdma_connection(t);
1300         }
1301
1302         if (atomic_inc_return(&t->rw_avail_ops) > 0)
1303                 wake_up(&t->wait_rw_avail_ops);
1304
1305         rdma_rw_ctx_destroy(&msg->rw_ctx, t->qp, t->qp->port,
1306                             msg->sg_list, msg->sgt.nents, dir);
1307         sg_free_table_chained(&msg->sgt, SG_CHUNK_SIZE);
1308         complete(msg->completion);
1309         kfree(msg);
1310 }
1311
1312 static void read_done(struct ib_cq *cq, struct ib_wc *wc)
1313 {
1314         read_write_done(cq, wc, DMA_FROM_DEVICE);
1315 }
1316
1317 static void write_done(struct ib_cq *cq, struct ib_wc *wc)
1318 {
1319         read_write_done(cq, wc, DMA_TO_DEVICE);
1320 }
1321
1322 static int smb_direct_rdma_xmit(struct smb_direct_transport *t, void *buf,
1323                                 int buf_len, u32 remote_key, u64 remote_offset,
1324                                 u32 remote_len, bool is_read)
1325 {
1326         struct smb_direct_rdma_rw_msg *msg;
1327         int ret;
1328         DECLARE_COMPLETION_ONSTACK(completion);
1329         struct ib_send_wr *first_wr = NULL;
1330
1331         ret = wait_for_credits(t, &t->wait_rw_avail_ops, &t->rw_avail_ops);
1332         if (ret < 0)
1333                 return ret;
1334
1335         /* TODO: mempool */
1336         msg = kmalloc(offsetof(struct smb_direct_rdma_rw_msg, sg_list) +
1337                       sizeof(struct scatterlist) * SG_CHUNK_SIZE, GFP_KERNEL);
1338         if (!msg) {
1339                 atomic_inc(&t->rw_avail_ops);
1340                 return -ENOMEM;
1341         }
1342
1343         msg->sgt.sgl = &msg->sg_list[0];
1344         ret = sg_alloc_table_chained(&msg->sgt,
1345                                      get_buf_page_count(buf, buf_len),
1346                                      msg->sg_list, SG_CHUNK_SIZE);
1347         if (ret) {
1348                 atomic_inc(&t->rw_avail_ops);
1349                 kfree(msg);
1350                 return -ENOMEM;
1351         }
1352
1353         ret = get_sg_list(buf, buf_len, msg->sgt.sgl, msg->sgt.orig_nents);
1354         if (ret <= 0) {
1355                 pr_err("failed to get pages\n");
1356                 goto err;
1357         }
1358
1359         ret = rdma_rw_ctx_init(&msg->rw_ctx, t->qp, t->qp->port,
1360                                msg->sg_list, get_buf_page_count(buf, buf_len),
1361                                0, remote_offset, remote_key,
1362                                is_read ? DMA_FROM_DEVICE : DMA_TO_DEVICE);
1363         if (ret < 0) {
1364                 pr_err("failed to init rdma_rw_ctx: %d\n", ret);
1365                 goto err;
1366         }
1367
1368         msg->t = t;
1369         msg->cqe.done = is_read ? read_done : write_done;
1370         msg->completion = &completion;
1371         first_wr = rdma_rw_ctx_wrs(&msg->rw_ctx, t->qp, t->qp->port,
1372                                    &msg->cqe, NULL);
1373
1374         ret = ib_post_send(t->qp, first_wr, NULL);
1375         if (ret) {
1376                 pr_err("failed to post send wr: %d\n", ret);
1377                 goto err;
1378         }
1379
1380         wait_for_completion(&completion);
1381         return 0;
1382
1383 err:
1384         atomic_inc(&t->rw_avail_ops);
1385         if (first_wr)
1386                 rdma_rw_ctx_destroy(&msg->rw_ctx, t->qp, t->qp->port,
1387                                     msg->sg_list, msg->sgt.nents,
1388                                     is_read ? DMA_FROM_DEVICE : DMA_TO_DEVICE);
1389         sg_free_table_chained(&msg->sgt, SG_CHUNK_SIZE);
1390         kfree(msg);
1391         return ret;
1392 }
1393
1394 static int smb_direct_rdma_write(struct ksmbd_transport *t, void *buf,
1395                                  unsigned int buflen, u32 remote_key,
1396                                  u64 remote_offset, u32 remote_len)
1397 {
1398         return smb_direct_rdma_xmit(smb_trans_direct_transfort(t), buf, buflen,
1399                                     remote_key, remote_offset,
1400                                     remote_len, false);
1401 }
1402
1403 static int smb_direct_rdma_read(struct ksmbd_transport *t, void *buf,
1404                                 unsigned int buflen, u32 remote_key,
1405                                 u64 remote_offset, u32 remote_len)
1406 {
1407         return smb_direct_rdma_xmit(smb_trans_direct_transfort(t), buf, buflen,
1408                                     remote_key, remote_offset,
1409                                     remote_len, true);
1410 }
1411
1412 static void smb_direct_disconnect(struct ksmbd_transport *t)
1413 {
1414         struct smb_direct_transport *st = smb_trans_direct_transfort(t);
1415
1416         ksmbd_debug(RDMA, "Disconnecting cm_id=%p\n", st->cm_id);
1417
1418         smb_direct_disconnect_rdma_work(&st->disconnect_work);
1419         wait_event_interruptible(st->wait_status,
1420                                  st->status == SMB_DIRECT_CS_DISCONNECTED);
1421         free_transport(st);
1422 }
1423
1424 static int smb_direct_cm_handler(struct rdma_cm_id *cm_id,
1425                                  struct rdma_cm_event *event)
1426 {
1427         struct smb_direct_transport *t = cm_id->context;
1428
1429         ksmbd_debug(RDMA, "RDMA CM event. cm_id=%p event=%s (%d)\n",
1430                     cm_id, rdma_event_msg(event->event), event->event);
1431
1432         switch (event->event) {
1433         case RDMA_CM_EVENT_ESTABLISHED: {
1434                 t->status = SMB_DIRECT_CS_CONNECTED;
1435                 wake_up_interruptible(&t->wait_status);
1436                 break;
1437         }
1438         case RDMA_CM_EVENT_DEVICE_REMOVAL:
1439         case RDMA_CM_EVENT_DISCONNECTED: {
1440                 t->status = SMB_DIRECT_CS_DISCONNECTED;
1441                 wake_up_interruptible(&t->wait_status);
1442                 wake_up_interruptible(&t->wait_reassembly_queue);
1443                 wake_up(&t->wait_send_credits);
1444                 break;
1445         }
1446         case RDMA_CM_EVENT_CONNECT_ERROR: {
1447                 t->status = SMB_DIRECT_CS_DISCONNECTED;
1448                 wake_up_interruptible(&t->wait_status);
1449                 break;
1450         }
1451         default:
1452                 pr_err("Unexpected RDMA CM event. cm_id=%p, event=%s (%d)\n",
1453                        cm_id, rdma_event_msg(event->event),
1454                        event->event);
1455                 break;
1456         }
1457         return 0;
1458 }
1459
1460 static void smb_direct_qpair_handler(struct ib_event *event, void *context)
1461 {
1462         struct smb_direct_transport *t = context;
1463
1464         ksmbd_debug(RDMA, "Received QP event. cm_id=%p, event=%s (%d)\n",
1465                     t->cm_id, ib_event_msg(event->event), event->event);
1466
1467         switch (event->event) {
1468         case IB_EVENT_CQ_ERR:
1469         case IB_EVENT_QP_FATAL:
1470                 smb_direct_disconnect_rdma_connection(t);
1471                 break;
1472         default:
1473                 break;
1474         }
1475 }
1476
1477 static int smb_direct_send_negotiate_response(struct smb_direct_transport *t,
1478                                               int failed)
1479 {
1480         struct smb_direct_sendmsg *sendmsg;
1481         struct smb_direct_negotiate_resp *resp;
1482         int ret;
1483
1484         sendmsg = smb_direct_alloc_sendmsg(t);
1485         if (IS_ERR(sendmsg))
1486                 return -ENOMEM;
1487
1488         resp = (struct smb_direct_negotiate_resp *)sendmsg->packet;
1489         if (failed) {
1490                 memset(resp, 0, sizeof(*resp));
1491                 resp->min_version = cpu_to_le16(0x0100);
1492                 resp->max_version = cpu_to_le16(0x0100);
1493                 resp->status = STATUS_NOT_SUPPORTED;
1494         } else {
1495                 resp->status = STATUS_SUCCESS;
1496                 resp->min_version = SMB_DIRECT_VERSION_LE;
1497                 resp->max_version = SMB_DIRECT_VERSION_LE;
1498                 resp->negotiated_version = SMB_DIRECT_VERSION_LE;
1499                 resp->reserved = 0;
1500                 resp->credits_requested =
1501                                 cpu_to_le16(t->send_credit_target);
1502                 resp->credits_granted = cpu_to_le16(manage_credits_prior_sending(t));
1503                 resp->max_readwrite_size = cpu_to_le32(t->max_rdma_rw_size);
1504                 resp->preferred_send_size = cpu_to_le32(t->max_send_size);
1505                 resp->max_receive_size = cpu_to_le32(t->max_recv_size);
1506                 resp->max_fragmented_size =
1507                                 cpu_to_le32(t->max_fragmented_recv_size);
1508         }
1509
1510         sendmsg->sge[0].addr = ib_dma_map_single(t->cm_id->device,
1511                                                  (void *)resp, sizeof(*resp),
1512                                                  DMA_TO_DEVICE);
1513         ret = ib_dma_mapping_error(t->cm_id->device, sendmsg->sge[0].addr);
1514         if (ret) {
1515                 smb_direct_free_sendmsg(t, sendmsg);
1516                 return ret;
1517         }
1518
1519         sendmsg->num_sge = 1;
1520         sendmsg->sge[0].length = sizeof(*resp);
1521         sendmsg->sge[0].lkey = t->pd->local_dma_lkey;
1522
1523         ret = post_sendmsg(t, NULL, sendmsg);
1524         if (ret) {
1525                 smb_direct_free_sendmsg(t, sendmsg);
1526                 return ret;
1527         }
1528
1529         wait_event(t->wait_send_pending,
1530                    atomic_read(&t->send_pending) == 0);
1531         return 0;
1532 }
1533
1534 static int smb_direct_accept_client(struct smb_direct_transport *t)
1535 {
1536         struct rdma_conn_param conn_param;
1537         struct ib_port_immutable port_immutable;
1538         u32 ird_ord_hdr[2];
1539         int ret;
1540
1541         memset(&conn_param, 0, sizeof(conn_param));
1542         conn_param.initiator_depth = min_t(u8, t->cm_id->device->attrs.max_qp_rd_atom,
1543                                            SMB_DIRECT_CM_INITIATOR_DEPTH);
1544         conn_param.responder_resources = 0;
1545
1546         t->cm_id->device->ops.get_port_immutable(t->cm_id->device,
1547                                                  t->cm_id->port_num,
1548                                                  &port_immutable);
1549         if (port_immutable.core_cap_flags & RDMA_CORE_PORT_IWARP) {
1550                 ird_ord_hdr[0] = conn_param.responder_resources;
1551                 ird_ord_hdr[1] = 1;
1552                 conn_param.private_data = ird_ord_hdr;
1553                 conn_param.private_data_len = sizeof(ird_ord_hdr);
1554         } else {
1555                 conn_param.private_data = NULL;
1556                 conn_param.private_data_len = 0;
1557         }
1558         conn_param.retry_count = SMB_DIRECT_CM_RETRY;
1559         conn_param.rnr_retry_count = SMB_DIRECT_CM_RNR_RETRY;
1560         conn_param.flow_control = 0;
1561
1562         ret = rdma_accept(t->cm_id, &conn_param);
1563         if (ret) {
1564                 pr_err("error at rdma_accept: %d\n", ret);
1565                 return ret;
1566         }
1567
1568         wait_event_interruptible(t->wait_status,
1569                                  t->status != SMB_DIRECT_CS_NEW);
1570         if (t->status != SMB_DIRECT_CS_CONNECTED)
1571                 return -ENOTCONN;
1572         return 0;
1573 }
1574
1575 static int smb_direct_negotiate(struct smb_direct_transport *t)
1576 {
1577         int ret;
1578         struct smb_direct_recvmsg *recvmsg;
1579         struct smb_direct_negotiate_req *req;
1580
1581         recvmsg = get_free_recvmsg(t);
1582         if (!recvmsg)
1583                 return -ENOMEM;
1584         recvmsg->type = SMB_DIRECT_MSG_NEGOTIATE_REQ;
1585
1586         ret = smb_direct_post_recv(t, recvmsg);
1587         if (ret) {
1588                 pr_err("Can't post recv: %d\n", ret);
1589                 goto out;
1590         }
1591
1592         t->negotiation_requested = false;
1593         ret = smb_direct_accept_client(t);
1594         if (ret) {
1595                 pr_err("Can't accept client\n");
1596                 goto out;
1597         }
1598
1599         smb_direct_post_recv_credits(&t->post_recv_credits_work.work);
1600
1601         ksmbd_debug(RDMA, "Waiting for SMB_DIRECT negotiate request\n");
1602         ret = wait_event_interruptible_timeout(t->wait_status,
1603                                                t->negotiation_requested ||
1604                                                 t->status == SMB_DIRECT_CS_DISCONNECTED,
1605                                                SMB_DIRECT_NEGOTIATE_TIMEOUT * HZ);
1606         if (ret <= 0 || t->status == SMB_DIRECT_CS_DISCONNECTED) {
1607                 ret = ret < 0 ? ret : -ETIMEDOUT;
1608                 goto out;
1609         }
1610
1611         ret = smb_direct_check_recvmsg(recvmsg);
1612         if (ret == -ECONNABORTED)
1613                 goto out;
1614
1615         req = (struct smb_direct_negotiate_req *)recvmsg->packet;
1616         t->max_recv_size = min_t(int, t->max_recv_size,
1617                                  le32_to_cpu(req->preferred_send_size));
1618         t->max_send_size = min_t(int, t->max_send_size,
1619                                  le32_to_cpu(req->max_receive_size));
1620         t->max_fragmented_send_size =
1621                         le32_to_cpu(req->max_fragmented_size);
1622
1623         ret = smb_direct_send_negotiate_response(t, ret);
1624 out:
1625         if (recvmsg)
1626                 put_recvmsg(t, recvmsg);
1627         return ret;
1628 }
1629
1630 static int smb_direct_init_params(struct smb_direct_transport *t,
1631                                   struct ib_qp_cap *cap)
1632 {
1633         struct ib_device *device = t->cm_id->device;
1634         int max_send_sges, max_pages, max_rw_wrs, max_send_wrs;
1635
1636         /* need 2 more sge. because a SMB_DIRECT header will be mapped,
1637          * and maybe a send buffer could be not page aligned.
1638          */
1639         t->max_send_size = smb_direct_max_send_size;
1640         max_send_sges = DIV_ROUND_UP(t->max_send_size, PAGE_SIZE) + 2;
1641         if (max_send_sges > SMB_DIRECT_MAX_SEND_SGES) {
1642                 pr_err("max_send_size %d is too large\n", t->max_send_size);
1643                 return -EINVAL;
1644         }
1645
1646         /*
1647          * allow smb_direct_max_outstanding_rw_ops of in-flight RDMA
1648          * read/writes. HCA guarantees at least max_send_sge of sges for
1649          * a RDMA read/write work request, and if memory registration is used,
1650          * we need reg_mr, local_inv wrs for each read/write.
1651          */
1652         t->max_rdma_rw_size = smb_direct_max_read_write_size;
1653         max_pages = DIV_ROUND_UP(t->max_rdma_rw_size, PAGE_SIZE) + 1;
1654         max_rw_wrs = DIV_ROUND_UP(max_pages, SMB_DIRECT_MAX_SEND_SGES);
1655         max_rw_wrs += rdma_rw_mr_factor(device, t->cm_id->port_num,
1656                         max_pages) * 2;
1657         max_rw_wrs *= smb_direct_max_outstanding_rw_ops;
1658
1659         max_send_wrs = smb_direct_send_credit_target + max_rw_wrs;
1660         if (max_send_wrs > device->attrs.max_cqe ||
1661             max_send_wrs > device->attrs.max_qp_wr) {
1662                 pr_err("consider lowering send_credit_target = %d, or max_outstanding_rw_ops = %d\n",
1663                        smb_direct_send_credit_target,
1664                        smb_direct_max_outstanding_rw_ops);
1665                 pr_err("Possible CQE overrun, device reporting max_cqe %d max_qp_wr %d\n",
1666                        device->attrs.max_cqe, device->attrs.max_qp_wr);
1667                 return -EINVAL;
1668         }
1669
1670         if (smb_direct_receive_credit_max > device->attrs.max_cqe ||
1671             smb_direct_receive_credit_max > device->attrs.max_qp_wr) {
1672                 pr_err("consider lowering receive_credit_max = %d\n",
1673                        smb_direct_receive_credit_max);
1674                 pr_err("Possible CQE overrun, device reporting max_cpe %d max_qp_wr %d\n",
1675                        device->attrs.max_cqe, device->attrs.max_qp_wr);
1676                 return -EINVAL;
1677         }
1678
1679         if (device->attrs.max_send_sge < SMB_DIRECT_MAX_SEND_SGES) {
1680                 pr_err("warning: device max_send_sge = %d too small\n",
1681                        device->attrs.max_send_sge);
1682                 return -EINVAL;
1683         }
1684         if (device->attrs.max_recv_sge < SMB_DIRECT_MAX_RECV_SGES) {
1685                 pr_err("warning: device max_recv_sge = %d too small\n",
1686                        device->attrs.max_recv_sge);
1687                 return -EINVAL;
1688         }
1689
1690         t->recv_credits = 0;
1691         t->count_avail_recvmsg = 0;
1692
1693         t->recv_credit_max = smb_direct_receive_credit_max;
1694         t->recv_credit_target = 10;
1695         t->new_recv_credits = 0;
1696
1697         t->send_credit_target = smb_direct_send_credit_target;
1698         atomic_set(&t->send_credits, 0);
1699         atomic_set(&t->rw_avail_ops, smb_direct_max_outstanding_rw_ops);
1700
1701         t->max_send_size = smb_direct_max_send_size;
1702         t->max_recv_size = smb_direct_max_receive_size;
1703         t->max_fragmented_recv_size = smb_direct_max_fragmented_recv_size;
1704
1705         cap->max_send_wr = max_send_wrs;
1706         cap->max_recv_wr = t->recv_credit_max;
1707         cap->max_send_sge = SMB_DIRECT_MAX_SEND_SGES;
1708         cap->max_recv_sge = SMB_DIRECT_MAX_RECV_SGES;
1709         cap->max_inline_data = 0;
1710         cap->max_rdma_ctxs = 0;
1711         return 0;
1712 }
1713
1714 static void smb_direct_destroy_pools(struct smb_direct_transport *t)
1715 {
1716         struct smb_direct_recvmsg *recvmsg;
1717
1718         while ((recvmsg = get_free_recvmsg(t)))
1719                 mempool_free(recvmsg, t->recvmsg_mempool);
1720         while ((recvmsg = get_empty_recvmsg(t)))
1721                 mempool_free(recvmsg, t->recvmsg_mempool);
1722
1723         mempool_destroy(t->recvmsg_mempool);
1724         t->recvmsg_mempool = NULL;
1725
1726         kmem_cache_destroy(t->recvmsg_cache);
1727         t->recvmsg_cache = NULL;
1728
1729         mempool_destroy(t->sendmsg_mempool);
1730         t->sendmsg_mempool = NULL;
1731
1732         kmem_cache_destroy(t->sendmsg_cache);
1733         t->sendmsg_cache = NULL;
1734 }
1735
1736 static int smb_direct_create_pools(struct smb_direct_transport *t)
1737 {
1738         char name[80];
1739         int i;
1740         struct smb_direct_recvmsg *recvmsg;
1741
1742         snprintf(name, sizeof(name), "smb_direct_rqst_pool_%p", t);
1743         t->sendmsg_cache = kmem_cache_create(name,
1744                                              sizeof(struct smb_direct_sendmsg) +
1745                                               sizeof(struct smb_direct_negotiate_resp),
1746                                              0, SLAB_HWCACHE_ALIGN, NULL);
1747         if (!t->sendmsg_cache)
1748                 return -ENOMEM;
1749
1750         t->sendmsg_mempool = mempool_create(t->send_credit_target,
1751                                             mempool_alloc_slab, mempool_free_slab,
1752                                             t->sendmsg_cache);
1753         if (!t->sendmsg_mempool)
1754                 goto err;
1755
1756         snprintf(name, sizeof(name), "smb_direct_resp_%p", t);
1757         t->recvmsg_cache = kmem_cache_create(name,
1758                                              sizeof(struct smb_direct_recvmsg) +
1759                                               t->max_recv_size,
1760                                              0, SLAB_HWCACHE_ALIGN, NULL);
1761         if (!t->recvmsg_cache)
1762                 goto err;
1763
1764         t->recvmsg_mempool =
1765                 mempool_create(t->recv_credit_max, mempool_alloc_slab,
1766                                mempool_free_slab, t->recvmsg_cache);
1767         if (!t->recvmsg_mempool)
1768                 goto err;
1769
1770         INIT_LIST_HEAD(&t->recvmsg_queue);
1771
1772         for (i = 0; i < t->recv_credit_max; i++) {
1773                 recvmsg = mempool_alloc(t->recvmsg_mempool, GFP_KERNEL);
1774                 if (!recvmsg)
1775                         goto err;
1776                 recvmsg->transport = t;
1777                 list_add(&recvmsg->list, &t->recvmsg_queue);
1778         }
1779         t->count_avail_recvmsg = t->recv_credit_max;
1780
1781         return 0;
1782 err:
1783         smb_direct_destroy_pools(t);
1784         return -ENOMEM;
1785 }
1786
1787 static int smb_direct_create_qpair(struct smb_direct_transport *t,
1788                                    struct ib_qp_cap *cap)
1789 {
1790         int ret;
1791         struct ib_qp_init_attr qp_attr;
1792
1793         t->pd = ib_alloc_pd(t->cm_id->device, 0);
1794         if (IS_ERR(t->pd)) {
1795                 pr_err("Can't create RDMA PD\n");
1796                 ret = PTR_ERR(t->pd);
1797                 t->pd = NULL;
1798                 return ret;
1799         }
1800
1801         t->send_cq = ib_alloc_cq(t->cm_id->device, t,
1802                                  t->send_credit_target, 0, IB_POLL_WORKQUEUE);
1803         if (IS_ERR(t->send_cq)) {
1804                 pr_err("Can't create RDMA send CQ\n");
1805                 ret = PTR_ERR(t->send_cq);
1806                 t->send_cq = NULL;
1807                 goto err;
1808         }
1809
1810         t->recv_cq = ib_alloc_cq(t->cm_id->device, t,
1811                                  cap->max_send_wr + cap->max_rdma_ctxs,
1812                                  0, IB_POLL_WORKQUEUE);
1813         if (IS_ERR(t->recv_cq)) {
1814                 pr_err("Can't create RDMA recv CQ\n");
1815                 ret = PTR_ERR(t->recv_cq);
1816                 t->recv_cq = NULL;
1817                 goto err;
1818         }
1819
1820         memset(&qp_attr, 0, sizeof(qp_attr));
1821         qp_attr.event_handler = smb_direct_qpair_handler;
1822         qp_attr.qp_context = t;
1823         qp_attr.cap = *cap;
1824         qp_attr.sq_sig_type = IB_SIGNAL_REQ_WR;
1825         qp_attr.qp_type = IB_QPT_RC;
1826         qp_attr.send_cq = t->send_cq;
1827         qp_attr.recv_cq = t->recv_cq;
1828         qp_attr.port_num = ~0;
1829
1830         ret = rdma_create_qp(t->cm_id, t->pd, &qp_attr);
1831         if (ret) {
1832                 pr_err("Can't create RDMA QP: %d\n", ret);
1833                 goto err;
1834         }
1835
1836         t->qp = t->cm_id->qp;
1837         t->cm_id->event_handler = smb_direct_cm_handler;
1838
1839         return 0;
1840 err:
1841         if (t->qp) {
1842                 ib_destroy_qp(t->qp);
1843                 t->qp = NULL;
1844         }
1845         if (t->recv_cq) {
1846                 ib_destroy_cq(t->recv_cq);
1847                 t->recv_cq = NULL;
1848         }
1849         if (t->send_cq) {
1850                 ib_destroy_cq(t->send_cq);
1851                 t->send_cq = NULL;
1852         }
1853         if (t->pd) {
1854                 ib_dealloc_pd(t->pd);
1855                 t->pd = NULL;
1856         }
1857         return ret;
1858 }
1859
1860 static int smb_direct_prepare(struct ksmbd_transport *t)
1861 {
1862         struct smb_direct_transport *st = smb_trans_direct_transfort(t);
1863         int ret;
1864         struct ib_qp_cap qp_cap;
1865
1866         ret = smb_direct_init_params(st, &qp_cap);
1867         if (ret) {
1868                 pr_err("Can't configure RDMA parameters\n");
1869                 return ret;
1870         }
1871
1872         ret = smb_direct_create_pools(st);
1873         if (ret) {
1874                 pr_err("Can't init RDMA pool: %d\n", ret);
1875                 return ret;
1876         }
1877
1878         ret = smb_direct_create_qpair(st, &qp_cap);
1879         if (ret) {
1880                 pr_err("Can't accept RDMA client: %d\n", ret);
1881                 return ret;
1882         }
1883
1884         ret = smb_direct_negotiate(st);
1885         if (ret) {
1886                 pr_err("Can't negotiate: %d\n", ret);
1887                 return ret;
1888         }
1889
1890         st->status = SMB_DIRECT_CS_CONNECTED;
1891         return 0;
1892 }
1893
1894 static bool rdma_frwr_is_supported(struct ib_device_attr *attrs)
1895 {
1896         if (!(attrs->device_cap_flags & IB_DEVICE_MEM_MGT_EXTENSIONS))
1897                 return false;
1898         if (attrs->max_fast_reg_page_list_len == 0)
1899                 return false;
1900         return true;
1901 }
1902
1903 static int smb_direct_handle_connect_request(struct rdma_cm_id *new_cm_id)
1904 {
1905         struct smb_direct_transport *t;
1906
1907         if (!rdma_frwr_is_supported(&new_cm_id->device->attrs)) {
1908                 ksmbd_debug(RDMA,
1909                             "Fast Registration Work Requests is not supported. device capabilities=%llx\n",
1910                             new_cm_id->device->attrs.device_cap_flags);
1911                 return -EPROTONOSUPPORT;
1912         }
1913
1914         t = alloc_transport(new_cm_id);
1915         if (!t)
1916                 return -ENOMEM;
1917
1918         KSMBD_TRANS(t)->handler = kthread_run(ksmbd_conn_handler_loop,
1919                                               KSMBD_TRANS(t)->conn, "ksmbd:r%u",
1920                                               SMB_DIRECT_PORT);
1921         if (IS_ERR(KSMBD_TRANS(t)->handler)) {
1922                 int ret = PTR_ERR(KSMBD_TRANS(t)->handler);
1923
1924                 pr_err("Can't start thread\n");
1925                 free_transport(t);
1926                 return ret;
1927         }
1928
1929         return 0;
1930 }
1931
1932 static int smb_direct_listen_handler(struct rdma_cm_id *cm_id,
1933                                      struct rdma_cm_event *event)
1934 {
1935         switch (event->event) {
1936         case RDMA_CM_EVENT_CONNECT_REQUEST: {
1937                 int ret = smb_direct_handle_connect_request(cm_id);
1938
1939                 if (ret) {
1940                         pr_err("Can't create transport: %d\n", ret);
1941                         return ret;
1942                 }
1943
1944                 ksmbd_debug(RDMA, "Received connection request. cm_id=%p\n",
1945                             cm_id);
1946                 break;
1947         }
1948         default:
1949                 pr_err("Unexpected listen event. cm_id=%p, event=%s (%d)\n",
1950                        cm_id, rdma_event_msg(event->event), event->event);
1951                 break;
1952         }
1953         return 0;
1954 }
1955
1956 static int smb_direct_listen(int port)
1957 {
1958         int ret;
1959         struct rdma_cm_id *cm_id;
1960         struct sockaddr_in sin = {
1961                 .sin_family             = AF_INET,
1962                 .sin_addr.s_addr        = htonl(INADDR_ANY),
1963                 .sin_port               = htons(port),
1964         };
1965
1966         cm_id = rdma_create_id(&init_net, smb_direct_listen_handler,
1967                                &smb_direct_listener, RDMA_PS_TCP, IB_QPT_RC);
1968         if (IS_ERR(cm_id)) {
1969                 pr_err("Can't create cm id: %ld\n", PTR_ERR(cm_id));
1970                 return PTR_ERR(cm_id);
1971         }
1972
1973         ret = rdma_bind_addr(cm_id, (struct sockaddr *)&sin);
1974         if (ret) {
1975                 pr_err("Can't bind: %d\n", ret);
1976                 goto err;
1977         }
1978
1979         smb_direct_listener.cm_id = cm_id;
1980
1981         ret = rdma_listen(cm_id, 10);
1982         if (ret) {
1983                 pr_err("Can't listen: %d\n", ret);
1984                 goto err;
1985         }
1986         return 0;
1987 err:
1988         smb_direct_listener.cm_id = NULL;
1989         rdma_destroy_id(cm_id);
1990         return ret;
1991 }
1992
1993 int ksmbd_rdma_init(void)
1994 {
1995         int ret;
1996
1997         smb_direct_listener.cm_id = NULL;
1998
1999         /* When a client is running out of send credits, the credits are
2000          * granted by the server's sending a packet using this queue.
2001          * This avoids the situation that a clients cannot send packets
2002          * for lack of credits
2003          */
2004         smb_direct_wq = alloc_workqueue("ksmbd-smb_direct-wq",
2005                                         WQ_HIGHPRI | WQ_MEM_RECLAIM, 0);
2006         if (!smb_direct_wq)
2007                 return -ENOMEM;
2008
2009         ret = smb_direct_listen(SMB_DIRECT_PORT);
2010         if (ret) {
2011                 destroy_workqueue(smb_direct_wq);
2012                 smb_direct_wq = NULL;
2013                 pr_err("Can't listen: %d\n", ret);
2014                 return ret;
2015         }
2016
2017         ksmbd_debug(RDMA, "init RDMA listener. cm_id=%p\n",
2018                     smb_direct_listener.cm_id);
2019         return 0;
2020 }
2021
2022 int ksmbd_rdma_destroy(void)
2023 {
2024         if (smb_direct_listener.cm_id)
2025                 rdma_destroy_id(smb_direct_listener.cm_id);
2026         smb_direct_listener.cm_id = NULL;
2027
2028         if (smb_direct_wq) {
2029                 flush_workqueue(smb_direct_wq);
2030                 destroy_workqueue(smb_direct_wq);
2031                 smb_direct_wq = NULL;
2032         }
2033         return 0;
2034 }
2035
2036 bool ksmbd_rdma_capable_netdev(struct net_device *netdev)
2037 {
2038         struct ib_device *ibdev;
2039         bool rdma_capable = false;
2040
2041         ibdev = ib_device_get_by_netdev(netdev, RDMA_DRIVER_UNKNOWN);
2042         if (ibdev) {
2043                 if (rdma_frwr_is_supported(&ibdev->attrs))
2044                         rdma_capable = true;
2045                 ib_device_put(ibdev);
2046         }
2047         return rdma_capable;
2048 }
2049
2050 static struct ksmbd_transport_ops ksmbd_smb_direct_transport_ops = {
2051         .prepare        = smb_direct_prepare,
2052         .disconnect     = smb_direct_disconnect,
2053         .writev         = smb_direct_writev,
2054         .read           = smb_direct_read,
2055         .rdma_read      = smb_direct_rdma_read,
2056         .rdma_write     = smb_direct_rdma_write,
2057 };