netlink: fix splat in skb_clone with large messages
[platform/adaptation/renesas_rcar/renesas_kernel.git] / net / netlink / af_netlink.c
index d0b3dd6..0c61b59 100644 (file)
@@ -57,6 +57,7 @@
 #include <linux/audit.h>
 #include <linux/mutex.h>
 #include <linux/vmalloc.h>
+#include <linux/if_arp.h>
 #include <asm/cacheflush.h>
 
 #include <net/net_namespace.h>
@@ -101,6 +102,9 @@ static atomic_t nl_table_users = ATOMIC_INIT(0);
 
 static ATOMIC_NOTIFIER_HEAD(netlink_chain);
 
+static DEFINE_SPINLOCK(netlink_tap_lock);
+static struct list_head netlink_tap_all __read_mostly;
+
 static inline u32 netlink_group_mask(u32 group)
 {
        return group ? 1 << (group - 1) : 0;
@@ -111,6 +115,100 @@ static inline struct hlist_head *nl_portid_hashfn(struct nl_portid_hash *hash, u
        return &hash->table[jhash_1word(portid, hash->rnd) & hash->mask];
 }
 
+int netlink_add_tap(struct netlink_tap *nt)
+{
+       if (unlikely(nt->dev->type != ARPHRD_NETLINK))
+               return -EINVAL;
+
+       spin_lock(&netlink_tap_lock);
+       list_add_rcu(&nt->list, &netlink_tap_all);
+       spin_unlock(&netlink_tap_lock);
+
+       if (nt->module)
+               __module_get(nt->module);
+
+       return 0;
+}
+EXPORT_SYMBOL_GPL(netlink_add_tap);
+
+int __netlink_remove_tap(struct netlink_tap *nt)
+{
+       bool found = false;
+       struct netlink_tap *tmp;
+
+       spin_lock(&netlink_tap_lock);
+
+       list_for_each_entry(tmp, &netlink_tap_all, list) {
+               if (nt == tmp) {
+                       list_del_rcu(&nt->list);
+                       found = true;
+                       goto out;
+               }
+       }
+
+       pr_warn("__netlink_remove_tap: %p not found\n", nt);
+out:
+       spin_unlock(&netlink_tap_lock);
+
+       if (found && nt->module)
+               module_put(nt->module);
+
+       return found ? 0 : -ENODEV;
+}
+EXPORT_SYMBOL_GPL(__netlink_remove_tap);
+
+int netlink_remove_tap(struct netlink_tap *nt)
+{
+       int ret;
+
+       ret = __netlink_remove_tap(nt);
+       synchronize_net();
+
+       return ret;
+}
+EXPORT_SYMBOL_GPL(netlink_remove_tap);
+
+static int __netlink_deliver_tap_skb(struct sk_buff *skb,
+                                    struct net_device *dev)
+{
+       struct sk_buff *nskb;
+       int ret = -ENOMEM;
+
+       dev_hold(dev);
+       nskb = skb_clone(skb, GFP_ATOMIC);
+       if (nskb) {
+               nskb->dev = dev;
+               ret = dev_queue_xmit(nskb);
+               if (unlikely(ret > 0))
+                       ret = net_xmit_errno(ret);
+       }
+
+       dev_put(dev);
+       return ret;
+}
+
+static void __netlink_deliver_tap(struct sk_buff *skb)
+{
+       int ret;
+       struct netlink_tap *tmp;
+
+       list_for_each_entry_rcu(tmp, &netlink_tap_all, list) {
+               ret = __netlink_deliver_tap_skb(skb, tmp->dev);
+               if (unlikely(ret))
+                       break;
+       }
+}
+
+static void netlink_deliver_tap(struct sk_buff *skb)
+{
+       rcu_read_lock();
+
+       if (unlikely(!list_empty(&netlink_tap_all)))
+               __netlink_deliver_tap(skb);
+
+       rcu_read_unlock();
+}
+
 static void netlink_overrun(struct sock *sk)
 {
        struct netlink_sock *nlk = nlk_sk(sk);
@@ -371,7 +469,7 @@ static int netlink_mmap(struct file *file, struct socket *sock,
        err = 0;
 out:
        mutex_unlock(&nlk->pg_vec_lock);
-       return 0;
+       return err;
 }
 
 static void netlink_frame_flush_dcache(const struct nl_mmap_hdr *hdr)
@@ -750,6 +848,13 @@ static void netlink_skb_destructor(struct sk_buff *skb)
                skb->head = NULL;
        }
 #endif
+       if (is_vmalloc_addr(skb->head)) {
+               if (!skb->cloned ||
+                   !atomic_dec_return(&(skb_shinfo(skb)->dataref)))
+                       vfree(skb->head);
+
+               skb->head = NULL;
+       }
        if (skb->sk != NULL)
                sock_rfree(skb);
 }
