rxrpc: Fix local endpoint refcounting
[platform/kernel/linux-rpi.git] / net / rxrpc / local_object.c
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /* Local endpoint object management
3  *
4  * Copyright (C) 2016 Red Hat, Inc. All Rights Reserved.
5  * Written by David Howells (dhowells@redhat.com)
6  */
7
8 #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
9
10 #include <linux/module.h>
11 #include <linux/net.h>
12 #include <linux/skbuff.h>
13 #include <linux/slab.h>
14 #include <linux/udp.h>
15 #include <linux/ip.h>
16 #include <linux/hashtable.h>
17 #include <net/sock.h>
18 #include <net/udp.h>
19 #include <net/af_rxrpc.h>
20 #include "ar-internal.h"
21
22 static void rxrpc_local_processor(struct work_struct *);
23 static void rxrpc_local_rcu(struct rcu_head *);
24
25 /*
26  * Compare a local to an address.  Return -ve, 0 or +ve to indicate less than,
27  * same or greater than.
28  *
29  * We explicitly don't compare the RxRPC service ID as we want to reject
30  * conflicting uses by differing services.  Further, we don't want to share
31  * addresses with different options (IPv6), so we don't compare those bits
32  * either.
33  */
34 static long rxrpc_local_cmp_key(const struct rxrpc_local *local,
35                                 const struct sockaddr_rxrpc *srx)
36 {
37         long diff;
38
39         diff = ((local->srx.transport_type - srx->transport_type) ?:
40                 (local->srx.transport_len - srx->transport_len) ?:
41                 (local->srx.transport.family - srx->transport.family));
42         if (diff != 0)
43                 return diff;
44
45         switch (srx->transport.family) {
46         case AF_INET:
47                 /* If the choice of UDP port is left up to the transport, then
48                  * the endpoint record doesn't match.
49                  */
50                 return ((u16 __force)local->srx.transport.sin.sin_port -
51                         (u16 __force)srx->transport.sin.sin_port) ?:
52                         memcmp(&local->srx.transport.sin.sin_addr,
53                                &srx->transport.sin.sin_addr,
54                                sizeof(struct in_addr));
55 #ifdef CONFIG_AF_RXRPC_IPV6
56         case AF_INET6:
57                 /* If the choice of UDP6 port is left up to the transport, then
58                  * the endpoint record doesn't match.
59                  */
60                 return ((u16 __force)local->srx.transport.sin6.sin6_port -
61                         (u16 __force)srx->transport.sin6.sin6_port) ?:
62                         memcmp(&local->srx.transport.sin6.sin6_addr,
63                                &srx->transport.sin6.sin6_addr,
64                                sizeof(struct in6_addr));
65 #endif
66         default:
67                 BUG();
68         }
69 }
70
71 /*
72  * Allocate a new local endpoint.
73  */
74 static struct rxrpc_local *rxrpc_alloc_local(struct rxrpc_net *rxnet,
75                                              const struct sockaddr_rxrpc *srx)
76 {
77         struct rxrpc_local *local;
78
79         local = kzalloc(sizeof(struct rxrpc_local), GFP_KERNEL);
80         if (local) {
81                 atomic_set(&local->usage, 1);
82                 atomic_set(&local->active_users, 1);
83                 local->rxnet = rxnet;
84                 INIT_LIST_HEAD(&local->link);
85                 INIT_WORK(&local->processor, rxrpc_local_processor);
86                 init_rwsem(&local->defrag_sem);
87                 skb_queue_head_init(&local->reject_queue);
88                 skb_queue_head_init(&local->event_queue);
89                 local->client_conns = RB_ROOT;
90                 spin_lock_init(&local->client_conns_lock);
91                 spin_lock_init(&local->lock);
92                 rwlock_init(&local->services_lock);
93                 local->debug_id = atomic_inc_return(&rxrpc_debug_id);
94                 memcpy(&local->srx, srx, sizeof(*srx));
95                 local->srx.srx_service = 0;
96                 trace_rxrpc_local(local, rxrpc_local_new, 1, NULL);
97         }
98
99         _leave(" = %p", local);
100         return local;
101 }
102
103 /*
104  * create the local socket
105  * - must be called with rxrpc_local_mutex locked
106  */
107 static int rxrpc_open_socket(struct rxrpc_local *local, struct net *net)
108 {
109         struct sock *usk;
110         int ret, opt;
111
112         _enter("%p{%d,%d}",
113                local, local->srx.transport_type, local->srx.transport.family);
114
115         /* create a socket to represent the local endpoint */
116         ret = sock_create_kern(net, local->srx.transport.family,
117                                local->srx.transport_type, 0, &local->socket);
118         if (ret < 0) {
119                 _leave(" = %d [socket]", ret);
120                 return ret;
121         }
122
123         /* set the socket up */
124         usk = local->socket->sk;
125         inet_sk(usk)->mc_loop = 0;
126
127         /* Enable CHECKSUM_UNNECESSARY to CHECKSUM_COMPLETE conversion */
128         inet_inc_convert_csum(usk);
129
130         rcu_assign_sk_user_data(usk, local);
131
132         udp_sk(usk)->encap_type = UDP_ENCAP_RXRPC;
133         udp_sk(usk)->encap_rcv = rxrpc_input_packet;
134         udp_sk(usk)->encap_destroy = NULL;
135         udp_sk(usk)->gro_receive = NULL;
136         udp_sk(usk)->gro_complete = NULL;
137
138         udp_encap_enable();
139 #if IS_ENABLED(CONFIG_AF_RXRPC_IPV6)
140         if (local->srx.transport.family == AF_INET6)
141                 udpv6_encap_enable();
142 #endif
143         usk->sk_error_report = rxrpc_error_report;
144
145         /* if a local address was supplied then bind it */
146         if (local->srx.transport_len > sizeof(sa_family_t)) {
147                 _debug("bind");
148                 ret = kernel_bind(local->socket,
149                                   (struct sockaddr *)&local->srx.transport,
150                                   local->srx.transport_len);
151                 if (ret < 0) {
152                         _debug("bind failed %d", ret);
153                         goto error;
154                 }
155         }
156
157         switch (local->srx.transport.family) {
158         case AF_INET6:
159                 /* we want to receive ICMPv6 errors */
160                 opt = 1;
161                 ret = kernel_setsockopt(local->socket, SOL_IPV6, IPV6_RECVERR,
162                                         (char *) &opt, sizeof(opt));
163                 if (ret < 0) {
164                         _debug("setsockopt failed");
165                         goto error;
166                 }
167
168                 /* we want to set the don't fragment bit */
169                 opt = IPV6_PMTUDISC_DO;
170                 ret = kernel_setsockopt(local->socket, SOL_IPV6, IPV6_MTU_DISCOVER,
171                                         (char *) &opt, sizeof(opt));
172                 if (ret < 0) {
173                         _debug("setsockopt failed");
174                         goto error;
175                 }
176
177                 /* Fall through and set IPv4 options too otherwise we don't get
178                  * errors from IPv4 packets sent through the IPv6 socket.
179                  */
180                 /* Fall through */
181         case AF_INET:
182                 /* we want to receive ICMP errors */
183                 opt = 1;
184                 ret = kernel_setsockopt(local->socket, SOL_IP, IP_RECVERR,
185                                         (char *) &opt, sizeof(opt));
186                 if (ret < 0) {
187                         _debug("setsockopt failed");
188                         goto error;
189                 }
190
191                 /* we want to set the don't fragment bit */
192                 opt = IP_PMTUDISC_DO;
193                 ret = kernel_setsockopt(local->socket, SOL_IP, IP_MTU_DISCOVER,
194                                         (char *) &opt, sizeof(opt));
195                 if (ret < 0) {
196                         _debug("setsockopt failed");
197                         goto error;
198                 }
199
200                 /* We want receive timestamps. */
201                 opt = 1;
202                 ret = kernel_setsockopt(local->socket, SOL_SOCKET, SO_TIMESTAMPNS_OLD,
203                                         (char *)&opt, sizeof(opt));
204                 if (ret < 0) {
205                         _debug("setsockopt failed");
206                         goto error;
207                 }
208                 break;
209
210         default:
211                 BUG();
212         }
213
214         _leave(" = 0");
215         return 0;
216
217 error:
218         kernel_sock_shutdown(local->socket, SHUT_RDWR);
219         local->socket->sk->sk_user_data = NULL;
220         sock_release(local->socket);
221         local->socket = NULL;
222
223         _leave(" = %d", ret);
224         return ret;
225 }
226
227 /*
228  * Look up or create a new local endpoint using the specified local address.
229  */
230 struct rxrpc_local *rxrpc_lookup_local(struct net *net,
231                                        const struct sockaddr_rxrpc *srx)
232 {
233         struct rxrpc_local *local;
234         struct rxrpc_net *rxnet = rxrpc_net(net);
235         struct list_head *cursor;
236         const char *age;
237         long diff;
238         int ret;
239
240         _enter("{%d,%d,%pISp}",
241                srx->transport_type, srx->transport.family, &srx->transport);
242
243         mutex_lock(&rxnet->local_mutex);
244
245         for (cursor = rxnet->local_endpoints.next;
246              cursor != &rxnet->local_endpoints;
247              cursor = cursor->next) {
248                 local = list_entry(cursor, struct rxrpc_local, link);
249
250                 diff = rxrpc_local_cmp_key(local, srx);
251                 if (diff < 0)
252                         continue;
253                 if (diff > 0)
254                         break;
255
256                 /* Services aren't allowed to share transport sockets, so
257                  * reject that here.  It is possible that the object is dying -
258                  * but it may also still have the local transport address that
259                  * we want bound.
260                  */
261                 if (srx->srx_service) {
262                         local = NULL;
263                         goto addr_in_use;
264                 }
265
266                 /* Found a match.  We replace a dying object.  Attempting to
267                  * bind the transport socket may still fail if we're attempting
268                  * to use a local address that the dying object is still using.
269                  */
270                 if (!rxrpc_use_local(local))
271                         break;
272
273                 age = "old";
274                 goto found;
275         }
276
277         local = rxrpc_alloc_local(rxnet, srx);
278         if (!local)
279                 goto nomem;
280
281         ret = rxrpc_open_socket(local, net);
282         if (ret < 0)
283                 goto sock_error;
284
285         if (cursor != &rxnet->local_endpoints)
286                 list_replace(cursor, &local->link);
287         else
288                 list_add_tail(&local->link, cursor);
289         age = "new";
290
291 found:
292         mutex_unlock(&rxnet->local_mutex);
293
294         _net("LOCAL %s %d {%pISp}",
295              age, local->debug_id, &local->srx.transport);
296
297         _leave(" = %p", local);
298         return local;
299
300 nomem:
301         ret = -ENOMEM;
302 sock_error:
303         mutex_unlock(&rxnet->local_mutex);
304         if (local)
305                 call_rcu(&local->rcu, rxrpc_local_rcu);
306         _leave(" = %d", ret);
307         return ERR_PTR(ret);
308
309 addr_in_use:
310         mutex_unlock(&rxnet->local_mutex);
311         _leave(" = -EADDRINUSE");
312         return ERR_PTR(-EADDRINUSE);
313 }
314
315 /*
316  * Get a ref on a local endpoint.
317  */
318 struct rxrpc_local *rxrpc_get_local(struct rxrpc_local *local)
319 {
320         const void *here = __builtin_return_address(0);
321         int n;
322
323         n = atomic_inc_return(&local->usage);
324         trace_rxrpc_local(local, rxrpc_local_got, n, here);
325         return local;
326 }
327
328 /*
329  * Get a ref on a local endpoint unless its usage has already reached 0.
330  */
331 struct rxrpc_local *rxrpc_get_local_maybe(struct rxrpc_local *local)
332 {
333         const void *here = __builtin_return_address(0);
334
335         if (local) {
336                 int n = atomic_fetch_add_unless(&local->usage, 1, 0);
337                 if (n > 0)
338                         trace_rxrpc_local(local, rxrpc_local_got, n + 1, here);
339                 else
340                         local = NULL;
341         }
342         return local;
343 }
344
345 /*
346  * Queue a local endpoint unless it has become unreferenced and pass the
347  * caller's reference to the work item.
348  */
349 void rxrpc_queue_local(struct rxrpc_local *local)
350 {
351         const void *here = __builtin_return_address(0);
352
353         if (rxrpc_queue_work(&local->processor))
354                 trace_rxrpc_local(local, rxrpc_local_queued,
355                                   atomic_read(&local->usage), here);
356         else
357                 rxrpc_put_local(local);
358 }
359
360 /*
361  * Drop a ref on a local endpoint.
362  */
363 void rxrpc_put_local(struct rxrpc_local *local)
364 {
365         const void *here = __builtin_return_address(0);
366         int n;
367
368         if (local) {
369                 n = atomic_dec_return(&local->usage);
370                 trace_rxrpc_local(local, rxrpc_local_put, n, here);
371
372                 if (n == 0)
373                         call_rcu(&local->rcu, rxrpc_local_rcu);
374         }
375 }
376
377 /*
378  * Start using a local endpoint.
379  */
380 struct rxrpc_local *rxrpc_use_local(struct rxrpc_local *local)
381 {
382         unsigned int au;
383
384         local = rxrpc_get_local_maybe(local);
385         if (!local)
386                 return NULL;
387
388         au = atomic_fetch_add_unless(&local->active_users, 1, 0);
389         if (au == 0) {
390                 rxrpc_put_local(local);
391                 return NULL;
392         }
393
394         return local;
395 }
396
397 /*
398  * Cease using a local endpoint.  Once the number of active users reaches 0, we
399  * start the closure of the transport in the work processor.
400  */
401 void rxrpc_unuse_local(struct rxrpc_local *local)
402 {
403         unsigned int au;
404
405         au = atomic_dec_return(&local->active_users);
406         if (au == 0)
407                 rxrpc_queue_local(local);
408         else
409                 rxrpc_put_local(local);
410 }
411
412 /*
413  * Destroy a local endpoint's socket and then hand the record to RCU to dispose
414  * of.
415  *
416  * Closing the socket cannot be done from bottom half context or RCU callback
417  * context because it might sleep.
418  */
419 static void rxrpc_local_destroyer(struct rxrpc_local *local)
420 {
421         struct socket *socket = local->socket;
422         struct rxrpc_net *rxnet = local->rxnet;
423
424         _enter("%d", local->debug_id);
425
426         mutex_lock(&rxnet->local_mutex);
427         list_del_init(&local->link);
428         mutex_unlock(&rxnet->local_mutex);
429
430         ASSERT(RB_EMPTY_ROOT(&local->client_conns));
431         ASSERT(!local->service);
432
433         if (socket) {
434                 local->socket = NULL;
435                 kernel_sock_shutdown(socket, SHUT_RDWR);
436                 socket->sk->sk_user_data = NULL;
437                 sock_release(socket);
438         }
439
440         /* At this point, there should be no more packets coming in to the
441          * local endpoint.
442          */
443         rxrpc_purge_queue(&local->reject_queue);
444         rxrpc_purge_queue(&local->event_queue);
445 }
446
447 /*
448  * Process events on an endpoint.  The work item carries a ref which
449  * we must release.
450  */
451 static void rxrpc_local_processor(struct work_struct *work)
452 {
453         struct rxrpc_local *local =
454                 container_of(work, struct rxrpc_local, processor);
455         bool again;
456
457         trace_rxrpc_local(local, rxrpc_local_processing,
458                           atomic_read(&local->usage), NULL);
459
460         do {
461                 again = false;
462                 if (atomic_read(&local->active_users) == 0) {
463                         rxrpc_local_destroyer(local);
464                         break;
465                 }
466
467                 if (!skb_queue_empty(&local->reject_queue)) {
468                         rxrpc_reject_packets(local);
469                         again = true;
470                 }
471
472                 if (!skb_queue_empty(&local->event_queue)) {
473                         rxrpc_process_local_events(local);
474                         again = true;
475                 }
476         } while (again);
477
478         rxrpc_put_local(local);
479 }
480
481 /*
482  * Destroy a local endpoint after the RCU grace period expires.
483  */
484 static void rxrpc_local_rcu(struct rcu_head *rcu)
485 {
486         struct rxrpc_local *local = container_of(rcu, struct rxrpc_local, rcu);
487
488         _enter("%d", local->debug_id);
489
490         ASSERT(!work_pending(&local->processor));
491
492         _net("DESTROY LOCAL %d", local->debug_id);
493         kfree(local);
494         _leave("");
495 }
496
497 /*
498  * Verify the local endpoint list is empty by this point.
499  */
500 void rxrpc_destroy_all_locals(struct rxrpc_net *rxnet)
501 {
502         struct rxrpc_local *local;
503
504         _enter("");
505
506         flush_workqueue(rxrpc_workqueue);
507
508         if (!list_empty(&rxnet->local_endpoints)) {
509                 mutex_lock(&rxnet->local_mutex);
510                 list_for_each_entry(local, &rxnet->local_endpoints, link) {
511                         pr_err("AF_RXRPC: Leaked local %p {%d}\n",
512                                local, atomic_read(&local->usage));
513                 }
514                 mutex_unlock(&rxnet->local_mutex);
515                 BUG();
516         }
517 }