rxrpc: Extract the peer address from an incoming packet earlier
authorDavid Howells <dhowells@redhat.com>
Thu, 20 Oct 2022 21:36:20 +0000 (22:36 +0100)
committerDavid Howells <dhowells@redhat.com>
Thu, 1 Dec 2022 13:36:42 +0000 (13:36 +0000)
Extract the peer address from an incoming packet earlier, at the beginning
of rxrpc_input_packet() and thence pass a pointer to it to various
functions that use it as part of the lookup rather than doing it on several
separate paths.

Signed-off-by: David Howells <dhowells@redhat.com>
cc: Marc Dionne <marc.dionne@auristor.com>
cc: linux-afs@lists.infradead.org

net/rxrpc/ar-internal.h
net/rxrpc/call_accept.c
net/rxrpc/conn_object.c
net/rxrpc/io_thread.c

index cfd16f1e5c838ae861ffa9f95e16b1f8c56048af..c3c915a0562772951989009348ce71fe2a5e2dae 100644 (file)
@@ -824,6 +824,7 @@ int rxrpc_service_prealloc(struct rxrpc_sock *, gfp_t);
 void rxrpc_discard_prealloc(struct rxrpc_sock *);
 struct rxrpc_call *rxrpc_new_incoming_call(struct rxrpc_local *,
                                           struct rxrpc_sock *,
 void rxrpc_discard_prealloc(struct rxrpc_sock *);
 struct rxrpc_call *rxrpc_new_incoming_call(struct rxrpc_local *,
                                           struct rxrpc_sock *,
+                                          struct sockaddr_rxrpc *,
                                           struct sk_buff *);
 void rxrpc_accept_incoming_calls(struct rxrpc_local *);
 int rxrpc_user_charge_accept(struct rxrpc_sock *, unsigned long);
                                           struct sk_buff *);
 void rxrpc_accept_incoming_calls(struct rxrpc_local *);
 int rxrpc_user_charge_accept(struct rxrpc_sock *, unsigned long);
@@ -916,6 +917,7 @@ extern unsigned int rxrpc_closed_conn_expiry;
 
 struct rxrpc_connection *rxrpc_alloc_connection(struct rxrpc_net *, gfp_t);
 struct rxrpc_connection *rxrpc_find_connection_rcu(struct rxrpc_local *,
 
 struct rxrpc_connection *rxrpc_alloc_connection(struct rxrpc_net *, gfp_t);
 struct rxrpc_connection *rxrpc_find_connection_rcu(struct rxrpc_local *,
+                                                  struct sockaddr_rxrpc *,
                                                   struct sk_buff *,
                                                   struct rxrpc_peer **);
 void __rxrpc_disconnect_call(struct rxrpc_connection *, struct rxrpc_call *);
                                                   struct sk_buff *,
                                                   struct rxrpc_peer **);
 void __rxrpc_disconnect_call(struct rxrpc_connection *, struct rxrpc_call *);