@@ -854,16 +959,23 @@ netlink_unlock_table(void)
                wake_up(&nl_table_wait);
 }
 
+static bool netlink_compare(struct net *net, struct sock *sk)
+{
+       return net_eq(sock_net(sk), net);
+}
+
 static struct sock *netlink_lookup(struct net *net, int protocol, u32 portid)
 {
-       struct nl_portid_hash *hash = &nl_table[protocol].hash;
+       struct netlink_table *table = &nl_table[protocol];
+       struct nl_portid_hash *hash = &table->hash;
        struct hlist_head *head;
        struct sock *sk;
 
        read_lock(&nl_table_lock);
        head = nl_portid_hashfn(hash, portid);
        sk_for_each(sk, head) {
-               if (net_eq(sock_net(sk), net) && (nlk_sk(sk)->portid == portid)) {
+               if (table->compare(net, sk) &&
+                   (nlk_sk(sk)->portid == portid)) {
                        sock_hold(sk);
                        goto found;
                }
@@ -976,7 +1088,8 @@ netlink_update_listeners(struct sock *sk)
 
 static int netlink_insert(struct sock *sk, struct net *net, u32 portid)
 {
-       struct nl_portid_hash *hash = &nl_table[sk->sk_protocol].hash;
+       struct netlink_table *table = &nl_table[sk->sk_protocol];
+       struct nl_portid_hash *hash = &table->hash;
        struct hlist_head *head;
        int err = -EADDRINUSE;
        struct sock *osk;
@@ -986,7 +1099,8 @@ static int netlink_insert(struct sock *sk, struct net *net, u32 portid)
        head = nl_portid_hashfn(hash, portid);
        len = 0;
        sk_for_each(osk, head) {
-               if (net_eq(sock_net(osk), net) && (nlk_sk(osk)->portid == portid))
+               if (table->compare(net, osk) &&
+                   (nlk_sk(osk)->portid == portid))
                        break;
                len++;
        }
@@ -1183,7 +1297,8 @@ static int netlink_autobind(struct socket *sock)
 {
        struct sock *sk = sock->sk;
        struct net *net = sock_net(sk);
-       struct nl_portid_hash *hash = &nl_table[sk->sk_protocol].hash;
+       struct netlink_table *table = &nl_table[sk->sk_protocol];
+       struct nl_portid_hash *hash = &table->hash;
        struct hlist_head *head;
        struct sock *osk;
        s32 portid = task_tgid_vnr(current);
@@ -1195,7 +1310,7 @@ retry:
        netlink_table_grab();
        head = nl_portid_hashfn(hash, portid);
        sk_for_each(osk, head) {
-               if (!net_eq(sock_net(osk), net))
+               if (!table->compare(net, osk))
                        continue;
                if (nlk_sk(osk)->portid == portid) {
                        /* Bind collision, search negative portid values. */
@@ -1420,6 +1535,33 @@ struct sock *netlink_getsockbyfilp(struct file *filp)
        return sock;
 }
 
+static struct sk_buff *netlink_alloc_large_skb(unsigned int size,
+                                              int broadcast)
+{
+       struct sk_buff *skb;
+       void *data;
+
+       if (size <= NLMSG_GOODSIZE || broadcast)
+               return alloc_skb(size, GFP_KERNEL);
+
+       size = SKB_DATA_ALIGN(size) +
+              SKB_DATA_ALIGN(sizeof(struct skb_shared_info));
+
+       data = vmalloc(size);
+       if (data == NULL)
+               return NULL;
+
+       skb = build_skb(data, size);
+       if (skb == NULL)
+               vfree(data);
+       else {
+               skb->head_frag = 0;
+               skb->destructor = netlink_skb_destructor;
+       }
+
+       return skb;
+}
+
 /*
  * Attach a skb to a netlink socket.
  * The caller must hold a reference to the destination socket. On error, the
@@ -1475,6 +1617,8 @@ static int __netlink_sendskb(struct sock *sk, struct sk_buff *skb)
 {
        int len = skb->len;
 
+       netlink_deliver_tap(skb);
+
 #ifdef CONFIG_NETLINK_MMAP
        if (netlink_skb_is_mmaped(skb))
                netlink_queue_mmaped_skb(sk, skb);
@@ -1510,7 +1654,7 @@ static struct sk_buff *netlink_trim(struct sk_buff *skb, gfp_t allocation)
                return skb;
 
        delta = skb->end - skb->tail;
-       if (delta * 2 < skb->truesize)
+       if (is_vmalloc_addr(skb->head) || delta * 2 < skb->truesize)
                return skb;
 
        if (skb_shared(skb)) {
@@ -1535,6 +1679,11 @@ static int netlink_unicast_kernel(struct sock *sk, struct sk_buff *skb,
 
        ret = -ECONNREFUSED;
        if (nlk->netlink_rcv != NULL) {
+               /* We could do a netlink_deliver_tap(skb) here as well
+                * but since this is intended for the kernel only, we
+                * should rather let it stay under the hood.
+                */
+
                ret = skb->len;
                netlink_skb_set_owner_r(skb, sk);
                NETLINK_CB(skb).sk = ssk;
@@ -2096,7 +2245,7 @@ static int netlink_sendmsg(struct kiocb *kiocb, struct socket *sock,
        if (len > sk->sk_sndbuf - 32)
                goto out;
        err = -ENOBUFS;
-       skb = alloc_skb(len, GFP_KERNEL);
+       skb = netlink_alloc_large_skb(len, dst_group);
        if (skb == NULL)
                goto out;
 
@@ -2285,6 +2434,8 @@ __netlink_kernel_create(struct net *net, int unit, struct module *module,
                if (cfg) {
                        nl_table[unit].bind = cfg->bind;
                        nl_table[unit].flags = cfg->flags;
+                       if (cfg->compare)
+                               nl_table[unit].compare = cfg->compare;
                }
                nl_table[unit].registered = 1;
        } else {
@@ -2707,6 +2858,7 @@ static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos)
 {
        struct sock *s;
        struct nl_seq_iter *iter;
+       struct net *net;
        int i, j;
 
        ++*pos;
@@ -2714,11 +2866,12 @@ static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos)
        if (v == SEQ_START_TOKEN)
                return netlink_seq_socket_idx(seq, 0);
 
+       net = seq_file_net(seq);
        iter = seq->private;
        s = v;
        do {
                s = sk_next(s);
-       } while (s && sock_net(s) != seq_file_net(seq));
+       } while (s && !nl_table[s->sk_protocol].compare(net, s));
        if (s)
                return s;
 
@@ -2730,7 +2883,8 @@ static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos)
 
                for (; j <= hash->mask; j++) {
                        s = sk_head(&hash->table[j]);
-                       while (s && sock_net(s) != seq_file_net(seq))
+
+                       while (s && !nl_table[s->sk_protocol].compare(net, s))
                                s = sk_next(s);
                        if (s) {
                                iter->link = i;
@@ -2923,8 +3077,12 @@ static int __init netlink_proto_init(void)
                hash->shift = 0;
                hash->mask = 0;
                hash->rehash_time = jiffies;
+
+               nl_table[i].compare = netlink_compare;
        }
 
+       INIT_LIST_HEAD(&netlink_tap_all);
+
        netlink_add_usersock_entry();
 
        sock_register(&netlink_family_ops);