l2tp: prevent lockdep issue in l2tp_tunnel_register()
[platform/kernel/linux-rpi.git] / net / l2tp / l2tp_core.c
index 93271a2..a2b13e2 100644 (file)
@@ -104,9 +104,9 @@ static struct workqueue_struct *l2tp_wq;
 /* per-net private data for this module */
 static unsigned int l2tp_net_id;
 struct l2tp_net {
-       struct list_head l2tp_tunnel_list;
-       /* Lock for write access to l2tp_tunnel_list */
-       spinlock_t l2tp_tunnel_list_lock;
+       /* Lock for write access to l2tp_tunnel_idr */
+       spinlock_t l2tp_tunnel_idr_lock;
+       struct idr l2tp_tunnel_idr;
        struct hlist_head l2tp_session_hlist[L2TP_HASH_SIZE_2];
        /* Lock for write access to l2tp_session_hlist */
        spinlock_t l2tp_session_hlist_lock;
@@ -208,13 +208,10 @@ struct l2tp_tunnel *l2tp_tunnel_get(const struct net *net, u32 tunnel_id)
        struct l2tp_tunnel *tunnel;
 
        rcu_read_lock_bh();
-       list_for_each_entry_rcu(tunnel, &pn->l2tp_tunnel_list, list) {
-               if (tunnel->tunnel_id == tunnel_id &&
-                   refcount_inc_not_zero(&tunnel->ref_count)) {
-                       rcu_read_unlock_bh();
-
-                       return tunnel;
-               }
+       tunnel = idr_find(&pn->l2tp_tunnel_idr, tunnel_id);
+       if (tunnel && refcount_inc_not_zero(&tunnel->ref_count)) {
+               rcu_read_unlock_bh();
+               return tunnel;
        }
        rcu_read_unlock_bh();
 
@@ -224,13 +221,14 @@ EXPORT_SYMBOL_GPL(l2tp_tunnel_get);
 
 struct l2tp_tunnel *l2tp_tunnel_get_nth(const struct net *net, int nth)
 {
-       const struct l2tp_net *pn = l2tp_pernet(net);
+       struct l2tp_net *pn = l2tp_pernet(net);
+       unsigned long tunnel_id, tmp;
        struct l2tp_tunnel *tunnel;
        int count = 0;
 
        rcu_read_lock_bh();
-       list_for_each_entry_rcu(tunnel, &pn->l2tp_tunnel_list, list) {
-               if (++count > nth &&
+       idr_for_each_entry_ul(&pn->l2tp_tunnel_idr, tunnel, tmp, tunnel_id) {
+               if (tunnel && ++count > nth &&
                    refcount_inc_not_zero(&tunnel->ref_count)) {
                        rcu_read_unlock_bh();
                        return tunnel;
@@ -1043,7 +1041,7 @@ static int l2tp_xmit_core(struct l2tp_session *session, struct sk_buff *skb, uns
        IPCB(skb)->flags &= ~(IPSKB_XFRM_TUNNEL_SIZE | IPSKB_XFRM_TRANSFORMED | IPSKB_REROUTED);
        nf_reset_ct(skb);
 
-       bh_lock_sock(sk);
+       bh_lock_sock_nested(sk);
        if (sock_owned_by_user(sk)) {
                kfree_skb(skb);
                ret = NET_XMIT_DROP;
@@ -1150,8 +1148,10 @@ static void l2tp_tunnel_destruct(struct sock *sk)
        }
 
        /* Remove hooks into tunnel socket */
+       write_lock_bh(&sk->sk_callback_lock);
        sk->sk_destruct = tunnel->old_sk_destruct;
        sk->sk_user_data = NULL;
+       write_unlock_bh(&sk->sk_callback_lock);
 
        /* Call the original destructor */
        if (sk->sk_destruct)
@@ -1227,6 +1227,15 @@ static void l2tp_udp_encap_destroy(struct sock *sk)
                l2tp_tunnel_delete(tunnel);
 }
 
+static void l2tp_tunnel_remove(struct net *net, struct l2tp_tunnel *tunnel)
+{
+       struct l2tp_net *pn = l2tp_pernet(net);
+
+       spin_lock_bh(&pn->l2tp_tunnel_idr_lock);
+       idr_remove(&pn->l2tp_tunnel_idr, tunnel->tunnel_id);
+       spin_unlock_bh(&pn->l2tp_tunnel_idr_lock);
+}
+
 /* Workqueue tunnel deletion function */
 static void l2tp_tunnel_del_work(struct work_struct *work)
 {
@@ -1234,7 +1243,6 @@ static void l2tp_tunnel_del_work(struct work_struct *work)
                                                  del_work);
        struct sock *sk = tunnel->sock;
        struct socket *sock = sk->sk_socket;
-       struct l2tp_net *pn;
 
        l2tp_tunnel_closeall(tunnel);
 
@@ -1248,12 +1256,7 @@ static void l2tp_tunnel_del_work(struct work_struct *work)
                }
        }
 
-       /* Remove the tunnel struct from the tunnel list */
-       pn = l2tp_pernet(tunnel->l2tp_net);
-       spin_lock_bh(&pn->l2tp_tunnel_list_lock);
-       list_del_rcu(&tunnel->list);
-       spin_unlock_bh(&pn->l2tp_tunnel_list_lock);
-
+       l2tp_tunnel_remove(tunnel->l2tp_net, tunnel);
        /* drop initial ref */
        l2tp_tunnel_dec_refcount(tunnel);
 
@@ -1384,8 +1387,6 @@ out:
        return err;
 }
 
-static struct lock_class_key l2tp_socket_class;
-
 int l2tp_tunnel_create(int fd, int version, u32 tunnel_id, u32 peer_tunnel_id,
                       struct l2tp_tunnel_cfg *cfg, struct l2tp_tunnel **tunnelp)
 {
@@ -1455,12 +1456,19 @@ static int l2tp_validate_socket(const struct sock *sk, const struct net *net,
 int l2tp_tunnel_register(struct l2tp_tunnel *tunnel, struct net *net,
                         struct l2tp_tunnel_cfg *cfg)
 {
-       struct l2tp_tunnel *tunnel_walk;
-       struct l2tp_net *pn;
+       struct l2tp_net *pn = l2tp_pernet(net);
+       u32 tunnel_id = tunnel->tunnel_id;
        struct socket *sock;
        struct sock *sk;
        int ret;
 
+       spin_lock_bh(&pn->l2tp_tunnel_idr_lock);
+       ret = idr_alloc_u32(&pn->l2tp_tunnel_idr, NULL, &tunnel_id, tunnel_id,
+                           GFP_ATOMIC);
+       spin_unlock_bh(&pn->l2tp_tunnel_idr_lock);
+       if (ret)
+               return ret == -ENOSPC ? -EEXIST : ret;
+
        if (tunnel->fd < 0) {
                ret = l2tp_tunnel_sock_create(net, tunnel->tunnel_id,
                                              tunnel->peer_tunnel_id, cfg,
@@ -1471,30 +1479,16 @@ int l2tp_tunnel_register(struct l2tp_tunnel *tunnel, struct net *net,
                sock = sockfd_lookup(tunnel->fd, &ret);
                if (!sock)
                        goto err;
-
-               ret = l2tp_validate_socket(sock->sk, net, tunnel->encap);
-               if (ret < 0)
-                       goto err_sock;
        }
 
-       tunnel->l2tp_net = net;
-       pn = l2tp_pernet(net);
-
        sk = sock->sk;
-       sock_hold(sk);
-       tunnel->sock = sk;
-
-       spin_lock_bh(&pn->l2tp_tunnel_list_lock);
-       list_for_each_entry(tunnel_walk, &pn->l2tp_tunnel_list, list) {
-               if (tunnel_walk->tunnel_id == tunnel->tunnel_id) {
-                       spin_unlock_bh(&pn->l2tp_tunnel_list_lock);
-                       sock_put(sk);
-                       ret = -EEXIST;
-                       goto err_sock;
-               }
-       }
-       list_add_rcu(&tunnel->list, &pn->l2tp_tunnel_list);
-       spin_unlock_bh(&pn->l2tp_tunnel_list_lock);
+       lock_sock(sk);
+       write_lock_bh(&sk->sk_callback_lock);
+       ret = l2tp_validate_socket(sk, net, tunnel->encap);
+       if (ret < 0)
+               goto err_inval_sock;
+       rcu_assign_sk_user_data(sk, tunnel);
+       write_unlock_bh(&sk->sk_callback_lock);
 
        if (tunnel->encap == L2TP_ENCAPTYPE_UDP) {
                struct udp_tunnel_sock_cfg udp_cfg = {
@@ -1505,15 +1499,20 @@ int l2tp_tunnel_register(struct l2tp_tunnel *tunnel, struct net *net,
                };
 
                setup_udp_tunnel_sock(net, sock, &udp_cfg);
-       } else {
-               sk->sk_user_data = tunnel;
        }
 
        tunnel->old_sk_destruct = sk->sk_destruct;
        sk->sk_destruct = &l2tp_tunnel_destruct;
-       lockdep_set_class_and_name(&sk->sk_lock.slock, &l2tp_socket_class,
-                                  "l2tp_sock");
        sk->sk_allocation = GFP_ATOMIC;
+       release_sock(sk);
+
+       sock_hold(sk);
+       tunnel->sock = sk;
+       tunnel->l2tp_net = net;
+
+       spin_lock_bh(&pn->l2tp_tunnel_idr_lock);
+       idr_replace(&pn->l2tp_tunnel_idr, tunnel, tunnel->tunnel_id);
+       spin_unlock_bh(&pn->l2tp_tunnel_idr_lock);
 
        trace_register_tunnel(tunnel);
 
@@ -1522,12 +1521,16 @@ int l2tp_tunnel_register(struct l2tp_tunnel *tunnel, struct net *net,
 
        return 0;
 
-err_sock:
+err_inval_sock:
+       write_unlock_bh(&sk->sk_callback_lock);
+       release_sock(sk);
+
        if (tunnel->fd < 0)
                sock_release(sock);
        else
                sockfd_put(sock);
 err:
+       l2tp_tunnel_remove(net, tunnel);
        return ret;
 }
 EXPORT_SYMBOL_GPL(l2tp_tunnel_register);
@@ -1641,8 +1644,8 @@ static __net_init int l2tp_init_net(struct net *net)
        struct l2tp_net *pn = net_generic(net, l2tp_net_id);
        int hash;
 
-       INIT_LIST_HEAD(&pn->l2tp_tunnel_list);
-       spin_lock_init(&pn->l2tp_tunnel_list_lock);
+       idr_init(&pn->l2tp_tunnel_idr);
+       spin_lock_init(&pn->l2tp_tunnel_idr_lock);
 
        for (hash = 0; hash < L2TP_HASH_SIZE_2; hash++)
                INIT_HLIST_HEAD(&pn->l2tp_session_hlist[hash]);
@@ -1656,11 +1659,13 @@ static __net_exit void l2tp_exit_net(struct net *net)
 {
        struct l2tp_net *pn = l2tp_pernet(net);
        struct l2tp_tunnel *tunnel = NULL;
+       unsigned long tunnel_id, tmp;
        int hash;
 
        rcu_read_lock_bh();
-       list_for_each_entry_rcu(tunnel, &pn->l2tp_tunnel_list, list) {
-               l2tp_tunnel_delete(tunnel);
+       idr_for_each_entry_ul(&pn->l2tp_tunnel_idr, tunnel, tmp, tunnel_id) {
+               if (tunnel)
+                       l2tp_tunnel_delete(tunnel);
        }
        rcu_read_unlock_bh();
 
@@ -1670,6 +1675,7 @@ static __net_exit void l2tp_exit_net(struct net *net)
 
        for (hash = 0; hash < L2TP_HASH_SIZE_2; hash++)
                WARN_ON_ONCE(!hlist_empty(&pn->l2tp_session_hlist[hash]));
+       idr_destroy(&pn->l2tp_tunnel_idr);
 }
 
 static struct pernet_operations l2tp_net_ops = {