inet_diag: Remove indirect sizeof from inet diag handlers
[platform/kernel/linux-starfive.git] / net / ipv4 / inet_diag.c
index 68e8ac5..a247f85 100644 (file)
@@ -33,6 +33,7 @@
 #include <linux/stddef.h>
 
 #include <linux/inet_diag.h>
+#include <linux/sock_diag.h>
 
 static const struct inet_diag_handler **inet_diag_table;
 
@@ -45,24 +46,22 @@ struct inet_diag_entry {
        u16 userlocks;
 };
 
-static struct sock *idiagnl;
-
 #define INET_DIAG_PUT(skb, attrtype, attrlen) \
        RTA_DATA(__RTA_PUT(skb, attrtype, attrlen))
 
 static DEFINE_MUTEX(inet_diag_table_mutex);
 
-static const struct inet_diag_handler *inet_diag_lock_handler(int type)
+static const struct inet_diag_handler *inet_diag_lock_handler(int proto)
 {
-       if (!inet_diag_table[type])
+       if (!inet_diag_table[proto])
                request_module("net-pf-%d-proto-%d-type-%d", PF_NETLINK,
-                              NETLINK_INET_DIAG, type);
+                              NETLINK_SOCK_DIAG, proto);
 
        mutex_lock(&inet_diag_table_mutex);
-       if (!inet_diag_table[type])
+       if (!inet_diag_table[proto])
                return ERR_PTR(-ENOENT);
 
