maple_tree: refine ma_state init from mas_start()
[platform/kernel/linux-starfive.git] / lib / maple_tree.c
index 203bd3e..54484bd 100644 (file)
@@ -534,6 +534,7 @@ static inline bool ma_dead_node(const struct maple_node *node)
 
        return (parent == node);
 }
+
 /*
  * mte_dead_node() - check if the @enode is dead.
  * @enode: The encoded maple node
@@ -615,6 +616,8 @@ static inline unsigned int mas_alloc_req(const struct ma_state *mas)
  * @node - the maple node
  * @type - the node type
  *
+ * In the event of a dead node, this array may be %NULL
+ *
  * Return: A pointer to the maple node pivots
  */
 static inline unsigned long *ma_pivots(struct maple_node *node,
@@ -1086,8 +1089,11 @@ static int mas_ascend(struct ma_state *mas)
                a_type = mas_parent_enum(mas, p_enode);
                a_node = mte_parent(p_enode);
                a_slot = mte_parent_slot(p_enode);
-               pivots = ma_pivots(a_node, a_type);
                a_enode = mt_mk_node(a_node, a_type);
+               pivots = ma_pivots(a_node, a_type);
+
+               if (unlikely(ma_dead_node(a_node)))
+                       return 1;
 
                if (!set_min && a_slot) {
                        set_min = true;
@@ -1328,7 +1334,7 @@ static void mas_node_count(struct ma_state *mas, int count)
  * mas_start() - Sets up maple state for operations.
  * @mas: The maple state.
  *
- * If mas->node == MAS_START, then set the min, max, depth, and offset to
+ * If mas->node == MAS_START, then set the min, max and depth to
  * defaults.
  *
  * Return:
@@ -1342,22 +1348,22 @@ static inline struct maple_enode *mas_start(struct ma_state *mas)
        if (likely(mas_is_start(mas))) {
                struct maple_enode *root;
 
-               mas->node = MAS_NONE;
                mas->min = 0;
                mas->max = ULONG_MAX;
                mas->depth = 0;
-               mas->offset = 0;
 
                root = mas_root(mas);
                /* Tree with nodes */
                if (likely(xa_is_node(root))) {
                        mas->depth = 1;
                        mas->node = mte_safe_root(root);
+                       mas->offset = 0;
                        return NULL;
                }
 
                /* empty tree */
                if (unlikely(!root)) {
+                       mas->node = MAS_NONE;
                        mas->offset = MAPLE_NODE_SLOTS;
                        return NULL;
                }
@@ -1393,6 +1399,9 @@ static inline unsigned char ma_data_end(struct maple_node *node,
 {
        unsigned char offset;
 
+       if (!pivots)
+               return 0;
+
        if (type == maple_arange_64)
                return ma_meta_end(node, type);
 
@@ -1428,6 +1437,9 @@ static inline unsigned char mas_data_end(struct ma_state *mas)
                return ma_meta_end(node, type);
 
        pivots = ma_pivots(node, type);
+       if (unlikely(ma_dead_node(node)))
+               return 0;
+
        offset = mt_pivots[type] - 1;
        if (likely(!pivots[offset]))
                return ma_meta_end(node, type);
@@ -4493,6 +4505,9 @@ static inline int mas_prev_node(struct ma_state *mas, unsigned long min)
        node = mas_mn(mas);
        slots = ma_slots(node, mt);
        pivots = ma_pivots(node, mt);
+       if (unlikely(ma_dead_node(node)))
+               return 1;
+
        mas->max = pivots[offset];
        if (offset)
                mas->min = pivots[offset - 1] + 1;
@@ -4514,6 +4529,9 @@ static inline int mas_prev_node(struct ma_state *mas, unsigned long min)
                slots = ma_slots(node, mt);
                pivots = ma_pivots(node, mt);
                offset = ma_data_end(node, mt, pivots, mas->max);
+               if (unlikely(ma_dead_node(node)))
+                       return 1;
+
                if (offset)
                        mas->min = pivots[offset - 1] + 1;
 
@@ -4562,6 +4580,7 @@ static inline int mas_next_node(struct ma_state *mas, struct maple_node *node,
        struct maple_enode *enode;
        int level = 0;
        unsigned char offset;
+       unsigned char node_end;
        enum maple_type mt;
        void __rcu **slots;
 
@@ -4585,7 +4604,11 @@ static inline int mas_next_node(struct ma_state *mas, struct maple_node *node,
                node = mas_mn(mas);
                mt = mte_node_type(mas->node);
                pivots = ma_pivots(node, mt);
-       } while (unlikely(offset == ma_data_end(node, mt, pivots, mas->max)));
+               node_end = ma_data_end(node, mt, pivots, mas->max);
+               if (unlikely(ma_dead_node(node)))
+                       return 1;
+
+       } while (unlikely(offset == node_end));
 
        slots = ma_slots(node, mt);
        pivot = mas_safe_pivot(mas, pivots, ++offset, mt);
@@ -4601,6 +4624,9 @@ static inline int mas_next_node(struct ma_state *mas, struct maple_node *node,
                mt = mte_node_type(mas->node);
                slots = ma_slots(node, mt);
                pivots = ma_pivots(node, mt);
+               if (unlikely(ma_dead_node(node)))
+                       return 1;
+
                offset = 0;
                pivot = pivots[0];
        }
@@ -4647,11 +4673,14 @@ static inline void *mas_next_nentry(struct ma_state *mas,
                return NULL;
        }
 
-       pivots = ma_pivots(node, type);
        slots = ma_slots(node, type);
-       mas->index = mas_safe_min(mas, pivots, mas->offset);
+       pivots = ma_pivots(node, type);
        count = ma_data_end(node, type, pivots, mas->max);
-       if (ma_dead_node(node))
+       if (unlikely(ma_dead_node(node)))
+               return NULL;
+
+       mas->index = mas_safe_min(mas, pivots, mas->offset);
+       if (unlikely(ma_dead_node(node)))
                return NULL;
 
        if (mas->index > max)
@@ -4809,6 +4838,11 @@ retry:
 
        slots = ma_slots(mn, mt);
        pivots = ma_pivots(mn, mt);
+       if (unlikely(ma_dead_node(mn))) {
+               mas_rewalk(mas, index);
+               goto retry;
+       }
+
        if (offset == mt_pivots[mt])
                pivot = mas->max;
        else
@@ -4844,7 +4878,7 @@ static inline void *mas_prev_entry(struct ma_state *mas, unsigned long min)
 
        if (mas->index < min) {
                mas->index = mas->last = min;
-               mas_pause(mas);
+               mas->node = MAS_NONE;
                return NULL;
        }
 retry:
@@ -5906,6 +5940,7 @@ void *mas_prev(struct ma_state *mas, unsigned long min)
        if (!mas->index) {
                /* Nothing comes before 0 */
                mas->last = 0;
+               mas->node = MAS_NONE;
                return NULL;
        }
 
@@ -5996,6 +6031,9 @@ void *mas_find(struct ma_state *mas, unsigned long max)
                mas->index = ++mas->last;
        }
 
+       if (unlikely(mas_is_none(mas)))
+               mas->node = MAS_START;
+
        if (unlikely(mas_is_start(mas))) {
                /* First run or continue */
                void *entry;
@@ -6607,11 +6645,11 @@ static inline void *mas_first_entry(struct ma_state *mas, struct maple_node *mn,
        while (likely(!ma_is_leaf(mt))) {
                MT_BUG_ON(mas->tree, mte_dead_node(mas->node));
                slots = ma_slots(mn, mt);
-               pivots = ma_pivots(mn, mt);
-               max = pivots[0];
                entry = mas_slot(mas, slots, 0);
+               pivots = ma_pivots(mn, mt);
                if (unlikely(ma_dead_node(mn)))
                        return NULL;
+               max = pivots[0];
                mas->node = entry;
                mn = mas_mn(mas);
                mt = mte_node_type(mas->node);
@@ -6631,13 +6669,13 @@ static inline void *mas_first_entry(struct ma_state *mas, struct maple_node *mn,
        if (likely(entry))
                return entry;
 
-       pivots = ma_pivots(mn, mt);
-       mas->index = pivots[0] + 1;
        mas->offset = 1;
        entry = mas_slot(mas, slots, 1);
+       pivots = ma_pivots(mn, mt);
        if (unlikely(ma_dead_node(mn)))
                return NULL;
 
+       mas->index = pivots[0] + 1;
        if (mas->index > limit)
                goto none;