fs: dlm: retry accept() until -EAGAIN or error returns
[platform/kernel/linux-rpi.git] / fs / dlm / lowcomms.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /******************************************************************************
3 *******************************************************************************
4 **
5 **  Copyright (C) Sistina Software, Inc.  1997-2003  All rights reserved.
6 **  Copyright (C) 2004-2009 Red Hat, Inc.  All rights reserved.
7 **
8 **
9 *******************************************************************************
10 ******************************************************************************/
11
12 /*
13  * lowcomms.c
14  *
15  * This is the "low-level" comms layer.
16  *
17  * It is responsible for sending/receiving messages
18  * from other nodes in the cluster.
19  *
20  * Cluster nodes are referred to by their nodeids. nodeids are
21  * simply 32 bit numbers to the locking module - if they need to
22  * be expanded for the cluster infrastructure then that is its
23  * responsibility. It is this layer's
24  * responsibility to resolve these into IP address or
25  * whatever it needs for inter-node communication.
26  *
27  * The comms level is two kernel threads that deal mainly with
28  * the receiving of messages from other nodes and passing them
29  * up to the mid-level comms layer (which understands the
30  * message format) for execution by the locking core, and
31  * a send thread which does all the setting up of connections
32  * to remote nodes and the sending of data. Threads are not allowed
33  * to send their own data because it may cause them to wait in times
34  * of high load. Also, this way, the sending thread can collect together
35  * messages bound for one node and send them in one block.
36  *
37  * lowcomms will choose to use either TCP or SCTP as its transport layer
38  * depending on the configuration variable 'protocol'. This should be set
39  * to 0 (default) for TCP or 1 for SCTP. It should be configured using a
40  * cluster-wide mechanism as it must be the same on all nodes of the cluster
41  * for the DLM to function.
42  *
43  */
44
45 #include <asm/ioctls.h>
46 #include <net/sock.h>
47 #include <net/tcp.h>
48 #include <linux/pagemap.h>
49 #include <linux/file.h>
50 #include <linux/mutex.h>
51 #include <linux/sctp.h>
52 #include <linux/slab.h>
53 #include <net/sctp/sctp.h>
54 #include <net/ipv6.h>
55
56 #include "dlm_internal.h"
57 #include "lowcomms.h"
58 #include "midcomms.h"
59 #include "config.h"
60
61 #define NEEDED_RMEM (4*1024*1024)
62
63 /* Number of messages to send before rescheduling */
64 #define MAX_SEND_MSG_COUNT 25
65 #define DLM_SHUTDOWN_WAIT_TIMEOUT msecs_to_jiffies(10000)
66
67 struct connection {
68         struct socket *sock;    /* NULL if not connected */
69         uint32_t nodeid;        /* So we know who we are in the list */
70         struct mutex sock_mutex;
71         unsigned long flags;
72 #define CF_READ_PENDING 1
73 #define CF_WRITE_PENDING 2
74 #define CF_INIT_PENDING 4
75 #define CF_IS_OTHERCON 5
76 #define CF_CLOSE 6
77 #define CF_APP_LIMITED 7
78 #define CF_CLOSING 8
79 #define CF_SHUTDOWN 9
80 #define CF_CONNECTED 10
81 #define CF_RECONNECT 11
82 #define CF_DELAY_CONNECT 12
83 #define CF_EOF 13
84         struct list_head writequeue;  /* List of outgoing writequeue_entries */
85         spinlock_t writequeue_lock;
86         atomic_t writequeue_cnt;
87         struct mutex wq_alloc;
88         int retries;
89 #define MAX_CONNECT_RETRIES 3
90         struct hlist_node list;
91         struct connection *othercon;
92         struct connection *sendcon;
93         struct work_struct rwork; /* Receive workqueue */
94         struct work_struct swork; /* Send workqueue */
95         wait_queue_head_t shutdown_wait; /* wait for graceful shutdown */
96         unsigned char *rx_buf;
97         int rx_buflen;
98         int rx_leftover;
99         struct rcu_head rcu;
100 };
101 #define sock2con(x) ((struct connection *)(x)->sk_user_data)
102
103 struct listen_connection {
104         struct socket *sock;
105         struct work_struct rwork;
106 };
107
108 #define DLM_WQ_REMAIN_BYTES(e) (PAGE_SIZE - e->end)
109 #define DLM_WQ_LENGTH_BYTES(e) (e->end - e->offset)
110
111 /* An entry waiting to be sent */
112 struct writequeue_entry {
113         struct list_head list;
114         struct page *page;
115         int offset;
116         int len;
117         int end;
118         int users;
119         bool dirty;
120         struct connection *con;
121         struct list_head msgs;
122         struct kref ref;
123 };
124
125 struct dlm_msg {
126         struct writequeue_entry *entry;
127         struct dlm_msg *orig_msg;
128         bool retransmit;
129         void *ppc;
130         int len;
131         int idx; /* new()/commit() idx exchange */
132
133         struct list_head list;
134         struct kref ref;
135 };
136
137 struct dlm_node_addr {
138         struct list_head list;
139         int nodeid;
140         int mark;
141         int addr_count;
142         int curr_addr_index;
143         struct sockaddr_storage *addr[DLM_MAX_ADDR_COUNT];
144 };
145
146 struct dlm_proto_ops {
147         bool try_new_addr;
148         const char *name;
149         int proto;
150
151         int (*connect)(struct connection *con, struct socket *sock,
152                        struct sockaddr *addr, int addr_len);
153         void (*sockopts)(struct socket *sock);
154         int (*bind)(struct socket *sock);
155         int (*listen_validate)(void);
156         void (*listen_sockopts)(struct socket *sock);
157         int (*listen_bind)(struct socket *sock);
158         /* What to do to shutdown */
159         void (*shutdown_action)(struct connection *con);
160         /* What to do to eof check */
161         bool (*eof_condition)(struct connection *con);
162 };
163
164 static struct listen_sock_callbacks {
165         void (*sk_error_report)(struct sock *);
166         void (*sk_data_ready)(struct sock *);
167         void (*sk_state_change)(struct sock *);
168         void (*sk_write_space)(struct sock *);
169 } listen_sock;
170
171 static LIST_HEAD(dlm_node_addrs);
172 static DEFINE_SPINLOCK(dlm_node_addrs_spin);
173
174 static struct listen_connection listen_con;
175 static struct sockaddr_storage *dlm_local_addr[DLM_MAX_ADDR_COUNT];
176 static int dlm_local_count;
177 int dlm_allow_conn;
178
179 /* Work queues */
180 static struct workqueue_struct *recv_workqueue;
181 static struct workqueue_struct *send_workqueue;
182
183 static struct hlist_head connection_hash[CONN_HASH_SIZE];
184 static DEFINE_SPINLOCK(connections_lock);
185 DEFINE_STATIC_SRCU(connections_srcu);
186
187 static const struct dlm_proto_ops *dlm_proto_ops;
188
189 static void process_recv_sockets(struct work_struct *work);
190 static void process_send_sockets(struct work_struct *work);
191
192 /* need to held writequeue_lock */
193 static struct writequeue_entry *con_next_wq(struct connection *con)
194 {
195         struct writequeue_entry *e;
196
197         if (list_empty(&con->writequeue))
198                 return NULL;
199
200         e = list_first_entry(&con->writequeue, struct writequeue_entry,
201                              list);
202         if (e->len == 0)
203                 return NULL;
204
205         return e;
206 }
207
208 static struct connection *__find_con(int nodeid, int r)
209 {
210         struct connection *con;
211
212         hlist_for_each_entry_rcu(con, &connection_hash[r], list) {
213                 if (con->nodeid == nodeid)
214                         return con;
215         }
216
217         return NULL;
218 }
219
220 static bool tcp_eof_condition(struct connection *con)
221 {
222         return atomic_read(&con->writequeue_cnt);
223 }
224
225 static int dlm_con_init(struct connection *con, int nodeid)
226 {
227         con->rx_buflen = dlm_config.ci_buffer_size;
228         con->rx_buf = kmalloc(con->rx_buflen, GFP_NOFS);
229         if (!con->rx_buf)
230                 return -ENOMEM;
231
232         con->nodeid = nodeid;
233         mutex_init(&con->sock_mutex);
234         INIT_LIST_HEAD(&con->writequeue);
235         spin_lock_init(&con->writequeue_lock);
236         atomic_set(&con->writequeue_cnt, 0);
237         INIT_WORK(&con->swork, process_send_sockets);
238         INIT_WORK(&con->rwork, process_recv_sockets);
239         init_waitqueue_head(&con->shutdown_wait);
240
241         return 0;
242 }
243
244 /*
245  * If 'allocation' is zero then we don't attempt to create a new
246  * connection structure for this node.
247  */
248 static struct connection *nodeid2con(int nodeid, gfp_t alloc)
249 {
250         struct connection *con, *tmp;
251         int r, ret;
252
253         r = nodeid_hash(nodeid);
254         con = __find_con(nodeid, r);
255         if (con || !alloc)
256                 return con;
257
258         con = kzalloc(sizeof(*con), alloc);
259         if (!con)
260                 return NULL;
261
262         ret = dlm_con_init(con, nodeid);
263         if (ret) {
264                 kfree(con);
265                 return NULL;
266         }
267
268         mutex_init(&con->wq_alloc);
269
270         spin_lock(&connections_lock);
271         /* Because multiple workqueues/threads calls this function it can
272          * race on multiple cpu's. Instead of locking hot path __find_con()
273          * we just check in rare cases of recently added nodes again
274          * under protection of connections_lock. If this is the case we
275          * abort our connection creation and return the existing connection.
276          */
277         tmp = __find_con(nodeid, r);
278         if (tmp) {
279                 spin_unlock(&connections_lock);
280                 kfree(con->rx_buf);
281                 kfree(con);
282                 return tmp;
283         }
284
285         hlist_add_head_rcu(&con->list, &connection_hash[r]);
286         spin_unlock(&connections_lock);
287
288         return con;
289 }
290
291 /* Loop round all connections */
292 static void foreach_conn(void (*conn_func)(struct connection *c))
293 {
294         int i;
295         struct connection *con;
296
297         for (i = 0; i < CONN_HASH_SIZE; i++) {
298                 hlist_for_each_entry_rcu(con, &connection_hash[i], list)
299                         conn_func(con);
300         }
301 }
302
303 static struct dlm_node_addr *find_node_addr(int nodeid)
304 {
305         struct dlm_node_addr *na;
306
307         list_for_each_entry(na, &dlm_node_addrs, list) {
308                 if (na->nodeid == nodeid)
309                         return na;
310         }
311         return NULL;
312 }
313
314 static int addr_compare(const struct sockaddr_storage *x,
315                         const struct sockaddr_storage *y)
316 {
317         switch (x->ss_family) {
318         case AF_INET: {
319                 struct sockaddr_in *sinx = (struct sockaddr_in *)x;
320                 struct sockaddr_in *siny = (struct sockaddr_in *)y;
321                 if (sinx->sin_addr.s_addr != siny->sin_addr.s_addr)
322                         return 0;
323                 if (sinx->sin_port != siny->sin_port)
324                         return 0;
325                 break;
326         }
327         case AF_INET6: {
328                 struct sockaddr_in6 *sinx = (struct sockaddr_in6 *)x;
329                 struct sockaddr_in6 *siny = (struct sockaddr_in6 *)y;
330                 if (!ipv6_addr_equal(&sinx->sin6_addr, &siny->sin6_addr))
331                         return 0;
332                 if (sinx->sin6_port != siny->sin6_port)
333                         return 0;
334                 break;
335         }
336         default:
337                 return 0;
338         }
339         return 1;
340 }
341
342 static int nodeid_to_addr(int nodeid, struct sockaddr_storage *sas_out,
343                           struct sockaddr *sa_out, bool try_new_addr,
344                           unsigned int *mark)
345 {
346         struct sockaddr_storage sas;
347         struct dlm_node_addr *na;
348
349         if (!dlm_local_count)
350                 return -1;
351
352         spin_lock(&dlm_node_addrs_spin);
353         na = find_node_addr(nodeid);
354         if (na && na->addr_count) {
355                 memcpy(&sas, na->addr[na->curr_addr_index],
356                        sizeof(struct sockaddr_storage));
357
358                 if (try_new_addr) {
359                         na->curr_addr_index++;
360                         if (na->curr_addr_index == na->addr_count)
361                                 na->curr_addr_index = 0;
362                 }
363         }
364         spin_unlock(&dlm_node_addrs_spin);
365
366         if (!na)
367                 return -EEXIST;
368
369         if (!na->addr_count)
370                 return -ENOENT;
371
372         *mark = na->mark;
373
374         if (sas_out)
375                 memcpy(sas_out, &sas, sizeof(struct sockaddr_storage));
376
377         if (!sa_out)
378                 return 0;
379
380         if (dlm_local_addr[0]->ss_family == AF_INET) {
381                 struct sockaddr_in *in4  = (struct sockaddr_in *) &sas;
382                 struct sockaddr_in *ret4 = (struct sockaddr_in *) sa_out;
383                 ret4->sin_addr.s_addr = in4->sin_addr.s_addr;
384         } else {
385                 struct sockaddr_in6 *in6  = (struct sockaddr_in6 *) &sas;
386                 struct sockaddr_in6 *ret6 = (struct sockaddr_in6 *) sa_out;
387                 ret6->sin6_addr = in6->sin6_addr;
388         }
389
390         return 0;
391 }
392
393 static int addr_to_nodeid(struct sockaddr_storage *addr, int *nodeid,
394                           unsigned int *mark)
395 {
396         struct dlm_node_addr *na;
397         int rv = -EEXIST;
398         int addr_i;
399
400         spin_lock(&dlm_node_addrs_spin);
401         list_for_each_entry(na, &dlm_node_addrs, list) {
402                 if (!na->addr_count)
403                         continue;
404
405                 for (addr_i = 0; addr_i < na->addr_count; addr_i++) {
406                         if (addr_compare(na->addr[addr_i], addr)) {
407                                 *nodeid = na->nodeid;
408                                 *mark = na->mark;
409                                 rv = 0;
410                                 goto unlock;
411                         }
412                 }
413         }
414 unlock:
415         spin_unlock(&dlm_node_addrs_spin);
416         return rv;
417 }
418
419 /* caller need to held dlm_node_addrs_spin lock */
420 static bool dlm_lowcomms_na_has_addr(const struct dlm_node_addr *na,
421                                      const struct sockaddr_storage *addr)
422 {
423         int i;
424
425         for (i = 0; i < na->addr_count; i++) {
426                 if (addr_compare(na->addr[i], addr))
427                         return true;
428         }
429
430         return false;
431 }
432
433 int dlm_lowcomms_addr(int nodeid, struct sockaddr_storage *addr, int len)
434 {
435         struct sockaddr_storage *new_addr;
436         struct dlm_node_addr *new_node, *na;
437         bool ret;
438
439         new_node = kzalloc(sizeof(struct dlm_node_addr), GFP_NOFS);
440         if (!new_node)
441                 return -ENOMEM;
442
443         new_addr = kzalloc(sizeof(struct sockaddr_storage), GFP_NOFS);
444         if (!new_addr) {
445                 kfree(new_node);
446                 return -ENOMEM;
447         }
448
449         memcpy(new_addr, addr, len);
450
451         spin_lock(&dlm_node_addrs_spin);
452         na = find_node_addr(nodeid);
453         if (!na) {
454                 new_node->nodeid = nodeid;
455                 new_node->addr[0] = new_addr;
456                 new_node->addr_count = 1;
457                 new_node->mark = dlm_config.ci_mark;
458                 list_add(&new_node->list, &dlm_node_addrs);
459                 spin_unlock(&dlm_node_addrs_spin);
460                 return 0;
461         }
462
463         ret = dlm_lowcomms_na_has_addr(na, addr);
464         if (ret) {
465                 spin_unlock(&dlm_node_addrs_spin);
466                 kfree(new_addr);
467                 kfree(new_node);
468                 return -EEXIST;
469         }
470
471         if (na->addr_count >= DLM_MAX_ADDR_COUNT) {
472                 spin_unlock(&dlm_node_addrs_spin);
473                 kfree(new_addr);
474                 kfree(new_node);
475                 return -ENOSPC;
476         }
477
478         na->addr[na->addr_count++] = new_addr;
479         spin_unlock(&dlm_node_addrs_spin);
480         kfree(new_node);
481         return 0;
482 }
483
484 /* Data available on socket or listen socket received a connect */
485 static void lowcomms_data_ready(struct sock *sk)
486 {
487         struct connection *con;
488
489         read_lock_bh(&sk->sk_callback_lock);
490         con = sock2con(sk);
491         if (con && !test_and_set_bit(CF_READ_PENDING, &con->flags))
492                 queue_work(recv_workqueue, &con->rwork);
493         read_unlock_bh(&sk->sk_callback_lock);
494 }
495
496 static void lowcomms_listen_data_ready(struct sock *sk)
497 {
498         if (!dlm_allow_conn)
499                 return;
500
501         queue_work(recv_workqueue, &listen_con.rwork);
502 }
503
504 static void lowcomms_write_space(struct sock *sk)
505 {
506         struct connection *con;
507
508         read_lock_bh(&sk->sk_callback_lock);
509         con = sock2con(sk);
510         if (!con)
511                 goto out;
512
513         if (!test_and_set_bit(CF_CONNECTED, &con->flags)) {
514                 log_print("successful connected to node %d", con->nodeid);
515                 queue_work(send_workqueue, &con->swork);
516                 goto out;
517         }
518
519         clear_bit(SOCK_NOSPACE, &con->sock->flags);
520
521         if (test_and_clear_bit(CF_APP_LIMITED, &con->flags)) {
522                 con->sock->sk->sk_write_pending--;
523                 clear_bit(SOCKWQ_ASYNC_NOSPACE, &con->sock->flags);
524         }
525
526         queue_work(send_workqueue, &con->swork);
527 out:
528         read_unlock_bh(&sk->sk_callback_lock);
529 }
530
531 static inline void lowcomms_connect_sock(struct connection *con)
532 {
533         if (test_bit(CF_CLOSE, &con->flags))
534                 return;
535         queue_work(send_workqueue, &con->swork);
536         cond_resched();
537 }
538
539 static void lowcomms_state_change(struct sock *sk)
540 {
541         /* SCTP layer is not calling sk_data_ready when the connection
542          * is done, so we catch the signal through here. Also, it
543          * doesn't switch socket state when entering shutdown, so we
544          * skip the write in that case.
545          */
546         if (sk->sk_shutdown) {
547                 if (sk->sk_shutdown == RCV_SHUTDOWN)
548                         lowcomms_data_ready(sk);
549         } else if (sk->sk_state == TCP_ESTABLISHED) {
550                 lowcomms_write_space(sk);
551         }
552 }
553
554 int dlm_lowcomms_connect_node(int nodeid)
555 {
556         struct connection *con;
557         int idx;
558
559         if (nodeid == dlm_our_nodeid())
560                 return 0;
561
562         idx = srcu_read_lock(&connections_srcu);
563         con = nodeid2con(nodeid, GFP_NOFS);
564         if (!con) {
565                 srcu_read_unlock(&connections_srcu, idx);
566                 return -ENOMEM;
567         }
568
569         lowcomms_connect_sock(con);
570         srcu_read_unlock(&connections_srcu, idx);
571
572         return 0;
573 }
574
575 int dlm_lowcomms_nodes_set_mark(int nodeid, unsigned int mark)
576 {
577         struct dlm_node_addr *na;
578
579         spin_lock(&dlm_node_addrs_spin);
580         na = find_node_addr(nodeid);
581         if (!na) {
582                 spin_unlock(&dlm_node_addrs_spin);
583                 return -ENOENT;
584         }
585
586         na->mark = mark;
587         spin_unlock(&dlm_node_addrs_spin);
588
589         return 0;
590 }
591
592 static void lowcomms_error_report(struct sock *sk)
593 {
594         struct connection *con;
595         void (*orig_report)(struct sock *) = NULL;
596         struct inet_sock *inet;
597
598         read_lock_bh(&sk->sk_callback_lock);
599         con = sock2con(sk);
600         if (con == NULL)
601                 goto out;
602
603         orig_report = listen_sock.sk_error_report;
604
605         inet = inet_sk(sk);
606         switch (sk->sk_family) {
607         case AF_INET:
608                 printk_ratelimited(KERN_ERR "dlm: node %d: socket error "
609                                    "sending to node %d at %pI4, dport %d, "
610                                    "sk_err=%d/%d\n", dlm_our_nodeid(),
611                                    con->nodeid, &inet->inet_daddr,
612                                    ntohs(inet->inet_dport), sk->sk_err,
613                                    sk->sk_err_soft);
614                 break;
615 #if IS_ENABLED(CONFIG_IPV6)
616         case AF_INET6:
617                 printk_ratelimited(KERN_ERR "dlm: node %d: socket error "
618                                    "sending to node %d at %pI6c, "
619                                    "dport %d, sk_err=%d/%d\n", dlm_our_nodeid(),
620                                    con->nodeid, &sk->sk_v6_daddr,
621                                    ntohs(inet->inet_dport), sk->sk_err,
622                                    sk->sk_err_soft);
623                 break;
624 #endif
625         default:
626                 printk_ratelimited(KERN_ERR "dlm: node %d: socket error "
627                                    "invalid socket family %d set, "
628                                    "sk_err=%d/%d\n", dlm_our_nodeid(),
629                                    sk->sk_family, sk->sk_err, sk->sk_err_soft);
630                 goto out;
631         }
632
633         /* below sendcon only handling */
634         if (test_bit(CF_IS_OTHERCON, &con->flags))
635                 con = con->sendcon;
636
637         switch (sk->sk_err) {
638         case ECONNREFUSED:
639                 set_bit(CF_DELAY_CONNECT, &con->flags);
640                 break;
641         default:
642                 break;
643         }
644
645         if (!test_and_set_bit(CF_RECONNECT, &con->flags))
646                 queue_work(send_workqueue, &con->swork);
647
648 out:
649         read_unlock_bh(&sk->sk_callback_lock);
650         if (orig_report)
651                 orig_report(sk);
652 }
653
654 /* Note: sk_callback_lock must be locked before calling this function. */
655 static void save_listen_callbacks(struct socket *sock)
656 {
657         struct sock *sk = sock->sk;
658
659         listen_sock.sk_data_ready = sk->sk_data_ready;
660         listen_sock.sk_state_change = sk->sk_state_change;
661         listen_sock.sk_write_space = sk->sk_write_space;
662         listen_sock.sk_error_report = sk->sk_error_report;
663 }
664
665 static void restore_callbacks(struct socket *sock)
666 {
667         struct sock *sk = sock->sk;
668
669         write_lock_bh(&sk->sk_callback_lock);
670         sk->sk_user_data = NULL;
671         sk->sk_data_ready = listen_sock.sk_data_ready;
672         sk->sk_state_change = listen_sock.sk_state_change;
673         sk->sk_write_space = listen_sock.sk_write_space;
674         sk->sk_error_report = listen_sock.sk_error_report;
675         write_unlock_bh(&sk->sk_callback_lock);
676 }
677
678 static void add_listen_sock(struct socket *sock, struct listen_connection *con)
679 {
680         struct sock *sk = sock->sk;
681
682         write_lock_bh(&sk->sk_callback_lock);
683         save_listen_callbacks(sock);
684         con->sock = sock;
685
686         sk->sk_user_data = con;
687         sk->sk_allocation = GFP_NOFS;
688         /* Install a data_ready callback */
689         sk->sk_data_ready = lowcomms_listen_data_ready;
690         write_unlock_bh(&sk->sk_callback_lock);
691 }
692
693 /* Make a socket active */
694 static void add_sock(struct socket *sock, struct connection *con)
695 {
696         struct sock *sk = sock->sk;
697
698         write_lock_bh(&sk->sk_callback_lock);
699         con->sock = sock;
700
701         sk->sk_user_data = con;
702         /* Install a data_ready callback */
703         sk->sk_data_ready = lowcomms_data_ready;
704         sk->sk_write_space = lowcomms_write_space;
705         sk->sk_state_change = lowcomms_state_change;
706         sk->sk_allocation = GFP_NOFS;
707         sk->sk_error_report = lowcomms_error_report;
708         write_unlock_bh(&sk->sk_callback_lock);
709 }
710
711 /* Add the port number to an IPv6 or 4 sockaddr and return the address
712    length */
713 static void make_sockaddr(struct sockaddr_storage *saddr, uint16_t port,
714                           int *addr_len)
715 {
716         saddr->ss_family =  dlm_local_addr[0]->ss_family;
717         if (saddr->ss_family == AF_INET) {
718                 struct sockaddr_in *in4_addr = (struct sockaddr_in *)saddr;
719                 in4_addr->sin_port = cpu_to_be16(port);
720                 *addr_len = sizeof(struct sockaddr_in);
721                 memset(&in4_addr->sin_zero, 0, sizeof(in4_addr->sin_zero));
722         } else {
723                 struct sockaddr_in6 *in6_addr = (struct sockaddr_in6 *)saddr;
724                 in6_addr->sin6_port = cpu_to_be16(port);
725                 *addr_len = sizeof(struct sockaddr_in6);
726         }
727         memset((char *)saddr + *addr_len, 0, sizeof(struct sockaddr_storage) - *addr_len);
728 }
729
730 static void dlm_page_release(struct kref *kref)
731 {
732         struct writequeue_entry *e = container_of(kref, struct writequeue_entry,
733                                                   ref);
734
735         __free_page(e->page);
736         kfree(e);
737 }
738
739 static void dlm_msg_release(struct kref *kref)
740 {
741         struct dlm_msg *msg = container_of(kref, struct dlm_msg, ref);
742
743         kref_put(&msg->entry->ref, dlm_page_release);
744         kfree(msg);
745 }
746
747 static void free_entry(struct writequeue_entry *e)
748 {
749         struct dlm_msg *msg, *tmp;
750
751         list_for_each_entry_safe(msg, tmp, &e->msgs, list) {
752                 if (msg->orig_msg) {
753                         msg->orig_msg->retransmit = false;
754                         kref_put(&msg->orig_msg->ref, dlm_msg_release);
755                 }
756
757                 list_del(&msg->list);
758                 kref_put(&msg->ref, dlm_msg_release);
759         }
760
761         list_del(&e->list);
762         atomic_dec(&e->con->writequeue_cnt);
763         kref_put(&e->ref, dlm_page_release);
764 }
765
766 static void dlm_close_sock(struct socket **sock)
767 {
768         if (*sock) {
769                 restore_callbacks(*sock);
770                 sock_release(*sock);
771                 *sock = NULL;
772         }
773 }
774
775 /* Close a remote connection and tidy up */
776 static void close_connection(struct connection *con, bool and_other,
777                              bool tx, bool rx)
778 {
779         bool closing = test_and_set_bit(CF_CLOSING, &con->flags);
780         struct writequeue_entry *e;
781
782         if (tx && !closing && cancel_work_sync(&con->swork)) {
783                 log_print("canceled swork for node %d", con->nodeid);
784                 clear_bit(CF_WRITE_PENDING, &con->flags);
785         }
786         if (rx && !closing && cancel_work_sync(&con->rwork)) {
787                 log_print("canceled rwork for node %d", con->nodeid);
788                 clear_bit(CF_READ_PENDING, &con->flags);
789         }
790
791         mutex_lock(&con->sock_mutex);
792         dlm_close_sock(&con->sock);
793
794         if (con->othercon && and_other) {
795                 /* Will only re-enter once. */
796                 close_connection(con->othercon, false, tx, rx);
797         }
798
799         /* if we send a writequeue entry only a half way, we drop the
800          * whole entry because reconnection and that we not start of the
801          * middle of a msg which will confuse the other end.
802          *
803          * we can always drop messages because retransmits, but what we
804          * cannot allow is to transmit half messages which may be processed
805          * at the other side.
806          *
807          * our policy is to start on a clean state when disconnects, we don't
808          * know what's send/received on transport layer in this case.
809          */
810         spin_lock(&con->writequeue_lock);
811         if (!list_empty(&con->writequeue)) {
812                 e = list_first_entry(&con->writequeue, struct writequeue_entry,
813                                      list);
814                 if (e->dirty)
815                         free_entry(e);
816         }
817         spin_unlock(&con->writequeue_lock);
818
819         con->rx_leftover = 0;
820         con->retries = 0;
821         clear_bit(CF_APP_LIMITED, &con->flags);
822         clear_bit(CF_CONNECTED, &con->flags);
823         clear_bit(CF_DELAY_CONNECT, &con->flags);
824         clear_bit(CF_RECONNECT, &con->flags);
825         clear_bit(CF_EOF, &con->flags);
826         mutex_unlock(&con->sock_mutex);
827         clear_bit(CF_CLOSING, &con->flags);
828 }
829
830 static void shutdown_connection(struct connection *con)
831 {
832         int ret;
833
834         flush_work(&con->swork);
835
836         mutex_lock(&con->sock_mutex);
837         /* nothing to shutdown */
838         if (!con->sock) {
839                 mutex_unlock(&con->sock_mutex);
840                 return;
841         }
842
843         set_bit(CF_SHUTDOWN, &con->flags);
844         ret = kernel_sock_shutdown(con->sock, SHUT_WR);
845         mutex_unlock(&con->sock_mutex);
846         if (ret) {
847                 log_print("Connection %p failed to shutdown: %d will force close",
848                           con, ret);
849                 goto force_close;
850         } else {
851                 ret = wait_event_timeout(con->shutdown_wait,
852                                          !test_bit(CF_SHUTDOWN, &con->flags),
853                                          DLM_SHUTDOWN_WAIT_TIMEOUT);
854                 if (ret == 0) {
855                         log_print("Connection %p shutdown timed out, will force close",
856                                   con);
857                         goto force_close;
858                 }
859         }
860
861         return;
862
863 force_close:
864         clear_bit(CF_SHUTDOWN, &con->flags);
865         close_connection(con, false, true, true);
866 }
867
868 static void dlm_tcp_shutdown(struct connection *con)
869 {
870         if (con->othercon)
871                 shutdown_connection(con->othercon);
872         shutdown_connection(con);
873 }
874
875 static int con_realloc_receive_buf(struct connection *con, int newlen)
876 {
877         unsigned char *newbuf;
878
879         newbuf = kmalloc(newlen, GFP_NOFS);
880         if (!newbuf)
881                 return -ENOMEM;
882
883         /* copy any leftover from last receive */
884         if (con->rx_leftover)
885                 memmove(newbuf, con->rx_buf, con->rx_leftover);
886
887         /* swap to new buffer space */
888         kfree(con->rx_buf);
889         con->rx_buflen = newlen;
890         con->rx_buf = newbuf;
891
892         return 0;
893 }
894
895 /* Data received from remote end */
896 static int receive_from_sock(struct connection *con)
897 {
898         struct msghdr msg;
899         struct kvec iov;
900         int ret, buflen;
901
902         mutex_lock(&con->sock_mutex);
903
904         if (con->sock == NULL) {
905                 ret = -EAGAIN;
906                 goto out_close;
907         }
908
909         /* realloc if we get new buffer size to read out */
910         buflen = dlm_config.ci_buffer_size;
911         if (con->rx_buflen != buflen && con->rx_leftover <= buflen) {
912                 ret = con_realloc_receive_buf(con, buflen);
913                 if (ret < 0)
914                         goto out_resched;
915         }
916
917         for (;;) {
918                 /* calculate new buffer parameter regarding last receive and
919                  * possible leftover bytes
920                  */
921                 iov.iov_base = con->rx_buf + con->rx_leftover;
922                 iov.iov_len = con->rx_buflen - con->rx_leftover;
923
924                 memset(&msg, 0, sizeof(msg));
925                 msg.msg_flags = MSG_DONTWAIT | MSG_NOSIGNAL;
926                 ret = kernel_recvmsg(con->sock, &msg, &iov, 1, iov.iov_len,
927                                      msg.msg_flags);
928                 if (ret == -EAGAIN)
929                         break;
930                 else if (ret <= 0)
931                         goto out_close;
932
933                 /* new buflen according readed bytes and leftover from last receive */
934                 buflen = ret + con->rx_leftover;
935                 ret = dlm_process_incoming_buffer(con->nodeid, con->rx_buf, buflen);
936                 if (ret < 0)
937                         goto out_close;
938
939                 /* calculate leftover bytes from process and put it into begin of
940                  * the receive buffer, so next receive we have the full message
941                  * at the start address of the receive buffer.
942                  */
943                 con->rx_leftover = buflen - ret;
944                 if (con->rx_leftover) {
945                         memmove(con->rx_buf, con->rx_buf + ret,
946                                 con->rx_leftover);
947                 }
948         }
949
950         dlm_midcomms_receive_done(con->nodeid);
951         mutex_unlock(&con->sock_mutex);
952         return 0;
953
954 out_resched:
955         if (!test_and_set_bit(CF_READ_PENDING, &con->flags))
956                 queue_work(recv_workqueue, &con->rwork);
957         mutex_unlock(&con->sock_mutex);
958         return -EAGAIN;
959
960 out_close:
961         if (ret == 0) {
962                 log_print("connection %p got EOF from %d",
963                           con, con->nodeid);
964
965                 if (dlm_proto_ops->eof_condition &&
966                     dlm_proto_ops->eof_condition(con)) {
967                         set_bit(CF_EOF, &con->flags);
968                         mutex_unlock(&con->sock_mutex);
969                 } else {
970                         mutex_unlock(&con->sock_mutex);
971                         close_connection(con, false, true, false);
972
973                         /* handling for tcp shutdown */
974                         clear_bit(CF_SHUTDOWN, &con->flags);
975                         wake_up(&con->shutdown_wait);
976                 }
977
978                 /* signal to breaking receive worker */
979                 ret = -1;
980         } else {
981                 mutex_unlock(&con->sock_mutex);
982         }
983         return ret;
984 }
985
986 /* Listening socket is busy, accept a connection */
987 static int accept_from_sock(struct listen_connection *con)
988 {
989         int result;
990         struct sockaddr_storage peeraddr;
991         struct socket *newsock;
992         int len, idx;
993         int nodeid;
994         struct connection *newcon;
995         struct connection *addcon;
996         unsigned int mark;
997
998         if (!con->sock)
999                 return -ENOTCONN;
1000
1001         result = kernel_accept(con->sock, &newsock, O_NONBLOCK);
1002         if (result < 0)
1003                 goto accept_err;
1004
1005         /* Get the connected socket's peer */
1006         memset(&peeraddr, 0, sizeof(peeraddr));
1007         len = newsock->ops->getname(newsock, (struct sockaddr *)&peeraddr, 2);
1008         if (len < 0) {
1009                 result = -ECONNABORTED;
1010                 goto accept_err;
1011         }
1012
1013         /* Get the new node's NODEID */
1014         make_sockaddr(&peeraddr, 0, &len);
1015         if (addr_to_nodeid(&peeraddr, &nodeid, &mark)) {
1016                 unsigned char *b=(unsigned char *)&peeraddr;
1017                 log_print("connect from non cluster node");
1018                 print_hex_dump_bytes("ss: ", DUMP_PREFIX_NONE, 
1019                                      b, sizeof(struct sockaddr_storage));
1020                 sock_release(newsock);
1021                 return -1;
1022         }
1023
1024         log_print("got connection from %d", nodeid);
1025
1026         /*  Check to see if we already have a connection to this node. This
1027          *  could happen if the two nodes initiate a connection at roughly
1028          *  the same time and the connections cross on the wire.
1029          *  In this case we store the incoming one in "othercon"
1030          */
1031         idx = srcu_read_lock(&connections_srcu);
1032         newcon = nodeid2con(nodeid, GFP_NOFS);
1033         if (!newcon) {
1034                 srcu_read_unlock(&connections_srcu, idx);
1035                 result = -ENOMEM;
1036                 goto accept_err;
1037         }
1038
1039         sock_set_mark(newsock->sk, mark);
1040
1041         mutex_lock(&newcon->sock_mutex);
1042         if (newcon->sock) {
1043                 struct connection *othercon = newcon->othercon;
1044
1045                 if (!othercon) {
1046                         othercon = kzalloc(sizeof(*othercon), GFP_NOFS);
1047                         if (!othercon) {
1048                                 log_print("failed to allocate incoming socket");
1049                                 mutex_unlock(&newcon->sock_mutex);
1050                                 srcu_read_unlock(&connections_srcu, idx);
1051                                 result = -ENOMEM;
1052                                 goto accept_err;
1053                         }
1054
1055                         result = dlm_con_init(othercon, nodeid);
1056                         if (result < 0) {
1057                                 kfree(othercon);
1058                                 mutex_unlock(&newcon->sock_mutex);
1059                                 srcu_read_unlock(&connections_srcu, idx);
1060                                 goto accept_err;
1061                         }
1062
1063                         lockdep_set_subclass(&othercon->sock_mutex, 1);
1064                         set_bit(CF_IS_OTHERCON, &othercon->flags);
1065                         newcon->othercon = othercon;
1066                         othercon->sendcon = newcon;
1067                 } else {
1068                         /* close other sock con if we have something new */
1069                         close_connection(othercon, false, true, false);
1070                 }
1071
1072                 mutex_lock(&othercon->sock_mutex);
1073                 add_sock(newsock, othercon);
1074                 addcon = othercon;
1075                 mutex_unlock(&othercon->sock_mutex);
1076         }
1077         else {
1078                 /* accept copies the sk after we've saved the callbacks, so we
1079                    don't want to save them a second time or comm errors will
1080                    result in calling sk_error_report recursively. */
1081                 add_sock(newsock, newcon);
1082                 addcon = newcon;
1083         }
1084
1085         set_bit(CF_CONNECTED, &addcon->flags);
1086         mutex_unlock(&newcon->sock_mutex);
1087
1088         /*
1089          * Add it to the active queue in case we got data
1090          * between processing the accept adding the socket
1091          * to the read_sockets list
1092          */
1093         if (!test_and_set_bit(CF_READ_PENDING, &addcon->flags))
1094                 queue_work(recv_workqueue, &addcon->rwork);
1095
1096         srcu_read_unlock(&connections_srcu, idx);
1097
1098         return 0;
1099
1100 accept_err:
1101         if (newsock)
1102                 sock_release(newsock);
1103
1104         if (result != -EAGAIN)
1105                 log_print("error accepting connection from node: %d", result);
1106         return result;
1107 }
1108
1109 /*
1110  * writequeue_entry_complete - try to delete and free write queue entry
1111  * @e: write queue entry to try to delete
1112  * @completed: bytes completed
1113  *
1114  * writequeue_lock must be held.
1115  */
1116 static void writequeue_entry_complete(struct writequeue_entry *e, int completed)
1117 {
1118         e->offset += completed;
1119         e->len -= completed;
1120         /* signal that page was half way transmitted */
1121         e->dirty = true;
1122
1123         if (e->len == 0 && e->users == 0)
1124                 free_entry(e);
1125 }
1126
1127 /*
1128  * sctp_bind_addrs - bind a SCTP socket to all our addresses
1129  */
1130 static int sctp_bind_addrs(struct socket *sock, uint16_t port)
1131 {
1132         struct sockaddr_storage localaddr;
1133         struct sockaddr *addr = (struct sockaddr *)&localaddr;
1134         int i, addr_len, result = 0;
1135
1136         for (i = 0; i < dlm_local_count; i++) {
1137                 memcpy(&localaddr, dlm_local_addr[i], sizeof(localaddr));
1138                 make_sockaddr(&localaddr, port, &addr_len);
1139
1140                 if (!i)
1141                         result = kernel_bind(sock, addr, addr_len);
1142                 else
1143                         result = sock_bind_add(sock->sk, addr, addr_len);
1144
1145                 if (result < 0) {
1146                         log_print("Can't bind to %d addr number %d, %d.\n",
1147                                   port, i + 1, result);
1148                         break;
1149                 }
1150         }
1151         return result;
1152 }
1153
1154 /* Get local addresses */
1155 static void init_local(void)
1156 {
1157         struct sockaddr_storage sas, *addr;
1158         int i;
1159
1160         dlm_local_count = 0;
1161         for (i = 0; i < DLM_MAX_ADDR_COUNT; i++) {
1162                 if (dlm_our_addr(&sas, i))
1163                         break;
1164
1165                 addr = kmemdup(&sas, sizeof(*addr), GFP_NOFS);
1166                 if (!addr)
1167                         break;
1168                 dlm_local_addr[dlm_local_count++] = addr;
1169         }
1170 }
1171
1172 static void deinit_local(void)
1173 {
1174         int i;
1175
1176         for (i = 0; i < dlm_local_count; i++)
1177                 kfree(dlm_local_addr[i]);
1178 }
1179
1180 static struct writequeue_entry *new_writequeue_entry(struct connection *con,
1181                                                      gfp_t allocation)
1182 {
1183         struct writequeue_entry *entry;
1184
1185         entry = kzalloc(sizeof(*entry), allocation);
1186         if (!entry)
1187                 return NULL;
1188
1189         entry->page = alloc_page(allocation | __GFP_ZERO);
1190         if (!entry->page) {
1191                 kfree(entry);
1192                 return NULL;
1193         }
1194
1195         entry->con = con;
1196         entry->users = 1;
1197         kref_init(&entry->ref);
1198         INIT_LIST_HEAD(&entry->msgs);
1199
1200         return entry;
1201 }
1202
1203 static struct writequeue_entry *new_wq_entry(struct connection *con, int len,
1204                                              gfp_t allocation, char **ppc,
1205                                              void (*cb)(struct dlm_mhandle *mh),
1206                                              struct dlm_mhandle *mh)
1207 {
1208         struct writequeue_entry *e;
1209
1210         spin_lock(&con->writequeue_lock);
1211         if (!list_empty(&con->writequeue)) {
1212                 e = list_last_entry(&con->writequeue, struct writequeue_entry, list);
1213                 if (DLM_WQ_REMAIN_BYTES(e) >= len) {
1214                         kref_get(&e->ref);
1215
1216                         *ppc = page_address(e->page) + e->end;
1217                         if (cb)
1218                                 cb(mh);
1219
1220                         e->end += len;
1221                         e->users++;
1222                         spin_unlock(&con->writequeue_lock);
1223
1224                         return e;
1225                 }
1226         }
1227         spin_unlock(&con->writequeue_lock);
1228
1229         e = new_writequeue_entry(con, allocation);
1230         if (!e)
1231                 return NULL;
1232
1233         kref_get(&e->ref);
1234         *ppc = page_address(e->page);
1235         e->end += len;
1236         atomic_inc(&con->writequeue_cnt);
1237
1238         spin_lock(&con->writequeue_lock);
1239         if (cb)
1240                 cb(mh);
1241
1242         list_add_tail(&e->list, &con->writequeue);
1243         spin_unlock(&con->writequeue_lock);
1244
1245         return e;
1246 };
1247
1248 static struct dlm_msg *dlm_lowcomms_new_msg_con(struct connection *con, int len,
1249                                                 gfp_t allocation, char **ppc,
1250                                                 void (*cb)(struct dlm_mhandle *mh),
1251                                                 struct dlm_mhandle *mh)
1252 {
1253         struct writequeue_entry *e;
1254         struct dlm_msg *msg;
1255         bool sleepable;
1256
1257         msg = kzalloc(sizeof(*msg), allocation);
1258         if (!msg)
1259                 return NULL;
1260
1261         /* this mutex is being used as a wait to avoid multiple "fast"
1262          * new writequeue page list entry allocs in new_wq_entry in
1263          * normal operation which is sleepable context. Without it
1264          * we could end in multiple writequeue entries with one
1265          * dlm message because multiple callers were waiting at
1266          * the writequeue_lock in new_wq_entry().
1267          */
1268         sleepable = gfpflags_normal_context(allocation);
1269         if (sleepable)
1270                 mutex_lock(&con->wq_alloc);
1271
1272         kref_init(&msg->ref);
1273
1274         e = new_wq_entry(con, len, allocation, ppc, cb, mh);
1275         if (!e) {
1276                 if (sleepable)
1277                         mutex_unlock(&con->wq_alloc);
1278
1279                 kfree(msg);
1280                 return NULL;
1281         }
1282
1283         if (sleepable)
1284                 mutex_unlock(&con->wq_alloc);
1285
1286         msg->ppc = *ppc;
1287         msg->len = len;
1288         msg->entry = e;
1289
1290         return msg;
1291 }
1292
1293 struct dlm_msg *dlm_lowcomms_new_msg(int nodeid, int len, gfp_t allocation,
1294                                      char **ppc, void (*cb)(struct dlm_mhandle *mh),
1295                                      struct dlm_mhandle *mh)
1296 {
1297         struct connection *con;
1298         struct dlm_msg *msg;
1299         int idx;
1300
1301         if (len > DLM_MAX_SOCKET_BUFSIZE ||
1302             len < sizeof(struct dlm_header)) {
1303                 BUILD_BUG_ON(PAGE_SIZE < DLM_MAX_SOCKET_BUFSIZE);
1304                 log_print("failed to allocate a buffer of size %d", len);
1305                 WARN_ON(1);
1306                 return NULL;
1307         }
1308
1309         idx = srcu_read_lock(&connections_srcu);
1310         con = nodeid2con(nodeid, allocation);
1311         if (!con) {
1312                 srcu_read_unlock(&connections_srcu, idx);
1313                 return NULL;
1314         }
1315
1316         msg = dlm_lowcomms_new_msg_con(con, len, allocation, ppc, cb, mh);
1317         if (!msg) {
1318                 srcu_read_unlock(&connections_srcu, idx);
1319                 return NULL;
1320         }
1321
1322         /* for dlm_lowcomms_commit_msg() */
1323         kref_get(&msg->ref);
1324         /* we assume if successful commit must called */
1325         msg->idx = idx;
1326         return msg;
1327 }
1328
1329 static void _dlm_lowcomms_commit_msg(struct dlm_msg *msg)
1330 {
1331         struct writequeue_entry *e = msg->entry;
1332         struct connection *con = e->con;
1333         int users;
1334
1335         spin_lock(&con->writequeue_lock);
1336         kref_get(&msg->ref);
1337         list_add(&msg->list, &e->msgs);
1338
1339         users = --e->users;
1340         if (users)
1341                 goto out;
1342
1343         e->len = DLM_WQ_LENGTH_BYTES(e);
1344         spin_unlock(&con->writequeue_lock);
1345
1346         queue_work(send_workqueue, &con->swork);
1347         return;
1348
1349 out:
1350         spin_unlock(&con->writequeue_lock);
1351         return;
1352 }
1353
1354 void dlm_lowcomms_commit_msg(struct dlm_msg *msg)
1355 {
1356         _dlm_lowcomms_commit_msg(msg);
1357         srcu_read_unlock(&connections_srcu, msg->idx);
1358         /* because dlm_lowcomms_new_msg() */
1359         kref_put(&msg->ref, dlm_msg_release);
1360 }
1361
1362 void dlm_lowcomms_put_msg(struct dlm_msg *msg)
1363 {
1364         kref_put(&msg->ref, dlm_msg_release);
1365 }
1366
1367 /* does not held connections_srcu, usage workqueue only */
1368 int dlm_lowcomms_resend_msg(struct dlm_msg *msg)
1369 {
1370         struct dlm_msg *msg_resend;
1371         char *ppc;
1372
1373         if (msg->retransmit)
1374                 return 1;
1375
1376         msg_resend = dlm_lowcomms_new_msg_con(msg->entry->con, msg->len,
1377                                               GFP_ATOMIC, &ppc, NULL, NULL);
1378         if (!msg_resend)
1379                 return -ENOMEM;
1380
1381         msg->retransmit = true;
1382         kref_get(&msg->ref);
1383         msg_resend->orig_msg = msg;
1384
1385         memcpy(ppc, msg->ppc, msg->len);
1386         _dlm_lowcomms_commit_msg(msg_resend);
1387         dlm_lowcomms_put_msg(msg_resend);
1388
1389         return 0;
1390 }
1391
1392 /* Send a message */
1393 static void send_to_sock(struct connection *con)
1394 {
1395         const int msg_flags = MSG_DONTWAIT | MSG_NOSIGNAL;
1396         struct writequeue_entry *e;
1397         int len, offset, ret;
1398         int count = 0;
1399
1400         mutex_lock(&con->sock_mutex);
1401         if (con->sock == NULL)
1402                 goto out_connect;
1403
1404         spin_lock(&con->writequeue_lock);
1405         for (;;) {
1406                 e = con_next_wq(con);
1407                 if (!e)
1408                         break;
1409
1410                 e = list_first_entry(&con->writequeue, struct writequeue_entry, list);
1411                 len = e->len;
1412                 offset = e->offset;
1413                 BUG_ON(len == 0 && e->users == 0);
1414                 spin_unlock(&con->writequeue_lock);
1415
1416                 ret = kernel_sendpage(con->sock, e->page, offset, len,
1417                                       msg_flags);
1418                 if (ret == -EAGAIN || ret == 0) {
1419                         if (ret == -EAGAIN &&
1420                             test_bit(SOCKWQ_ASYNC_NOSPACE, &con->sock->flags) &&
1421                             !test_and_set_bit(CF_APP_LIMITED, &con->flags)) {
1422                                 /* Notify TCP that we're limited by the
1423                                  * application window size.
1424                                  */
1425                                 set_bit(SOCK_NOSPACE, &con->sock->flags);
1426                                 con->sock->sk->sk_write_pending++;
1427                         }
1428                         cond_resched();
1429                         goto out;
1430                 } else if (ret < 0)
1431                         goto out;
1432
1433                 /* Don't starve people filling buffers */
1434                 if (++count >= MAX_SEND_MSG_COUNT) {
1435                         cond_resched();
1436                         count = 0;
1437                 }
1438
1439                 spin_lock(&con->writequeue_lock);
1440                 writequeue_entry_complete(e, ret);
1441         }
1442         spin_unlock(&con->writequeue_lock);
1443
1444         /* close if we got EOF */
1445         if (test_and_clear_bit(CF_EOF, &con->flags)) {
1446                 mutex_unlock(&con->sock_mutex);
1447                 close_connection(con, false, false, true);
1448
1449                 /* handling for tcp shutdown */
1450                 clear_bit(CF_SHUTDOWN, &con->flags);
1451                 wake_up(&con->shutdown_wait);
1452         } else {
1453                 mutex_unlock(&con->sock_mutex);
1454         }
1455
1456         return;
1457
1458 out:
1459         mutex_unlock(&con->sock_mutex);
1460         return;
1461
1462 out_connect:
1463         mutex_unlock(&con->sock_mutex);
1464         queue_work(send_workqueue, &con->swork);
1465         cond_resched();
1466 }
1467
1468 static void clean_one_writequeue(struct connection *con)
1469 {
1470         struct writequeue_entry *e, *safe;
1471
1472         spin_lock(&con->writequeue_lock);
1473         list_for_each_entry_safe(e, safe, &con->writequeue, list) {
1474                 free_entry(e);
1475         }
1476         spin_unlock(&con->writequeue_lock);
1477 }
1478
1479 /* Called from recovery when it knows that a node has
1480    left the cluster */
1481 int dlm_lowcomms_close(int nodeid)
1482 {
1483         struct connection *con;
1484         struct dlm_node_addr *na;
1485         int idx;
1486
1487         log_print("closing connection to node %d", nodeid);
1488         idx = srcu_read_lock(&connections_srcu);
1489         con = nodeid2con(nodeid, 0);
1490         if (con) {
1491                 set_bit(CF_CLOSE, &con->flags);
1492                 close_connection(con, true, true, true);
1493                 clean_one_writequeue(con);
1494                 if (con->othercon)
1495                         clean_one_writequeue(con->othercon);
1496         }
1497         srcu_read_unlock(&connections_srcu, idx);
1498
1499         spin_lock(&dlm_node_addrs_spin);
1500         na = find_node_addr(nodeid);
1501         if (na) {
1502                 list_del(&na->list);
1503                 while (na->addr_count--)
1504                         kfree(na->addr[na->addr_count]);
1505                 kfree(na);
1506         }
1507         spin_unlock(&dlm_node_addrs_spin);
1508
1509         return 0;
1510 }
1511
1512 /* Receive workqueue function */
1513 static void process_recv_sockets(struct work_struct *work)
1514 {
1515         struct connection *con = container_of(work, struct connection, rwork);
1516
1517         clear_bit(CF_READ_PENDING, &con->flags);
1518         receive_from_sock(con);
1519 }
1520
1521 static void process_listen_recv_socket(struct work_struct *work)
1522 {
1523         int ret;
1524
1525         do {
1526                 ret = accept_from_sock(&listen_con);
1527         } while (!ret);
1528 }
1529
1530 static void dlm_connect(struct connection *con)
1531 {
1532         struct sockaddr_storage addr;
1533         int result, addr_len;
1534         struct socket *sock;
1535         unsigned int mark;
1536
1537         /* Some odd races can cause double-connects, ignore them */
1538         if (con->retries++ > MAX_CONNECT_RETRIES)
1539                 return;
1540
1541         if (con->sock) {
1542                 log_print("node %d already connected.", con->nodeid);
1543                 return;
1544         }
1545
1546         memset(&addr, 0, sizeof(addr));
1547         result = nodeid_to_addr(con->nodeid, &addr, NULL,
1548                                 dlm_proto_ops->try_new_addr, &mark);
1549         if (result < 0) {
1550                 log_print("no address for nodeid %d", con->nodeid);
1551                 return;
1552         }
1553
1554         /* Create a socket to communicate with */
1555         result = sock_create_kern(&init_net, dlm_local_addr[0]->ss_family,
1556                                   SOCK_STREAM, dlm_proto_ops->proto, &sock);
1557         if (result < 0)
1558                 goto socket_err;
1559
1560         sock_set_mark(sock->sk, mark);
1561         dlm_proto_ops->sockopts(sock);
1562
1563         add_sock(sock, con);
1564
1565         result = dlm_proto_ops->bind(sock);
1566         if (result < 0)
1567                 goto add_sock_err;
1568
1569         log_print_ratelimited("connecting to %d", con->nodeid);
1570         make_sockaddr(&addr, dlm_config.ci_tcp_port, &addr_len);
1571         result = dlm_proto_ops->connect(con, sock, (struct sockaddr *)&addr,
1572                                         addr_len);
1573         if (result < 0)
1574                 goto add_sock_err;
1575
1576         return;
1577
1578 add_sock_err:
1579         dlm_close_sock(&con->sock);
1580
1581 socket_err:
1582         /*
1583          * Some errors are fatal and this list might need adjusting. For other
1584          * errors we try again until the max number of retries is reached.
1585          */
1586         if (result != -EHOSTUNREACH &&
1587             result != -ENETUNREACH &&
1588             result != -ENETDOWN &&
1589             result != -EINVAL &&
1590             result != -EPROTONOSUPPORT) {
1591                 log_print("connect %d try %d error %d", con->nodeid,
1592                           con->retries, result);
1593                 msleep(1000);
1594                 lowcomms_connect_sock(con);
1595         }
1596 }
1597
1598 /* Send workqueue function */
1599 static void process_send_sockets(struct work_struct *work)
1600 {
1601         struct connection *con = container_of(work, struct connection, swork);
1602
1603         WARN_ON(test_bit(CF_IS_OTHERCON, &con->flags));
1604
1605         clear_bit(CF_WRITE_PENDING, &con->flags);
1606
1607         if (test_and_clear_bit(CF_RECONNECT, &con->flags)) {
1608                 close_connection(con, false, false, true);
1609                 dlm_midcomms_unack_msg_resend(con->nodeid);
1610         }
1611
1612         if (con->sock == NULL) {
1613                 if (test_and_clear_bit(CF_DELAY_CONNECT, &con->flags))
1614                         msleep(1000);
1615
1616                 mutex_lock(&con->sock_mutex);
1617                 dlm_connect(con);
1618                 mutex_unlock(&con->sock_mutex);
1619         }
1620
1621         if (!list_empty(&con->writequeue))
1622                 send_to_sock(con);
1623 }
1624
1625 static void work_stop(void)
1626 {
1627         if (recv_workqueue) {
1628                 destroy_workqueue(recv_workqueue);
1629                 recv_workqueue = NULL;
1630         }
1631
1632         if (send_workqueue) {
1633                 destroy_workqueue(send_workqueue);
1634                 send_workqueue = NULL;
1635         }
1636 }
1637
1638 static int work_start(void)
1639 {
1640         recv_workqueue = alloc_ordered_workqueue("dlm_recv", WQ_MEM_RECLAIM);
1641         if (!recv_workqueue) {
1642                 log_print("can't start dlm_recv");
1643                 return -ENOMEM;
1644         }
1645
1646         send_workqueue = alloc_ordered_workqueue("dlm_send", WQ_MEM_RECLAIM);
1647         if (!send_workqueue) {
1648                 log_print("can't start dlm_send");
1649                 destroy_workqueue(recv_workqueue);
1650                 recv_workqueue = NULL;
1651                 return -ENOMEM;
1652         }
1653
1654         return 0;
1655 }
1656
1657 static void shutdown_conn(struct connection *con)
1658 {
1659         if (dlm_proto_ops->shutdown_action)
1660                 dlm_proto_ops->shutdown_action(con);
1661 }
1662
1663 void dlm_lowcomms_shutdown(void)
1664 {
1665         int idx;
1666
1667         /* Set all the flags to prevent any
1668          * socket activity.
1669          */
1670         dlm_allow_conn = 0;
1671
1672         if (recv_workqueue)
1673                 flush_workqueue(recv_workqueue);
1674         if (send_workqueue)
1675                 flush_workqueue(send_workqueue);
1676
1677         dlm_close_sock(&listen_con.sock);
1678
1679         idx = srcu_read_lock(&connections_srcu);
1680         foreach_conn(shutdown_conn);
1681         srcu_read_unlock(&connections_srcu, idx);
1682 }
1683
1684 static void _stop_conn(struct connection *con, bool and_other)
1685 {
1686         mutex_lock(&con->sock_mutex);
1687         set_bit(CF_CLOSE, &con->flags);
1688         set_bit(CF_READ_PENDING, &con->flags);
1689         set_bit(CF_WRITE_PENDING, &con->flags);
1690         if (con->sock && con->sock->sk) {
1691                 write_lock_bh(&con->sock->sk->sk_callback_lock);
1692                 con->sock->sk->sk_user_data = NULL;
1693                 write_unlock_bh(&con->sock->sk->sk_callback_lock);
1694         }
1695         if (con->othercon && and_other)
1696                 _stop_conn(con->othercon, false);
1697         mutex_unlock(&con->sock_mutex);
1698 }
1699
1700 static void stop_conn(struct connection *con)
1701 {
1702         _stop_conn(con, true);
1703 }
1704
1705 static void connection_release(struct rcu_head *rcu)
1706 {
1707         struct connection *con = container_of(rcu, struct connection, rcu);
1708
1709         kfree(con->rx_buf);
1710         kfree(con);
1711 }
1712
1713 static void free_conn(struct connection *con)
1714 {
1715         close_connection(con, true, true, true);
1716         spin_lock(&connections_lock);
1717         hlist_del_rcu(&con->list);
1718         spin_unlock(&connections_lock);
1719         if (con->othercon) {
1720                 clean_one_writequeue(con->othercon);
1721                 call_srcu(&connections_srcu, &con->othercon->rcu,
1722                           connection_release);
1723         }
1724         clean_one_writequeue(con);
1725         call_srcu(&connections_srcu, &con->rcu, connection_release);
1726 }
1727
1728 static void work_flush(void)
1729 {
1730         int ok;
1731         int i;
1732         struct connection *con;
1733
1734         do {
1735                 ok = 1;
1736                 foreach_conn(stop_conn);
1737                 if (recv_workqueue)
1738                         flush_workqueue(recv_workqueue);
1739                 if (send_workqueue)
1740                         flush_workqueue(send_workqueue);
1741                 for (i = 0; i < CONN_HASH_SIZE && ok; i++) {
1742                         hlist_for_each_entry_rcu(con, &connection_hash[i],
1743                                                  list) {
1744                                 ok &= test_bit(CF_READ_PENDING, &con->flags);
1745                                 ok &= test_bit(CF_WRITE_PENDING, &con->flags);
1746                                 if (con->othercon) {
1747                                         ok &= test_bit(CF_READ_PENDING,
1748                                                        &con->othercon->flags);
1749                                         ok &= test_bit(CF_WRITE_PENDING,
1750                                                        &con->othercon->flags);
1751                                 }
1752                         }
1753                 }
1754         } while (!ok);
1755 }
1756
1757 void dlm_lowcomms_stop(void)
1758 {
1759         int idx;
1760
1761         idx = srcu_read_lock(&connections_srcu);
1762         work_flush();
1763         foreach_conn(free_conn);
1764         srcu_read_unlock(&connections_srcu, idx);
1765         work_stop();
1766         deinit_local();
1767
1768         dlm_proto_ops = NULL;
1769 }
1770
1771 static int dlm_listen_for_all(void)
1772 {
1773         struct socket *sock;
1774         int result;
1775
1776         log_print("Using %s for communications",
1777                   dlm_proto_ops->name);
1778
1779         result = dlm_proto_ops->listen_validate();
1780         if (result < 0)
1781                 return result;
1782
1783         result = sock_create_kern(&init_net, dlm_local_addr[0]->ss_family,
1784                                   SOCK_STREAM, dlm_proto_ops->proto, &sock);
1785         if (result < 0) {
1786                 log_print("Can't create comms socket, check SCTP is loaded");
1787                 return result;
1788         }
1789
1790         sock_set_mark(sock->sk, dlm_config.ci_mark);
1791         dlm_proto_ops->listen_sockopts(sock);
1792
1793         result = dlm_proto_ops->listen_bind(sock);
1794         if (result < 0)
1795                 goto out;
1796
1797         save_listen_callbacks(sock);
1798         add_listen_sock(sock, &listen_con);
1799
1800         INIT_WORK(&listen_con.rwork, process_listen_recv_socket);
1801         result = sock->ops->listen(sock, 5);
1802         if (result < 0) {
1803                 dlm_close_sock(&listen_con.sock);
1804                 return result;
1805         }
1806
1807         return 0;
1808
1809 out:
1810         sock_release(sock);
1811         return result;
1812 }
1813
1814 static int dlm_tcp_bind(struct socket *sock)
1815 {
1816         struct sockaddr_storage src_addr;
1817         int result, addr_len;
1818
1819         /* Bind to our cluster-known address connecting to avoid
1820          * routing problems.
1821          */
1822         memcpy(&src_addr, dlm_local_addr[0], sizeof(src_addr));
1823         make_sockaddr(&src_addr, 0, &addr_len);
1824
1825         result = sock->ops->bind(sock, (struct sockaddr *)&src_addr,
1826                                  addr_len);
1827         if (result < 0) {
1828                 /* This *may* not indicate a critical error */
1829                 log_print("could not bind for connect: %d", result);
1830         }
1831
1832         return 0;
1833 }
1834
1835 static int dlm_tcp_connect(struct connection *con, struct socket *sock,
1836                            struct sockaddr *addr, int addr_len)
1837 {
1838         int ret;
1839
1840         ret = sock->ops->connect(sock, addr, addr_len, O_NONBLOCK);
1841         switch (ret) {
1842         case -EINPROGRESS:
1843                 fallthrough;
1844         case 0:
1845                 return 0;
1846         }
1847
1848         return ret;
1849 }
1850
1851 static int dlm_tcp_listen_validate(void)
1852 {
1853         /* We don't support multi-homed hosts */
1854         if (dlm_local_count > 1) {
1855                 log_print("TCP protocol can't handle multi-homed hosts, try SCTP");
1856                 return -EINVAL;
1857         }
1858
1859         return 0;
1860 }
1861
1862 static void dlm_tcp_sockopts(struct socket *sock)
1863 {
1864         /* Turn off Nagle's algorithm */
1865         tcp_sock_set_nodelay(sock->sk);
1866 }
1867
1868 static void dlm_tcp_listen_sockopts(struct socket *sock)
1869 {
1870         dlm_tcp_sockopts(sock);
1871         sock_set_reuseaddr(sock->sk);
1872 }
1873
1874 static int dlm_tcp_listen_bind(struct socket *sock)
1875 {
1876         int addr_len;
1877
1878         /* Bind to our port */
1879         make_sockaddr(dlm_local_addr[0], dlm_config.ci_tcp_port, &addr_len);
1880         return sock->ops->bind(sock, (struct sockaddr *)dlm_local_addr[0],
1881                                addr_len);
1882 }
1883
1884 static const struct dlm_proto_ops dlm_tcp_ops = {
1885         .name = "TCP",
1886         .proto = IPPROTO_TCP,
1887         .connect = dlm_tcp_connect,
1888         .sockopts = dlm_tcp_sockopts,
1889         .bind = dlm_tcp_bind,
1890         .listen_validate = dlm_tcp_listen_validate,
1891         .listen_sockopts = dlm_tcp_listen_sockopts,
1892         .listen_bind = dlm_tcp_listen_bind,
1893         .shutdown_action = dlm_tcp_shutdown,
1894         .eof_condition = tcp_eof_condition,
1895 };
1896
1897 static int dlm_sctp_bind(struct socket *sock)
1898 {
1899         return sctp_bind_addrs(sock, 0);
1900 }
1901
1902 static int dlm_sctp_connect(struct connection *con, struct socket *sock,
1903                             struct sockaddr *addr, int addr_len)
1904 {
1905         int ret;
1906
1907         /*
1908          * Make sock->ops->connect() function return in specified time,
1909          * since O_NONBLOCK argument in connect() function does not work here,
1910          * then, we should restore the default value of this attribute.
1911          */
1912         sock_set_sndtimeo(sock->sk, 5);
1913         ret = sock->ops->connect(sock, addr, addr_len, 0);
1914         sock_set_sndtimeo(sock->sk, 0);
1915         if (ret < 0)
1916                 return ret;
1917
1918         if (!test_and_set_bit(CF_CONNECTED, &con->flags))
1919                 log_print("successful connected to node %d", con->nodeid);
1920
1921         return 0;
1922 }
1923
1924 static int dlm_sctp_listen_validate(void)
1925 {
1926         if (!IS_ENABLED(CONFIG_IP_SCTP)) {
1927                 log_print("SCTP is not enabled by this kernel");
1928                 return -EOPNOTSUPP;
1929         }
1930
1931         request_module("sctp");
1932         return 0;
1933 }
1934
1935 static int dlm_sctp_bind_listen(struct socket *sock)
1936 {
1937         return sctp_bind_addrs(sock, dlm_config.ci_tcp_port);
1938 }
1939
1940 static void dlm_sctp_sockopts(struct socket *sock)
1941 {
1942         /* Turn off Nagle's algorithm */
1943         sctp_sock_set_nodelay(sock->sk);
1944         sock_set_rcvbuf(sock->sk, NEEDED_RMEM);
1945 }
1946
1947 static const struct dlm_proto_ops dlm_sctp_ops = {
1948         .name = "SCTP",
1949         .proto = IPPROTO_SCTP,
1950         .try_new_addr = true,
1951         .connect = dlm_sctp_connect,
1952         .sockopts = dlm_sctp_sockopts,
1953         .bind = dlm_sctp_bind,
1954         .listen_validate = dlm_sctp_listen_validate,
1955         .listen_sockopts = dlm_sctp_sockopts,
1956         .listen_bind = dlm_sctp_bind_listen,
1957 };
1958
1959 int dlm_lowcomms_start(void)
1960 {
1961         int error = -EINVAL;
1962         int i;
1963
1964         for (i = 0; i < CONN_HASH_SIZE; i++)
1965                 INIT_HLIST_HEAD(&connection_hash[i]);
1966
1967         init_local();
1968         if (!dlm_local_count) {
1969                 error = -ENOTCONN;
1970                 log_print("no local IP address has been set");
1971                 goto fail;
1972         }
1973
1974         INIT_WORK(&listen_con.rwork, process_listen_recv_socket);
1975
1976         error = work_start();
1977         if (error)
1978                 goto fail_local;
1979
1980         dlm_allow_conn = 1;
1981
1982         /* Start listening */
1983         switch (dlm_config.ci_protocol) {
1984         case DLM_PROTO_TCP:
1985                 dlm_proto_ops = &dlm_tcp_ops;
1986                 break;
1987         case DLM_PROTO_SCTP:
1988                 dlm_proto_ops = &dlm_sctp_ops;
1989                 break;
1990         default:
1991                 log_print("Invalid protocol identifier %d set",
1992                           dlm_config.ci_protocol);
1993                 error = -EINVAL;
1994                 goto fail_proto_ops;
1995         }
1996
1997         error = dlm_listen_for_all();
1998         if (error)
1999                 goto fail_listen;
2000
2001         return 0;
2002
2003 fail_listen:
2004         dlm_proto_ops = NULL;
2005 fail_proto_ops:
2006         dlm_allow_conn = 0;
2007         work_stop();
2008 fail_local:
2009         deinit_local();
2010 fail:
2011         return error;
2012 }
2013
2014 void dlm_lowcomms_exit(void)
2015 {
2016         struct dlm_node_addr *na, *safe;
2017
2018         spin_lock(&dlm_node_addrs_spin);
2019         list_for_each_entry_safe(na, safe, &dlm_node_addrs, list) {
2020                 list_del(&na->list);
2021                 while (na->addr_count--)
2022                         kfree(na->addr[na->addr_count]);
2023                 kfree(na);
2024         }
2025         spin_unlock(&dlm_node_addrs_spin);
2026 }