net: qrtr: ns: Change servers radix tree to xarray
authorVignesh Viswanathan <quic_viswanat@quicinc.com>
Fri, 14 Jul 2023 05:58:44 +0000 (11:28 +0530)
committerDavid S. Miller <davem@davemloft.net>
Mon, 17 Jul 2023 08:02:30 +0000 (09:02 +0100)
There is a use after free scenario while iterating through the servers
radix tree despite the ns being a single threaded process. This can
happen when the radix tree APIs are not synchronized with the
rcu_read_lock() APIs.

Convert the radix tree for servers to xarray to take advantage of the
built in rcu lock usage provided by xarray.

Signed-off-by: Chris Lew <quic_clew@quicinc.com>
Signed-off-by: Vignesh Viswanathan <quic_viswanat@quicinc.com>
Reviewed-by: Simon Horman <simon.horman@corigine.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
net/qrtr/ns.c

index 0f7a729..af28d9e 100644 (file)
@@ -66,7 +66,7 @@ struct qrtr_server {
 
 struct qrtr_node {
        unsigned int id;
-       struct radix_tree_root servers;
+       struct xarray servers;
 };
 
 static struct qrtr_node *node_get(unsigned int node_id)
@@ -83,6 +83,7 @@ static struct qrtr_node *node_get(unsigned int node_id)
                return NULL;
 
        node->id = node_id;
+       xa_init(&node->servers);
 
        if (radix_tree_insert(&nodes, node_id, node)) {
                kfree(node);
@@ -193,40 +194,23 @@ static void lookup_notify(struct sockaddr_qrtr *to, struct qrtr_server *srv,
 
 static int announce_servers(struct sockaddr_qrtr *sq)
 {
-       struct radix_tree_iter iter;
        struct qrtr_server *srv;
        struct qrtr_node *node;
-       void __rcu **slot;
+       unsigned long index;
        int ret;
 
        node = node_get(qrtr_ns.local_node);
        if (!node)
                return 0;
 
-       rcu_read_lock();
        /* Announce the list of servers registered in this node */
-       radix_tree_for_each_slot(slot, &node->servers, &iter, 0) {
-               srv = radix_tree_deref_slot(slot);
-               if (!srv)
-                       continue;
-               if (radix_tree_deref_retry(srv)) {
-                       slot = radix_tree_iter_retry(&iter);
-                       continue;
-               }
-               slot = radix_tree_iter_resume(slot, &iter);
-               rcu_read_unlock();
-
+       xa_for_each(&node->servers, index, srv) {
                ret = service_announce_new(sq, srv);
                if (ret < 0) {
                        pr_err("failed to announce new service\n");
                        return ret;
                }
-
-               rcu_read_lock();
        }
-
-       rcu_read_unlock();
-
        return 0;
 }
 
@@ -256,14 +240,17 @@ static struct qrtr_server *server_add(unsigned int service,
                goto err;
 
        /* Delete the old server on the same port */
-       old = radix_tree_lookup(&node->servers, port);
+       old = xa_store(&node->servers, port, srv, GFP_KERNEL);
        if (old) {
-               radix_tree_delete(&node->servers, port);
-               kfree(old);
+               if (xa_is_err(old)) {
+                       pr_err("failed to add server [0x%x:0x%x] ret:%d\n",
+                              srv->service, srv->instance, xa_err(old));
+                       goto err;
+               } else {
+                       kfree(old);
+               }
        }
 
-       radix_tree_insert(&node->servers, port, srv);
-
        trace_qrtr_ns_server_add(srv->service, srv->instance,
                                 srv->node, srv->port);
 
@@ -280,11 +267,11 @@ static int server_del(struct qrtr_node *node, unsigned int port, bool bcast)
        struct qrtr_server *srv;
        struct list_head *li;
 
-       srv = radix_tree_lookup(&node->servers, port);
+       srv = xa_load(&node->servers, port);
        if (!srv)
                return -ENOENT;
 
-       radix_tree_delete(&node->servers, port);
+       xa_erase(&node->servers, port);
 
        /* Broadcast the removal of local servers */
        if (srv->node == qrtr_ns.local_node && bcast)
@@ -344,13 +331,12 @@ static int ctrl_cmd_hello(struct sockaddr_qrtr *sq)
 static int ctrl_cmd_bye(struct sockaddr_qrtr *from)
 {
        struct qrtr_node *local_node;
-       struct radix_tree_iter iter;
        struct qrtr_ctrl_pkt pkt;
        struct qrtr_server *srv;
        struct sockaddr_qrtr sq;
        struct msghdr msg = { };
        struct qrtr_node *node;
-       void __rcu **slot;
+       unsigned long index;
        struct kvec iv;
        int ret;
 
@@ -361,22 +347,9 @@ static int ctrl_cmd_bye(struct sockaddr_qrtr *from)
        if (!node)
                return 0;
 
-       rcu_read_lock();
        /* Advertise removal of this client to all servers of remote node */
-       radix_tree_for_each_slot(slot, &node->servers, &iter, 0) {
-               srv = radix_tree_deref_slot(slot);
-               if (!srv)
-                       continue;
-               if (radix_tree_deref_retry(srv)) {
-                       slot = radix_tree_iter_retry(&iter);
-                       continue;
-               }
-               slot = radix_tree_iter_resume(slot, &iter);
-               rcu_read_unlock();
+       xa_for_each(&node->servers, index, srv)
                server_del(node, srv->port, true);
-               rcu_read_lock();
-       }
-       rcu_read_unlock();
 
        /* Advertise the removal of this client to all local servers */
        local_node = node_get(qrtr_ns.local_node);
@@ -387,18 +360,7 @@ static int ctrl_cmd_bye(struct sockaddr_qrtr *from)
        pkt.cmd = cpu_to_le32(QRTR_TYPE_BYE);
        pkt.client.node = cpu_to_le32(from->sq_node);
 
-       rcu_read_lock();
-       radix_tree_for_each_slot(slot, &local_node->servers, &iter, 0) {
-               srv = radix_tree_deref_slot(slot);
-               if (!srv)
-                       continue;
-               if (radix_tree_deref_retry(srv)) {
-                       slot = radix_tree_iter_retry(&iter);
-                       continue;
-               }
-               slot = radix_tree_iter_resume(slot, &iter);
-               rcu_read_unlock();
-
+       xa_for_each(&local_node->servers, index, srv) {
                sq.sq_family = AF_QIPCRTR;
                sq.sq_node = srv->node;
                sq.sq_port = srv->port;
@@ -411,11 +373,7 @@ static int ctrl_cmd_bye(struct sockaddr_qrtr *from)
                        pr_err("failed to send bye cmd\n");
                        return ret;
                }
-               rcu_read_lock();
        }
-
-       rcu_read_unlock();
-
        return 0;
 }
 
@@ -423,7 +381,6 @@ static int ctrl_cmd_del_client(struct sockaddr_qrtr *from,
                               unsigned int node_id, unsigned int port)
 {
        struct qrtr_node *local_node;
-       struct radix_tree_iter iter;
        struct qrtr_lookup *lookup;
        struct qrtr_ctrl_pkt pkt;
        struct msghdr msg = { };
@@ -432,7 +389,7 @@ static int ctrl_cmd_del_client(struct sockaddr_qrtr *from,
        struct qrtr_node *node;
        struct list_head *tmp;
        struct list_head *li;
-       void __rcu **slot;
+       unsigned long index;
        struct kvec iv;
        int ret;
 
@@ -477,18 +434,7 @@ static int ctrl_cmd_del_client(struct sockaddr_qrtr *from,
        pkt.client.node = cpu_to_le32(node_id);
        pkt.client.port = cpu_to_le32(port);
 
-       rcu_read_lock();
-       radix_tree_for_each_slot(slot, &local_node->servers, &iter, 0) {
-               srv = radix_tree_deref_slot(slot);
-               if (!srv)
-                       continue;
-               if (radix_tree_deref_retry(srv)) {
-                       slot = radix_tree_iter_retry(&iter);
-                       continue;
-               }
-               slot = radix_tree_iter_resume(slot, &iter);
-               rcu_read_unlock();
-
+       xa_for_each(&local_node->servers, index, srv) {
                sq.sq_family = AF_QIPCRTR;
                sq.sq_node = srv->node;
                sq.sq_port = srv->port;
@@ -501,11 +447,7 @@ static int ctrl_cmd_del_client(struct sockaddr_qrtr *from,
                        pr_err("failed to send del client cmd\n");
                        return ret;
                }
-               rcu_read_lock();
        }
-
-       rcu_read_unlock();
-
        return 0;
 }
 
@@ -576,13 +518,12 @@ static int ctrl_cmd_del_server(struct sockaddr_qrtr *from,
 static int ctrl_cmd_new_lookup(struct sockaddr_qrtr *from,
                               unsigned int service, unsigned int instance)
 {
-       struct radix_tree_iter node_iter;
        struct qrtr_server_filter filter;
-       struct radix_tree_iter srv_iter;
        struct qrtr_lookup *lookup;
+       struct qrtr_server *srv;
        struct qrtr_node *node;
-       void __rcu **node_slot;
-       void __rcu **srv_slot;
+       unsigned long node_idx;
+       unsigned long srv_idx;
 
        /* Accept only local observers */
        if (from->sq_node != qrtr_ns.local_node)
@@ -601,40 +542,14 @@ static int ctrl_cmd_new_lookup(struct sockaddr_qrtr *from,
        filter.service = service;
        filter.instance = instance;
 
-       rcu_read_lock();
-       radix_tree_for_each_slot(node_slot, &nodes, &node_iter, 0) {
-               node = radix_tree_deref_slot(node_slot);
-               if (!node)
-                       continue;
-               if (radix_tree_deref_retry(node)) {
-                       node_slot = radix_tree_iter_retry(&node_iter);
-                       continue;
-               }
-               node_slot = radix_tree_iter_resume(node_slot, &node_iter);
-
-               radix_tree_for_each_slot(srv_slot, &node->servers,
-                                        &srv_iter, 0) {
-                       struct qrtr_server *srv;
-
-                       srv = radix_tree_deref_slot(srv_slot);
-                       if (!srv)
-                               continue;
-                       if (radix_tree_deref_retry(srv)) {
-                               srv_slot = radix_tree_iter_retry(&srv_iter);
-                               continue;
-                       }
-
+       xa_for_each(&nodes, node_idx, node) {
+               xa_for_each(&node->servers, srv_idx, srv) {
                        if (!server_match(srv, &filter))
                                continue;
 
-                       srv_slot = radix_tree_iter_resume(srv_slot, &srv_iter);
-
-                       rcu_read_unlock();
                        lookup_notify(from, srv, true);
-                       rcu_read_lock();
                }
        }
-       rcu_read_unlock();
 
        /* Empty notification, to indicate end of listing */
        lookup_notify(from, NULL, true);