#include <linux/tcp_diag.h>
+static const struct inet_diag_handler **inet_diag_table;
+
struct tcpdiag_entry
{
u32 *saddr;
const struct inet_connection_sock *icsk = inet_csk(sk);
struct tcpdiagmsg *r;
struct nlmsghdr *nlh;
- struct tcp_info *info = NULL;
+ void *info = NULL;
struct tcpdiag_meminfo *minfo = NULL;
unsigned char *b = skb->tail;
+ const struct inet_diag_handler *handler;
+
+ handler = inet_diag_table[unlh->nlmsg_type];
+ BUG_ON(handler == NULL);
nlh = NLMSG_PUT(skb, pid, seq, unlh->nlmsg_type, sizeof(*r));
nlh->nlmsg_flags = nlmsg_flags;
+
r = NLMSG_DATA(nlh);
if (sk->sk_state != TCP_TIME_WAIT) {
if (ext & (1<<(TCPDIAG_MEMINFO-1)))
minfo = TCPDIAG_PUT(skb, TCPDIAG_MEMINFO, sizeof(*minfo));
if (ext & (1<<(TCPDIAG_INFO-1)))
- info = TCPDIAG_PUT(skb, TCPDIAG_INFO, sizeof(*info));
+ info = TCPDIAG_PUT(skb, TCPDIAG_INFO,
+ handler->idiag_info_size);
if ((ext & (1 << (TCPDIAG_CONG - 1))) && icsk->icsk_ca_ops) {
size_t len = strlen(icsk->icsk_ca_ops->name);
r->tcpdiag_expires = 0;
}
#undef EXPIRES_IN_MS
- /*
- * Ahem... for now we'll have some knowledge about TCP -acme
- * But this is just one of two small exceptions, both in this
- * function, so lets close our eyes for some 15 lines or so... 8)
- * -acme
- */
- if (sk->sk_protocol == IPPROTO_TCP) {
- const struct tcp_sock *tp = tcp_sk(sk);
-
- r->tcpdiag_rqueue = tp->rcv_nxt - tp->copied_seq;
- r->tcpdiag_wqueue = tp->write_seq - tp->snd_una;
- } else
- r->tcpdiag_rqueue = r->tcpdiag_wqueue = 0;
r->tcpdiag_uid = sock_i_uid(sk);
r->tcpdiag_inode = sock_i_ino(sk);
minfo->tcpdiag_tmem = atomic_read(&sk->sk_wmem_alloc);
}
- /* Ahem... for now we'll have some knowledge about TCP -acme */
- if (info) {
- if (sk->sk_protocol == IPPROTO_TCP)
- tcp_get_info(sk, info);
- else
- memset(info, 0, sizeof(*info));
- }
+ handler->idiag_get_info(sk, r, info);
if (sk->sk_state < TCP_TIME_WAIT &&
icsk->icsk_ca_ops && icsk->icsk_ca_ops->get_info)
struct sock *sk;
struct tcpdiagreq *req = NLMSG_DATA(nlh);
struct sk_buff *rep;
- struct inet_hashinfo *hashinfo = &tcp_hashinfo;
-#ifdef CONFIG_IP_TCPDIAG_DCCP
- if (nlh->nlmsg_type == DCCPDIAG_GETSOCK)
- hashinfo = &dccp_hashinfo;
-#endif
+ struct inet_hashinfo *hashinfo;
+ const struct inet_diag_handler *handler;
+
+ handler = inet_diag_table[nlh->nlmsg_type];
+ BUG_ON(handler == NULL);
+ hashinfo = handler->idiag_hashinfo;
+
if (req->tcpdiag_family == AF_INET) {
sk = inet_lookup(hashinfo, req->id.tcpdiag_dst[0],
req->id.tcpdiag_dport, req->id.tcpdiag_src[0],
goto out;
err = -ENOMEM;
- rep = alloc_skb(NLMSG_SPACE(sizeof(struct tcpdiagmsg)+
- sizeof(struct tcpdiag_meminfo)+
- sizeof(struct tcp_info)+64), GFP_KERNEL);
+ rep = alloc_skb(NLMSG_SPACE((sizeof(struct tcpdiagmsg) +
+ sizeof(struct tcpdiag_meminfo) +
+ handler->idiag_info_size + 64)),
+ GFP_KERNEL);
if (!rep)
goto out;
int i, num;
int s_i, s_num;
struct tcpdiagreq *r = NLMSG_DATA(cb->nlh);
+ const struct inet_diag_handler *handler;
struct inet_hashinfo *hashinfo;
+ handler = inet_diag_table[cb->nlh->nlmsg_type];
+ BUG_ON(handler == NULL);
+ hashinfo = handler->idiag_hashinfo;
+
s_i = cb->args[1];
s_num = num = cb->args[2];
- hashinfo = &tcp_hashinfo;
-#ifdef CONFIG_IP_TCPDIAG_DCCP
- if (cb->nlh->nlmsg_type == DCCPDIAG_GETSOCK)
- hashinfo = &dccp_hashinfo;
-#endif
+
if (cb->args[0] == 0) {
if (!(r->tcpdiag_states&(TCPF_LISTEN|TCPF_SYN_RECV)))
goto skip_listen_ht;
if (!(nlh->nlmsg_flags&NLM_F_REQUEST))
return 0;
- if (nlh->nlmsg_type != TCPDIAG_GETSOCK
-#ifdef CONFIG_IP_TCPDIAG_DCCP
- && nlh->nlmsg_type != DCCPDIAG_GETSOCK
-#endif
- )
+ if (nlh->nlmsg_type >= INET_DIAG_GETSOCK_MAX)
goto err_inval;
+ if (inet_diag_table[nlh->nlmsg_type] == NULL)
+ return -ENOENT;
+
if (NLMSG_LENGTH(sizeof(struct tcpdiagreq)) > skb->len)
goto err_inval;
}
}
+static void tcp_diag_get_info(struct sock *sk, struct tcpdiagmsg *r,
+ void *_info)
+{
+ const struct tcp_sock *tp = tcp_sk(sk);
+ struct tcp_info *info = _info;
+
+ r->tcpdiag_rqueue = tp->rcv_nxt - tp->copied_seq;
+ r->tcpdiag_wqueue = tp->write_seq - tp->snd_una;
+ if (info != NULL)
+ tcp_get_info(sk, info);
+}
+
+static struct inet_diag_handler tcp_diag_handler = {
+ .idiag_hashinfo = &tcp_hashinfo,
+ .idiag_get_info = tcp_diag_get_info,
+ .idiag_type = TCPDIAG_GETSOCK,
+ .idiag_info_size = sizeof(struct tcp_info),
+};
+
+static DEFINE_SPINLOCK(inet_diag_register_lock);
+
+int inet_diag_register(const struct inet_diag_handler *h)
+{
+ const __u16 type = h->idiag_type;
+ int err = -EINVAL;
+
+ if (type >= INET_DIAG_GETSOCK_MAX)
+ goto out;
+
+ spin_lock(&inet_diag_register_lock);
+ err = -EEXIST;
+ if (inet_diag_table[type] == NULL) {
+ inet_diag_table[type] = h;
+ err = 0;
+ }
+ spin_unlock(&inet_diag_register_lock);
+out:
+ return err;
+}
+EXPORT_SYMBOL_GPL(inet_diag_register);
+
+void inet_diag_unregister(const struct inet_diag_handler *h)
+{
+ const __u16 type = h->idiag_type;
+
+ if (type >= INET_DIAG_GETSOCK_MAX)
+ return;
+
+ spin_lock(&inet_diag_register_lock);
+ inet_diag_table[type] = NULL;
+ spin_unlock(&inet_diag_register_lock);
+
+ synchronize_rcu();
+}
+EXPORT_SYMBOL_GPL(inet_diag_unregister);
+
static int __init tcpdiag_init(void)
{
+ const int inet_diag_table_size = (INET_DIAG_GETSOCK_MAX *
+ sizeof(struct inet_diag_handler *));
+ int err = -ENOMEM;
+
+ inet_diag_table = kmalloc(inet_diag_table_size, GFP_KERNEL);
+ if (!inet_diag_table)
+ goto out;
+
+ memset(inet_diag_table, 0, inet_diag_table_size);
+
tcpnl = netlink_kernel_create(NETLINK_TCPDIAG, tcpdiag_rcv,
THIS_MODULE);
if (tcpnl == NULL)
- return -ENOMEM;
- return 0;
+ goto out_free_table;
+
+ err = inet_diag_register(&tcp_diag_handler);
+ if (err)
+ goto out_sock_release;
+out:
+ return err;
+out_sock_release:
+ sock_release(tcpnl->sk_socket);
+out_free_table:
+ kfree(inet_diag_table);
+ goto out;
}
static void __exit tcpdiag_exit(void)
{
sock_release(tcpnl->sk_socket);
+ kfree(inet_diag_table);
}
module_init(tcpdiag_init);