Merge tag 'optee-fix-for-v5.15' of git://git.linaro.org/people/jens.wiklander/linux...
[platform/kernel/linux-starfive.git] / fs / ksmbd / transport_rdma.c
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  *   Copyright (C) 2017, Microsoft Corporation.
4  *   Copyright (C) 2018, LG Electronics.
5  *
6  *   Author(s): Long Li <longli@microsoft.com>,
7  *              Hyunchul Lee <hyc.lee@gmail.com>
8  *
9  *   This program is free software;  you can redistribute it and/or modify
10  *   it under the terms of the GNU General Public License as published by
11  *   the Free Software Foundation; either version 2 of the License, or
12  *   (at your option) any later version.
13  *
14  *   This program is distributed in the hope that it will be useful,
15  *   but WITHOUT ANY WARRANTY;  without even the implied warranty of
16  *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See
17  *   the GNU General Public License for more details.
18  */
19
20 #define SUBMOD_NAME     "smb_direct"
21
22 #include <linux/kthread.h>
23 #include <linux/rwlock.h>
24 #include <linux/list.h>
25 #include <linux/mempool.h>
26 #include <linux/highmem.h>
27 #include <linux/scatterlist.h>
28 #include <rdma/ib_verbs.h>
29 #include <rdma/rdma_cm.h>
30 #include <rdma/rw.h>
31
32 #include "glob.h"
33 #include "connection.h"
34 #include "smb_common.h"
35 #include "smbstatus.h"
36 #include "transport_rdma.h"
37
38 #define SMB_DIRECT_PORT 5445
39
40 #define SMB_DIRECT_VERSION_LE           cpu_to_le16(0x0100)
41
42 /* SMB_DIRECT negotiation timeout in seconds */
43 #define SMB_DIRECT_NEGOTIATE_TIMEOUT            120
44
45 #define SMB_DIRECT_MAX_SEND_SGES                8
46 #define SMB_DIRECT_MAX_RECV_SGES                1
47
48 /*
49  * Default maximum number of RDMA read/write outstanding on this connection
50  * This value is possibly decreased during QP creation on hardware limit
51  */
52 #define SMB_DIRECT_CM_INITIATOR_DEPTH           8
53
54 /* Maximum number of retries on data transfer operations */
55 #define SMB_DIRECT_CM_RETRY                     6
56 /* No need to retry on Receiver Not Ready since SMB_DIRECT manages credits */
57 #define SMB_DIRECT_CM_RNR_RETRY         0
58
59 /*
60  * User configurable initial values per SMB_DIRECT transport connection
61  * as defined in [MS-SMBD] 3.1.1.1
62  * Those may change after a SMB_DIRECT negotiation
63  */
64 /* The local peer's maximum number of credits to grant to the peer */
65 static int smb_direct_receive_credit_max = 255;
66
67 /* The remote peer's credit request of local peer */
68 static int smb_direct_send_credit_target = 255;
69
70 /* The maximum single message size can be sent to remote peer */
71 static int smb_direct_max_send_size = 8192;
72
73 /*  The maximum fragmented upper-layer payload receive size supported */
74 static int smb_direct_max_fragmented_recv_size = 1024 * 1024;
75
76 /*  The maximum single-message size which can be received */
77 static int smb_direct_max_receive_size = 8192;
78
79 static int smb_direct_max_read_write_size = 1024 * 1024;
80
81 static int smb_direct_max_outstanding_rw_ops = 8;
82
83 static struct smb_direct_listener {
84         struct rdma_cm_id       *cm_id;
85 } smb_direct_listener;
86
87 static struct workqueue_struct *smb_direct_wq;
88
89 enum smb_direct_status {
90         SMB_DIRECT_CS_NEW = 0,
91         SMB_DIRECT_CS_CONNECTED,
92         SMB_DIRECT_CS_DISCONNECTING,
93         SMB_DIRECT_CS_DISCONNECTED,
94 };
95
96 struct smb_direct_transport {
97         struct ksmbd_transport  transport;
98
99         enum smb_direct_status  status;
100         bool                    full_packet_received;
101         wait_queue_head_t       wait_status;
102
103         struct rdma_cm_id       *cm_id;
104         struct ib_cq            *send_cq;
105         struct ib_cq            *recv_cq;
106         struct ib_pd            *pd;
107         struct ib_qp            *qp;
108
109         int                     max_send_size;
110         int                     max_recv_size;
111         int                     max_fragmented_send_size;
112         int                     max_fragmented_recv_size;
113         int                     max_rdma_rw_size;
114
115         spinlock_t              reassembly_queue_lock;
116         struct list_head        reassembly_queue;
117         int                     reassembly_data_length;
118         int                     reassembly_queue_length;
119         int                     first_entry_offset;
120         wait_queue_head_t       wait_reassembly_queue;
121
122         spinlock_t              receive_credit_lock;
123         int                     recv_credits;
124         int                     count_avail_recvmsg;
125         int                     recv_credit_max;
126         int                     recv_credit_target;
127
128         spinlock_t              recvmsg_queue_lock;
129         struct list_head        recvmsg_queue;
130
131         spinlock_t              empty_recvmsg_queue_lock;
132         struct list_head        empty_recvmsg_queue;
133
134         int                     send_credit_target;
135         atomic_t                send_credits;
136         spinlock_t              lock_new_recv_credits;
137         int                     new_recv_credits;
138         atomic_t                rw_avail_ops;
139
140         wait_queue_head_t       wait_send_credits;
141         wait_queue_head_t       wait_rw_avail_ops;
142
143         mempool_t               *sendmsg_mempool;
144         struct kmem_cache       *sendmsg_cache;
145         mempool_t               *recvmsg_mempool;
146         struct kmem_cache       *recvmsg_cache;
147
148         wait_queue_head_t       wait_send_payload_pending;
149         atomic_t                send_payload_pending;
150         wait_queue_head_t       wait_send_pending;
151         atomic_t                send_pending;
152
153         struct delayed_work     post_recv_credits_work;
154         struct work_struct      send_immediate_work;
155         struct work_struct      disconnect_work;
156
157         bool                    negotiation_requested;
158 };
159
160 #define KSMBD_TRANS(t) ((struct ksmbd_transport *)&((t)->transport))
161
162 enum {
163         SMB_DIRECT_MSG_NEGOTIATE_REQ = 0,
164         SMB_DIRECT_MSG_DATA_TRANSFER
165 };
166
167 static struct ksmbd_transport_ops ksmbd_smb_direct_transport_ops;
168
169 struct smb_direct_send_ctx {
170         struct list_head        msg_list;
171         int                     wr_cnt;
172         bool                    need_invalidate_rkey;
173         unsigned int            remote_key;
174 };
175
176 struct smb_direct_sendmsg {
177         struct smb_direct_transport     *transport;
178         struct ib_send_wr       wr;
179         struct list_head        list;
180         int                     num_sge;
181         struct ib_sge           sge[SMB_DIRECT_MAX_SEND_SGES];
182         struct ib_cqe           cqe;
183         u8                      packet[];
184 };
185
186 struct smb_direct_recvmsg {
187         struct smb_direct_transport     *transport;
188         struct list_head        list;
189         int                     type;
190         struct ib_sge           sge;
191         struct ib_cqe           cqe;
192         bool                    first_segment;
193         u8                      packet[];
194 };
195
196 struct smb_direct_rdma_rw_msg {
197         struct smb_direct_transport     *t;
198         struct ib_cqe           cqe;
199         struct completion       *completion;
200         struct rdma_rw_ctx      rw_ctx;
201         struct sg_table         sgt;
202         struct scatterlist      sg_list[0];
203 };
204
205 static inline int get_buf_page_count(void *buf, int size)
206 {
207         return DIV_ROUND_UP((uintptr_t)buf + size, PAGE_SIZE) -
208                 (uintptr_t)buf / PAGE_SIZE;
209 }
210
211 static void smb_direct_destroy_pools(struct smb_direct_transport *transport);
212 static void smb_direct_post_recv_credits(struct work_struct *work);
213 static int smb_direct_post_send_data(struct smb_direct_transport *t,
214                                      struct smb_direct_send_ctx *send_ctx,
215                                      struct kvec *iov, int niov,
216                                      int remaining_data_length);
217
218 static inline struct smb_direct_transport *
219 smb_trans_direct_transfort(struct ksmbd_transport *t)
220 {
221         return container_of(t, struct smb_direct_transport, transport);
222 }
223
224 static inline void
225 *smb_direct_recvmsg_payload(struct smb_direct_recvmsg *recvmsg)
226 {
227         return (void *)recvmsg->packet;
228 }
229
230 static inline bool is_receive_credit_post_required(int receive_credits,
231                                                    int avail_recvmsg_count)
232 {
233         return receive_credits <= (smb_direct_receive_credit_max >> 3) &&
234                 avail_recvmsg_count >= (receive_credits >> 2);
235 }
236
237 static struct
238 smb_direct_recvmsg *get_free_recvmsg(struct smb_direct_transport *t)
239 {
240         struct smb_direct_recvmsg *recvmsg = NULL;
241
242         spin_lock(&t->recvmsg_queue_lock);
243         if (!list_empty(&t->recvmsg_queue)) {
244                 recvmsg = list_first_entry(&t->recvmsg_queue,
245                                            struct smb_direct_recvmsg,
246                                            list);
247                 list_del(&recvmsg->list);
248         }
249         spin_unlock(&t->recvmsg_queue_lock);
250         return recvmsg;
251 }
252
253 static void put_recvmsg(struct smb_direct_transport *t,
254                         struct smb_direct_recvmsg *recvmsg)
255 {
256         ib_dma_unmap_single(t->cm_id->device, recvmsg->sge.addr,
257                             recvmsg->sge.length, DMA_FROM_DEVICE);
258
259         spin_lock(&t->recvmsg_queue_lock);
260         list_add(&recvmsg->list, &t->recvmsg_queue);
261         spin_unlock(&t->recvmsg_queue_lock);
262 }
263
264 static struct
265 smb_direct_recvmsg *get_empty_recvmsg(struct smb_direct_transport *t)
266 {
267         struct smb_direct_recvmsg *recvmsg = NULL;
268
269         spin_lock(&t->empty_recvmsg_queue_lock);
270         if (!list_empty(&t->empty_recvmsg_queue)) {
271                 recvmsg = list_first_entry(&t->empty_recvmsg_queue,
272                                            struct smb_direct_recvmsg, list);
273                 list_del(&recvmsg->list);
274         }
275         spin_unlock(&t->empty_recvmsg_queue_lock);
276         return recvmsg;
277 }
278
279 static void put_empty_recvmsg(struct smb_direct_transport *t,
280                               struct smb_direct_recvmsg *recvmsg)
281 {
282         ib_dma_unmap_single(t->cm_id->device, recvmsg->sge.addr,
283                             recvmsg->sge.length, DMA_FROM_DEVICE);
284
285         spin_lock(&t->empty_recvmsg_queue_lock);
286         list_add_tail(&recvmsg->list, &t->empty_recvmsg_queue);
287         spin_unlock(&t->empty_recvmsg_queue_lock);
288 }
289
290 static void enqueue_reassembly(struct smb_direct_transport *t,
291                                struct smb_direct_recvmsg *recvmsg,
292                                int data_length)
293 {
294         spin_lock(&t->reassembly_queue_lock);
295         list_add_tail(&recvmsg->list, &t->reassembly_queue);
296         t->reassembly_queue_length++;
297         /*
298          * Make sure reassembly_data_length is updated after list and
299          * reassembly_queue_length are updated. On the dequeue side
300          * reassembly_data_length is checked without a lock to determine
301          * if reassembly_queue_length and list is up to date
302          */
303         virt_wmb();
304         t->reassembly_data_length += data_length;
305         spin_unlock(&t->reassembly_queue_lock);
306 }
307
308 static struct smb_direct_recvmsg *get_first_reassembly(struct smb_direct_transport *t)
309 {
310         if (!list_empty(&t->reassembly_queue))
311                 return list_first_entry(&t->reassembly_queue,
312                                 struct smb_direct_recvmsg, list);
313         else
314                 return NULL;
315 }
316
317 static void smb_direct_disconnect_rdma_work(struct work_struct *work)
318 {
319         struct smb_direct_transport *t =
320                 container_of(work, struct smb_direct_transport,
321                              disconnect_work);
322
323         if (t->status == SMB_DIRECT_CS_CONNECTED) {
324                 t->status = SMB_DIRECT_CS_DISCONNECTING;
325                 rdma_disconnect(t->cm_id);
326         }
327 }
328
329 static void
330 smb_direct_disconnect_rdma_connection(struct smb_direct_transport *t)
331 {
332         if (t->status == SMB_DIRECT_CS_CONNECTED)
333                 queue_work(smb_direct_wq, &t->disconnect_work);
334 }
335
336 static void smb_direct_send_immediate_work(struct work_struct *work)
337 {
338         struct smb_direct_transport *t = container_of(work,
339                         struct smb_direct_transport, send_immediate_work);
340
341         if (t->status != SMB_DIRECT_CS_CONNECTED)
342                 return;
343
344         smb_direct_post_send_data(t, NULL, NULL, 0, 0);
345 }
346
347 static struct smb_direct_transport *alloc_transport(struct rdma_cm_id *cm_id)
348 {
349         struct smb_direct_transport *t;
350         struct ksmbd_conn *conn;
351
352         t = kzalloc(sizeof(*t), GFP_KERNEL);
353         if (!t)
354                 return NULL;
355
356         t->cm_id = cm_id;
357         cm_id->context = t;
358
359         t->status = SMB_DIRECT_CS_NEW;
360         init_waitqueue_head(&t->wait_status);
361
362         spin_lock_init(&t->reassembly_queue_lock);
363         INIT_LIST_HEAD(&t->reassembly_queue);
364         t->reassembly_data_length = 0;
365         t->reassembly_queue_length = 0;
366         init_waitqueue_head(&t->wait_reassembly_queue);
367         init_waitqueue_head(&t->wait_send_credits);
368         init_waitqueue_head(&t->wait_rw_avail_ops);
369
370         spin_lock_init(&t->receive_credit_lock);
371         spin_lock_init(&t->recvmsg_queue_lock);
372         INIT_LIST_HEAD(&t->recvmsg_queue);
373
374         spin_lock_init(&t->empty_recvmsg_queue_lock);
375         INIT_LIST_HEAD(&t->empty_recvmsg_queue);
376
377         init_waitqueue_head(&t->wait_send_payload_pending);
378         atomic_set(&t->send_payload_pending, 0);
379         init_waitqueue_head(&t->wait_send_pending);
380         atomic_set(&t->send_pending, 0);
381
382         spin_lock_init(&t->lock_new_recv_credits);
383
384         INIT_DELAYED_WORK(&t->post_recv_credits_work,
385                           smb_direct_post_recv_credits);
386         INIT_WORK(&t->send_immediate_work, smb_direct_send_immediate_work);
387         INIT_WORK(&t->disconnect_work, smb_direct_disconnect_rdma_work);
388
389         conn = ksmbd_conn_alloc();
390         if (!conn)
391                 goto err;
392         conn->transport = KSMBD_TRANS(t);
393         KSMBD_TRANS(t)->conn = conn;
394         KSMBD_TRANS(t)->ops = &ksmbd_smb_direct_transport_ops;
395         return t;
396 err:
397         kfree(t);
398         return NULL;
399 }
400
401 static void free_transport(struct smb_direct_transport *t)
402 {
403         struct smb_direct_recvmsg *recvmsg;
404
405         wake_up_interruptible(&t->wait_send_credits);
406
407         ksmbd_debug(RDMA, "wait for all send posted to IB to finish\n");
408         wait_event(t->wait_send_payload_pending,
409                    atomic_read(&t->send_payload_pending) == 0);
410         wait_event(t->wait_send_pending,
411                    atomic_read(&t->send_pending) == 0);
412
413         cancel_work_sync(&t->disconnect_work);
414         cancel_delayed_work_sync(&t->post_recv_credits_work);
415         cancel_work_sync(&t->send_immediate_work);
416
417         if (t->qp) {
418                 ib_drain_qp(t->qp);
419                 ib_destroy_qp(t->qp);
420         }
421
422         ksmbd_debug(RDMA, "drain the reassembly queue\n");
423         do {
424                 spin_lock(&t->reassembly_queue_lock);
425                 recvmsg = get_first_reassembly(t);
426                 if (recvmsg) {
427                         list_del(&recvmsg->list);
428                         spin_unlock(&t->reassembly_queue_lock);
429                         put_recvmsg(t, recvmsg);
430                 } else {
431                         spin_unlock(&t->reassembly_queue_lock);
432                 }
433         } while (recvmsg);
434         t->reassembly_data_length = 0;
435
436         if (t->send_cq)
437                 ib_free_cq(t->send_cq);
438         if (t->recv_cq)
439                 ib_free_cq(t->recv_cq);
440         if (t->pd)
441                 ib_dealloc_pd(t->pd);
442         if (t->cm_id)
443                 rdma_destroy_id(t->cm_id);
444
445         smb_direct_destroy_pools(t);
446         ksmbd_conn_free(KSMBD_TRANS(t)->conn);
447         kfree(t);
448 }
449
450 static struct smb_direct_sendmsg
451 *smb_direct_alloc_sendmsg(struct smb_direct_transport *t)
452 {
453         struct smb_direct_sendmsg *msg;
454
455         msg = mempool_alloc(t->sendmsg_mempool, GFP_KERNEL);
456         if (!msg)
457                 return ERR_PTR(-ENOMEM);
458         msg->transport = t;
459         INIT_LIST_HEAD(&msg->list);
460         msg->num_sge = 0;
461         return msg;
462 }
463
464 static void smb_direct_free_sendmsg(struct smb_direct_transport *t,
465                                     struct smb_direct_sendmsg *msg)
466 {
467         int i;
468
469         if (msg->num_sge > 0) {
470                 ib_dma_unmap_single(t->cm_id->device,
471                                     msg->sge[0].addr, msg->sge[0].length,
472                                     DMA_TO_DEVICE);
473                 for (i = 1; i < msg->num_sge; i++)
474                         ib_dma_unmap_page(t->cm_id->device,
475                                           msg->sge[i].addr, msg->sge[i].length,
476                                           DMA_TO_DEVICE);
477         }
478         mempool_free(msg, t->sendmsg_mempool);
479 }
480
481 static int smb_direct_check_recvmsg(struct smb_direct_recvmsg *recvmsg)
482 {
483         switch (recvmsg->type) {
484         case SMB_DIRECT_MSG_DATA_TRANSFER: {
485                 struct smb_direct_data_transfer *req =
486                         (struct smb_direct_data_transfer *)recvmsg->packet;
487                 struct smb2_hdr *hdr = (struct smb2_hdr *)(recvmsg->packet
488                                 + le32_to_cpu(req->data_offset) - 4);
489                 ksmbd_debug(RDMA,
490                             "CreditGranted: %u, CreditRequested: %u, DataLength: %u, RemainingDataLength: %u, SMB: %x, Command: %u\n",
491                             le16_to_cpu(req->credits_granted),
492                             le16_to_cpu(req->credits_requested),
493                             req->data_length, req->remaining_data_length,
494                             hdr->ProtocolId, hdr->Command);
495                 break;
496         }
497         case SMB_DIRECT_MSG_NEGOTIATE_REQ: {
498                 struct smb_direct_negotiate_req *req =
499                         (struct smb_direct_negotiate_req *)recvmsg->packet;
500                 ksmbd_debug(RDMA,
501                             "MinVersion: %u, MaxVersion: %u, CreditRequested: %u, MaxSendSize: %u, MaxRecvSize: %u, MaxFragmentedSize: %u\n",
502                             le16_to_cpu(req->min_version),
503                             le16_to_cpu(req->max_version),
504                             le16_to_cpu(req->credits_requested),
505                             le32_to_cpu(req->preferred_send_size),
506                             le32_to_cpu(req->max_receive_size),
507                             le32_to_cpu(req->max_fragmented_size));
508                 if (le16_to_cpu(req->min_version) > 0x0100 ||
509                     le16_to_cpu(req->max_version) < 0x0100)
510                         return -EOPNOTSUPP;
511                 if (le16_to_cpu(req->credits_requested) <= 0 ||
512                     le32_to_cpu(req->max_receive_size) <= 128 ||
513                     le32_to_cpu(req->max_fragmented_size) <=
514                                         128 * 1024)
515                         return -ECONNABORTED;
516
517                 break;
518         }
519         default:
520                 return -EINVAL;
521         }
522         return 0;
523 }
524
525 static void recv_done(struct ib_cq *cq, struct ib_wc *wc)
526 {
527         struct smb_direct_recvmsg *recvmsg;
528         struct smb_direct_transport *t;
529
530         recvmsg = container_of(wc->wr_cqe, struct smb_direct_recvmsg, cqe);
531         t = recvmsg->transport;
532
533         if (wc->status != IB_WC_SUCCESS || wc->opcode != IB_WC_RECV) {
534                 if (wc->status != IB_WC_WR_FLUSH_ERR) {
535                         pr_err("Recv error. status='%s (%d)' opcode=%d\n",
536                                ib_wc_status_msg(wc->status), wc->status,
537                                wc->opcode);
538                         smb_direct_disconnect_rdma_connection(t);
539                 }
540                 put_empty_recvmsg(t, recvmsg);
541                 return;
542         }
543
544         ksmbd_debug(RDMA, "Recv completed. status='%s (%d)', opcode=%d\n",
545                     ib_wc_status_msg(wc->status), wc->status,
546                     wc->opcode);
547
548         ib_dma_sync_single_for_cpu(wc->qp->device, recvmsg->sge.addr,
549                                    recvmsg->sge.length, DMA_FROM_DEVICE);
550
551         switch (recvmsg->type) {
552         case SMB_DIRECT_MSG_NEGOTIATE_REQ:
553                 t->negotiation_requested = true;
554                 t->full_packet_received = true;
555                 wake_up_interruptible(&t->wait_status);
556                 break;
557         case SMB_DIRECT_MSG_DATA_TRANSFER: {
558                 struct smb_direct_data_transfer *data_transfer =
559                         (struct smb_direct_data_transfer *)recvmsg->packet;
560                 int data_length = le32_to_cpu(data_transfer->data_length);
561                 int avail_recvmsg_count, receive_credits;
562
563                 if (data_length) {
564                         if (t->full_packet_received)
565                                 recvmsg->first_segment = true;
566
567                         if (le32_to_cpu(data_transfer->remaining_data_length))
568                                 t->full_packet_received = false;
569                         else
570                                 t->full_packet_received = true;
571
572                         enqueue_reassembly(t, recvmsg, data_length);
573                         wake_up_interruptible(&t->wait_reassembly_queue);
574
575                         spin_lock(&t->receive_credit_lock);
576                         receive_credits = --(t->recv_credits);
577                         avail_recvmsg_count = t->count_avail_recvmsg;
578                         spin_unlock(&t->receive_credit_lock);
579                 } else {
580                         put_empty_recvmsg(t, recvmsg);
581
582                         spin_lock(&t->receive_credit_lock);
583                         receive_credits = --(t->recv_credits);
584                         avail_recvmsg_count = ++(t->count_avail_recvmsg);
585                         spin_unlock(&t->receive_credit_lock);
586                 }
587
588                 t->recv_credit_target =
589                                 le16_to_cpu(data_transfer->credits_requested);
590                 atomic_add(le16_to_cpu(data_transfer->credits_granted),
591                            &t->send_credits);
592
593                 if (le16_to_cpu(data_transfer->flags) &
594                     SMB_DIRECT_RESPONSE_REQUESTED)
595                         queue_work(smb_direct_wq, &t->send_immediate_work);
596
597                 if (atomic_read(&t->send_credits) > 0)
598                         wake_up_interruptible(&t->wait_send_credits);
599
600                 if (is_receive_credit_post_required(receive_credits, avail_recvmsg_count))
601                         mod_delayed_work(smb_direct_wq,
602                                          &t->post_recv_credits_work, 0);
603                 break;
604         }
605         default:
606                 break;
607         }
608 }
609
610 static int smb_direct_post_recv(struct smb_direct_transport *t,
611                                 struct smb_direct_recvmsg *recvmsg)
612 {
613         struct ib_recv_wr wr;
614         int ret;
615
616         recvmsg->sge.addr = ib_dma_map_single(t->cm_id->device,
617                                               recvmsg->packet, t->max_recv_size,
618                                               DMA_FROM_DEVICE);
619         ret = ib_dma_mapping_error(t->cm_id->device, recvmsg->sge.addr);
620         if (ret)
621                 return ret;
622         recvmsg->sge.length = t->max_recv_size;
623         recvmsg->sge.lkey = t->pd->local_dma_lkey;
624         recvmsg->cqe.done = recv_done;
625
626         wr.wr_cqe = &recvmsg->cqe;
627         wr.next = NULL;
628         wr.sg_list = &recvmsg->sge;
629         wr.num_sge = 1;
630
631         ret = ib_post_recv(t->qp, &wr, NULL);
632         if (ret) {
633                 pr_err("Can't post recv: %d\n", ret);
634                 ib_dma_unmap_single(t->cm_id->device,
635                                     recvmsg->sge.addr, recvmsg->sge.length,
636                                     DMA_FROM_DEVICE);
637                 smb_direct_disconnect_rdma_connection(t);
638                 return ret;
639         }
640         return ret;
641 }
642
643 static int smb_direct_read(struct ksmbd_transport *t, char *buf,
644                            unsigned int size)
645 {
646         struct smb_direct_recvmsg *recvmsg;
647         struct smb_direct_data_transfer *data_transfer;
648         int to_copy, to_read, data_read, offset;
649         u32 data_length, remaining_data_length, data_offset;
650         int rc;
651         struct smb_direct_transport *st = smb_trans_direct_transfort(t);
652
653 again:
654         if (st->status != SMB_DIRECT_CS_CONNECTED) {
655                 pr_err("disconnected\n");
656                 return -ENOTCONN;
657         }
658
659         /*
660          * No need to hold the reassembly queue lock all the time as we are
661          * the only one reading from the front of the queue. The transport
662          * may add more entries to the back of the queue at the same time
663          */
664         if (st->reassembly_data_length >= size) {
665                 int queue_length;
666                 int queue_removed = 0;
667
668                 /*
669                  * Need to make sure reassembly_data_length is read before
670                  * reading reassembly_queue_length and calling
671                  * get_first_reassembly. This call is lock free
672                  * as we never read at the end of the queue which are being
673                  * updated in SOFTIRQ as more data is received
674                  */
675                 virt_rmb();
676                 queue_length = st->reassembly_queue_length;
677                 data_read = 0;
678                 to_read = size;
679                 offset = st->first_entry_offset;
680                 while (data_read < size) {
681                         recvmsg = get_first_reassembly(st);
682                         data_transfer = smb_direct_recvmsg_payload(recvmsg);
683                         data_length = le32_to_cpu(data_transfer->data_length);
684                         remaining_data_length =
685                                 le32_to_cpu(data_transfer->remaining_data_length);
686                         data_offset = le32_to_cpu(data_transfer->data_offset);
687
688                         /*
689                          * The upper layer expects RFC1002 length at the
690                          * beginning of the payload. Return it to indicate
691                          * the total length of the packet. This minimize the
692                          * change to upper layer packet processing logic. This
693                          * will be eventually remove when an intermediate
694                          * transport layer is added
695                          */
696                         if (recvmsg->first_segment && size == 4) {
697                                 unsigned int rfc1002_len =
698                                         data_length + remaining_data_length;
699                                 *((__be32 *)buf) = cpu_to_be32(rfc1002_len);
700                                 data_read = 4;
701                                 recvmsg->first_segment = false;
702                                 ksmbd_debug(RDMA,
703                                             "returning rfc1002 length %d\n",
704                                             rfc1002_len);
705                                 goto read_rfc1002_done;
706                         }
707
708                         to_copy = min_t(int, data_length - offset, to_read);
709                         memcpy(buf + data_read, (char *)data_transfer + data_offset + offset,
710                                to_copy);
711
712                         /* move on to the next buffer? */
713                         if (to_copy == data_length - offset) {
714                                 queue_length--;
715                                 /*
716                                  * No need to lock if we are not at the
717                                  * end of the queue
718                                  */
719                                 if (queue_length) {
720                                         list_del(&recvmsg->list);
721                                 } else {
722                                         spin_lock_irq(&st->reassembly_queue_lock);
723                                         list_del(&recvmsg->list);
724                                         spin_unlock_irq(&st->reassembly_queue_lock);
725                                 }
726                                 queue_removed++;
727                                 put_recvmsg(st, recvmsg);
728                                 offset = 0;
729                         } else {
730                                 offset += to_copy;
731                         }
732
733                         to_read -= to_copy;
734                         data_read += to_copy;
735                 }
736
737                 spin_lock_irq(&st->reassembly_queue_lock);
738                 st->reassembly_data_length -= data_read;
739                 st->reassembly_queue_length -= queue_removed;
740                 spin_unlock_irq(&st->reassembly_queue_lock);
741
742                 spin_lock(&st->receive_credit_lock);
743                 st->count_avail_recvmsg += queue_removed;
744                 if (is_receive_credit_post_required(st->recv_credits, st->count_avail_recvmsg)) {
745                         spin_unlock(&st->receive_credit_lock);
746                         mod_delayed_work(smb_direct_wq,
747                                          &st->post_recv_credits_work, 0);
748                 } else {
749                         spin_unlock(&st->receive_credit_lock);
750                 }
751
752                 st->first_entry_offset = offset;
753                 ksmbd_debug(RDMA,
754                             "returning to thread data_read=%d reassembly_data_length=%d first_entry_offset=%d\n",
755                             data_read, st->reassembly_data_length,
756                             st->first_entry_offset);
757 read_rfc1002_done:
758                 return data_read;
759         }
760
761         ksmbd_debug(RDMA, "wait_event on more data\n");
762         rc = wait_event_interruptible(st->wait_reassembly_queue,
763                                       st->reassembly_data_length >= size ||
764                                        st->status != SMB_DIRECT_CS_CONNECTED);
765         if (rc)
766                 return -EINTR;
767
768         goto again;
769 }
770
771 static void smb_direct_post_recv_credits(struct work_struct *work)
772 {
773         struct smb_direct_transport *t = container_of(work,
774                 struct smb_direct_transport, post_recv_credits_work.work);
775         struct smb_direct_recvmsg *recvmsg;
776         int receive_credits, credits = 0;
777         int ret;
778         int use_free = 1;
779
780         spin_lock(&t->receive_credit_lock);
781         receive_credits = t->recv_credits;
782         spin_unlock(&t->receive_credit_lock);
783
784         if (receive_credits < t->recv_credit_target) {
785                 while (true) {
786                         if (use_free)
787                                 recvmsg = get_free_recvmsg(t);
788                         else
789                                 recvmsg = get_empty_recvmsg(t);
790                         if (!recvmsg) {
791                                 if (use_free) {
792                                         use_free = 0;
793                                         continue;
794                                 } else {
795                                         break;
796                                 }
797                         }
798
799                         recvmsg->type = SMB_DIRECT_MSG_DATA_TRANSFER;
800                         recvmsg->first_segment = false;
801
802                         ret = smb_direct_post_recv(t, recvmsg);
803                         if (ret) {
804                                 pr_err("Can't post recv: %d\n", ret);
805                                 put_recvmsg(t, recvmsg);
806                                 break;
807                         }
808                         credits++;
809                 }
810         }
811
812         spin_lock(&t->receive_credit_lock);
813         t->recv_credits += credits;
814         t->count_avail_recvmsg -= credits;
815         spin_unlock(&t->receive_credit_lock);
816
817         spin_lock(&t->lock_new_recv_credits);
818         t->new_recv_credits += credits;
819         spin_unlock(&t->lock_new_recv_credits);
820
821         if (credits)
822                 queue_work(smb_direct_wq, &t->send_immediate_work);
823 }
824
825 static void send_done(struct ib_cq *cq, struct ib_wc *wc)
826 {
827         struct smb_direct_sendmsg *sendmsg, *sibling;
828         struct smb_direct_transport *t;
829         struct list_head *pos, *prev, *end;
830
831         sendmsg = container_of(wc->wr_cqe, struct smb_direct_sendmsg, cqe);
832         t = sendmsg->transport;
833
834         ksmbd_debug(RDMA, "Send completed. status='%s (%d)', opcode=%d\n",
835                     ib_wc_status_msg(wc->status), wc->status,
836                     wc->opcode);
837
838         if (wc->status != IB_WC_SUCCESS || wc->opcode != IB_WC_SEND) {
839                 pr_err("Send error. status='%s (%d)', opcode=%d\n",
840                        ib_wc_status_msg(wc->status), wc->status,
841                        wc->opcode);
842                 smb_direct_disconnect_rdma_connection(t);
843         }
844
845         if (sendmsg->num_sge > 1) {
846                 if (atomic_dec_and_test(&t->send_payload_pending))
847                         wake_up(&t->wait_send_payload_pending);
848         } else {
849                 if (atomic_dec_and_test(&t->send_pending))
850                         wake_up(&t->wait_send_pending);
851         }
852
853         /* iterate and free the list of messages in reverse. the list's head
854          * is invalid.
855          */
856         for (pos = &sendmsg->list, prev = pos->prev, end = sendmsg->list.next;
857              prev != end; pos = prev, prev = prev->prev) {
858                 sibling = container_of(pos, struct smb_direct_sendmsg, list);
859                 smb_direct_free_sendmsg(t, sibling);
860         }
861
862         sibling = container_of(pos, struct smb_direct_sendmsg, list);
863         smb_direct_free_sendmsg(t, sibling);
864 }
865
866 static int manage_credits_prior_sending(struct smb_direct_transport *t)
867 {
868         int new_credits;
869
870         spin_lock(&t->lock_new_recv_credits);
871         new_credits = t->new_recv_credits;
872         t->new_recv_credits = 0;
873         spin_unlock(&t->lock_new_recv_credits);
874
875         return new_credits;
876 }
877
878 static int smb_direct_post_send(struct smb_direct_transport *t,
879                                 struct ib_send_wr *wr)
880 {
881         int ret;
882
883         if (wr->num_sge > 1)
884                 atomic_inc(&t->send_payload_pending);
885         else
886                 atomic_inc(&t->send_pending);
887
888         ret = ib_post_send(t->qp, wr, NULL);
889         if (ret) {
890                 pr_err("failed to post send: %d\n", ret);
891                 if (wr->num_sge > 1) {
892                         if (atomic_dec_and_test(&t->send_payload_pending))
893                                 wake_up(&t->wait_send_payload_pending);
894                 } else {
895                         if (atomic_dec_and_test(&t->send_pending))
896                                 wake_up(&t->wait_send_pending);
897                 }
898                 smb_direct_disconnect_rdma_connection(t);
899         }
900         return ret;
901 }
902
903 static void smb_direct_send_ctx_init(struct smb_direct_transport *t,
904                                      struct smb_direct_send_ctx *send_ctx,
905                                      bool need_invalidate_rkey,
906                                      unsigned int remote_key)
907 {
908         INIT_LIST_HEAD(&send_ctx->msg_list);
909         send_ctx->wr_cnt = 0;
910         send_ctx->need_invalidate_rkey = need_invalidate_rkey;
911         send_ctx->remote_key = remote_key;
912 }
913
914 static int smb_direct_flush_send_list(struct smb_direct_transport *t,
915                                       struct smb_direct_send_ctx *send_ctx,
916                                       bool is_last)
917 {
918         struct smb_direct_sendmsg *first, *last;
919         int ret;
920
921         if (list_empty(&send_ctx->msg_list))
922                 return 0;
923
924         first = list_first_entry(&send_ctx->msg_list,
925                                  struct smb_direct_sendmsg,
926                                  list);
927         last = list_last_entry(&send_ctx->msg_list,
928                                struct smb_direct_sendmsg,
929                                list);
930
931         last->wr.send_flags = IB_SEND_SIGNALED;
932         last->wr.wr_cqe = &last->cqe;
933         if (is_last && send_ctx->need_invalidate_rkey) {
934                 last->wr.opcode = IB_WR_SEND_WITH_INV;
935                 last->wr.ex.invalidate_rkey = send_ctx->remote_key;
936         }
937
938         ret = smb_direct_post_send(t, &first->wr);
939         if (!ret) {
940                 smb_direct_send_ctx_init(t, send_ctx,
941                                          send_ctx->need_invalidate_rkey,
942                                          send_ctx->remote_key);
943         } else {
944                 atomic_add(send_ctx->wr_cnt, &t->send_credits);
945                 wake_up(&t->wait_send_credits);
946                 list_for_each_entry_safe(first, last, &send_ctx->msg_list,
947                                          list) {
948                         smb_direct_free_sendmsg(t, first);
949                 }
950         }
951         return ret;
952 }
953
954 static int wait_for_credits(struct smb_direct_transport *t,
955                             wait_queue_head_t *waitq, atomic_t *credits)
956 {
957         int ret;
958
959         do {
960                 if (atomic_dec_return(credits) >= 0)
961                         return 0;
962
963                 atomic_inc(credits);
964                 ret = wait_event_interruptible(*waitq,
965                                                atomic_read(credits) > 0 ||
966                                                 t->status != SMB_DIRECT_CS_CONNECTED);
967
968                 if (t->status != SMB_DIRECT_CS_CONNECTED)
969                         return -ENOTCONN;
970                 else if (ret < 0)
971                         return ret;
972         } while (true);
973 }
974
975 static int wait_for_send_credits(struct smb_direct_transport *t,
976                                  struct smb_direct_send_ctx *send_ctx)
977 {
978         int ret;
979
980         if (send_ctx &&
981             (send_ctx->wr_cnt >= 16 || atomic_read(&t->send_credits) <= 1)) {
982                 ret = smb_direct_flush_send_list(t, send_ctx, false);
983                 if (ret)
984                         return ret;
985         }
986
987         return wait_for_credits(t, &t->wait_send_credits, &t->send_credits);
988 }
989
990 static int smb_direct_create_header(struct smb_direct_transport *t,
991                                     int size, int remaining_data_length,
992                                     struct smb_direct_sendmsg **sendmsg_out)
993 {
994         struct smb_direct_sendmsg *sendmsg;
995         struct smb_direct_data_transfer *packet;
996         int header_length;
997         int ret;
998
999         sendmsg = smb_direct_alloc_sendmsg(t);
1000         if (IS_ERR(sendmsg))
1001                 return PTR_ERR(sendmsg);
1002
1003         /* Fill in the packet header */
1004         packet = (struct smb_direct_data_transfer *)sendmsg->packet;
1005         packet->credits_requested = cpu_to_le16(t->send_credit_target);
1006         packet->credits_granted = cpu_to_le16(manage_credits_prior_sending(t));
1007
1008         packet->flags = 0;
1009         packet->reserved = 0;
1010         if (!size)
1011                 packet->data_offset = 0;
1012         else
1013                 packet->data_offset = cpu_to_le32(24);
1014         packet->data_length = cpu_to_le32(size);
1015         packet->remaining_data_length = cpu_to_le32(remaining_data_length);
1016         packet->padding = 0;
1017
1018         ksmbd_debug(RDMA,
1019                     "credits_requested=%d credits_granted=%d data_offset=%d data_length=%d remaining_data_length=%d\n",
1020                     le16_to_cpu(packet->credits_requested),
1021                     le16_to_cpu(packet->credits_granted),
1022                     le32_to_cpu(packet->data_offset),
1023                     le32_to_cpu(packet->data_length),
1024                     le32_to_cpu(packet->remaining_data_length));
1025
1026         /* Map the packet to DMA */
1027         header_length = sizeof(struct smb_direct_data_transfer);
1028         /* If this is a packet without payload, don't send padding */
1029         if (!size)
1030                 header_length =
1031                         offsetof(struct smb_direct_data_transfer, padding);
1032
1033         sendmsg->sge[0].addr = ib_dma_map_single(t->cm_id->device,
1034                                                  (void *)packet,
1035                                                  header_length,
1036                                                  DMA_TO_DEVICE);
1037         ret = ib_dma_mapping_error(t->cm_id->device, sendmsg->sge[0].addr);
1038         if (ret) {
1039                 smb_direct_free_sendmsg(t, sendmsg);
1040                 return ret;
1041         }
1042
1043         sendmsg->num_sge = 1;
1044         sendmsg->sge[0].length = header_length;
1045         sendmsg->sge[0].lkey = t->pd->local_dma_lkey;
1046
1047         *sendmsg_out = sendmsg;
1048         return 0;
1049 }
1050
1051 static int get_sg_list(void *buf, int size, struct scatterlist *sg_list, int nentries)
1052 {
1053         bool high = is_vmalloc_addr(buf);
1054         struct page *page;
1055         int offset, len;
1056         int i = 0;
1057
1058         if (nentries < get_buf_page_count(buf, size))
1059                 return -EINVAL;
1060
1061         offset = offset_in_page(buf);
1062         buf -= offset;
1063         while (size > 0) {
1064                 len = min_t(int, PAGE_SIZE - offset, size);
1065                 if (high)
1066                         page = vmalloc_to_page(buf);
1067                 else
1068                         page = kmap_to_page(buf);
1069
1070                 if (!sg_list)
1071                         return -EINVAL;
1072                 sg_set_page(sg_list, page, len, offset);
1073                 sg_list = sg_next(sg_list);
1074
1075                 buf += PAGE_SIZE;
1076                 size -= len;
1077                 offset = 0;
1078                 i++;
1079         }
1080         return i;
1081 }
1082
1083 static int get_mapped_sg_list(struct ib_device *device, void *buf, int size,
1084                               struct scatterlist *sg_list, int nentries,
1085                               enum dma_data_direction dir)
1086 {
1087         int npages;
1088
1089         npages = get_sg_list(buf, size, sg_list, nentries);
1090         if (npages <= 0)
1091                 return -EINVAL;
1092         return ib_dma_map_sg(device, sg_list, npages, dir);
1093 }
1094
1095 static int post_sendmsg(struct smb_direct_transport *t,
1096                         struct smb_direct_send_ctx *send_ctx,
1097                         struct smb_direct_sendmsg *msg)
1098 {
1099         int i;
1100
1101         for (i = 0; i < msg->num_sge; i++)
1102                 ib_dma_sync_single_for_device(t->cm_id->device,
1103                                               msg->sge[i].addr, msg->sge[i].length,
1104                                               DMA_TO_DEVICE);
1105
1106         msg->cqe.done = send_done;
1107         msg->wr.opcode = IB_WR_SEND;
1108         msg->wr.sg_list = &msg->sge[0];
1109         msg->wr.num_sge = msg->num_sge;
1110         msg->wr.next = NULL;
1111
1112         if (send_ctx) {
1113                 msg->wr.wr_cqe = NULL;
1114                 msg->wr.send_flags = 0;
1115                 if (!list_empty(&send_ctx->msg_list)) {
1116                         struct smb_direct_sendmsg *last;
1117
1118                         last = list_last_entry(&send_ctx->msg_list,
1119                                                struct smb_direct_sendmsg,
1120                                                list);
1121                         last->wr.next = &msg->wr;
1122                 }
1123                 list_add_tail(&msg->list, &send_ctx->msg_list);
1124                 send_ctx->wr_cnt++;
1125                 return 0;
1126         }
1127
1128         msg->wr.wr_cqe = &msg->cqe;
1129         msg->wr.send_flags = IB_SEND_SIGNALED;
1130         return smb_direct_post_send(t, &msg->wr);
1131 }
1132
1133 static int smb_direct_post_send_data(struct smb_direct_transport *t,
1134                                      struct smb_direct_send_ctx *send_ctx,
1135                                      struct kvec *iov, int niov,
1136                                      int remaining_data_length)
1137 {
1138         int i, j, ret;
1139         struct smb_direct_sendmsg *msg;
1140         int data_length;
1141         struct scatterlist sg[SMB_DIRECT_MAX_SEND_SGES - 1];
1142
1143         ret = wait_for_send_credits(t, send_ctx);
1144         if (ret)
1145                 return ret;
1146
1147         data_length = 0;
1148         for (i = 0; i < niov; i++)
1149                 data_length += iov[i].iov_len;
1150
1151         ret = smb_direct_create_header(t, data_length, remaining_data_length,
1152                                        &msg);
1153         if (ret) {
1154                 atomic_inc(&t->send_credits);
1155                 return ret;
1156         }
1157
1158         for (i = 0; i < niov; i++) {
1159                 struct ib_sge *sge;
1160                 int sg_cnt;
1161
1162                 sg_init_table(sg, SMB_DIRECT_MAX_SEND_SGES - 1);
1163                 sg_cnt = get_mapped_sg_list(t->cm_id->device,
1164                                             iov[i].iov_base, iov[i].iov_len,
1165                                             sg, SMB_DIRECT_MAX_SEND_SGES - 1,
1166                                             DMA_TO_DEVICE);
1167                 if (sg_cnt <= 0) {
1168                         pr_err("failed to map buffer\n");
1169                         ret = -ENOMEM;
1170                         goto err;
1171                 } else if (sg_cnt + msg->num_sge > SMB_DIRECT_MAX_SEND_SGES) {
1172                         pr_err("buffer not fitted into sges\n");
1173                         ret = -E2BIG;
1174                         ib_dma_unmap_sg(t->cm_id->device, sg, sg_cnt,
1175                                         DMA_TO_DEVICE);
1176                         goto err;
1177                 }
1178
1179                 for (j = 0; j < sg_cnt; j++) {
1180                         sge = &msg->sge[msg->num_sge];
1181                         sge->addr = sg_dma_address(&sg[j]);
1182                         sge->length = sg_dma_len(&sg[j]);
1183                         sge->lkey  = t->pd->local_dma_lkey;
1184                         msg->num_sge++;
1185                 }
1186         }
1187
1188         ret = post_sendmsg(t, send_ctx, msg);
1189         if (ret)
1190                 goto err;
1191         return 0;
1192 err:
1193         smb_direct_free_sendmsg(t, msg);
1194         atomic_inc(&t->send_credits);
1195         return ret;
1196 }
1197
1198 static int smb_direct_writev(struct ksmbd_transport *t,
1199                              struct kvec *iov, int niovs, int buflen,
1200                              bool need_invalidate, unsigned int remote_key)
1201 {
1202         struct smb_direct_transport *st = smb_trans_direct_transfort(t);
1203         int remaining_data_length;
1204         int start, i, j;
1205         int max_iov_size = st->max_send_size -
1206                         sizeof(struct smb_direct_data_transfer);
1207         int ret;
1208         struct kvec vec;
1209         struct smb_direct_send_ctx send_ctx;
1210
1211         if (st->status != SMB_DIRECT_CS_CONNECTED)
1212                 return -ENOTCONN;
1213
1214         //FIXME: skip RFC1002 header..
1215         buflen -= 4;
1216         iov[0].iov_base += 4;
1217         iov[0].iov_len -= 4;
1218
1219         remaining_data_length = buflen;
1220         ksmbd_debug(RDMA, "Sending smb (RDMA): smb_len=%u\n", buflen);
1221
1222         smb_direct_send_ctx_init(st, &send_ctx, need_invalidate, remote_key);
1223         start = i = 0;
1224         buflen = 0;
1225         while (true) {
1226                 buflen += iov[i].iov_len;
1227                 if (buflen > max_iov_size) {
1228                         if (i > start) {
1229                                 remaining_data_length -=
1230                                         (buflen - iov[i].iov_len);
1231                                 ret = smb_direct_post_send_data(st, &send_ctx,
1232                                                                 &iov[start], i - start,
1233                                                                 remaining_data_length);
1234                                 if (ret)
1235                                         goto done;
1236                         } else {
1237                                 /* iov[start] is too big, break it */
1238                                 int nvec  = (buflen + max_iov_size - 1) /
1239                                                 max_iov_size;
1240
1241                                 for (j = 0; j < nvec; j++) {
1242                                         vec.iov_base =
1243                                                 (char *)iov[start].iov_base +
1244                                                 j * max_iov_size;
1245                                         vec.iov_len =
1246                                                 min_t(int, max_iov_size,
1247                                                       buflen - max_iov_size * j);
1248                                         remaining_data_length -= vec.iov_len;
1249                                         ret = smb_direct_post_send_data(st, &send_ctx, &vec, 1,
1250                                                                         remaining_data_length);
1251                                         if (ret)
1252                                                 goto done;
1253                                 }
1254                                 i++;
1255                                 if (i == niovs)
1256                                         break;
1257                         }
1258                         start = i;
1259                         buflen = 0;
1260                 } else {
1261                         i++;
1262                         if (i == niovs) {
1263                                 /* send out all remaining vecs */
1264                                 remaining_data_length -= buflen;
1265                                 ret = smb_direct_post_send_data(st, &send_ctx,
1266                                                                 &iov[start], i - start,
1267                                                                 remaining_data_length);
1268                                 if (ret)
1269                                         goto done;
1270                                 break;
1271                         }
1272                 }
1273         }
1274
1275 done:
1276         ret = smb_direct_flush_send_list(st, &send_ctx, true);
1277
1278         /*
1279          * As an optimization, we don't wait for individual I/O to finish
1280          * before sending the next one.
1281          * Send them all and wait for pending send count to get to 0
1282          * that means all the I/Os have been out and we are good to return
1283          */
1284
1285         wait_event(st->wait_send_payload_pending,
1286                    atomic_read(&st->send_payload_pending) == 0);
1287         return ret;
1288 }
1289
1290 static void read_write_done(struct ib_cq *cq, struct ib_wc *wc,
1291                             enum dma_data_direction dir)
1292 {
1293         struct smb_direct_rdma_rw_msg *msg = container_of(wc->wr_cqe,
1294                                                           struct smb_direct_rdma_rw_msg, cqe);
1295         struct smb_direct_transport *t = msg->t;
1296
1297         if (wc->status != IB_WC_SUCCESS) {
1298                 pr_err("read/write error. opcode = %d, status = %s(%d)\n",
1299                        wc->opcode, ib_wc_status_msg(wc->status), wc->status);
1300                 smb_direct_disconnect_rdma_connection(t);
1301         }
1302
1303         if (atomic_inc_return(&t->rw_avail_ops) > 0)
1304                 wake_up(&t->wait_rw_avail_ops);
1305
1306         rdma_rw_ctx_destroy(&msg->rw_ctx, t->qp, t->qp->port,
1307                             msg->sg_list, msg->sgt.nents, dir);
1308         sg_free_table_chained(&msg->sgt, SG_CHUNK_SIZE);
1309         complete(msg->completion);
1310         kfree(msg);
1311 }
1312
1313 static void read_done(struct ib_cq *cq, struct ib_wc *wc)
1314 {
1315         read_write_done(cq, wc, DMA_FROM_DEVICE);
1316 }
1317
1318 static void write_done(struct ib_cq *cq, struct ib_wc *wc)
1319 {
1320         read_write_done(cq, wc, DMA_TO_DEVICE);
1321 }
1322
1323 static int smb_direct_rdma_xmit(struct smb_direct_transport *t, void *buf,
1324                                 int buf_len, u32 remote_key, u64 remote_offset,
1325                                 u32 remote_len, bool is_read)
1326 {
1327         struct smb_direct_rdma_rw_msg *msg;
1328         int ret;
1329         DECLARE_COMPLETION_ONSTACK(completion);
1330         struct ib_send_wr *first_wr = NULL;
1331
1332         ret = wait_for_credits(t, &t->wait_rw_avail_ops, &t->rw_avail_ops);
1333         if (ret < 0)
1334                 return ret;
1335
1336         /* TODO: mempool */
1337         msg = kmalloc(offsetof(struct smb_direct_rdma_rw_msg, sg_list) +
1338                       sizeof(struct scatterlist) * SG_CHUNK_SIZE, GFP_KERNEL);
1339         if (!msg) {
1340                 atomic_inc(&t->rw_avail_ops);
1341                 return -ENOMEM;
1342         }
1343
1344         msg->sgt.sgl = &msg->sg_list[0];
1345         ret = sg_alloc_table_chained(&msg->sgt,
1346                                      get_buf_page_count(buf, buf_len),
1347                                      msg->sg_list, SG_CHUNK_SIZE);
1348         if (ret) {
1349                 atomic_inc(&t->rw_avail_ops);
1350                 kfree(msg);
1351                 return -ENOMEM;
1352         }
1353
1354         ret = get_sg_list(buf, buf_len, msg->sgt.sgl, msg->sgt.orig_nents);
1355         if (ret <= 0) {
1356                 pr_err("failed to get pages\n");
1357                 goto err;
1358         }
1359
1360         ret = rdma_rw_ctx_init(&msg->rw_ctx, t->qp, t->qp->port,
1361                                msg->sg_list, get_buf_page_count(buf, buf_len),
1362                                0, remote_offset, remote_key,
1363                                is_read ? DMA_FROM_DEVICE : DMA_TO_DEVICE);
1364         if (ret < 0) {
1365                 pr_err("failed to init rdma_rw_ctx: %d\n", ret);
1366                 goto err;
1367         }
1368
1369         msg->t = t;
1370         msg->cqe.done = is_read ? read_done : write_done;
1371         msg->completion = &completion;
1372         first_wr = rdma_rw_ctx_wrs(&msg->rw_ctx, t->qp, t->qp->port,
1373                                    &msg->cqe, NULL);
1374
1375         ret = ib_post_send(t->qp, first_wr, NULL);
1376         if (ret) {
1377                 pr_err("failed to post send wr: %d\n", ret);
1378                 goto err;
1379         }
1380
1381         wait_for_completion(&completion);
1382         return 0;
1383
1384 err:
1385         atomic_inc(&t->rw_avail_ops);
1386         if (first_wr)
1387                 rdma_rw_ctx_destroy(&msg->rw_ctx, t->qp, t->qp->port,
1388                                     msg->sg_list, msg->sgt.nents,
1389                                     is_read ? DMA_FROM_DEVICE : DMA_TO_DEVICE);
1390         sg_free_table_chained(&msg->sgt, SG_CHUNK_SIZE);
1391         kfree(msg);
1392         return ret;
1393 }
1394
1395 static int smb_direct_rdma_write(struct ksmbd_transport *t, void *buf,
1396                                  unsigned int buflen, u32 remote_key,
1397                                  u64 remote_offset, u32 remote_len)
1398 {
1399         return smb_direct_rdma_xmit(smb_trans_direct_transfort(t), buf, buflen,
1400                                     remote_key, remote_offset,
1401                                     remote_len, false);
1402 }
1403
1404 static int smb_direct_rdma_read(struct ksmbd_transport *t, void *buf,
1405                                 unsigned int buflen, u32 remote_key,
1406                                 u64 remote_offset, u32 remote_len)
1407 {
1408         return smb_direct_rdma_xmit(smb_trans_direct_transfort(t), buf, buflen,
1409                                     remote_key, remote_offset,
1410                                     remote_len, true);
1411 }
1412
1413 static void smb_direct_disconnect(struct ksmbd_transport *t)
1414 {
1415         struct smb_direct_transport *st = smb_trans_direct_transfort(t);
1416
1417         ksmbd_debug(RDMA, "Disconnecting cm_id=%p\n", st->cm_id);
1418
1419         smb_direct_disconnect_rdma_work(&st->disconnect_work);
1420         wait_event_interruptible(st->wait_status,
1421                                  st->status == SMB_DIRECT_CS_DISCONNECTED);
1422         free_transport(st);
1423 }
1424
1425 static int smb_direct_cm_handler(struct rdma_cm_id *cm_id,
1426                                  struct rdma_cm_event *event)
1427 {
1428         struct smb_direct_transport *t = cm_id->context;
1429
1430         ksmbd_debug(RDMA, "RDMA CM event. cm_id=%p event=%s (%d)\n",
1431                     cm_id, rdma_event_msg(event->event), event->event);
1432
1433         switch (event->event) {
1434         case RDMA_CM_EVENT_ESTABLISHED: {
1435                 t->status = SMB_DIRECT_CS_CONNECTED;
1436                 wake_up_interruptible(&t->wait_status);
1437                 break;
1438         }
1439         case RDMA_CM_EVENT_DEVICE_REMOVAL:
1440         case RDMA_CM_EVENT_DISCONNECTED: {
1441                 t->status = SMB_DIRECT_CS_DISCONNECTED;
1442                 wake_up_interruptible(&t->wait_status);
1443                 wake_up_interruptible(&t->wait_reassembly_queue);
1444                 wake_up(&t->wait_send_credits);
1445                 break;
1446         }
1447         case RDMA_CM_EVENT_CONNECT_ERROR: {
1448                 t->status = SMB_DIRECT_CS_DISCONNECTED;
1449                 wake_up_interruptible(&t->wait_status);
1450                 break;
1451         }
1452         default:
1453                 pr_err("Unexpected RDMA CM event. cm_id=%p, event=%s (%d)\n",
1454                        cm_id, rdma_event_msg(event->event),
1455                        event->event);
1456                 break;
1457         }
1458         return 0;
1459 }
1460
1461 static void smb_direct_qpair_handler(struct ib_event *event, void *context)
1462 {
1463         struct smb_direct_transport *t = context;
1464
1465         ksmbd_debug(RDMA, "Received QP event. cm_id=%p, event=%s (%d)\n",
1466                     t->cm_id, ib_event_msg(event->event), event->event);
1467
1468         switch (event->event) {
1469         case IB_EVENT_CQ_ERR:
1470         case IB_EVENT_QP_FATAL:
1471                 smb_direct_disconnect_rdma_connection(t);
1472                 break;
1473         default:
1474                 break;
1475         }
1476 }
1477
1478 static int smb_direct_send_negotiate_response(struct smb_direct_transport *t,
1479                                               int failed)
1480 {
1481         struct smb_direct_sendmsg *sendmsg;
1482         struct smb_direct_negotiate_resp *resp;
1483         int ret;
1484
1485         sendmsg = smb_direct_alloc_sendmsg(t);
1486         if (IS_ERR(sendmsg))
1487                 return -ENOMEM;
1488
1489         resp = (struct smb_direct_negotiate_resp *)sendmsg->packet;
1490         if (failed) {
1491                 memset(resp, 0, sizeof(*resp));
1492                 resp->min_version = cpu_to_le16(0x0100);
1493                 resp->max_version = cpu_to_le16(0x0100);
1494                 resp->status = STATUS_NOT_SUPPORTED;
1495         } else {
1496                 resp->status = STATUS_SUCCESS;
1497                 resp->min_version = SMB_DIRECT_VERSION_LE;
1498                 resp->max_version = SMB_DIRECT_VERSION_LE;
1499                 resp->negotiated_version = SMB_DIRECT_VERSION_LE;
1500                 resp->reserved = 0;
1501                 resp->credits_requested =
1502                                 cpu_to_le16(t->send_credit_target);
1503                 resp->credits_granted = cpu_to_le16(manage_credits_prior_sending(t));
1504                 resp->max_readwrite_size = cpu_to_le32(t->max_rdma_rw_size);
1505                 resp->preferred_send_size = cpu_to_le32(t->max_send_size);
1506                 resp->max_receive_size = cpu_to_le32(t->max_recv_size);
1507                 resp->max_fragmented_size =
1508                                 cpu_to_le32(t->max_fragmented_recv_size);
1509         }
1510
1511         sendmsg->sge[0].addr = ib_dma_map_single(t->cm_id->device,
1512                                                  (void *)resp, sizeof(*resp),
1513                                                  DMA_TO_DEVICE);
1514         ret = ib_dma_mapping_error(t->cm_id->device, sendmsg->sge[0].addr);
1515         if (ret) {
1516                 smb_direct_free_sendmsg(t, sendmsg);
1517                 return ret;
1518         }
1519
1520         sendmsg->num_sge = 1;
1521         sendmsg->sge[0].length = sizeof(*resp);
1522         sendmsg->sge[0].lkey = t->pd->local_dma_lkey;
1523
1524         ret = post_sendmsg(t, NULL, sendmsg);
1525         if (ret) {
1526                 smb_direct_free_sendmsg(t, sendmsg);
1527                 return ret;
1528         }
1529
1530         wait_event(t->wait_send_pending,
1531                    atomic_read(&t->send_pending) == 0);
1532         return 0;
1533 }
1534
1535 static int smb_direct_accept_client(struct smb_direct_transport *t)
1536 {
1537         struct rdma_conn_param conn_param;
1538         struct ib_port_immutable port_immutable;
1539         u32 ird_ord_hdr[2];
1540         int ret;
1541
1542         memset(&conn_param, 0, sizeof(conn_param));
1543         conn_param.initiator_depth = min_t(u8, t->cm_id->device->attrs.max_qp_rd_atom,
1544                                            SMB_DIRECT_CM_INITIATOR_DEPTH);
1545         conn_param.responder_resources = 0;
1546
1547         t->cm_id->device->ops.get_port_immutable(t->cm_id->device,
1548                                                  t->cm_id->port_num,
1549                                                  &port_immutable);
1550         if (port_immutable.core_cap_flags & RDMA_CORE_PORT_IWARP) {
1551                 ird_ord_hdr[0] = conn_param.responder_resources;
1552                 ird_ord_hdr[1] = 1;
1553                 conn_param.private_data = ird_ord_hdr;
1554                 conn_param.private_data_len = sizeof(ird_ord_hdr);
1555         } else {
1556                 conn_param.private_data = NULL;
1557                 conn_param.private_data_len = 0;
1558         }
1559         conn_param.retry_count = SMB_DIRECT_CM_RETRY;
1560         conn_param.rnr_retry_count = SMB_DIRECT_CM_RNR_RETRY;
1561         conn_param.flow_control = 0;
1562
1563         ret = rdma_accept(t->cm_id, &conn_param);
1564         if (ret) {
1565                 pr_err("error at rdma_accept: %d\n", ret);
1566                 return ret;
1567         }
1568
1569         wait_event_interruptible(t->wait_status,
1570                                  t->status != SMB_DIRECT_CS_NEW);
1571         if (t->status != SMB_DIRECT_CS_CONNECTED)
1572                 return -ENOTCONN;
1573         return 0;
1574 }
1575
1576 static int smb_direct_negotiate(struct smb_direct_transport *t)
1577 {
1578         int ret;
1579         struct smb_direct_recvmsg *recvmsg;
1580         struct smb_direct_negotiate_req *req;
1581
1582         recvmsg = get_free_recvmsg(t);
1583         if (!recvmsg)
1584                 return -ENOMEM;
1585         recvmsg->type = SMB_DIRECT_MSG_NEGOTIATE_REQ;
1586
1587         ret = smb_direct_post_recv(t, recvmsg);
1588         if (ret) {
1589                 pr_err("Can't post recv: %d\n", ret);
1590                 goto out;
1591         }
1592
1593         t->negotiation_requested = false;
1594         ret = smb_direct_accept_client(t);
1595         if (ret) {
1596                 pr_err("Can't accept client\n");
1597                 goto out;
1598         }
1599
1600         smb_direct_post_recv_credits(&t->post_recv_credits_work.work);
1601
1602         ksmbd_debug(RDMA, "Waiting for SMB_DIRECT negotiate request\n");
1603         ret = wait_event_interruptible_timeout(t->wait_status,
1604                                                t->negotiation_requested ||
1605                                                 t->status == SMB_DIRECT_CS_DISCONNECTED,
1606                                                SMB_DIRECT_NEGOTIATE_TIMEOUT * HZ);
1607         if (ret <= 0 || t->status == SMB_DIRECT_CS_DISCONNECTED) {
1608                 ret = ret < 0 ? ret : -ETIMEDOUT;
1609                 goto out;
1610         }
1611
1612         ret = smb_direct_check_recvmsg(recvmsg);
1613         if (ret == -ECONNABORTED)
1614                 goto out;
1615
1616         req = (struct smb_direct_negotiate_req *)recvmsg->packet;
1617         t->max_recv_size = min_t(int, t->max_recv_size,
1618                                  le32_to_cpu(req->preferred_send_size));
1619         t->max_send_size = min_t(int, t->max_send_size,
1620                                  le32_to_cpu(req->max_receive_size));
1621         t->max_fragmented_send_size =
1622                         le32_to_cpu(req->max_fragmented_size);
1623
1624         ret = smb_direct_send_negotiate_response(t, ret);
1625 out:
1626         if (recvmsg)
1627                 put_recvmsg(t, recvmsg);
1628         return ret;
1629 }
1630
1631 static int smb_direct_init_params(struct smb_direct_transport *t,
1632                                   struct ib_qp_cap *cap)
1633 {
1634         struct ib_device *device = t->cm_id->device;
1635         int max_send_sges, max_pages, max_rw_wrs, max_send_wrs;
1636
1637         /* need 2 more sge. because a SMB_DIRECT header will be mapped,
1638          * and maybe a send buffer could be not page aligned.
1639          */
1640         t->max_send_size = smb_direct_max_send_size;
1641         max_send_sges = DIV_ROUND_UP(t->max_send_size, PAGE_SIZE) + 2;
1642         if (max_send_sges > SMB_DIRECT_MAX_SEND_SGES) {
1643                 pr_err("max_send_size %d is too large\n", t->max_send_size);
1644                 return -EINVAL;
1645         }
1646
1647         /*
1648          * allow smb_direct_max_outstanding_rw_ops of in-flight RDMA
1649          * read/writes. HCA guarantees at least max_send_sge of sges for
1650          * a RDMA read/write work request, and if memory registration is used,
1651          * we need reg_mr, local_inv wrs for each read/write.
1652          */
1653         t->max_rdma_rw_size = smb_direct_max_read_write_size;
1654         max_pages = DIV_ROUND_UP(t->max_rdma_rw_size, PAGE_SIZE) + 1;
1655         max_rw_wrs = DIV_ROUND_UP(max_pages, SMB_DIRECT_MAX_SEND_SGES);
1656         max_rw_wrs += rdma_rw_mr_factor(device, t->cm_id->port_num,
1657                         max_pages) * 2;
1658         max_rw_wrs *= smb_direct_max_outstanding_rw_ops;
1659
1660         max_send_wrs = smb_direct_send_credit_target + max_rw_wrs;
1661         if (max_send_wrs > device->attrs.max_cqe ||
1662             max_send_wrs > device->attrs.max_qp_wr) {
1663                 pr_err("consider lowering send_credit_target = %d, or max_outstanding_rw_ops = %d\n",
1664                        smb_direct_send_credit_target,
1665                        smb_direct_max_outstanding_rw_ops);
1666                 pr_err("Possible CQE overrun, device reporting max_cqe %d max_qp_wr %d\n",
1667                        device->attrs.max_cqe, device->attrs.max_qp_wr);
1668                 return -EINVAL;
1669         }
1670
1671         if (smb_direct_receive_credit_max > device->attrs.max_cqe ||
1672             smb_direct_receive_credit_max > device->attrs.max_qp_wr) {
1673                 pr_err("consider lowering receive_credit_max = %d\n",
1674                        smb_direct_receive_credit_max);
1675                 pr_err("Possible CQE overrun, device reporting max_cpe %d max_qp_wr %d\n",
1676                        device->attrs.max_cqe, device->attrs.max_qp_wr);
1677                 return -EINVAL;
1678         }
1679
1680         if (device->attrs.max_send_sge < SMB_DIRECT_MAX_SEND_SGES) {
1681                 pr_err("warning: device max_send_sge = %d too small\n",
1682                        device->attrs.max_send_sge);
1683                 return -EINVAL;
1684         }
1685         if (device->attrs.max_recv_sge < SMB_DIRECT_MAX_RECV_SGES) {
1686                 pr_err("warning: device max_recv_sge = %d too small\n",
1687                        device->attrs.max_recv_sge);
1688                 return -EINVAL;
1689         }
1690
1691         t->recv_credits = 0;
1692         t->count_avail_recvmsg = 0;
1693
1694         t->recv_credit_max = smb_direct_receive_credit_max;
1695         t->recv_credit_target = 10;
1696         t->new_recv_credits = 0;
1697
1698         t->send_credit_target = smb_direct_send_credit_target;
1699         atomic_set(&t->send_credits, 0);
1700         atomic_set(&t->rw_avail_ops, smb_direct_max_outstanding_rw_ops);
1701
1702         t->max_send_size = smb_direct_max_send_size;
1703         t->max_recv_size = smb_direct_max_receive_size;
1704         t->max_fragmented_recv_size = smb_direct_max_fragmented_recv_size;
1705
1706         cap->max_send_wr = max_send_wrs;
1707         cap->max_recv_wr = t->recv_credit_max;
1708         cap->max_send_sge = SMB_DIRECT_MAX_SEND_SGES;
1709         cap->max_recv_sge = SMB_DIRECT_MAX_RECV_SGES;
1710         cap->max_inline_data = 0;
1711         cap->max_rdma_ctxs = 0;
1712         return 0;
1713 }
1714
1715 static void smb_direct_destroy_pools(struct smb_direct_transport *t)
1716 {
1717         struct smb_direct_recvmsg *recvmsg;
1718
1719         while ((recvmsg = get_free_recvmsg(t)))
1720                 mempool_free(recvmsg, t->recvmsg_mempool);
1721         while ((recvmsg = get_empty_recvmsg(t)))
1722                 mempool_free(recvmsg, t->recvmsg_mempool);
1723
1724         mempool_destroy(t->recvmsg_mempool);
1725         t->recvmsg_mempool = NULL;
1726
1727         kmem_cache_destroy(t->recvmsg_cache);
1728         t->recvmsg_cache = NULL;
1729
1730         mempool_destroy(t->sendmsg_mempool);
1731         t->sendmsg_mempool = NULL;
1732
1733         kmem_cache_destroy(t->sendmsg_cache);
1734         t->sendmsg_cache = NULL;
1735 }
1736
1737 static int smb_direct_create_pools(struct smb_direct_transport *t)
1738 {
1739         char name[80];
1740         int i;
1741         struct smb_direct_recvmsg *recvmsg;
1742
1743         snprintf(name, sizeof(name), "smb_direct_rqst_pool_%p", t);
1744         t->sendmsg_cache = kmem_cache_create(name,
1745                                              sizeof(struct smb_direct_sendmsg) +
1746                                               sizeof(struct smb_direct_negotiate_resp),
1747                                              0, SLAB_HWCACHE_ALIGN, NULL);
1748         if (!t->sendmsg_cache)
1749                 return -ENOMEM;
1750
1751         t->sendmsg_mempool = mempool_create(t->send_credit_target,
1752                                             mempool_alloc_slab, mempool_free_slab,
1753                                             t->sendmsg_cache);
1754         if (!t->sendmsg_mempool)
1755                 goto err;
1756
1757         snprintf(name, sizeof(name), "smb_direct_resp_%p", t);
1758         t->recvmsg_cache = kmem_cache_create(name,
1759                                              sizeof(struct smb_direct_recvmsg) +
1760                                               t->max_recv_size,
1761                                              0, SLAB_HWCACHE_ALIGN, NULL);
1762         if (!t->recvmsg_cache)
1763                 goto err;
1764
1765         t->recvmsg_mempool =
1766                 mempool_create(t->recv_credit_max, mempool_alloc_slab,
1767                                mempool_free_slab, t->recvmsg_cache);
1768         if (!t->recvmsg_mempool)
1769                 goto err;
1770
1771         INIT_LIST_HEAD(&t->recvmsg_queue);
1772
1773         for (i = 0; i < t->recv_credit_max; i++) {
1774                 recvmsg = mempool_alloc(t->recvmsg_mempool, GFP_KERNEL);
1775                 if (!recvmsg)
1776                         goto err;
1777                 recvmsg->transport = t;
1778                 list_add(&recvmsg->list, &t->recvmsg_queue);
1779         }
1780         t->count_avail_recvmsg = t->recv_credit_max;
1781
1782         return 0;
1783 err:
1784         smb_direct_destroy_pools(t);
1785         return -ENOMEM;
1786 }
1787
1788 static int smb_direct_create_qpair(struct smb_direct_transport *t,
1789                                    struct ib_qp_cap *cap)
1790 {
1791         int ret;
1792         struct ib_qp_init_attr qp_attr;
1793
1794         t->pd = ib_alloc_pd(t->cm_id->device, 0);
1795         if (IS_ERR(t->pd)) {
1796                 pr_err("Can't create RDMA PD\n");
1797                 ret = PTR_ERR(t->pd);
1798                 t->pd = NULL;
1799                 return ret;
1800         }
1801
1802         t->send_cq = ib_alloc_cq(t->cm_id->device, t,
1803                                  t->send_credit_target, 0, IB_POLL_WORKQUEUE);
1804         if (IS_ERR(t->send_cq)) {
1805                 pr_err("Can't create RDMA send CQ\n");
1806                 ret = PTR_ERR(t->send_cq);
1807                 t->send_cq = NULL;
1808                 goto err;
1809         }
1810
1811         t->recv_cq = ib_alloc_cq(t->cm_id->device, t,
1812                                  cap->max_send_wr + cap->max_rdma_ctxs,
1813                                  0, IB_POLL_WORKQUEUE);
1814         if (IS_ERR(t->recv_cq)) {
1815                 pr_err("Can't create RDMA recv CQ\n");
1816                 ret = PTR_ERR(t->recv_cq);
1817                 t->recv_cq = NULL;
1818                 goto err;
1819         }
1820
1821         memset(&qp_attr, 0, sizeof(qp_attr));
1822         qp_attr.event_handler = smb_direct_qpair_handler;
1823         qp_attr.qp_context = t;
1824         qp_attr.cap = *cap;
1825         qp_attr.sq_sig_type = IB_SIGNAL_REQ_WR;
1826         qp_attr.qp_type = IB_QPT_RC;
1827         qp_attr.send_cq = t->send_cq;
1828         qp_attr.recv_cq = t->recv_cq;
1829         qp_attr.port_num = ~0;
1830
1831         ret = rdma_create_qp(t->cm_id, t->pd, &qp_attr);
1832         if (ret) {
1833                 pr_err("Can't create RDMA QP: %d\n", ret);
1834                 goto err;
1835         }
1836
1837         t->qp = t->cm_id->qp;
1838         t->cm_id->event_handler = smb_direct_cm_handler;
1839
1840         return 0;
1841 err:
1842         if (t->qp) {
1843                 ib_destroy_qp(t->qp);
1844                 t->qp = NULL;
1845         }
1846         if (t->recv_cq) {
1847                 ib_destroy_cq(t->recv_cq);
1848                 t->recv_cq = NULL;
1849         }
1850         if (t->send_cq) {
1851                 ib_destroy_cq(t->send_cq);
1852                 t->send_cq = NULL;
1853         }
1854         if (t->pd) {
1855                 ib_dealloc_pd(t->pd);
1856                 t->pd = NULL;
1857         }
1858         return ret;
1859 }
1860
1861 static int smb_direct_prepare(struct ksmbd_transport *t)
1862 {
1863         struct smb_direct_transport *st = smb_trans_direct_transfort(t);
1864         int ret;
1865         struct ib_qp_cap qp_cap;
1866
1867         ret = smb_direct_init_params(st, &qp_cap);
1868         if (ret) {
1869                 pr_err("Can't configure RDMA parameters\n");
1870                 return ret;
1871         }
1872
1873         ret = smb_direct_create_pools(st);
1874         if (ret) {
1875                 pr_err("Can't init RDMA pool: %d\n", ret);
1876                 return ret;
1877         }
1878
1879         ret = smb_direct_create_qpair(st, &qp_cap);
1880         if (ret) {
1881                 pr_err("Can't accept RDMA client: %d\n", ret);
1882                 return ret;
1883         }
1884
1885         ret = smb_direct_negotiate(st);
1886         if (ret) {
1887                 pr_err("Can't negotiate: %d\n", ret);
1888                 return ret;
1889         }
1890
1891         st->status = SMB_DIRECT_CS_CONNECTED;
1892         return 0;
1893 }
1894
1895 static bool rdma_frwr_is_supported(struct ib_device_attr *attrs)
1896 {
1897         if (!(attrs->device_cap_flags & IB_DEVICE_MEM_MGT_EXTENSIONS))
1898                 return false;
1899         if (attrs->max_fast_reg_page_list_len == 0)
1900                 return false;
1901         return true;
1902 }
1903
1904 static int smb_direct_handle_connect_request(struct rdma_cm_id *new_cm_id)
1905 {
1906         struct smb_direct_transport *t;
1907
1908         if (!rdma_frwr_is_supported(&new_cm_id->device->attrs)) {
1909                 ksmbd_debug(RDMA,
1910                             "Fast Registration Work Requests is not supported. device capabilities=%llx\n",
1911                             new_cm_id->device->attrs.device_cap_flags);
1912                 return -EPROTONOSUPPORT;
1913         }
1914
1915         t = alloc_transport(new_cm_id);
1916         if (!t)
1917                 return -ENOMEM;
1918
1919         KSMBD_TRANS(t)->handler = kthread_run(ksmbd_conn_handler_loop,
1920                                               KSMBD_TRANS(t)->conn, "ksmbd:r%u",
1921                                               SMB_DIRECT_PORT);
1922         if (IS_ERR(KSMBD_TRANS(t)->handler)) {
1923                 int ret = PTR_ERR(KSMBD_TRANS(t)->handler);
1924
1925                 pr_err("Can't start thread\n");
1926                 free_transport(t);
1927                 return ret;
1928         }
1929
1930         return 0;
1931 }
1932
1933 static int smb_direct_listen_handler(struct rdma_cm_id *cm_id,
1934                                      struct rdma_cm_event *event)
1935 {
1936         switch (event->event) {
1937         case RDMA_CM_EVENT_CONNECT_REQUEST: {
1938                 int ret = smb_direct_handle_connect_request(cm_id);
1939
1940                 if (ret) {
1941                         pr_err("Can't create transport: %d\n", ret);
1942                         return ret;
1943                 }
1944
1945                 ksmbd_debug(RDMA, "Received connection request. cm_id=%p\n",
1946                             cm_id);
1947                 break;
1948         }
1949         default:
1950                 pr_err("Unexpected listen event. cm_id=%p, event=%s (%d)\n",
1951                        cm_id, rdma_event_msg(event->event), event->event);
1952                 break;
1953         }
1954         return 0;
1955 }
1956
1957 static int smb_direct_listen(int port)
1958 {
1959         int ret;
1960         struct rdma_cm_id *cm_id;
1961         struct sockaddr_in sin = {
1962                 .sin_family             = AF_INET,
1963                 .sin_addr.s_addr        = htonl(INADDR_ANY),
1964                 .sin_port               = htons(port),
1965         };
1966
1967         cm_id = rdma_create_id(&init_net, smb_direct_listen_handler,
1968                                &smb_direct_listener, RDMA_PS_TCP, IB_QPT_RC);
1969         if (IS_ERR(cm_id)) {
1970                 pr_err("Can't create cm id: %ld\n", PTR_ERR(cm_id));
1971                 return PTR_ERR(cm_id);
1972         }
1973
1974         ret = rdma_bind_addr(cm_id, (struct sockaddr *)&sin);
1975         if (ret) {
1976                 pr_err("Can't bind: %d\n", ret);
1977                 goto err;
1978         }
1979
1980         smb_direct_listener.cm_id = cm_id;
1981
1982         ret = rdma_listen(cm_id, 10);
1983         if (ret) {
1984                 pr_err("Can't listen: %d\n", ret);
1985                 goto err;
1986         }
1987         return 0;
1988 err:
1989         smb_direct_listener.cm_id = NULL;
1990         rdma_destroy_id(cm_id);
1991         return ret;
1992 }
1993
1994 int ksmbd_rdma_init(void)
1995 {
1996         int ret;
1997
1998         smb_direct_listener.cm_id = NULL;
1999
2000         /* When a client is running out of send credits, the credits are
2001          * granted by the server's sending a packet using this queue.
2002          * This avoids the situation that a clients cannot send packets
2003          * for lack of credits
2004          */
2005         smb_direct_wq = alloc_workqueue("ksmbd-smb_direct-wq",
2006                                         WQ_HIGHPRI | WQ_MEM_RECLAIM, 0);
2007         if (!smb_direct_wq)
2008                 return -ENOMEM;
2009
2010         ret = smb_direct_listen(SMB_DIRECT_PORT);
2011         if (ret) {
2012                 destroy_workqueue(smb_direct_wq);
2013                 smb_direct_wq = NULL;
2014                 pr_err("Can't listen: %d\n", ret);
2015                 return ret;
2016         }
2017
2018         ksmbd_debug(RDMA, "init RDMA listener. cm_id=%p\n",
2019                     smb_direct_listener.cm_id);
2020         return 0;
2021 }
2022
2023 int ksmbd_rdma_destroy(void)
2024 {
2025         if (smb_direct_listener.cm_id)
2026                 rdma_destroy_id(smb_direct_listener.cm_id);
2027         smb_direct_listener.cm_id = NULL;
2028
2029         if (smb_direct_wq) {
2030                 flush_workqueue(smb_direct_wq);
2031                 destroy_workqueue(smb_direct_wq);
2032                 smb_direct_wq = NULL;
2033         }
2034         return 0;
2035 }
2036
2037 bool ksmbd_rdma_capable_netdev(struct net_device *netdev)
2038 {
2039         struct ib_device *ibdev;
2040         bool rdma_capable = false;
2041
2042         ibdev = ib_device_get_by_netdev(netdev, RDMA_DRIVER_UNKNOWN);
2043         if (ibdev) {
2044                 if (rdma_frwr_is_supported(&ibdev->attrs))
2045                         rdma_capable = true;
2046                 ib_device_put(ibdev);
2047         }
2048         return rdma_capable;
2049 }
2050
2051 static struct ksmbd_transport_ops ksmbd_smb_direct_transport_ops = {
2052         .prepare        = smb_direct_prepare,
2053         .disconnect     = smb_direct_disconnect,
2054         .writev         = smb_direct_writev,
2055         .read           = smb_direct_read,
2056         .rdma_read      = smb_direct_rdma_read,
2057         .rdma_write     = smb_direct_rdma_write,
2058 };