-       return inet_diag_table[type];
+       return inet_diag_table[proto];
 }
 
 static inline void inet_diag_unlock_handler(
@@ -72,8 +71,8 @@ static inline void inet_diag_unlock_handler(
 }
 
 static int inet_csk_diag_fill(struct sock *sk,
-                             struct sk_buff *skb,
-                             int ext, u32 pid, u32 seq, u16 nlmsg_flags,
+                             struct sk_buff *skb, struct inet_diag_req *req,
+                             u32 pid, u32 seq, u16 nlmsg_flags,
                              const struct nlmsghdr *unlh)
 {
        const struct inet_sock *inet = inet_sk(sk);
@@ -84,8 +83,9 @@ static int inet_csk_diag_fill(struct sock *sk,
        struct inet_diag_meminfo  *minfo = NULL;
        unsigned char    *b = skb_tail_pointer(skb);
        const struct inet_diag_handler *handler;
+       int ext = req->idiag_ext;
 
-       handler = inet_diag_table[unlh->nlmsg_type];
+       handler = inet_diag_table[req->sdiag_protocol];
        BUG_ON(handler == NULL);
 
        nlh = NLMSG_PUT(skb, pid, seq, unlh->nlmsg_type, sizeof(*r));
@@ -98,8 +98,7 @@ static int inet_csk_diag_fill(struct sock *sk,
                minfo = INET_DIAG_PUT(skb, INET_DIAG_MEMINFO, sizeof(*minfo));
 
        if (ext & (1 << (INET_DIAG_INFO - 1)))
-               info = INET_DIAG_PUT(skb, INET_DIAG_INFO,
-                                    handler->idiag_info_size);
+               info = INET_DIAG_PUT(skb, INET_DIAG_INFO, sizeof(struct tcp_info));
 
        if ((ext & (1 << (INET_DIAG_CONG - 1))) && icsk->icsk_ca_ops) {
                const size_t len = strlen(icsk->icsk_ca_ops->name);
@@ -108,9 +107,6 @@ static int inet_csk_diag_fill(struct sock *sk,
                       icsk->icsk_ca_ops->name);
        }
 
-       if ((ext & (1 << (INET_DIAG_TOS - 1))) && (sk->sk_family != AF_INET6))
-               RTA_PUT_U8(skb, INET_DIAG_TOS, inet->tos);
-
        r->idiag_family = sk->sk_family;
        r->idiag_state = sk->sk_state;
        r->idiag_timer = 0;
@@ -125,14 +121,18 @@ static int inet_csk_diag_fill(struct sock *sk,
        r->id.idiag_src[0] = inet->inet_rcv_saddr;
        r->id.idiag_dst[0] = inet->inet_daddr;
 
+       /* IPv6 dual-stack sockets use inet->tos for IPv4 connections,
+        * hence this needs to be included regardless of socket family.
+        */
+       if (ext & (1 << (INET_DIAG_TOS - 1)))
+               RTA_PUT_U8(skb, INET_DIAG_TOS, inet->tos);
+
 #if defined(CONFIG_IPV6) || defined (CONFIG_IPV6_MODULE)
        if (r->idiag_family == AF_INET6) {
                const struct ipv6_pinfo *np = inet6_sk(sk);
 
-               ipv6_addr_copy((struct in6_addr *)r->id.idiag_src,
-                              &np->rcv_saddr);
-               ipv6_addr_copy((struct in6_addr *)r->id.idiag_dst,
-                              &np->daddr);
+               *(struct in6_addr *)r->id.idiag_src = np->rcv_saddr;
+               *(struct in6_addr *)r->id.idiag_dst = np->daddr;
                if (ext & (1 << (INET_DIAG_TCLASS - 1)))
                        RTA_PUT_U8(skb, INET_DIAG_TCLASS, np->tclass);
        }
@@ -184,8 +184,8 @@ nlmsg_failure:
 }
 
 static int inet_twsk_diag_fill(struct inet_timewait_sock *tw,
-                              struct sk_buff *skb, int ext, u32 pid,
-                              u32 seq, u16 nlmsg_flags,
+                              struct sk_buff *skb, struct inet_diag_req *req,
+                              u32 pid, u32 seq, u16 nlmsg_flags,
                               const struct nlmsghdr *unlh)
 {
        long tmo;
@@ -224,10 +224,8 @@ static int inet_twsk_diag_fill(struct inet_timewait_sock *tw,
                const struct inet6_timewait_sock *tw6 =
                                                inet6_twsk((struct sock *)tw);
 
-               ipv6_addr_copy((struct in6_addr *)r->id.idiag_src,
-                              &tw6->tw_v6_rcv_saddr);
-               ipv6_addr_copy((struct in6_addr *)r->id.idiag_dst,
-                              &tw6->tw_v6_daddr);
+               *(struct in6_addr *)r->id.idiag_src = tw6->tw_v6_rcv_saddr;
+               *(struct in6_addr *)r->id.idiag_dst = tw6->tw_v6_daddr;
        }
 #endif
        nlh->nlmsg_len = skb_tail_pointer(skb) - previous_tail;
@@ -238,27 +236,27 @@ nlmsg_failure:
 }
 
 static int sk_diag_fill(struct sock *sk, struct sk_buff *skb,
-                       int ext, u32 pid, u32 seq, u16 nlmsg_flags,
+                       struct inet_diag_req *r, u32 pid, u32 seq, u16 nlmsg_flags,
                        const struct nlmsghdr *unlh)
 {
        if (sk->sk_state == TCP_TIME_WAIT)
                return inet_twsk_diag_fill((struct inet_timewait_sock *)sk,
-                                          skb, ext, pid, seq, nlmsg_flags,
+                                          skb, r, pid, seq, nlmsg_flags,
                                           unlh);
-       return inet_csk_diag_fill(sk, skb, ext, pid, seq, nlmsg_flags, unlh);
+       return inet_csk_diag_fill(sk, skb, r, pid, seq, nlmsg_flags, unlh);
 }
 
 static int inet_diag_get_exact(struct sk_buff *in_skb,
-                              const struct nlmsghdr *nlh)
+                              const struct nlmsghdr *nlh,
+                              struct inet_diag_req *req)
 {
        int err;
        struct sock *sk;
-       struct inet_diag_req *req = NLMSG_DATA(nlh);
        struct sk_buff *rep;
        struct inet_hashinfo *hashinfo;
        const struct inet_diag_handler *handler;
 
-       handler = inet_diag_lock_handler(nlh->nlmsg_type);
+       handler = inet_diag_lock_handler(req->sdiag_protocol);
        if (IS_ERR(handler)) {
                err = PTR_ERR(handler);
                goto unlock;
@@ -267,13 +265,13 @@ static int inet_diag_get_exact(struct sk_buff *in_skb,
        hashinfo = handler->idiag_hashinfo;
        err = -EINVAL;
 
-       if (req->idiag_family == AF_INET) {
+       if (req->sdiag_family == AF_INET) {
                sk = inet_lookup(&init_net, hashinfo, req->id.idiag_dst[0],
                                 req->id.idiag_dport, req->id.idiag_src[0],
                                 req->id.idiag_sport, req->id.idiag_if);
        }
 #if defined(CONFIG_IPV6) || defined (CONFIG_IPV6_MODULE)
-       else if (req->idiag_family == AF_INET6) {
+       else if (req->sdiag_family == AF_INET6) {
                sk = inet6_lookup(&init_net, hashinfo,
                                  (struct in6_addr *)req->id.idiag_dst,
                                  req->id.idiag_dport,
@@ -300,12 +298,12 @@ static int inet_diag_get_exact(struct sk_buff *in_skb,
        err = -ENOMEM;
        rep = alloc_skb(NLMSG_SPACE((sizeof(struct inet_diag_msg) +
                                     sizeof(struct inet_diag_meminfo) +
-                                    handler->idiag_info_size + 64)),
+                                    sizeof(struct tcp_info) + 64)),
                        GFP_KERNEL);
        if (!rep)
                goto out;
 
-       err = sk_diag_fill(sk, rep, req->idiag_ext,
+       err = sk_diag_fill(sk, rep, req,
                           NETLINK_CB(in_skb).pid,
                           nlh->nlmsg_seq, 0, nlh);
        if (err < 0) {
@@ -313,7 +311,7 @@ static int inet_diag_get_exact(struct sk_buff *in_skb,
                kfree_skb(rep);
                goto out;
        }
-       err = netlink_unicast(idiagnl, rep, NETLINK_CB(in_skb).pid,
+       err = netlink_unicast(sock_diag_nlsk, rep, NETLINK_CB(in_skb).pid,
                              MSG_DONTWAIT);
        if (err > 0)
                err = 0;
@@ -489,15 +487,12 @@ static int inet_diag_bc_audit(const void *bytecode, int bytecode_len)
 
 static int inet_csk_diag_dump(struct sock *sk,
                              struct sk_buff *skb,
-                             struct netlink_callback *cb)
+                             struct netlink_callback *cb,
+                             struct inet_diag_req *r,
+                             const struct nlattr *bc)
 {
-       struct inet_diag_req *r = NLMSG_DATA(cb->nlh);
-
-       if (nlmsg_attrlen(cb->nlh, sizeof(*r))) {
+       if (bc != NULL) {
                struct inet_diag_entry entry;
-               const struct nlattr *bc = nlmsg_find_attr(cb->nlh,
-                                                         sizeof(*r),
-                                                         INET_DIAG_REQ_BYTECODE);
                struct inet_sock *inet = inet_sk(sk);
 
                entry.family = sk->sk_family;
@@ -521,22 +516,19 @@ static int inet_csk_diag_dump(struct sock *sk,
                        return 0;
        }
 
-       return inet_csk_diag_fill(sk, skb, r->idiag_ext,
+       return inet_csk_diag_fill(sk, skb, r,
                                  NETLINK_CB(cb->skb).pid,
                                  cb->nlh->nlmsg_seq, NLM_F_MULTI, cb->nlh);
 }
 
 static int inet_twsk_diag_dump(struct inet_timewait_sock *tw,
                               struct sk_buff *skb,
-                              struct netlink_callback *cb)
+                              struct netlink_callback *cb,
+                              struct inet_diag_req *r,
+                              const struct nlattr *bc)
 {
-       struct inet_diag_req *r = NLMSG_DATA(cb->nlh);
-
-       if (nlmsg_attrlen(cb->nlh, sizeof(*r))) {
+       if (bc != NULL) {
                struct inet_diag_entry entry;
-               const struct nlattr *bc = nlmsg_find_attr(cb->nlh,
-                                                         sizeof(*r),
-                                                         INET_DIAG_REQ_BYTECODE);
 
                entry.family = tw->tw_family;
 #if defined(CONFIG_IPV6) || defined (CONFIG_IPV6_MODULE)
@@ -559,7 +551,7 @@ static int inet_twsk_diag_dump(struct inet_timewait_sock *tw,
                        return 0;
        }
 
-       return inet_twsk_diag_fill(tw, skb, r->idiag_ext,
+       return inet_twsk_diag_fill(tw, skb, r,
                                   NETLINK_CB(cb->skb).pid,
                                   cb->nlh->nlmsg_seq, NLM_F_MULTI, cb->nlh);
 }
@@ -603,10 +595,8 @@ static int inet_diag_fill_req(struct sk_buff *skb, struct sock *sk,
        r->idiag_inode = 0;
 #if defined(CONFIG_IPV6) || defined (CONFIG_IPV6_MODULE)
        if (r->idiag_family == AF_INET6) {
-               ipv6_addr_copy((struct in6_addr *)r->id.idiag_src,
-                              &inet6_rsk(req)->loc_addr);
-               ipv6_addr_copy((struct in6_addr *)r->id.idiag_dst,
-                              &inet6_rsk(req)->rmt_addr);
+               *(struct in6_addr *)r->id.idiag_src = inet6_rsk(req)->loc_addr;
+               *(struct in6_addr *)r->id.idiag_dst = inet6_rsk(req)->rmt_addr;
        }
 #endif
        nlh->nlmsg_len = skb_tail_pointer(skb) - b;
@@ -619,13 +609,13 @@ nlmsg_failure:
 }
 
 static int inet_diag_dump_reqs(struct sk_buff *skb, struct sock *sk,
-                              struct netlink_callback *cb)
+                              struct netlink_callback *cb,
+                              struct inet_diag_req *r,
+                              const struct nlattr *bc)
 {
        struct inet_diag_entry entry;
-       struct inet_diag_req *r = NLMSG_DATA(cb->nlh);
        struct inet_connection_sock *icsk = inet_csk(sk);
        struct listen_sock *lopt;
-       const struct nlattr *bc = NULL;
        struct inet_sock *inet = inet_sk(sk);
        int j, s_j;
        int reqnum, s_reqnum;
@@ -645,9 +635,7 @@ static int inet_diag_dump_reqs(struct sk_buff *skb, struct sock *sk,
        if (!lopt || !lopt->qlen)
                goto out;
 
-       if (nlmsg_attrlen(cb->nlh, sizeof(*r))) {
-               bc = nlmsg_find_attr(cb->nlh, sizeof(*r),
-                                    INET_DIAG_REQ_BYTECODE);
+       if (bc != NULL) {
                entry.sport = inet->inet_num;
                entry.userlocks = sk->sk_userlocks;
        }
@@ -704,15 +692,15 @@ out:
        return err;
 }
 
-static int inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
+static int __inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
+               struct inet_diag_req *r, struct nlattr *bc)
 {
        int i, num;
        int s_i, s_num;
-       struct inet_diag_req *r = NLMSG_DATA(cb->nlh);
        const struct inet_diag_handler *handler;
        struct inet_hashinfo *hashinfo;
 
-       handler = inet_diag_lock_handler(cb->nlh->nlmsg_type);
+       handler = inet_diag_lock_handler(r->sdiag_protocol);
        if (IS_ERR(handler))
                goto unlock;
 
@@ -741,6 +729,10 @@ static int inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
                                        continue;
                                }
 
+                               if (r->sdiag_family != AF_UNSPEC &&
+                                               sk->sk_family != r->sdiag_family)
+                                       goto next_listen;
+
                                if (r->id.idiag_sport != inet->inet_sport &&
                                    r->id.idiag_sport)
                                        goto next_listen;
@@ -750,7 +742,7 @@ static int inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
                                    cb->args[3] > 0)
                                        goto syn_recv;
 
-                               if (inet_csk_diag_dump(sk, skb, cb) < 0) {
+                               if (inet_csk_diag_dump(sk, skb, cb, r, bc) < 0) {
                                        spin_unlock_bh(&ilb->lock);
                                        goto done;
                                }
@@ -759,7 +751,7 @@ syn_recv:
                                if (!(r->idiag_states & TCPF_SYN_RECV))
                                        goto next_listen;
 
-                               if (inet_diag_dump_reqs(skb, sk, cb) < 0) {
+                               if (inet_diag_dump_reqs(skb, sk, cb, r, bc) < 0) {
                                        spin_unlock_bh(&ilb->lock);
                                        goto done;
                                }
@@ -806,13 +798,16 @@ skip_listen_ht:
                                goto next_normal;
                        if (!(r->idiag_states & (1 << sk->sk_state)))
                                goto next_normal;
+                       if (r->sdiag_family != AF_UNSPEC &&
+                                       sk->sk_family != r->sdiag_family)
+                               goto next_normal;
                        if (r->id.idiag_sport != inet->inet_sport &&
                            r->id.idiag_sport)
                                goto next_normal;
                        if (r->id.idiag_dport != inet->inet_dport &&
                            r->id.idiag_dport)
                                goto next_normal;
-                       if (inet_csk_diag_dump(sk, skb, cb) < 0) {
+                       if (inet_csk_diag_dump(sk, skb, cb, r, bc) < 0) {
                                spin_unlock_bh(lock);
                                goto done;
                        }
@@ -828,13 +823,16 @@ next_normal:
 
                                if (num < s_num)
                                        goto next_dying;
+                               if (r->sdiag_family != AF_UNSPEC &&
+                                               tw->tw_family != r->sdiag_family)
+                                       goto next_dying;
                                if (r->id.idiag_sport != tw->tw_sport &&
                                    r->id.idiag_sport)
                                        goto next_dying;
                                if (r->id.idiag_dport != tw->tw_dport &&
                                    r->id.idiag_dport)
                                        goto next_dying;
-                               if (inet_twsk_diag_dump(tw, skb, cb) < 0) {
+                               if (inet_twsk_diag_dump(tw, skb, cb, r, bc) < 0) {
                                        spin_unlock_bh(lock);
                                        goto done;
                                }
@@ -853,10 +851,67 @@ unlock:
        return skb->len;
 }
 
-static int inet_diag_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh)
+static int inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
 {
+       struct nlattr *bc = NULL;
        int hdrlen = sizeof(struct inet_diag_req);
 
+       if (nlmsg_attrlen(cb->nlh, hdrlen))
+               bc = nlmsg_find_attr(cb->nlh, hdrlen, INET_DIAG_REQ_BYTECODE);
+
+       return __inet_diag_dump(skb, cb, (struct inet_diag_req *)NLMSG_DATA(cb->nlh), bc);
+}
+
+static inline int inet_diag_type2proto(int type)
+{
+       switch (type) {
+       case TCPDIAG_GETSOCK:
+               return IPPROTO_TCP;
+       case DCCPDIAG_GETSOCK:
+               return IPPROTO_DCCP;
+       default:
+               return 0;
+       }
+}
+
+static int inet_diag_dump_compat(struct sk_buff *skb, struct netlink_callback *cb)
+{
+       struct inet_diag_req_compat *rc = NLMSG_DATA(cb->nlh);
+       struct inet_diag_req req;
+       struct nlattr *bc = NULL;
+       int hdrlen = sizeof(struct inet_diag_req_compat);
+
+       req.sdiag_family = AF_UNSPEC; /* compatibility */
+       req.sdiag_protocol = inet_diag_type2proto(cb->nlh->nlmsg_type);
+       req.idiag_ext = rc->idiag_ext;
+       req.idiag_states = rc->idiag_states;
+       req.id = rc->id;
+
+       if (nlmsg_attrlen(cb->nlh, hdrlen))
+               bc = nlmsg_find_attr(cb->nlh, hdrlen, INET_DIAG_REQ_BYTECODE);
+
+       return __inet_diag_dump(skb, cb, &req, bc);
+}
+
+static int inet_diag_get_exact_compat(struct sk_buff *in_skb,
+                              const struct nlmsghdr *nlh)
+{
+       struct inet_diag_req_compat *rc = NLMSG_DATA(nlh);
+       struct inet_diag_req req;
+
+       req.sdiag_family = rc->idiag_family;
+       req.sdiag_protocol = inet_diag_type2proto(nlh->nlmsg_type);
+       req.idiag_ext = rc->idiag_ext;
+       req.idiag_states = rc->idiag_states;
+       req.id = rc->id;
+
+       return inet_diag_get_exact(in_skb, nlh, &req);
+}
+
+static int inet_diag_rcv_msg_compat(struct sk_buff *skb, struct nlmsghdr *nlh)
+{
+       int hdrlen = sizeof(struct inet_diag_req_compat);
+
        if (nlh->nlmsg_type >= INET_DIAG_GETSOCK_MAX ||
            nlmsg_len(nlh) < hdrlen)
                return -EINVAL;
@@ -873,28 +928,54 @@ static int inet_diag_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh)
                                return -EINVAL;
                }
 
-               return netlink_dump_start(idiagnl, skb, nlh,
-                                         inet_diag_dump, NULL, 0);
+               return netlink_dump_start(sock_diag_nlsk, skb, nlh,
+                                         inet_diag_dump_compat, NULL, 0);
        }
 
-       return inet_diag_get_exact(skb, nlh);
+       return inet_diag_get_exact_compat(skb, nlh);
 }
 
-static DEFINE_MUTEX(inet_diag_mutex);
-
-static void inet_diag_rcv(struct sk_buff *skb)
+static int inet_diag_handler_dump(struct sk_buff *skb, struct nlmsghdr *h)
 {
-       mutex_lock(&inet_diag_mutex);
-       netlink_rcv_skb(skb, &inet_diag_rcv_msg);
-       mutex_unlock(&inet_diag_mutex);
+       int hdrlen = sizeof(struct inet_diag_req);
+
+       if (nlmsg_len(h) < hdrlen)
+               return -EINVAL;
+
+       if (h->nlmsg_flags & NLM_F_DUMP) {
+               if (nlmsg_attrlen(h, hdrlen)) {
+                       struct nlattr *attr;
+                       attr = nlmsg_find_attr(h, hdrlen,
+                                              INET_DIAG_REQ_BYTECODE);
+                       if (attr == NULL ||
+                           nla_len(attr) < sizeof(struct inet_diag_bc_op) ||
+                           inet_diag_bc_audit(nla_data(attr), nla_len(attr)))
+                               return -EINVAL;
+               }
+
+               return netlink_dump_start(sock_diag_nlsk, skb, h,
+                                         inet_diag_dump, NULL, 0);
+       }
+
+       return inet_diag_get_exact(skb, h, (struct inet_diag_req *)NLMSG_DATA(h));
 }
 
+static struct sock_diag_handler inet_diag_handler = {
+       .family = AF_INET,
+       .dump = inet_diag_handler_dump,
+};
+
+static struct sock_diag_handler inet6_diag_handler = {
+       .family = AF_INET6,
+       .dump = inet_diag_handler_dump,
+};
+
 int inet_diag_register(const struct inet_diag_handler *h)
 {
        const __u16 type = h->idiag_type;
        int err = -EINVAL;
 
-       if (type >= INET_DIAG_GETSOCK_MAX)
+       if (type >= IPPROTO_MAX)
                goto out;
 
        mutex_lock(&inet_diag_table_mutex);
@@ -913,7 +994,7 @@ void inet_diag_unregister(const struct inet_diag_handler *h)
 {
        const __u16 type = h->idiag_type;
 
-       if (type >= INET_DIAG_GETSOCK_MAX)
+       if (type >= IPPROTO_MAX)
                return;
 
        mutex_lock(&inet_diag_table_mutex);
@@ -924,7 +1005,7 @@ EXPORT_SYMBOL_GPL(inet_diag_unregister);
 
 static int __init inet_diag_init(void)
 {
-       const int inet_diag_table_size = (INET_DIAG_GETSOCK_MAX *
+       const int inet_diag_table_size = (IPPROTO_MAX *
                                          sizeof(struct inet_diag_handler *));
        int err = -ENOMEM;
 
@@ -932,25 +1013,34 @@ static int __init inet_diag_init(void)
        if (!inet_diag_table)
                goto out;
 
-       idiagnl = netlink_kernel_create(&init_net, NETLINK_INET_DIAG, 0,
-                                       inet_diag_rcv, NULL, THIS_MODULE);
-       if (idiagnl == NULL)
-               goto out_free_table;
-       err = 0;
+       err = sock_diag_register(&inet_diag_handler);
+       if (err)
+               goto out_free_nl;
+
+       err = sock_diag_register(&inet6_diag_handler);
+       if (err)
+               goto out_free_inet;
+
+       sock_diag_register_inet_compat(inet_diag_rcv_msg_compat);
 out:
        return err;
-out_free_table:
+
+out_free_inet:
+       sock_diag_unregister(&inet_diag_handler);
+out_free_nl:
        kfree(inet_diag_table);
        goto out;
 }
 
 static void __exit inet_diag_exit(void)
 {
-       netlink_kernel_release(idiagnl);
+       sock_diag_unregister(&inet6_diag_handler);
+       sock_diag_unregister(&inet_diag_handler);
+       sock_diag_unregister_inet_compat(inet_diag_rcv_msg_compat);
        kfree(inet_diag_table);
 }
 
 module_init(inet_diag_init);
 module_exit(inet_diag_exit);
 MODULE_LICENSE("GPL");
-MODULE_ALIAS_NET_PF_PROTO(PF_NETLINK, NETLINK_INET_DIAG);
+MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 0);