maple_tree: try harder to keep active node after mas_next()
authorLiam R. Howlett <Liam.Howlett@oracle.com>
Thu, 18 May 2023 14:55:32 +0000 (10:55 -0400)
committerAndrew Morton <akpm@linux-foundation.org>
Fri, 9 Jun 2023 23:25:32 +0000 (16:25 -0700)
Clean up the mas_next() call to try and keep a node reference when
possible.  This will avoid re-walking the tree in most cases.

Also clean up the single entry tree handling to ensure index/last are
consistent with what one would expect.  (returning NULL with limit of
1-oo).

Link: https://lkml.kernel.org/r/20230518145544.1722059-24-Liam.Howlett@oracle.com
Signed-off-by: Liam R. Howlett <Liam.Howlett@oracle.com>
Cc: David Binderman <dcb314@hotmail.com>
Cc: Peng Zhang <zhangpeng.00@bytedance.com>
Cc: Sergey Senozhatsky <senozhatsky@chromium.org>
Cc: Vernon Yang <vernon2gm@gmail.com>
Cc: Wei Yang <richard.weiyang@gmail.com>
Signed-off-by: Andrew Morton <akpm@linux-foundation.org>
lib/maple_tree.c

index 9eec435..3fa1276 100644 (file)
@@ -4727,33 +4727,25 @@ static inline void *mas_next_nentry(struct ma_state *mas,
                if (ma_dead_node(node))
                        return NULL;
 
+               mas->last = pivot;
                if (entry)
-                       goto found;
+                       return entry;
 
                if (pivot >= max)
                        return NULL;
 
+               if (pivot >= mas->max)
+                       return NULL;
+
                mas->index = pivot + 1;
                mas->offset++;
        }
 
-       if (mas->index > mas->max) {
-               mas->index = mas->last;
-               return NULL;
-       }
-
-       pivot = mas_safe_pivot(mas, pivots, mas->offset, type);
+       pivot = mas_logical_pivot(mas, pivots, mas->offset, type);
        entry = mas_slot(mas, slots, mas->offset);
        if (ma_dead_node(node))
                return NULL;
 
-       if (!pivot)
-               return NULL;
-
-       if (!entry)
-               return NULL;
-
-found:
        mas->last = pivot;
        return entry;
 }
@@ -4782,21 +4774,15 @@ retry:
 static inline void *mas_next_entry(struct ma_state *mas, unsigned long limit)
 {
        void *entry = NULL;
-       struct maple_enode *prev_node;
        struct maple_node *node;
-       unsigned char offset;
        unsigned long last;
        enum maple_type mt;
 
-       if (mas->index > limit) {
-               mas->index = mas->last = limit;
-               mas_pause(mas);
+       if (mas->last >= limit)
                return NULL;
-       }
+
        last = mas->last;
 retry:
-       offset = mas->offset;
-       prev_node = mas->node;
        node = mas_mn(mas);
        mt = mte_node_type(mas->node);
        mas->offset++;
@@ -4815,12 +4801,10 @@ retry:
                if (likely(entry))
                        return entry;
 
-               if (unlikely((mas->index > limit)))
-                       break;
+               if (unlikely((mas->last >= limit)))
+                       return NULL;
 
 next_node:
-               prev_node = mas->node;
-               offset = mas->offset;
                if (unlikely(mas_next_node(mas, node, limit))) {
                        mas_rewalk(mas, last);
                        goto retry;
@@ -4830,9 +4814,6 @@ next_node:
                mt = mte_node_type(mas->node);
        }
 
-       mas->index = mas->last = limit;
-       mas->offset = offset;
-       mas->node = prev_node;
        return NULL;
 }
 
@@ -5914,6 +5895,8 @@ EXPORT_SYMBOL_GPL(mas_expected_entries);
  */
 void *mas_next(struct ma_state *mas, unsigned long max)
 {
+       bool was_none = mas_is_none(mas);
+
        if (mas_is_none(mas) || mas_is_paused(mas))
                mas->node = MAS_START;
 
@@ -5921,16 +5904,16 @@ void *mas_next(struct ma_state *mas, unsigned long max)
                mas_walk(mas); /* Retries on dead nodes handled by mas_walk */
 
        if (mas_is_ptr(mas)) {
-               if (!mas->index) {
-                       mas->index = 1;
-                       mas->last = ULONG_MAX;
+               if (was_none && mas->index == 0) {
+                       mas->index = mas->last = 0;
+                       return mas_root(mas);
                }
+               mas->index = 1;
+               mas->last = ULONG_MAX;
+               mas->node = MAS_NONE;
                return NULL;
        }
 
-       if (mas->last == ULONG_MAX)
-               return NULL;
-
        /* Retries on dead nodes handled by mas_next_entry */
        return mas_next_entry(mas, max);
 }
@@ -6054,17 +6037,25 @@ EXPORT_SYMBOL_GPL(mas_pause);
  */
 void *mas_find(struct ma_state *mas, unsigned long max)
 {
+       if (unlikely(mas_is_none(mas))) {
+               if (unlikely(mas->last >= max))
+                       return NULL;
+
+               mas->index = mas->last;
+               mas->node = MAS_START;
+       }
+
        if (unlikely(mas_is_paused(mas))) {
-               if (unlikely(mas->last == ULONG_MAX)) {
-                       mas->node = MAS_NONE;
+               if (unlikely(mas->last >= max))
                        return NULL;
-               }
+
                mas->node = MAS_START;
                mas->index = ++mas->last;
        }
 
-       if (unlikely(mas_is_none(mas)))
-               mas->node = MAS_START;
+
+       if (unlikely(mas_is_ptr(mas)))
+               goto ptr_out_of_range;
 
        if (unlikely(mas_is_start(mas))) {
                /* First run or continue */
@@ -6076,13 +6067,27 @@ void *mas_find(struct ma_state *mas, unsigned long max)
                entry = mas_walk(mas);
                if (entry)
                        return entry;
+
        }
 
-       if (unlikely(!mas_searchable(mas)))
+       if (unlikely(!mas_searchable(mas))) {
+               if (unlikely(mas_is_ptr(mas)))
+                       goto ptr_out_of_range;
+
+               return NULL;
+       }
+
+       if (mas->index == max)
                return NULL;
 
        /* Retries on dead nodes handled by mas_next_entry */
        return mas_next_entry(mas, max);
+
+ptr_out_of_range:
+       mas->node = MAS_NONE;
+       mas->index = 1;
+       mas->last = ULONG_MAX;
+       return NULL;
 }
 EXPORT_SYMBOL_GPL(mas_find);
 
@@ -6513,7 +6518,7 @@ retry:
        if (entry)
                goto unlock;
 
-       while (mas_searchable(&mas) && (mas.index < max)) {
+       while (mas_searchable(&mas) && (mas.last < max)) {
                entry = mas_next_entry(&mas, max);
                if (likely(entry && !xa_is_zero(entry)))
                        break;