l2tp: define helper for parsing struct sockaddr_pppol2tp*
authorGuillaume Nault <g.nault@alphalink.fr>
Tue, 26 Jun 2018 16:41:36 +0000 (18:41 +0200)
committerDavid S. Miller <davem@davemloft.net>
Thu, 28 Jun 2018 07:06:50 +0000 (16:06 +0900)
'sockaddr_len' is checked against various values when entering
pppol2tp_connect(), to verify its validity. It is used again later, to
find out which sockaddr structure was passed from user space. This
patch combines these two operations into one new function in order to
simplify pppol2tp_connect().

A new structure, l2tp_connect_info, is used to pass sockaddr data back
to pppol2tp_connect(), to avoid passing too many parameters to
l2tp_sockaddr_get_info(). Also, the first parameter is void* in order
to avoid casting between all sockaddr_* structures manually.

Signed-off-by: Guillaume Nault <g.nault@alphalink.fr>
Signed-off-by: David S. Miller <davem@davemloft.net>
net/l2tp/l2tp_ppp.c

index eea5d78..d3a9355 100644 (file)
@@ -588,40 +588,113 @@ static void pppol2tp_session_init(struct l2tp_session *session)
        }
 }
 
+struct l2tp_connect_info {
+       u8 version;
+       int fd;
+       u32 tunnel_id;
+       u32 peer_tunnel_id;
+       u32 session_id;
+       u32 peer_session_id;
+};
+
+static int pppol2tp_sockaddr_get_info(const void *sa, int sa_len,
+                                     struct l2tp_connect_info *info)
+{
+       switch (sa_len) {
+       case sizeof(struct sockaddr_pppol2tp):
+       {
+               const struct sockaddr_pppol2tp *sa_v2in4 = sa;
+
+               if (sa_v2in4->sa_protocol != PX_PROTO_OL2TP)
+                       return -EINVAL;
+
+               info->version = 2;
+               info->fd = sa_v2in4->pppol2tp.fd;
+               info->tunnel_id = sa_v2in4->pppol2tp.s_tunnel;
+               info->peer_tunnel_id = sa_v2in4->pppol2tp.d_tunnel;
+               info->session_id = sa_v2in4->pppol2tp.s_session;
+               info->peer_session_id = sa_v2in4->pppol2tp.d_session;
+
+               break;
+       }
+       case sizeof(struct sockaddr_pppol2tpv3):
+       {
+               const struct sockaddr_pppol2tpv3 *sa_v3in4 = sa;
+
+               if (sa_v3in4->sa_protocol != PX_PROTO_OL2TP)
+                       return -EINVAL;
+
+               info->version = 3;
+               info->fd = sa_v3in4->pppol2tp.fd;
+               info->tunnel_id = sa_v3in4->pppol2tp.s_tunnel;
+               info->peer_tunnel_id = sa_v3in4->pppol2tp.d_tunnel;
+               info->session_id = sa_v3in4->pppol2tp.s_session;
+               info->peer_session_id = sa_v3in4->pppol2tp.d_session;
+
+               break;
+       }
+       case sizeof(struct sockaddr_pppol2tpin6):
+       {
+               const struct sockaddr_pppol2tpin6 *sa_v2in6 = sa;
+
+               if (sa_v2in6->sa_protocol != PX_PROTO_OL2TP)
+                       return -EINVAL;
+
+               info->version = 2;
+               info->fd = sa_v2in6->pppol2tp.fd;
+               info->tunnel_id = sa_v2in6->pppol2tp.s_tunnel;
+               info->peer_tunnel_id = sa_v2in6->pppol2tp.d_tunnel;
+               info->session_id = sa_v2in6->pppol2tp.s_session;
+               info->peer_session_id = sa_v2in6->pppol2tp.d_session;
+
+               break;
+       }
+       case sizeof(struct sockaddr_pppol2tpv3in6):
+       {
+               const struct sockaddr_pppol2tpv3in6 *sa_v3in6 = sa;
+
+               if (sa_v3in6->sa_protocol != PX_PROTO_OL2TP)
+                       return -EINVAL;
+
+               info->version = 3;
+               info->fd = sa_v3in6->pppol2tp.fd;
+               info->tunnel_id = sa_v3in6->pppol2tp.s_tunnel;
+               info->peer_tunnel_id = sa_v3in6->pppol2tp.d_tunnel;
+               info->session_id = sa_v3in6->pppol2tp.s_session;
+               info->peer_session_id = sa_v3in6->pppol2tp.d_session;
+
+               break;
+       }
+       default:
+               return -EINVAL;
+       }
+
+       return 0;
+}
+
 /* connect() handler. Attach a PPPoX socket to a tunnel UDP socket
  */
 static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr,
                            int sockaddr_len, int flags)
 {
        struct sock *sk = sock->sk;
-       struct sockaddr_pppol2tp *sp = (struct sockaddr_pppol2tp *) uservaddr;
        struct pppox_sock *po = pppox_sk(sk);
        struct l2tp_session *session = NULL;
+       struct l2tp_connect_info info;
        struct l2tp_tunnel *tunnel;
        struct pppol2tp_session *ps;
        struct l2tp_session_cfg cfg = { 0, };
-       int error = 0;
-       u32 tunnel_id, peer_tunnel_id;
-       u32 session_id, peer_session_id;
        bool drop_refcnt = false;
        bool drop_tunnel = false;
        bool new_session = false;
        bool new_tunnel = false;
-       int ver = 2;
-       int fd;
-
-       lock_sock(sk);
-
-       error = -EINVAL;
+       int error;
 
-       if (sockaddr_len != sizeof(struct sockaddr_pppol2tp) &&
-           sockaddr_len != sizeof(struct sockaddr_pppol2tpv3) &&
-           sockaddr_len != sizeof(struct sockaddr_pppol2tpin6) &&
-           sockaddr_len != sizeof(struct sockaddr_pppol2tpv3in6))
-               goto end;
+       error = pppol2tp_sockaddr_get_info(uservaddr, sockaddr_len, &info);
+       if (error < 0)
+               return error;
 
-       if (sp->sa_protocol != PX_PROTO_OL2TP)
-               goto end;
+       lock_sock(sk);
 
        /* Check for already bound sockets */
        error = -EBUSY;
@@ -633,56 +706,12 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr,
        if (sk->sk_user_data)
                goto end; /* socket is already attached */
 
-       /* Get params from socket address. Handle L2TPv2 and L2TPv3.
-        * This is nasty because there are different sockaddr_pppol2tp
-        * structs for L2TPv2, L2TPv3, over IPv4 and IPv6. We use
-        * the sockaddr size to determine which structure the caller
-        * is using.
-        */
-       peer_tunnel_id = 0;
-       if (sockaddr_len == sizeof(struct sockaddr_pppol2tp)) {
-               fd = sp->pppol2tp.fd;
-               tunnel_id = sp->pppol2tp.s_tunnel;
-               peer_tunnel_id = sp->pppol2tp.d_tunnel;
-               session_id = sp->pppol2tp.s_session;
-               peer_session_id = sp->pppol2tp.d_session;
-       } else if (sockaddr_len == sizeof(struct sockaddr_pppol2tpv3)) {
-               struct sockaddr_pppol2tpv3 *sp3 =
-                       (struct sockaddr_pppol2tpv3 *) sp;
-               ver = 3;
-               fd = sp3->pppol2tp.fd;
-               tunnel_id = sp3->pppol2tp.s_tunnel;
-               peer_tunnel_id = sp3->pppol2tp.d_tunnel;
-               session_id = sp3->pppol2tp.s_session;
-               peer_session_id = sp3->pppol2tp.d_session;
-       } else if (sockaddr_len == sizeof(struct sockaddr_pppol2tpin6)) {
-               struct sockaddr_pppol2tpin6 *sp6 =
-                       (struct sockaddr_pppol2tpin6 *) sp;
-               fd = sp6->pppol2tp.fd;
-               tunnel_id = sp6->pppol2tp.s_tunnel;
-               peer_tunnel_id = sp6->pppol2tp.d_tunnel;
-               session_id = sp6->pppol2tp.s_session;
-               peer_session_id = sp6->pppol2tp.d_session;
-       } else if (sockaddr_len == sizeof(struct sockaddr_pppol2tpv3in6)) {
-               struct sockaddr_pppol2tpv3in6 *sp6 =
-                       (struct sockaddr_pppol2tpv3in6 *) sp;
-               ver = 3;
-               fd = sp6->pppol2tp.fd;
-               tunnel_id = sp6->pppol2tp.s_tunnel;
-               peer_tunnel_id = sp6->pppol2tp.d_tunnel;
-               session_id = sp6->pppol2tp.s_session;
-               peer_session_id = sp6->pppol2tp.d_session;
-       } else {
-               error = -EINVAL;
-               goto end; /* bad socket address */
-       }
-
        /* Don't bind if tunnel_id is 0 */
        error = -EINVAL;
-       if (tunnel_id == 0)
+       if (!info.tunnel_id)
                goto end;
 
-       tunnel = l2tp_tunnel_get(sock_net(sk), tunnel_id);
+       tunnel = l2tp_tunnel_get(sock_net(sk), info.tunnel_id);
        if (tunnel)
                drop_tunnel = true;
 
@@ -690,7 +719,7 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr,
         * peer_session_id is 0. Otherwise look up tunnel using supplied
         * tunnel id.
         */
-       if ((session_id == 0) && (peer_session_id == 0)) {
+       if (!info.session_id && !info.peer_session_id) {
                if (tunnel == NULL) {
                        struct l2tp_tunnel_cfg tcfg = {
                                .encap = L2TP_ENCAPTYPE_UDP,
@@ -700,12 +729,16 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr,
                        /* Prevent l2tp_tunnel_register() from trying to set up
                         * a kernel socket.
                         */
-                       if (fd < 0) {
+                       if (info.fd < 0) {
                                error = -EBADF;
                                goto end;
                        }
 
-                       error = l2tp_tunnel_create(sock_net(sk), fd, ver, tunnel_id, peer_tunnel_id, &tcfg, &tunnel);
+                       error = l2tp_tunnel_create(sock_net(sk), info.fd,
+                                                  info.version,
+                                                  info.tunnel_id,
+                                                  info.peer_tunnel_id, &tcfg,
+                                                  &tunnel);
                        if (error < 0)
                                goto end;
 
@@ -734,9 +767,9 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr,
                tunnel->recv_payload_hook = pppol2tp_recv_payload_hook;
 
        if (tunnel->peer_tunnel_id == 0)
-               tunnel->peer_tunnel_id = peer_tunnel_id;
+               tunnel->peer_tunnel_id = info.peer_tunnel_id;
 
-       session = l2tp_session_get(sock_net(sk), tunnel, session_id);
+       session = l2tp_session_get(sock_net(sk), tunnel, info.session_id);
        if (session) {
                drop_refcnt = true;
 
@@ -765,8 +798,8 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr,
                cfg.pw_type = L2TP_PWTYPE_PPP;
 
                session = l2tp_session_create(sizeof(struct pppol2tp_session),
-                                             tunnel, session_id,
-                                             peer_session_id, &cfg);
+                                             tunnel, info.session_id,
+                                             info.peer_session_id, &cfg);
                if (IS_ERR(session)) {
                        error = PTR_ERR(session);
                        goto end;