wifi: mac80211: maintain link-sta hash table
authorJohannes Berg <johannes.berg@intel.com>
Tue, 14 Jun 2022 11:07:42 +0000 (13:07 +0200)
committerJohannes Berg <johannes.berg@intel.com>
Mon, 20 Jun 2022 10:57:08 +0000 (12:57 +0200)
Maintain a hash table of link-sta addresses so we can find
them for management frames etc. where addresses haven't
been replaced by the drivers to the MLD address yet.

Signed-off-by: Johannes Berg <johannes.berg@intel.com>
net/mac80211/ieee80211_i.h
net/mac80211/sta_info.c
net/mac80211/sta_info.h

index 46f4e89825a03c14020f4c3276f6db231c0333a7..2190d08f4e344496d600ea58e3368b57138de0f9 100644 (file)
@@ -1366,6 +1366,7 @@ struct ieee80211_local {
        unsigned long num_sta;
        struct list_head sta_list;
        struct rhltable sta_hash;
+       struct rhltable link_sta_hash;
        struct timer_list sta_cleanup;
        int sta_generation;
 
index 1f4189e08675ab05a104e9bb039c1796d1fd501a..ccd792af6b5695bf5d29f96b872b2f0da22a5e41 100644 (file)
@@ -79,6 +79,15 @@ static const struct rhashtable_params sta_rht_params = {
        .max_size = CONFIG_MAC80211_STA_HASH_MAX_SIZE,
 };
 
+static const struct rhashtable_params link_sta_rht_params = {
+       .nelem_hint = 3, /* start small */
+       .automatic_shrinking = true,
+       .head_offset = offsetof(struct link_sta_info, link_hash_node),
+       .key_offset = offsetof(struct link_sta_info, addr),
+       .key_len = ETH_ALEN,
+       .max_size = CONFIG_MAC80211_STA_HASH_MAX_SIZE,
+};
+
 /* Caller must hold local->sta_mtx */
 static int sta_info_hash_del(struct ieee80211_local *local,
                             struct sta_info *sta)
@@ -87,6 +96,14 @@ static int sta_info_hash_del(struct ieee80211_local *local,
                               sta_rht_params);
 }
 
