SUNRPC: Fix buffer overflow checking in gss_encode_v0_msg/gss_encode_v1_msg
authorTrond Myklebust <Trond.Myklebust@netapp.com>
Mon, 28 Oct 2013 22:18:00 +0000 (18:18 -0400)
committerTrond Myklebust <Trond.Myklebust@netapp.com>
Mon, 28 Oct 2013 22:53:21 +0000 (18:53 -0400)
In gss_encode_v1_msg, it is pointless to BUG() after the overflow has
happened. Replace the existing sprintf()-based code with scnprintf(),
and warn if an overflow is ever triggered.

In gss_encode_v0_msg, replace the runtime BUG_ON() with an appropriate
compile-time BUILD_BUG_ON.

Reported-by: Bruce Fields <bfields@fieldses.org>
Signed-off-by: Trond Myklebust <Trond.Myklebust@netapp.com>
net/sunrpc/auth_gss/auth_gss.c

index cc24323..97912b4 100644 (file)
@@ -420,41 +420,53 @@ static void gss_encode_v0_msg(struct gss_upcall_msg *gss_msg)
        memcpy(gss_msg->databuf, &uid, sizeof(uid));
        gss_msg->msg.data = gss_msg->databuf;
        gss_msg->msg.len = sizeof(uid);
-       BUG_ON(sizeof(uid) > UPCALL_BUF_LEN);
+
+       BUILD_BUG_ON(sizeof(uid) > sizeof(gss_msg->databuf));
 }
 
-static void gss_encode_v1_msg(struct gss_upcall_msg *gss_msg,
+static int gss_encode_v1_msg(struct gss_upcall_msg *gss_msg,
                                const char *service_name,
                                const char *target_name)
 {
        struct gss_api_mech *mech = gss_msg->auth->mech;
        char *p = gss_msg->databuf;
-       int len = 0;
-
-       gss_msg->msg.len = sprintf(gss_msg->databuf, "mech=%s uid=%d ",
-                                  mech->gm_name,
-                                  from_kuid(&init_user_ns, gss_msg->uid));
-       p += gss_msg->msg.len;
+       size_t buflen = sizeof(gss_msg->databuf);
+       int len;
+
+       len = scnprintf(p, buflen, "mech=%s uid=%d ", mech->gm_name,
+                       from_kuid(&init_user_ns, gss_msg->uid));
+       buflen -= len;
+       p += len;
+       gss_msg->msg.len = len;
        if (target_name) {
-               len = sprintf(p, "target=%s ", target_name);
+               len = scnprintf(p, buflen, "target=%s ", target_name);
+               buflen -= len;
                p += len;
                gss_msg->msg.len += len;
        }
        if (service_name != NULL) {
-               len = sprintf(p, "service=%s ", service_name);
+               len = scnprintf(p, buflen, "service=%s ", service_name);
+               buflen -= len;
                p += len;
                gss_msg->msg.len += len;
        }
        if (mech->gm_upcall_enctypes) {
-               len = sprintf(p, "enctypes=%s ", mech->gm_upcall_enctypes);
+               len = scnprintf(p, buflen, "enctypes=%s ",
+                               mech->gm_upcall_enctypes);
+               buflen -= len;
                p += len;
                gss_msg->msg.len += len;
        }
-       len = sprintf(p, "\n");
+       len = scnprintf(p, buflen, "\n");
+       if (len == 0)
+               goto out_overflow;
        gss_msg->msg.len += len;
 
        gss_msg->msg.data = gss_msg->databuf;
-       BUG_ON(gss_msg->msg.len > UPCALL_BUF_LEN);
+       return 0;
+out_overflow:
+       WARN_ON_ONCE(1);
+       return -ENOMEM;
 }
 
 static struct gss_upcall_msg *
@@ -463,15 +475,15 @@ gss_alloc_msg(struct gss_auth *gss_auth,
 {
        struct gss_upcall_msg *gss_msg;
        int vers;
+       int err = -ENOMEM;
 
        gss_msg = kzalloc(sizeof(*gss_msg), GFP_NOFS);
        if (gss_msg == NULL)
-               return ERR_PTR(-ENOMEM);
+               goto err;
        vers = get_pipe_version(gss_auth->net);
-       if (vers < 0) {
-               kfree(gss_msg);
-               return ERR_PTR(vers);
-       }
+       err = vers;
+       if (err < 0)
+               goto err_free_msg;
        gss_msg->pipe = gss_auth->gss_pipe[vers]->pipe;
        INIT_LIST_HEAD(&gss_msg->list);
        rpc_init_wait_queue(&gss_msg->rpc_waitqueue, "RPCSEC_GSS upcall waitq");
@@ -484,9 +496,15 @@ gss_alloc_msg(struct gss_auth *gss_auth,
                gss_encode_v0_msg(gss_msg);
                break;
        default:
-               gss_encode_v1_msg(gss_msg, service_name, gss_auth->target_name);
+               err = gss_encode_v1_msg(gss_msg, service_name, gss_auth->target_name);
+               if (err)
+                       goto err_free_msg;
        };
        return gss_msg;
+err_free_msg:
+       kfree(gss_msg);
+err:
+       return ERR_PTR(err);
 }
 
 static struct gss_upcall_msg *