maple_tree: fix freeing of nodes in rcu mode
authorLiam R. Howlett <Liam.Howlett@Oracle.com>
Tue, 11 Apr 2023 15:10:51 +0000 (11:10 -0400)
committerGreg Kroah-Hartman <gregkh@linuxfoundation.org>
Thu, 13 Apr 2023 14:55:39 +0000 (16:55 +0200)
commit 2e5b4921f8efc9e845f4f04741797d16f36847eb upstream.

The walk to destroy the nodes was not always setting the node type and
would result in a destroy method potentially using the values as nodes.
Avoid this by setting the correct node types.  This is necessary for the
RCU mode of the maple tree.

Link: https://lkml.kernel.org/r/20230227173632.3292573-4-surenb@google.com
Cc: <Stable@vger.kernel.org>
Fixes: 54a611b60590 ("Maple Tree: add new data structure")
Signed-off-by: Liam Howlett <Liam.Howlett@oracle.com>
Signed-off-by: Greg Kroah-Hartman <gregkh@linuxfoundation.org>
lib/maple_tree.c

index 6b451a4..7f451e0 100644 (file)
@@ -893,6 +893,44 @@ static inline void ma_set_meta(struct maple_node *mn, enum maple_type mt,
 }
 
 /*
+ * mas_clear_meta() - clear the metadata information of a node, if it exists
+ * @mas: The maple state
+ * @mn: The maple node
+ * @mt: The maple node type
+ * @offset: The offset of the highest sub-gap in this node.
+ * @end: The end of the data in this node.
+ */
+static inline void mas_clear_meta(struct ma_state *mas, struct maple_node *mn,
+                                 enum maple_type mt)
+{
+       struct maple_metadata *meta;
+       unsigned long *pivots;
+       void __rcu **slots;
+       void *next;
+
+       switch (mt) {
+       case maple_range_64:
+               pivots = mn->mr64.pivot;
+               if (unlikely(pivots[MAPLE_RANGE64_SLOTS - 2])) {
+                       slots = mn->mr64.slot;
+                       next = mas_slot_locked(mas, slots,
+                                              MAPLE_RANGE64_SLOTS - 1);
+                       if (unlikely((mte_to_node(next) && mte_node_type(next))))
+                               return; /* The last slot is a node, no metadata */
+               }
+               fallthrough;
+       case maple_arange_64:
+               meta = ma_meta(mn, mt);
+               break;
+       default:
+               return;
+       }
+
+       meta->gap = 0;
+       meta->end = 0;
+}
+
+/*
  * ma_meta_end() - Get the data end of a node from the metadata
  * @mn: The maple node
  * @mt: The maple node type
@@ -5433,20 +5471,22 @@ no_gap:
  * mas_dead_leaves() - Mark all leaves of a node as dead.
  * @mas: The maple state
  * @slots: Pointer to the slot array
+ * @type: The maple node type
  *
  * Must hold the write lock.
  *
  * Return: The number of leaves marked as dead.
  */
 static inline
-unsigned char mas_dead_leaves(struct ma_state *mas, void __rcu **slots)
+unsigned char mas_dead_leaves(struct ma_state *mas, void __rcu **slots,
+                             enum maple_type mt)
 {
        struct maple_node *node;
        enum maple_type type;
        void *entry;
        int offset;
 
-       for (offset = 0; offset < mt_slot_count(mas->node); offset++) {
+       for (offset = 0; offset < mt_slots[mt]; offset++) {
                entry = mas_slot_locked(mas, slots, offset);
                type = mte_node_type(entry);
                node = mte_to_node(entry);
@@ -5465,14 +5505,13 @@ unsigned char mas_dead_leaves(struct ma_state *mas, void __rcu **slots)
 
 static void __rcu **mas_dead_walk(struct ma_state *mas, unsigned char offset)
 {
-       struct maple_node *node, *next;
+       struct maple_node *next;
        void __rcu **slots = NULL;
 
        next = mas_mn(mas);
        do {
-               mas->node = ma_enode_ptr(next);
-               node = mas_mn(mas);
-               slots = ma_slots(node, node->type);
+               mas->node = mt_mk_node(next, next->type);
+               slots = ma_slots(next, next->type);
                next = mas_slot_locked(mas, slots, offset);
                offset = 0;
        } while (!ma_is_leaf(next->type));
@@ -5536,11 +5575,14 @@ static inline void __rcu **mas_destroy_descend(struct ma_state *mas,
                node = mas_mn(mas);
                slots = ma_slots(node, mte_node_type(mas->node));
                next = mas_slot_locked(mas, slots, 0);
-               if ((mte_dead_node(next)))
+               if ((mte_dead_node(next))) {
+                       mte_to_node(next)->type = mte_node_type(next);
                        next = mas_slot_locked(mas, slots, 1);
+               }
 
                mte_set_node_dead(mas->node);
                node->type = mte_node_type(mas->node);
+               mas_clear_meta(mas, node, node->type);
                node->piv_parent = prev;
                node->parent_slot = offset;
                offset = 0;
@@ -5560,13 +5602,18 @@ static void mt_destroy_walk(struct maple_enode *enode, unsigned char ma_flags,
 
        MA_STATE(mas, &mt, 0, 0);
 
-       if (mte_is_leaf(enode))
+       mas.node = enode;
+       if (mte_is_leaf(enode)) {
+               node->type = mte_node_type(enode);
                goto free_leaf;
+       }
 
+       ma_flags &= ~MT_FLAGS_LOCK_MASK;
        mt_init_flags(&mt, ma_flags);
        mas_lock(&mas);
 
-       mas.node = start = enode;
+       mte_to_node(enode)->ma_flags = ma_flags;
+       start = enode;
        slots = mas_destroy_descend(&mas, start, 0);
        node = mas_mn(&mas);
        do {
@@ -5574,7 +5621,8 @@ static void mt_destroy_walk(struct maple_enode *enode, unsigned char ma_flags,
                unsigned char offset;
                struct maple_enode *parent, *tmp;
 
-               node->slot_len = mas_dead_leaves(&mas, slots);
+               node->type = mte_node_type(mas.node);
+               node->slot_len = mas_dead_leaves(&mas, slots, node->type);
                if (free)
                        mt_free_bulk(node->slot_len, slots);
                offset = node->parent_slot + 1;
@@ -5598,7 +5646,8 @@ next:
        } while (start != mas.node);
 
        node = mas_mn(&mas);
-       node->slot_len = mas_dead_leaves(&mas, slots);
+       node->type = mte_node_type(mas.node);
+       node->slot_len = mas_dead_leaves(&mas, slots, node->type);
        if (free)
                mt_free_bulk(node->slot_len, slots);
 
@@ -5608,6 +5657,8 @@ start_slots_free:
 free_leaf:
        if (free)
                mt_free_rcu(&node->rcu);
+       else
+               mas_clear_meta(&mas, node, node->type);
 }
 
 /*