Merge tag 'for-linus' of git://git.armlinux.org.uk/~rmk/linux-arm
[platform/kernel/linux-rpi.git] / kernel / ucount.c
index 87799e2..eb03f3c 100644 (file)
@@ -58,14 +58,17 @@ static struct ctl_table_root set_root = {
        .permissions = set_permissions,
 };
 
-#define UCOUNT_ENTRY(name)                             \
-       {                                               \
-               .procname       = name,                 \
-               .maxlen         = sizeof(int),          \
-               .mode           = 0644,                 \
-               .proc_handler   = proc_dointvec_minmax, \
-               .extra1         = SYSCTL_ZERO,          \
-               .extra2         = SYSCTL_INT_MAX,       \
+static long ue_zero = 0;
+static long ue_int_max = INT_MAX;
+
+#define UCOUNT_ENTRY(name)                                     \
+       {                                                       \
+               .procname       = name,                         \
+               .maxlen         = sizeof(long),                 \
+               .mode           = 0644,                         \
+               .proc_handler   = proc_doulongvec_minmax,       \
+               .extra1         = &ue_zero,                     \
+               .extra2         = &ue_int_max,                  \
        }
 static struct ctl_table user_table[] = {
        UCOUNT_ENTRY("max_user_namespaces"),
@@ -160,6 +163,7 @@ struct ucounts *alloc_ucounts(struct user_namespace *ns, kuid_t uid)
 {
        struct hlist_head *hashent = ucounts_hashentry(ns, uid);
        struct ucounts *ucounts, *new;
+       long overflow;
 
        spin_lock_irq(&ucounts_lock);
        ucounts = find_ucounts(ns, uid, hashent);
@@ -184,8 +188,12 @@ struct ucounts *alloc_ucounts(struct user_namespace *ns, kuid_t uid)
                        return new;
                }
        }
+       overflow = atomic_add_negative(1, &ucounts->count);
        spin_unlock_irq(&ucounts_lock);
-       ucounts = get_ucounts(ucounts);
+       if (overflow) {
+               put_ucounts(ucounts);
+               return NULL;
+       }
        return ucounts;
 }
 
@@ -193,8 +201,7 @@ void put_ucounts(struct ucounts *ucounts)
 {
        unsigned long flags;
 
-       if (atomic_dec_and_test(&ucounts->count)) {
-               spin_lock_irqsave(&ucounts_lock, flags);
+       if (atomic_dec_and_lock_irqsave(&ucounts->count, &ucounts_lock, flags)) {
                hlist_del_init(&ucounts->node);
                spin_unlock_irqrestore(&ucounts_lock, flags);
                kfree(ucounts);
@@ -277,6 +284,55 @@ bool dec_rlimit_ucounts(struct ucounts *ucounts, enum ucount_type type, long v)
        return (new == 0);
 }
 
+static void do_dec_rlimit_put_ucounts(struct ucounts *ucounts,
+                               struct ucounts *last, enum ucount_type type)
+{
+       struct ucounts *iter, *next;
+       for (iter = ucounts; iter != last; iter = next) {
+               long dec = atomic_long_add_return(-1, &iter->ucount[type]);
+               WARN_ON_ONCE(dec < 0);
+               next = iter->ns->ucounts;
+               if (dec == 0)
+                       put_ucounts(iter);
+       }
+}
+
+void dec_rlimit_put_ucounts(struct ucounts *ucounts, enum ucount_type type)
+{
+       do_dec_rlimit_put_ucounts(ucounts, NULL, type);
+}
+
+long inc_rlimit_get_ucounts(struct ucounts *ucounts, enum ucount_type type)
+{
+       /* Caller must hold a reference to ucounts */
+       struct ucounts *iter;
+       long dec, ret = 0;
+
+       for (iter = ucounts; iter; iter = iter->ns->ucounts) {
+               long max = READ_ONCE(iter->ns->ucount_max[type]);
+               long new = atomic_long_add_return(1, &iter->ucount[type]);
+               if (new < 0 || new > max)
+                       goto unwind;
+               if (iter == ucounts)
+                       ret = new;
+               /*
+                * Grab an extra ucount reference for the caller when
+                * the rlimit count was previously 0.
+                */
+               if (new != 1)
+                       continue;
+               if (!get_ucounts(iter))
+                       goto dec_unwind;
+       }
+       return ret;
+dec_unwind:
+       dec = atomic_long_add_return(-1, &iter->ucount[type]);
+       WARN_ON_ONCE(dec < 0);
+unwind:
+       do_dec_rlimit_put_ucounts(ucounts, iter, type);
+       return 0;
+}
+
 bool is_ucounts_overlimit(struct ucounts *ucounts, enum ucount_type type, unsigned long max)
 {
        struct ucounts *iter;