+static int link_sta_info_hash_del(struct ieee80211_local *local,
+                                 struct link_sta_info *link_sta)
+{
+       return rhltable_remove(&local->link_sta_hash,
+                              &link_sta->link_hash_node,
+                              link_sta_rht_params);
+}
+
 static void __cleanup_single_sta(struct sta_info *sta)
 {
        int ac, i;
@@ -216,6 +233,37 @@ struct sta_info *sta_info_get_bss(struct ieee80211_sub_if_data *sdata,
        return NULL;
 }
 
+struct rhlist_head *link_sta_info_hash_lookup(struct ieee80211_local *local,
+                                             const u8 *addr)
+{
+       return rhltable_lookup(&local->link_sta_hash, addr,
+                              link_sta_rht_params);
+}
+
+struct link_sta_info *
+link_sta_info_get_bss(struct ieee80211_sub_if_data *sdata, const u8 *addr)
+{
+       struct ieee80211_local *local = sdata->local;
+       struct rhlist_head *tmp;
+       struct link_sta_info *link_sta;
+
+       rcu_read_lock();
+       for_each_link_sta_info(local, addr, link_sta, tmp) {
+               struct sta_info *sta = link_sta->sta;
+
+               if (sta->sdata == sdata ||
+                   (sta->sdata->bss && sta->sdata->bss == sdata->bss)) {
+                       rcu_read_unlock();
+                       /* this is safe as the caller must already hold
+                        * another rcu read section or the mutex
+                        */
+                       return link_sta;
+               }
+       }
+       rcu_read_unlock();
+       return NULL;
+}
+
 struct sta_info *sta_info_get_by_addrs(struct ieee80211_local *local,
                                       const u8 *sta_addr, const u8 *vif_addr)
 {
@@ -256,7 +304,8 @@ static void sta_info_free_link(struct link_sta_info *link_sta)
        free_percpu(link_sta->pcpu_rx_stats);
 }
 
-static void sta_remove_link(struct sta_info *sta, unsigned int link_id)
+static void sta_remove_link(struct sta_info *sta, unsigned int link_id,
+                           bool unhash)
 {
        struct sta_link_alloc *alloc = NULL;
        struct link_sta_info *link_sta;
@@ -267,6 +316,9 @@ static void sta_remove_link(struct sta_info *sta, unsigned int link_id)
        if (WARN_ON(!link_sta))
                return;
 
+       if (unhash)
+               link_sta_info_hash_del(sta->local, link_sta);
+
        if (link_sta != &sta->deflink)
                alloc = container_of(link_sta, typeof(*alloc), info);
 
@@ -298,7 +350,7 @@ void sta_info_free(struct ieee80211_local *local, struct sta_info *sta)
                if (!(sta->sta.valid_links & BIT(i)))
                        continue;
 
-               sta_remove_link(sta, i);
+               sta_remove_link(sta, i, true);
        }
 
        /*
@@ -1262,6 +1314,12 @@ int sta_info_init(struct ieee80211_local *local)
        if (err)
                return err;
 
+       err = rhltable_init(&local->link_sta_hash, &link_sta_rht_params);
+       if (err) {
+               rhltable_destroy(&local->sta_hash);
+               return err;
+       }
+
        spin_lock_init(&local->tim_lock);
        mutex_init(&local->sta_mtx);
        INIT_LIST_HEAD(&local->sta_list);
@@ -1274,6 +1332,7 @@ void sta_info_stop(struct ieee80211_local *local)
 {
        del_timer_sync(&local->sta_cleanup);
        rhltable_destroy(&local->sta_hash);
+       rhltable_destroy(&local->link_sta_hash);
 }
 
 
@@ -2685,6 +2744,14 @@ int ieee80211_sta_allocate_link(struct sta_info *sta, unsigned int link_id)
        return 0;
 }
 
+static int link_sta_info_hash_add(struct ieee80211_local *local,
+                                 struct link_sta_info *link_sta)
+{
+       return rhltable_insert(&local->link_sta_hash,
+                              &link_sta->link_hash_node,
+                              link_sta_rht_params);
+}
+
 int ieee80211_sta_activate_link(struct sta_info *sta, unsigned int link_id)
 {
        struct ieee80211_sub_if_data *sdata = sta->sdata;
@@ -2701,16 +2768,21 @@ int ieee80211_sta_activate_link(struct sta_info *sta, unsigned int link_id)
 
        sta->sta.valid_links = new_links;
 
-       if (!test_sta_flag(sta, WLAN_STA_INSERTED))
-               return 0;
+       if (!test_sta_flag(sta, WLAN_STA_INSERTED)) {
+               ret = 0;
+               goto hash;
+       }
 
        ret = drv_change_sta_links(sdata->local, sdata, &sta->sta,
                                   old_links, new_links);
        if (ret) {
                sta->sta.valid_links = old_links;
-               sta_remove_link(sta, link_id);
+               sta_remove_link(sta, link_id, false);
        }
 
+hash:
+       link_sta_info_hash_add(sdata->local, link_sta);
+
        return ret;
 }
 
@@ -2727,5 +2799,5 @@ void ieee80211_sta_remove_link(struct sta_info *sta, unsigned int link_id)
                                     sta->sta.valid_links,
                                     sta->sta.valid_links & ~BIT(link_id));
 
-       sta_remove_link(sta, link_id);
+       sta_remove_link(sta, link_id, true);
 }
index 27c96c04b13f7690a0010fd6985d0b2e64223baa..2184307906602991950490299bb2011c8bc5051d 100644 (file)
@@ -491,6 +491,7 @@ struct ieee80211_fragment_cache {
  *     same for non-MLD STA. This is used as key for searching link STA
  * @link_id: Link ID uniquely identifying the link STA. This is 0 for non-MLD
  *     and set to the corresponding vif LinkId for MLD STA
+ * @link_hash_node: hash node for rhashtable
  * @sta: Points to the STA info
  * @gtk: group keys negotiated with this station, if any
  * @tx_stats: TX statistics
@@ -523,7 +524,7 @@ struct link_sta_info {
        u8 addr[ETH_ALEN];
        u8 link_id;
 
-       /* TODO rhash head/node for finding link_sta based on addr */
+       struct rhlist_head link_hash_node;
 
        struct sta_info *sta;
        struct ieee80211_key __rcu *gtk[NUM_DEFAULT_KEYS +
@@ -824,6 +825,17 @@ struct sta_info *sta_info_get_by_addrs(struct ieee80211_local *local,
        rhl_for_each_entry_rcu(_sta, _tmp,                              \
                               sta_info_hash_lookup(local, _addr), hash_node)
 
+struct rhlist_head *link_sta_info_hash_lookup(struct ieee80211_local *local,
+                                             const u8 *addr);
+
+#define for_each_link_sta_info(local, _addr, _sta, _tmp)               \
+       rhl_for_each_entry_rcu(_sta, _tmp,                              \
+                              link_sta_info_hash_lookup(local, _addr), \
+                              link_hash_node)
+
+struct link_sta_info *
+link_sta_info_get_bss(struct ieee80211_sub_if_data *sdata, const u8 *addr);
+
 /*
  * Get STA info by index, BROKEN!
  */