maple_tree: refine ma_state init from mas_start()
[platform/kernel/linux-starfive.git] / lib / maple_tree.c
index fe21bf2..54484bd 100644 (file)
@@ -149,13 +149,12 @@ struct maple_subtree_state {
 /* Functions */
 static inline struct maple_node *mt_alloc_one(gfp_t gfp)
 {
-       return kmem_cache_alloc(maple_node_cache, gfp | __GFP_ZERO);
+       return kmem_cache_alloc(maple_node_cache, gfp);
 }
 
 static inline int mt_alloc_bulk(gfp_t gfp, size_t size, void **nodes)
 {
-       return kmem_cache_alloc_bulk(maple_node_cache, gfp | __GFP_ZERO, size,
-                                    nodes);
+       return kmem_cache_alloc_bulk(maple_node_cache, gfp, size, nodes);
 }
 
 static inline void mt_free_bulk(size_t size, void __rcu **nodes)
@@ -535,6 +534,7 @@ static inline bool ma_dead_node(const struct maple_node *node)
 
        return (parent == node);
 }
+
 /*
  * mte_dead_node() - check if the @enode is dead.
  * @enode: The encoded maple node
@@ -616,6 +616,8 @@ static inline unsigned int mas_alloc_req(const struct ma_state *mas)
  * @node - the maple node
  * @type - the node type
  *
+ * In the event of a dead node, this array may be %NULL
+ *
  * Return: A pointer to the maple node pivots
  */
 static inline unsigned long *ma_pivots(struct maple_node *node,
@@ -665,12 +667,13 @@ static inline unsigned long mte_pivot(const struct maple_enode *mn,
                                 unsigned char piv)
 {
        struct maple_node *node = mte_to_node(mn);
+       enum maple_type type = mte_node_type(mn);
 
-       if (piv >= mt_pivots[piv]) {
+       if (piv >= mt_pivots[type]) {
                WARN_ON(1);
                return 0;
        }
-       switch (mte_node_type(mn)) {
+       switch (type) {
        case maple_arange_64:
                return node->ma64.pivot[piv];
        case maple_range_64:
@@ -1086,8 +1089,11 @@ static int mas_ascend(struct ma_state *mas)
                a_type = mas_parent_enum(mas, p_enode);
                a_node = mte_parent(p_enode);
                a_slot = mte_parent_slot(p_enode);
-               pivots = ma_pivots(a_node, a_type);
                a_enode = mt_mk_node(a_node, a_type);
+               pivots = ma_pivots(a_node, a_type);
+
+               if (unlikely(ma_dead_node(a_node)))
+                       return 1;
 
                if (!set_min && a_slot) {
                        set_min = true;
@@ -1122,9 +1128,10 @@ static inline struct maple_node *mas_pop_node(struct ma_state *mas)
 {
        struct maple_alloc *ret, *node = mas->alloc;
        unsigned long total = mas_allocated(mas);
+       unsigned int req = mas_alloc_req(mas);
 
        /* nothing or a request pending. */
-       if (unlikely(!total))
+       if (WARN_ON(!total))
                return NULL;
 
        if (total == 1) {
@@ -1134,27 +1141,25 @@ static inline struct maple_node *mas_pop_node(struct ma_state *mas)
                goto single_node;
        }
 
-       if (!node->node_count) {
+       if (node->node_count == 1) {
                /* Single allocation in this node. */
                mas->alloc = node->slot[0];
-               node->slot[0] = NULL;
                mas->alloc->total = node->total - 1;
                ret = node;
                goto new_head;
        }
-
        node->total--;
-       ret = node->slot[node->node_count];
-       node->slot[node->node_count--] = NULL;
+       ret = node->slot[--node->node_count];
+       node->slot[node->node_count] = NULL;
 
 single_node:
 new_head:
-       ret->total = 0;
-       ret->node_count = 0;
-       if (ret->request_count) {
-               mas_set_alloc_req(mas, ret->request_count + 1);
-               ret->request_count = 0;
+       if (req) {
+               req++;
+               mas_set_alloc_req(mas, req);
        }
+
+       memset(ret, 0, sizeof(*ret));
        return (struct maple_node *)ret;
 }
 
@@ -1173,21 +1178,20 @@ static inline void mas_push_node(struct ma_state *mas, struct maple_node *used)
        unsigned long count;
        unsigned int requested = mas_alloc_req(mas);
 
-       memset(reuse, 0, sizeof(*reuse));
        count = mas_allocated(mas);
 
-       if (count && (head->node_count < MAPLE_ALLOC_SLOTS - 1)) {
-               if (head->slot[0])
-                       head->node_count++;
-               head->slot[head->node_count] = reuse;
+       reuse->request_count = 0;
+       reuse->node_count = 0;
+       if (count && (head->node_count < MAPLE_ALLOC_SLOTS)) {
+               head->slot[head->node_count++] = reuse;
                head->total++;
                goto done;
        }
 
        reuse->total = 1;
        if ((head) && !((unsigned long)head & 0x1)) {
-               head->request_count = 0;
                reuse->slot[0] = head;
+               reuse->node_count = 1;
                reuse->total += head->total;
        }
 
@@ -1206,7 +1210,6 @@ static inline void mas_alloc_nodes(struct ma_state *mas, gfp_t gfp)
 {
        struct maple_alloc *node;
        unsigned long allocated = mas_allocated(mas);
-       unsigned long success = allocated;
        unsigned int requested = mas_alloc_req(mas);
        unsigned int count;
        void **slots = NULL;
@@ -1222,24 +1225,29 @@ static inline void mas_alloc_nodes(struct ma_state *mas, gfp_t gfp)
                WARN_ON(!allocated);
        }
 
-       if (!allocated || mas->alloc->node_count == MAPLE_ALLOC_SLOTS - 1) {
+       if (!allocated || mas->alloc->node_count == MAPLE_ALLOC_SLOTS) {
                node = (struct maple_alloc *)mt_alloc_one(gfp);
                if (!node)
                        goto nomem_one;
 
-               if (allocated)
+               if (allocated) {
                        node->slot[0] = mas->alloc;
+                       node->node_count = 1;
+               } else {
+                       node->node_count = 0;
+               }
 
-               success++;
                mas->alloc = node;
+               node->total = ++allocated;
                requested--;
        }
 
        node = mas->alloc;
+       node->request_count = 0;
        while (requested) {
                max_req = MAPLE_ALLOC_SLOTS;
-               if (node->slot[0]) {
-                       unsigned int offset = node->node_count + 1;
+               if (node->node_count) {
+                       unsigned int offset = node->node_count;
 
                        slots = (void **)&node->slot[offset];
                        max_req -= offset;
@@ -1253,15 +1261,13 @@ static inline void mas_alloc_nodes(struct ma_state *mas, gfp_t gfp)
                        goto nomem_bulk;
 
                node->node_count += count;
-               /* zero indexed. */
-               if (slots == (void **)&node->slot)
-                       node->node_count--;
-
-               success += count;
+               allocated += count;
                node = node->slot[0];
+               node->node_count = 0;
+               node->request_count = 0;
                requested -= count;
        }
-       mas->alloc->total = success;
+       mas->alloc->total = allocated;
        return;
 
 nomem_bulk:
@@ -1270,7 +1276,7 @@ nomem_bulk:
 nomem_one:
        mas_set_alloc_req(mas, requested);
        if (mas->alloc && !(((unsigned long)mas->alloc & 0x1)))
-               mas->alloc->total = success;
+               mas->alloc->total = allocated;
        mas_set_err(mas, -ENOMEM);
        return;
 
@@ -1328,7 +1334,7 @@ static void mas_node_count(struct ma_state *mas, int count)
  * mas_start() - Sets up maple state for operations.
  * @mas: The maple state.
  *
- * If mas->node == MAS_START, then set the min, max, depth, and offset to
+ * If mas->node == MAS_START, then set the min, max and depth to
  * defaults.
  *
  * Return:
@@ -1342,22 +1348,22 @@ static inline struct maple_enode *mas_start(struct ma_state *mas)
        if (likely(mas_is_start(mas))) {
                struct maple_enode *root;
 
-               mas->node = MAS_NONE;
                mas->min = 0;
                mas->max = ULONG_MAX;
                mas->depth = 0;
-               mas->offset = 0;
 
                root = mas_root(mas);
                /* Tree with nodes */
                if (likely(xa_is_node(root))) {
                        mas->depth = 1;
                        mas->node = mte_safe_root(root);
+                       mas->offset = 0;
                        return NULL;
                }
 
                /* empty tree */
                if (unlikely(!root)) {
+                       mas->node = MAS_NONE;
                        mas->offset = MAPLE_NODE_SLOTS;
                        return NULL;
                }
@@ -1393,6 +1399,9 @@ static inline unsigned char ma_data_end(struct maple_node *node,
 {
        unsigned char offset;
 
+       if (!pivots)
+               return 0;
+
        if (type == maple_arange_64)
                return ma_meta_end(node, type);
 
@@ -1428,6 +1437,9 @@ static inline unsigned char mas_data_end(struct ma_state *mas)
                return ma_meta_end(node, type);
 
        pivots = ma_pivots(node, type);
+       if (unlikely(ma_dead_node(node)))
+               return 0;
+
        offset = mt_pivots[type] - 1;
        if (likely(!pivots[offset]))
                return ma_meta_end(node, type);
@@ -3653,10 +3665,9 @@ static inline int mas_root_expand(struct ma_state *mas, void *entry)
                slot++;
        mas->depth = 1;
        mas_set_height(mas);
-
+       ma_set_meta(node, maple_leaf_64, 0, slot);
        /* swap the new root into the tree */
        rcu_assign_pointer(mas->tree->ma_root, mte_mk_root(mas->node));
-       ma_set_meta(node, maple_leaf_64, 0, slot);
        return slot;
 }
 
@@ -3869,18 +3880,13 @@ static inline void *mtree_lookup_walk(struct ma_state *mas)
                end = ma_data_end(node, type, pivots, max);
                if (unlikely(ma_dead_node(node)))
                        goto dead_node;
-
-               if (pivots[offset] >= mas->index)
-                       goto next;
-
                do {
-                       offset++;
-               } while ((offset < end) && (pivots[offset] < mas->index));
-
-               if (likely(offset > end))
-                       max = pivots[offset];
+                       if (pivots[offset] >= mas->index) {
+                               max = pivots[offset];
+                               break;
+                       }
+               } while (++offset < end);
 
-next:
                slots = ma_slots(node, type);
                next = mt_slot(mas->tree, slots, offset);
                if (unlikely(ma_dead_node(node)))
@@ -4499,6 +4505,9 @@ static inline int mas_prev_node(struct ma_state *mas, unsigned long min)
        node = mas_mn(mas);
        slots = ma_slots(node, mt);
        pivots = ma_pivots(node, mt);
+       if (unlikely(ma_dead_node(node)))
+               return 1;
+
        mas->max = pivots[offset];
        if (offset)
                mas->min = pivots[offset - 1] + 1;
@@ -4520,6 +4529,9 @@ static inline int mas_prev_node(struct ma_state *mas, unsigned long min)
                slots = ma_slots(node, mt);
                pivots = ma_pivots(node, mt);
                offset = ma_data_end(node, mt, pivots, mas->max);
+               if (unlikely(ma_dead_node(node)))
+                       return 1;
+
                if (offset)
                        mas->min = pivots[offset - 1] + 1;
 
@@ -4568,6 +4580,7 @@ static inline int mas_next_node(struct ma_state *mas, struct maple_node *node,
        struct maple_enode *enode;
        int level = 0;
        unsigned char offset;
+       unsigned char node_end;
        enum maple_type mt;
        void __rcu **slots;
 
@@ -4591,7 +4604,11 @@ static inline int mas_next_node(struct ma_state *mas, struct maple_node *node,
                node = mas_mn(mas);
                mt = mte_node_type(mas->node);
                pivots = ma_pivots(node, mt);
-       } while (unlikely(offset == ma_data_end(node, mt, pivots, mas->max)));
+               node_end = ma_data_end(node, mt, pivots, mas->max);
+               if (unlikely(ma_dead_node(node)))
+                       return 1;
+
+       } while (unlikely(offset == node_end));
 
        slots = ma_slots(node, mt);
        pivot = mas_safe_pivot(mas, pivots, ++offset, mt);
@@ -4607,6 +4624,9 @@ static inline int mas_next_node(struct ma_state *mas, struct maple_node *node,
                mt = mte_node_type(mas->node);
                slots = ma_slots(node, mt);
                pivots = ma_pivots(node, mt);
+               if (unlikely(ma_dead_node(node)))
+                       return 1;
+
                offset = 0;
                pivot = pivots[0];
        }
@@ -4653,16 +4673,19 @@ static inline void *mas_next_nentry(struct ma_state *mas,
                return NULL;
        }
 
-       pivots = ma_pivots(node, type);
        slots = ma_slots(node, type);
+       pivots = ma_pivots(node, type);
+       count = ma_data_end(node, type, pivots, mas->max);
+       if (unlikely(ma_dead_node(node)))
+               return NULL;
+
        mas->index = mas_safe_min(mas, pivots, mas->offset);
-       if (ma_dead_node(node))
+       if (unlikely(ma_dead_node(node)))
                return NULL;
 
        if (mas->index > max)
                return NULL;
 
-       count = ma_data_end(node, type, pivots, mas->max);
        if (mas->offset > count)
                return NULL;
 
@@ -4737,6 +4760,11 @@ static inline void *mas_next_entry(struct ma_state *mas, unsigned long limit)
        unsigned long last;
        enum maple_type mt;
 
+       if (mas->index > limit) {
+               mas->index = mas->last = limit;
+               mas_pause(mas);
+               return NULL;
+       }
        last = mas->last;
 retry:
        offset = mas->offset;
@@ -4810,6 +4838,11 @@ retry:
 
        slots = ma_slots(mn, mt);
        pivots = ma_pivots(mn, mt);
+       if (unlikely(ma_dead_node(mn))) {
+               mas_rewalk(mas, index);
+               goto retry;
+       }
+
        if (offset == mt_pivots[mt])
                pivot = mas->max;
        else
@@ -4843,6 +4876,11 @@ static inline void *mas_prev_entry(struct ma_state *mas, unsigned long min)
 {
        void *entry;
 
+       if (mas->index < min) {
+               mas->index = mas->last = min;
+               mas->node = MAS_NONE;
+               return NULL;
+       }
 retry:
        while (likely(!mas_is_none(mas))) {
                entry = mas_prev_nentry(mas, min, mas->index);
@@ -4882,7 +4920,7 @@ static bool mas_rev_awalk(struct ma_state *mas, unsigned long size)
        unsigned long *pivots, *gaps;
        void __rcu **slots;
        unsigned long gap = 0;
-       unsigned long max, min, index;
+       unsigned long max, min;
        unsigned char offset;
 
        if (unlikely(mas_is_err(mas)))
@@ -4904,8 +4942,7 @@ static bool mas_rev_awalk(struct ma_state *mas, unsigned long size)
                min = mas_safe_min(mas, pivots, --offset);
 
        max = mas_safe_pivot(mas, pivots, offset, type);
-       index = mas->index;
-       while (index <= max) {
+       while (mas->index <= max) {
                gap = 0;
                if (gaps)
                        gap = gaps[offset];
@@ -4936,10 +4973,8 @@ static bool mas_rev_awalk(struct ma_state *mas, unsigned long size)
                min = mas_safe_min(mas, pivots, offset);
        }
 
-       if (unlikely(index > max)) {
-               mas_set_err(mas, -EBUSY);
-               return false;
-       }
+       if (unlikely((mas->index > max) || (size - 1 > max - mas->index)))
+               goto no_space;
 
        if (unlikely(ma_is_leaf(type))) {
                mas->offset = offset;
@@ -4956,9 +4991,11 @@ static bool mas_rev_awalk(struct ma_state *mas, unsigned long size)
        return false;
 
 ascend:
-       if (mte_is_root(mas->node))
-               mas_set_err(mas, -EBUSY);
+       if (!mte_is_root(mas->node))
+               return false;
 
+no_space:
+       mas_set_err(mas, -EBUSY);
        return false;
 }
 
@@ -5088,35 +5125,21 @@ static inline bool mas_rewind_node(struct ma_state *mas)
  */
 static inline bool mas_skip_node(struct ma_state *mas)
 {
-       unsigned char slot, slot_count;
-       unsigned long *pivots;
-       enum maple_type mt;
+       if (mas_is_err(mas))
+               return false;
 
-       mt = mte_node_type(mas->node);
-       slot_count = mt_slots[mt] - 1;
        do {
                if (mte_is_root(mas->node)) {
-                       slot = mas->offset;
-                       if (slot > slot_count) {
+                       if (mas->offset >= mas_data_end(mas)) {
                                mas_set_err(mas, -EBUSY);
                                return false;
                        }
                } else {
                        mas_ascend(mas);
-                       slot = mas->offset;
-                       mt = mte_node_type(mas->node);
-                       slot_count = mt_slots[mt] - 1;
                }
-       } while (slot > slot_count);
-
-       mas->offset = ++slot;
-       pivots = ma_pivots(mas_mn(mas), mt);
-       if (slot > 0)
-               mas->min = pivots[slot - 1] + 1;
-
-       if (slot <= slot_count)
-               mas->max = pivots[slot];
+       } while (mas->offset >= mas_data_end(mas));
 
+       mas->offset++;
        return true;
 }
 
@@ -5605,6 +5628,9 @@ static inline void mte_destroy_walk(struct maple_enode *enode,
 
 static void mas_wr_store_setup(struct ma_wr_state *wr_mas)
 {
+       if (unlikely(mas_is_paused(wr_mas->mas)))
+               mas_reset(wr_mas->mas);
+
        if (!mas_is_start(wr_mas->mas)) {
                if (mas_is_none(wr_mas->mas)) {
                        mas_reset(wr_mas->mas);
@@ -5740,6 +5766,7 @@ int mas_preallocate(struct ma_state *mas, void *entry, gfp_t gfp)
 void mas_destroy(struct ma_state *mas)
 {
        struct maple_alloc *node;
+       unsigned long total;
 
        /*
         * When using mas_for_each() to insert an expected number of elements,
@@ -5762,14 +5789,20 @@ void mas_destroy(struct ma_state *mas)
        }
        mas->mas_flags &= ~(MA_STATE_BULK|MA_STATE_PREALLOC);
 
-       while (mas->alloc && !((unsigned long)mas->alloc & 0x1)) {
+       total = mas_allocated(mas);
+       while (total) {
                node = mas->alloc;
                mas->alloc = node->slot[0];
-               if (node->node_count > 0)
-                       mt_free_bulk(node->node_count,
-                                    (void __rcu **)&node->slot[1]);
+               if (node->node_count > 1) {
+                       size_t count = node->node_count - 1;
+
+                       mt_free_bulk(count, (void __rcu **)&node->slot[1]);
+                       total -= count;
+               }
                kmem_cache_free(maple_node_cache, node);
+               total--;
        }
+
        mas->alloc = NULL;
 }
 EXPORT_SYMBOL_GPL(mas_destroy);
@@ -5907,6 +5940,7 @@ void *mas_prev(struct ma_state *mas, unsigned long min)
        if (!mas->index) {
                /* Nothing comes before 0 */
                mas->last = 0;
+               mas->node = MAS_NONE;
                return NULL;
        }
 
@@ -5997,6 +6031,9 @@ void *mas_find(struct ma_state *mas, unsigned long max)
                mas->index = ++mas->last;
        }
 
+       if (unlikely(mas_is_none(mas)))
+               mas->node = MAS_START;
+
        if (unlikely(mas_is_start(mas))) {
                /* First run or continue */
                void *entry;
@@ -6608,11 +6645,11 @@ static inline void *mas_first_entry(struct ma_state *mas, struct maple_node *mn,
        while (likely(!ma_is_leaf(mt))) {
                MT_BUG_ON(mas->tree, mte_dead_node(mas->node));
                slots = ma_slots(mn, mt);
-               pivots = ma_pivots(mn, mt);
-               max = pivots[0];
                entry = mas_slot(mas, slots, 0);
+               pivots = ma_pivots(mn, mt);
                if (unlikely(ma_dead_node(mn)))
                        return NULL;
+               max = pivots[0];
                mas->node = entry;
                mn = mas_mn(mas);
                mt = mte_node_type(mas->node);
@@ -6632,13 +6669,13 @@ static inline void *mas_first_entry(struct ma_state *mas, struct maple_node *mn,
        if (likely(entry))
                return entry;
 
-       pivots = ma_pivots(mn, mt);
-       mas->index = pivots[0] + 1;
        mas->offset = 1;
        entry = mas_slot(mas, slots, 1);
+       pivots = ma_pivots(mn, mt);
        if (unlikely(ma_dead_node(mn)))
                return NULL;
 
+       mas->index = pivots[0] + 1;
        if (mas->index > limit)
                goto none;