index beb8efa2e7a98a22865d7fee1bc12fcc1e4be22d..11134b7cec17a1ebc44ee3cba15e9963b20e2963 100644 (file)
@@ -258,6 +258,7 @@ static struct rxrpc_call *rxrpc_alloc_incoming_call(struct rxrpc_sock *rx,
                                                    struct rxrpc_peer *peer,
                                                    struct rxrpc_connection *conn,
                                                    const struct rxrpc_security *sec,
                                                    struct rxrpc_peer *peer,
                                                    struct rxrpc_connection *conn,
                                                    const struct rxrpc_security *sec,
+                                                   struct sockaddr_rxrpc *peer_srx,
                                                    struct sk_buff *skb)
 {
        struct rxrpc_backlog *b = rx->backlog;
                                                    struct sk_buff *skb)
 {
        struct rxrpc_backlog *b = rx->backlog;
@@ -287,8 +288,7 @@ static struct rxrpc_call *rxrpc_alloc_incoming_call(struct rxrpc_sock *rx,
                        peer = NULL;
                if (!peer) {
                        peer = b->peer_backlog[peer_tail];
                        peer = NULL;
                if (!peer) {
                        peer = b->peer_backlog[peer_tail];
-                       if (rxrpc_extract_addr_from_skb(&peer->srx, skb) < 0)
-                               return NULL;
+                       peer->srx = *peer_srx;
                        b->peer_backlog[peer_tail] = NULL;
                        smp_store_release(&b->peer_backlog_tail,
                                          (peer_tail + 1) &
                        b->peer_backlog[peer_tail] = NULL;
                        smp_store_release(&b->peer_backlog_tail,
                                          (peer_tail + 1) &
@@ -346,6 +346,7 @@ static struct rxrpc_call *rxrpc_alloc_incoming_call(struct rxrpc_sock *rx,
  */
 struct rxrpc_call *rxrpc_new_incoming_call(struct rxrpc_local *local,
                                           struct rxrpc_sock *rx,
  */
 struct rxrpc_call *rxrpc_new_incoming_call(struct rxrpc_local *local,
                                           struct rxrpc_sock *rx,
+                                          struct sockaddr_rxrpc *peer_srx,
                                           struct sk_buff *skb)
 {
        struct rxrpc_skb_priv *sp = rxrpc_skb(skb);
                                           struct sk_buff *skb)
 {
        struct rxrpc_skb_priv *sp = rxrpc_skb(skb);
@@ -371,7 +372,7 @@ struct rxrpc_call *rxrpc_new_incoming_call(struct rxrpc_local *local,
         * we have to recheck the routing.  However, we're now holding
         * rx->incoming_lock, so the values should remain stable.
         */
         * we have to recheck the routing.  However, we're now holding
         * rx->incoming_lock, so the values should remain stable.
         */
-       conn = rxrpc_find_connection_rcu(local, skb, &peer);
+       conn = rxrpc_find_connection_rcu(local, peer_srx, skb, &peer);
 
        if (!conn) {
                sec = rxrpc_get_incoming_security(rx, skb);
 
        if (!conn) {
                sec = rxrpc_get_incoming_security(rx, skb);
@@ -379,7 +380,8 @@ struct rxrpc_call *rxrpc_new_incoming_call(struct rxrpc_local *local,
                        goto no_call;
        }
 
                        goto no_call;
        }
 
-       call = rxrpc_alloc_incoming_call(rx, local, peer, conn, sec, skb);
+       call = rxrpc_alloc_incoming_call(rx, local, peer, conn, sec, peer_srx,
+                                        skb);
        if (!call) {
                skb->mark = RXRPC_SKB_MARK_REJECT_BUSY;
                goto no_call;
        if (!call) {
                skb->mark = RXRPC_SKB_MARK_REJECT_BUSY;
                goto no_call;
index 5a39255ea014d55104d25eb7877561b210b63947..98e49646ca1d76e7831f495f68860047b3aacd48 100644 (file)
@@ -73,29 +73,17 @@ struct rxrpc_connection *rxrpc_alloc_connection(struct rxrpc_net *rxnet,
  * The caller must be holding the RCU read lock.
  */
 struct rxrpc_connection *rxrpc_find_connection_rcu(struct rxrpc_local *local,
  * The caller must be holding the RCU read lock.
  */
 struct rxrpc_connection *rxrpc_find_connection_rcu(struct rxrpc_local *local,
+                                                  struct sockaddr_rxrpc *srx,
                                                   struct sk_buff *skb,
                                                   struct rxrpc_peer **_peer)
 {
        struct rxrpc_connection *conn;
        struct rxrpc_conn_proto k;
        struct rxrpc_skb_priv *sp = rxrpc_skb(skb);
                                                   struct sk_buff *skb,
                                                   struct rxrpc_peer **_peer)
 {
        struct rxrpc_connection *conn;
        struct rxrpc_conn_proto k;
        struct rxrpc_skb_priv *sp = rxrpc_skb(skb);
-       struct sockaddr_rxrpc srx;
        struct rxrpc_peer *peer;
 
        _enter(",%x", sp->hdr.cid & RXRPC_CIDMASK);
 
        struct rxrpc_peer *peer;
 
        _enter(",%x", sp->hdr.cid & RXRPC_CIDMASK);
 
-       if (rxrpc_extract_addr_from_skb(&srx, skb) < 0)
-               goto not_found;
-
-       if (srx.transport.family != local->srx.transport.family &&
-           (srx.transport.family == AF_INET &&
-            local->srx.transport.family != AF_INET6)) {
-               pr_warn_ratelimited("AF_RXRPC: Protocol mismatch %u not %u\n",
-                                   srx.transport.family,
-                                   local->srx.transport.family);
-               goto not_found;
-       }
-
        k.epoch = sp->hdr.epoch;
        k.cid   = sp->hdr.cid & RXRPC_CIDMASK;
 
        k.epoch = sp->hdr.epoch;
        k.cid   = sp->hdr.cid & RXRPC_CIDMASK;
 
@@ -104,7 +92,7 @@ struct rxrpc_connection *rxrpc_find_connection_rcu(struct rxrpc_local *local,
                 * parameter set.  We look up the peer first as an intermediate
                 * step and then the connection from the peer's tree.
                 */
                 * parameter set.  We look up the peer first as an intermediate
                 * step and then the connection from the peer's tree.
                 */
-               peer = rxrpc_lookup_peer_rcu(local, &srx);
+               peer = rxrpc_lookup_peer_rcu(local, srx);
                if (!peer)
                        goto not_found;
                *_peer = peer;
                if (!peer)
                        goto not_found;
                *_peer = peer;
@@ -117,8 +105,7 @@ struct rxrpc_connection *rxrpc_find_connection_rcu(struct rxrpc_local *local,
                /* Look up client connections by connection ID alone as their
                 * IDs are unique for this machine.
                 */
                /* Look up client connections by connection ID alone as their
                 * IDs are unique for this machine.
                 */
-               conn = idr_find(&rxrpc_client_conn_ids,
-                               sp->hdr.cid >> RXRPC_CIDSHIFT);
+               conn = idr_find(&rxrpc_client_conn_ids, sp->hdr.cid >> RXRPC_CIDSHIFT);
                if (!conn || refcount_read(&conn->ref) == 0) {
                        _debug("no conn");
                        goto not_found;
                if (!conn || refcount_read(&conn->ref) == 0) {
                        _debug("no conn");
                        goto not_found;
@@ -129,20 +116,20 @@ struct rxrpc_connection *rxrpc_find_connection_rcu(struct rxrpc_local *local,
                        goto not_found;
 
                peer = conn->peer;
                        goto not_found;
 
                peer = conn->peer;
-               switch (srx.transport.family) {
+               switch (srx->transport.family) {
                case AF_INET:
                        if (peer->srx.transport.sin.sin_port !=
                case AF_INET:
                        if (peer->srx.transport.sin.sin_port !=
-                           srx.transport.sin.sin_port ||
+                           srx->transport.sin.sin_port ||
                            peer->srx.transport.sin.sin_addr.s_addr !=
                            peer->srx.transport.sin.sin_addr.s_addr !=
-                           srx.transport.sin.sin_addr.s_addr)
+                           srx->transport.sin.sin_addr.s_addr)
                                goto not_found;
                        break;
 #ifdef CONFIG_AF_RXRPC_IPV6
                case AF_INET6:
                        if (peer->srx.transport.sin6.sin6_port !=
                                goto not_found;
                        break;
 #ifdef CONFIG_AF_RXRPC_IPV6
                case AF_INET6:
                        if (peer->srx.transport.sin6.sin6_port !=
-                           srx.transport.sin6.sin6_port ||
+                           srx->transport.sin6.sin6_port ||
                            memcmp(&peer->srx.transport.sin6.sin6_addr,
                            memcmp(&peer->srx.transport.sin6.sin6_addr,
-                                  &srx.transport.sin6.sin6_addr,
+                                  &srx->transport.sin6.sin6_addr,
                                   sizeof(struct in6_addr)) != 0)
                                goto not_found;
                        break;
                                   sizeof(struct in6_addr)) != 0)
                                goto not_found;
                        break;
index 3b6927610677c706ccdbb993f79c6923c516f921..bc65d83fab88a1587a1bbccaaab2aa3d27268335 100644 (file)
@@ -155,6 +155,7 @@ static bool rxrpc_extract_abort(struct sk_buff *skb)
 static int rxrpc_input_packet(struct rxrpc_local *local, struct sk_buff **_skb)
 {
        struct rxrpc_connection *conn;
 static int rxrpc_input_packet(struct rxrpc_local *local, struct sk_buff **_skb)
 {
        struct rxrpc_connection *conn;
+       struct sockaddr_rxrpc peer_srx;
        struct rxrpc_channel *chan;
        struct rxrpc_call *call = NULL;
        struct rxrpc_skb_priv *sp;
        struct rxrpc_channel *chan;
        struct rxrpc_call *call = NULL;
        struct rxrpc_skb_priv *sp;
@@ -257,6 +258,18 @@ static int rxrpc_input_packet(struct rxrpc_local *local, struct sk_buff **_skb)
        if (sp->hdr.serviceId == 0)
                goto bad_message;
 
        if (sp->hdr.serviceId == 0)
                goto bad_message;
 
+       if (WARN_ON_ONCE(rxrpc_extract_addr_from_skb(&peer_srx, skb) < 0))
+               return 0; /* Unsupported address type - discard. */
+
+       if (peer_srx.transport.family != local->srx.transport.family &&
+           (peer_srx.transport.family == AF_INET &&
+            local->srx.transport.family != AF_INET6)) {
+               pr_warn_ratelimited("AF_RXRPC: Protocol mismatch %u not %u\n",
+                                   peer_srx.transport.family,
+                                   local->srx.transport.family);
+               return 0; /* Wrong address type - discard. */
+       }
+
        rcu_read_lock();
 
        if (rxrpc_to_server(sp)) {
        rcu_read_lock();
 
        if (rxrpc_to_server(sp)) {
@@ -276,7 +289,7 @@ static int rxrpc_input_packet(struct rxrpc_local *local, struct sk_buff **_skb)
                }
        }
 
                }
        }
 
-       conn = rxrpc_find_connection_rcu(local, skb, &peer);
+       conn = rxrpc_find_connection_rcu(local, &peer_srx, skb, &peer);
        if (conn) {
                if (sp->hdr.securityIndex != conn->security_ix)
                        goto wrong_security;
        if (conn) {
                if (sp->hdr.securityIndex != conn->security_ix)
                        goto wrong_security;
@@ -389,7 +402,7 @@ static int rxrpc_input_packet(struct rxrpc_local *local, struct sk_buff **_skb)
                        rcu_read_unlock();
                        return 0;
                }
                        rcu_read_unlock();
                        return 0;
                }
-               call = rxrpc_new_incoming_call(local, rx, skb);
+               call = rxrpc_new_incoming_call(local, rx, &peer_srx, skb);
                if (!call) {
                        rcu_read_unlock();
                        goto reject_packet;
                if (!call) {
                        rcu_read_unlock();
                        goto reject_packet;