wifi: mac80211: properly implement MLO key handling
authorJohannes Berg <johannes.berg@intel.com>
Wed, 17 Aug 2022 09:17:01 +0000 (11:17 +0200)
committerJohannes Berg <johannes.berg@intel.com>
Thu, 25 Aug 2022 08:41:07 +0000 (10:41 +0200)
Implement key installation and lookup (on TX and RX)
for MLO, so we can use multiple GTKs/IGTKs/BIGTKs.

Co-authored-by: Ilan Peer <ilan.peer@intel.com>
Signed-off-by: Ilan Peer <ilan.peer@intel.com>
Signed-off-by: Johannes Berg <johannes.berg@intel.com>
include/net/mac80211.h
net/mac80211/cfg.c
net/mac80211/ieee80211_i.h
net/mac80211/iface.c
net/mac80211/key.c
net/mac80211/key.h
net/mac80211/rx.c
net/mac80211/tx.c

index f198af6..7951400 100644 (file)
@@ -1975,6 +1975,7 @@ enum ieee80211_key_flags {
  *     - Temporal Authenticator Rx MIC Key (64 bits)
  * @icv_len: The ICV length for this key type
  * @iv_len: The IV length for this key type
+ * @link_id: the link ID for MLO, or -1 for non-MLO or pairwise keys
  */
 struct ieee80211_key_conf {
        atomic64_t tx_pn;
@@ -1984,6 +1985,7 @@ struct ieee80211_key_conf {
        u8 hw_key_idx;
        s8 keyidx;
        u16 flags;
+       s8 link_id;
        u8 keylen;
        u8 key[];
 };
index c4c5e2d..854becd 100644 (file)
 #include "wme.h"
 
 static struct ieee80211_link_data *
-ieee80211_link_or_deflink(struct ieee80211_sub_if_data *sdata, int link_id)
+ieee80211_link_or_deflink(struct ieee80211_sub_if_data *sdata, int link_id,
+                         bool require_valid)
 {
        struct ieee80211_link_data *link;
 
        if (link_id < 0) {
-               if (sdata->vif.valid_links)
+               /*
+                * For keys, if sdata is not an MLD, we might not use
+                * the return value at all (if it's not a pairwise key),
+                * so in that case (require_valid==false) don't error.
+                */
+               if (require_valid && sdata->vif.valid_links)
                        return ERR_PTR(-EINVAL);
 
                return &sdata->deflink;
@@ -456,6 +462,8 @@ static int ieee80211_add_key(struct wiphy *wiphy, struct net_device *dev,
                             const u8 *mac_addr, struct key_params *params)
 {
        struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev);
+       struct ieee80211_link_data *link =
+               ieee80211_link_or_deflink(sdata, link_id, false);
        struct ieee80211_local *local = sdata->local;
        struct sta_info *sta = NULL;
        struct ieee80211_key *key;
@@ -464,6 +472,9 @@ static int ieee80211_add_key(struct wiphy *wiphy, struct net_device *dev,
        if (!ieee80211_sdata_running(sdata))
                return -ENETDOWN;
 
+       if (IS_ERR(link))
+               return PTR_ERR(link);
+
        if (pairwise && params->mode == NL80211_KEY_SET_TX)
                return ieee80211_set_tx(sdata, mac_addr, key_idx);
 
@@ -472,6 +483,8 @@ static int ieee80211_add_key(struct wiphy *wiphy, struct net_device *dev,
        case WLAN_CIPHER_SUITE_WEP40:
        case WLAN_CIPHER_SUITE_TKIP:
        case WLAN_CIPHER_SUITE_WEP104:
+               if (link_id >= 0)
+                       return -EINVAL;
                if (WARN_ON_ONCE(fips_enabled))
                        return -EINVAL;
                break;
@@ -484,6 +497,8 @@ static int ieee80211_add_key(struct wiphy *wiphy, struct net_device *dev,
        if (IS_ERR(key))
                return PTR_ERR(key);
 
+       key->conf.link_id = link_id;
+
        if (pairwise)
                key->conf.flags |= IEEE80211_KEY_FLAG_PAIRWISE;
 
@@ -545,7 +560,7 @@ static int ieee80211_add_key(struct wiphy *wiphy, struct net_device *dev,
                break;
        }
 
-       err = ieee80211_key_link(key, sdata, sta);
+       err = ieee80211_key_link(key, link, sta);
 
  out_unlock:
        mutex_unlock(&local->sta_mtx);
@@ -554,18 +569,37 @@ static int ieee80211_add_key(struct wiphy *wiphy, struct net_device *dev,
 }
 
 static struct ieee80211_key *
-ieee80211_lookup_key(struct ieee80211_sub_if_data *sdata,
+ieee80211_lookup_key(struct ieee80211_sub_if_data *sdata, int link_id,
                     u8 key_idx, bool pairwise, const u8 *mac_addr)
 {
        struct ieee80211_local *local = sdata->local;
+       struct ieee80211_link_data *link = &sdata->deflink;
        struct ieee80211_key *key;
-       struct sta_info *sta;
+
+       if (link_id >= 0) {
+               link = rcu_dereference_check(sdata->link[link_id],
+                                            lockdep_is_held(&sdata->wdev.mtx));
+               if (!link)
+                       return NULL;
+       }
 
        if (mac_addr) {
+               struct sta_info *sta;
+               struct link_sta_info *link_sta;
+
                sta = sta_info_get_bss(sdata, mac_addr);
                if (!sta)
                        return NULL;
 
+               if (link_id >= 0) {
+                       link_sta = rcu_dereference_check(sta->link[link_id],
+                                                        lockdep_is_held(&local->sta_mtx));
+                       if (!link_sta)
+                               return NULL;
+               } else {
+                       link_sta = &sta->deflink;
+               }
+
                if (pairwise && key_idx < NUM_DEFAULT_KEYS)
                        return rcu_dereference_check_key_mtx(local,
                                                             sta->ptk[key_idx]);
@@ -575,7 +609,7 @@ ieee80211_lookup_key(struct ieee80211_sub_if_data *sdata,
                              NUM_DEFAULT_MGMT_KEYS +
                              NUM_DEFAULT_BEACON_KEYS)
                        return rcu_dereference_check_key_mtx(local,
-                                                            sta->deflink.gtk[key_idx]);
+                                                            link_sta->gtk[key_idx]);
 
                return NULL;
        }
@@ -584,7 +618,7 @@ ieee80211_lookup_key(struct ieee80211_sub_if_data *sdata,
                return rcu_dereference_check_key_mtx(local,
                                                     sdata->keys[key_idx]);
 
-       key = rcu_dereference_check_key_mtx(local, sdata->deflink.gtk[key_idx]);
+       key = rcu_dereference_check_key_mtx(local, link->gtk[key_idx]);
        if (key)
                return key;
 
@@ -607,7 +641,7 @@ static int ieee80211_del_key(struct wiphy *wiphy, struct net_device *dev,
        mutex_lock(&local->sta_mtx);
        mutex_lock(&local->key_mtx);
 
-       key = ieee80211_lookup_key(sdata, key_idx, pairwise, mac_addr);
+       key = ieee80211_lookup_key(sdata, link_id, key_idx, pairwise, mac_addr);
        if (!key) {
                ret = -ENOENT;
                goto out_unlock;
@@ -643,7 +677,7 @@ static int ieee80211_get_key(struct wiphy *wiphy, struct net_device *dev,
 
        rcu_read_lock();
 
-       key = ieee80211_lookup_key(sdata, key_idx, pairwise, mac_addr);
+       key = ieee80211_lookup_key(sdata, link_id, key_idx, pairwise, mac_addr);
        if (!key)
                goto out;
 
@@ -734,8 +768,13 @@ static int ieee80211_config_default_key(struct wiphy *wiphy,
                                        bool multi)
 {
        struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev);
+       struct ieee80211_link_data *link =
+               ieee80211_link_or_deflink(sdata, link_id, false);
 
-       ieee80211_set_default_key(sdata, key_idx, uni, multi);
+       if (IS_ERR(link))
+               return PTR_ERR(link);
+
+       ieee80211_set_default_key(link, key_idx, uni, multi);
 
        return 0;
 }
@@ -745,8 +784,13 @@ static int ieee80211_config_default_mgmt_key(struct wiphy *wiphy,
                                             int link_id, u8 key_idx)
 {
        struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev);
+       struct ieee80211_link_data *link =
+               ieee80211_link_or_deflink(sdata, link_id, true);
 
-       ieee80211_set_default_mgmt_key(sdata, key_idx);
+       if (IS_ERR(link))
+               return PTR_ERR(link);
+
+       ieee80211_set_default_mgmt_key(link, key_idx);
 
        return 0;
 }
@@ -756,8 +800,13 @@ static int ieee80211_config_default_beacon_key(struct wiphy *wiphy,
                                               int link_id, u8 key_idx)
 {
        struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev);
+       struct ieee80211_link_data *link =
+               ieee80211_link_or_deflink(sdata, link_id, true);
+
+       if (IS_ERR(link))
+               return PTR_ERR(link);
 
-       ieee80211_set_default_beacon_key(sdata, key_idx);
+       ieee80211_set_default_beacon_key(link, key_idx);
 
        return 0;
 }
@@ -2588,7 +2637,7 @@ static int ieee80211_set_txq_params(struct wiphy *wiphy,
        struct ieee80211_local *local = wiphy_priv(wiphy);
        struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev);
        struct ieee80211_link_data *link =
-               ieee80211_link_or_deflink(sdata, params->link_id);
+               ieee80211_link_or_deflink(sdata, params->link_id, true);
        struct ieee80211_tx_queue_params p;
 
        if (!local->ops->conf_tx)
index e192e1e..6313c49 100644 (file)
@@ -213,6 +213,7 @@ struct ieee80211_rx_data {
        struct ieee80211_sub_if_data *sdata;
        struct ieee80211_link_data *link;
        struct sta_info *sta;
+       struct link_sta_info *link_sta;
        struct ieee80211_key *key;
 
        unsigned int flags;
index 3c30e12..b6e581f 100644 (file)
@@ -434,10 +434,19 @@ struct link_container {
 static void ieee80211_free_links(struct ieee80211_sub_if_data *sdata,
                                 struct link_container **links)
 {
+       LIST_HEAD(keys);
        unsigned int link_id;
 
+       for (link_id = 0; link_id < IEEE80211_MLD_MAX_NUM_LINKS; link_id++) {
+               if (!links[link_id])
+                       continue;
+               ieee80211_remove_link_keys(&links[link_id]->data, &keys);
+       }
+
        synchronize_rcu();
 
+       ieee80211_free_key_list(sdata->local, &keys);
+
        for (link_id = 0; link_id < IEEE80211_MLD_MAX_NUM_LINKS; link_id++) {
                if (!links[link_id])
                        continue;
index 6befb57..86aac87 100644 (file)
@@ -344,9 +344,10 @@ static void ieee80211_pairwise_rekey(struct ieee80211_key *old,
        }
 }
 
-static void __ieee80211_set_default_key(struct ieee80211_sub_if_data *sdata,
+static void __ieee80211_set_default_key(struct ieee80211_link_data *link,
                                        int idx, bool uni, bool multi)
 {
+       struct ieee80211_sub_if_data *sdata = link->sdata;
        struct ieee80211_key *key = NULL;
 
        assert_key_lock(sdata->local);
@@ -354,7 +355,7 @@ static void __ieee80211_set_default_key(struct ieee80211_sub_if_data *sdata,
        if (idx >= 0 && idx < NUM_DEFAULT_KEYS) {
                key = key_mtx_dereference(sdata->local, sdata->keys[idx]);
                if (!key)
-                       key = key_mtx_dereference(sdata->local, sdata->deflink.gtk[idx]);
+                       key = key_mtx_dereference(sdata->local, link->gtk[idx]);
        }
 
        if (uni) {
@@ -365,47 +366,48 @@ static void __ieee80211_set_default_key(struct ieee80211_sub_if_data *sdata,
        }
 
        if (multi)
-               rcu_assign_pointer(sdata->deflink.default_multicast_key, key);
+               rcu_assign_pointer(link->default_multicast_key, key);
 
        ieee80211_debugfs_key_update_default(sdata);
 }
 
-void ieee80211_set_default_key(struct ieee80211_sub_if_data *sdata, int idx,
+void ieee80211_set_default_key(struct ieee80211_link_data *link, int idx,
                               bool uni, bool multi)
 {
-       mutex_lock(&sdata->local->key_mtx);
-       __ieee80211_set_default_key(sdata, idx, uni, multi);
-       mutex_unlock(&sdata->local->key_mtx);
+       mutex_lock(&link->sdata->local->key_mtx);
+       __ieee80211_set_default_key(link, idx, uni, multi);
+       mutex_unlock(&link->sdata->local->key_mtx);
 }
 
 static void
-__ieee80211_set_default_mgmt_key(struct ieee80211_sub_if_data *sdata, int idx)
+__ieee80211_set_default_mgmt_key(struct ieee80211_link_data *link, int idx)
 {
+       struct ieee80211_sub_if_data *sdata = link->sdata;
        struct ieee80211_key *key = NULL;
 
        assert_key_lock(sdata->local);
 
        if (idx >= NUM_DEFAULT_KEYS &&
            idx < NUM_DEFAULT_KEYS + NUM_DEFAULT_MGMT_KEYS)
-               key = key_mtx_dereference(sdata->local,
-                                         sdata->deflink.gtk[idx]);
+               key = key_mtx_dereference(sdata->local, link->gtk[idx]);
 
-       rcu_assign_pointer(sdata->deflink.default_mgmt_key, key);
+       rcu_assign_pointer(link->default_mgmt_key, key);
 
        ieee80211_debugfs_key_update_default(sdata);
 }
 
-void ieee80211_set_default_mgmt_key(struct ieee80211_sub_if_data *sdata,
+void ieee80211_set_default_mgmt_key(struct ieee80211_link_data *link,
                                    int idx)
 {
-       mutex_lock(&sdata->local->key_mtx);
-       __ieee80211_set_default_mgmt_key(sdata, idx);
-       mutex_unlock(&sdata->local->key_mtx);
+       mutex_lock(&link->sdata->local->key_mtx);
+       __ieee80211_set_default_mgmt_key(link, idx);
+       mutex_unlock(&link->sdata->local->key_mtx);
 }
 
 static void
-__ieee80211_set_default_beacon_key(struct ieee80211_sub_if_data *sdata, int idx)
+__ieee80211_set_default_beacon_key(struct ieee80211_link_data *link, int idx)
 {
+       struct ieee80211_sub_if_data *sdata = link->sdata;
        struct ieee80211_key *key = NULL;
 
        assert_key_lock(sdata->local);
@@ -413,28 +415,30 @@ __ieee80211_set_default_beacon_key(struct ieee80211_sub_if_data *sdata, int idx)
        if (idx >= NUM_DEFAULT_KEYS + NUM_DEFAULT_MGMT_KEYS &&
            idx < NUM_DEFAULT_KEYS + NUM_DEFAULT_MGMT_KEYS +
            NUM_DEFAULT_BEACON_KEYS)
-               key = key_mtx_dereference(sdata->local,
-                                         sdata->deflink.gtk[idx]);
+               key = key_mtx_dereference(sdata->local, link->gtk[idx]);
 
-       rcu_assign_pointer(sdata->deflink.default_beacon_key, key);
+       rcu_assign_pointer(link->default_beacon_key, key);
 
        ieee80211_debugfs_key_update_default(sdata);
 }
 
-void ieee80211_set_default_beacon_key(struct ieee80211_sub_if_data *sdata,
+void ieee80211_set_default_beacon_key(struct ieee80211_link_data *link,
                                      int idx)
 {
-       mutex_lock(&sdata->local->key_mtx);
-       __ieee80211_set_default_beacon_key(sdata, idx);
-       mutex_unlock(&sdata->local->key_mtx);
+       mutex_lock(&link->sdata->local->key_mtx);
+       __ieee80211_set_default_beacon_key(link, idx);
+       mutex_unlock(&link->sdata->local->key_mtx);
 }
 
 static int ieee80211_key_replace(struct ieee80211_sub_if_data *sdata,
-                                 struct sta_info *sta,
-                                 bool pairwise,
-                                 struct ieee80211_key *old,
-                                 struct ieee80211_key *new)
+                                struct ieee80211_link_data *link,
+                                struct sta_info *sta,
+                                bool pairwise,
+                                struct ieee80211_key *old,
+                                struct ieee80211_key *new)
 {
+       struct link_sta_info *link_sta = sta ? &sta->deflink : NULL;
+       int link_id;
        int idx;
        int ret = 0;
        bool defunikey, defmultikey, defmgmtkey, defbeaconkey;
@@ -446,13 +450,36 @@ static int ieee80211_key_replace(struct ieee80211_sub_if_data *sdata,
 
        if (new) {
                idx = new->conf.keyidx;
-               list_add_tail_rcu(&new->list, &sdata->key_list);
                is_wep = new->conf.cipher == WLAN_CIPHER_SUITE_WEP40 ||
                         new->conf.cipher == WLAN_CIPHER_SUITE_WEP104;
+               link_id = new->conf.link_id;
        } else {
                idx = old->conf.keyidx;
                is_wep = old->conf.cipher == WLAN_CIPHER_SUITE_WEP40 ||
                         old->conf.cipher == WLAN_CIPHER_SUITE_WEP104;
+               link_id = old->conf.link_id;
+       }
+
+       if (WARN(old && old->conf.link_id != link_id,
+                "old link ID %d doesn't match new link ID %d\n",
+                old->conf.link_id, link_id))
+               return -EINVAL;
+
+       if (link_id >= 0) {
+               if (!link) {
+                       link = sdata_dereference(sdata->link[link_id], sdata);
+                       if (!link)
+                               return -ENOLINK;
+               }
+
+               if (sta) {
+                       link_sta = rcu_dereference_protected(sta->link[link_id],
+                                                            lockdep_is_held(&sta->local->sta_mtx));
+                       if (!link_sta)
+                               return -ENOLINK;
+               }
+       } else {
+               link = &sdata->deflink;
        }
 
        if ((is_wep || pairwise) && idx >= NUM_DEFAULT_KEYS)
@@ -482,6 +509,9 @@ static int ieee80211_key_replace(struct ieee80211_sub_if_data *sdata,
        if (ret)
                return ret;
 
+       if (new)
+               list_add_tail_rcu(&new->list, &sdata->key_list);
+
        if (sta) {
                if (pairwise) {
                        rcu_assign_pointer(sta->ptk[idx], new);
@@ -489,7 +519,7 @@ static int ieee80211_key_replace(struct ieee80211_sub_if_data *sdata,
                            !(new->conf.flags & IEEE80211_KEY_FLAG_NO_AUTO_TX))
                                _ieee80211_set_tx_key(new, true);
                } else {
-                       rcu_assign_pointer(sta->deflink.gtk[idx], new);
+                       rcu_assign_pointer(link_sta->gtk[idx], new);
                }
                /* Only needed for transition from no key -> key.
                 * Still triggers unnecessary when using Extended Key ID
@@ -503,39 +533,39 @@ static int ieee80211_key_replace(struct ieee80211_sub_if_data *sdata,
                                                sdata->default_unicast_key);
                defmultikey = old &&
                        old == key_mtx_dereference(sdata->local,
-                                               sdata->deflink.default_multicast_key);
+                                                  link->default_multicast_key);
                defmgmtkey = old &&
                        old == key_mtx_dereference(sdata->local,
-                                               sdata->deflink.default_mgmt_key);
+                                                  link->default_mgmt_key);
                defbeaconkey = old &&
                        old == key_mtx_dereference(sdata->local,
-                                                  sdata->deflink.default_beacon_key);
+                                                  link->default_beacon_key);
 
                if (defunikey && !new)
-                       __ieee80211_set_default_key(sdata, -1, true, false);
+                       __ieee80211_set_default_key(link, -1, true, false);
                if (defmultikey && !new)
-                       __ieee80211_set_default_key(sdata, -1, false, true);
+                       __ieee80211_set_default_key(link, -1, false, true);
                if (defmgmtkey && !new)
-                       __ieee80211_set_default_mgmt_key(sdata, -1);
+                       __ieee80211_set_default_mgmt_key(link, -1);
                if (defbeaconkey && !new)
-                       __ieee80211_set_default_beacon_key(sdata, -1);
+                       __ieee80211_set_default_beacon_key(link, -1);
 
                if (is_wep || pairwise)
                        rcu_assign_pointer(sdata->keys[idx], new);
                else
-                       rcu_assign_pointer(sdata->deflink.gtk[idx], new);
+                       rcu_assign_pointer(link->gtk[idx], new);
 
                if (defunikey && new)
-                       __ieee80211_set_default_key(sdata, new->conf.keyidx,
+                       __ieee80211_set_default_key(link, new->conf.keyidx,
                                                    true, false);
                if (defmultikey && new)
-                       __ieee80211_set_default_key(sdata, new->conf.keyidx,
+                       __ieee80211_set_default_key(link, new->conf.keyidx,
                                                    false, true);
                if (defmgmtkey && new)
-                       __ieee80211_set_default_mgmt_key(sdata,
+                       __ieee80211_set_default_mgmt_key(link,
                                                         new->conf.keyidx);
                if (defbeaconkey && new)
-                       __ieee80211_set_default_beacon_key(sdata,
+                       __ieee80211_set_default_beacon_key(link,
                                                           new->conf.keyidx);
        }
 
@@ -569,6 +599,7 @@ ieee80211_key_alloc(u32 cipher, int idx, size_t key_len,
        key->conf.flags = 0;
        key->flags = 0;
 
+       key->conf.link_id = -1;
        key->conf.cipher = cipher;
        key->conf.keyidx = idx;
        key->conf.keylen = key_len;
@@ -797,9 +828,10 @@ static bool ieee80211_key_identical(struct ieee80211_sub_if_data *sdata,
 }
 
 int ieee80211_key_link(struct ieee80211_key *key,
-                      struct ieee80211_sub_if_data *sdata,
+                      struct ieee80211_link_data *link,
                       struct sta_info *sta)
 {
+       struct ieee80211_sub_if_data *sdata = link->sdata;
        static atomic_t key_color = ATOMIC_INIT(0);
        struct ieee80211_key *old_key = NULL;
        int idx = key->conf.keyidx;
@@ -827,15 +859,24 @@ int ieee80211_key_link(struct ieee80211_key *key,
                    (old_key && old_key->conf.cipher != key->conf.cipher))
                        goto out;
        } else if (sta) {
-               old_key = key_mtx_dereference(sdata->local,
-                                             sta->deflink.gtk[idx]);
+               struct link_sta_info *link_sta = &sta->deflink;
+               int link_id = key->conf.link_id;
+
+               if (link_id >= 0) {
+                       link_sta = rcu_dereference_protected(sta->link[link_id],
+                                                            lockdep_is_held(&sta->local->sta_mtx));
+                       if (!link_sta)
+                               return -ENOLINK;
+               }
+
+               old_key = key_mtx_dereference(sdata->local, link_sta->gtk[idx]);
        } else {
                if (idx < NUM_DEFAULT_KEYS)
                        old_key = key_mtx_dereference(sdata->local,
                                                      sdata->keys[idx]);
                if (!old_key)
                        old_key = key_mtx_dereference(sdata->local,
-                                                     sdata->deflink.gtk[idx]);
+                                                     link->gtk[idx]);
        }
 
        /* Non-pairwise keys must also not switch the cipher on rekey */
@@ -866,7 +907,7 @@ int ieee80211_key_link(struct ieee80211_key *key,
 
        increment_tailroom_need_count(sdata);
 
-       ret = ieee80211_key_replace(sdata, sta, pairwise, old_key, key);
+       ret = ieee80211_key_replace(sdata, link, sta, pairwise, old_key, key);
 
        if (!ret) {
                ieee80211_debugfs_key_add(key);
@@ -890,9 +931,9 @@ void ieee80211_key_free(struct ieee80211_key *key, bool delay_tailroom)
         * Replace key with nothingness if it was ever used.
         */
        if (key->sdata)
-               ieee80211_key_replace(key->sdata, key->sta,
-                               key->conf.flags & IEEE80211_KEY_FLAG_PAIRWISE,
-                               key, NULL);
+               ieee80211_key_replace(key->sdata, NULL, key->sta,
+                                     key->conf.flags & IEEE80211_KEY_FLAG_PAIRWISE,
+                                     key, NULL);
        ieee80211_key_destroy(key, delay_tailroom);
 }
 
@@ -1019,15 +1060,45 @@ static void ieee80211_free_keys_iface(struct ieee80211_sub_if_data *sdata,
        ieee80211_debugfs_key_remove_beacon_default(sdata);
 
        list_for_each_entry_safe(key, tmp, &sdata->key_list, list) {
-               ieee80211_key_replace(key->sdata, key->sta,
-                               key->conf.flags & IEEE80211_KEY_FLAG_PAIRWISE,
-                               key, NULL);
+               ieee80211_key_replace(key->sdata, NULL, key->sta,
+                                     key->conf.flags & IEEE80211_KEY_FLAG_PAIRWISE,
+                                     key, NULL);
                list_add_tail(&key->list, keys);
        }
 
        ieee80211_debugfs_key_update_default(sdata);
 }
 
+void ieee80211_remove_link_keys(struct ieee80211_link_data *link,
+                               struct list_head *keys)
+{
+       struct ieee80211_sub_if_data *sdata = link->sdata;
+       struct ieee80211_local *local = sdata->local;
+       struct ieee80211_key *key, *tmp;
+
+       mutex_lock(&local->key_mtx);
+       list_for_each_entry_safe(key, tmp, &sdata->key_list, list) {
+               if (key->conf.link_id != link->link_id)
+                       continue;
+               ieee80211_key_replace(key->sdata, link, key->sta,
+                                     key->conf.flags & IEEE80211_KEY_FLAG_PAIRWISE,
+                                     key, NULL);
+               list_add_tail(&key->list, keys);
+       }
+       mutex_unlock(&local->key_mtx);
+}
+
+void ieee80211_free_key_list(struct ieee80211_local *local,
+                            struct list_head *keys)
+{
+       struct ieee80211_key *key, *tmp;
+
+       mutex_lock(&local->key_mtx);
+       list_for_each_entry_safe(key, tmp, keys, list)
+               __ieee80211_key_destroy(key, false);
+       mutex_unlock(&local->key_mtx);
+}
+
 void ieee80211_free_keys(struct ieee80211_sub_if_data *sdata,
                         bool force_synchronize)
 {
@@ -1087,9 +1158,9 @@ void ieee80211_free_sta_keys(struct ieee80211_local *local,
                key = key_mtx_dereference(local, sta->deflink.gtk[i]);
                if (!key)
                        continue;
-               ieee80211_key_replace(key->sdata, key->sta,
-                               key->conf.flags & IEEE80211_KEY_FLAG_PAIRWISE,
-                               key, NULL);
+               ieee80211_key_replace(key->sdata, NULL, key->sta,
+                                     key->conf.flags & IEEE80211_KEY_FLAG_PAIRWISE,
+                                     key, NULL);
                __ieee80211_key_destroy(key, key->sdata->vif.type ==
                                        NL80211_IFTYPE_STATION);
        }
@@ -1098,9 +1169,9 @@ void ieee80211_free_sta_keys(struct ieee80211_local *local,
                key = key_mtx_dereference(local, sta->ptk[i]);
                if (!key)
                        continue;
-               ieee80211_key_replace(key->sdata, key->sta,
-                               key->conf.flags & IEEE80211_KEY_FLAG_PAIRWISE,
-                               key, NULL);
+               ieee80211_key_replace(key->sdata, NULL, key->sta,
+                                     key->conf.flags & IEEE80211_KEY_FLAG_PAIRWISE,
+                                     key, NULL);
                __ieee80211_key_destroy(key, key->sdata->vif.type ==
                                        NL80211_IFTYPE_STATION);
        }
@@ -1307,7 +1378,8 @@ ieee80211_gtk_rekey_add(struct ieee80211_vif *vif,
        if (sdata->u.mgd.mfp != IEEE80211_MFP_DISABLED)
                key->conf.flags |= IEEE80211_KEY_FLAG_RX_MGMT;
 
-       err = ieee80211_key_link(key, sdata, NULL);
+       /* FIXME: this function needs to get a link ID */
+       err = ieee80211_key_link(key, &sdata->deflink, NULL);
        if (err)
                return ERR_PTR(err);
 
index e994dce..518af24 100644 (file)
@@ -22,6 +22,7 @@
 
 struct ieee80211_local;
 struct ieee80211_sub_if_data;
+struct ieee80211_link_data;
 struct sta_info;
 
 /**
@@ -144,17 +145,21 @@ ieee80211_key_alloc(u32 cipher, int idx, size_t key_len,
  * to make it used, free old key. On failure, also free the new key.
  */
 int ieee80211_key_link(struct ieee80211_key *key,
-                      struct ieee80211_sub_if_data *sdata,
+                      struct ieee80211_link_data *link,
                       struct sta_info *sta);
 int ieee80211_set_tx_key(struct ieee80211_key *key);
 void ieee80211_key_free(struct ieee80211_key *key, bool delay_tailroom);
 void ieee80211_key_free_unused(struct ieee80211_key *key);
-void ieee80211_set_default_key(struct ieee80211_sub_if_data *sdata, int idx,
+void ieee80211_set_default_key(struct ieee80211_link_data *link, int idx,
                               bool uni, bool multi);
-void ieee80211_set_default_mgmt_key(struct ieee80211_sub_if_data *sdata,
+void ieee80211_set_default_mgmt_key(struct ieee80211_link_data *link,
                                    int idx);
-void ieee80211_set_default_beacon_key(struct ieee80211_sub_if_data *sdata,
+void ieee80211_set_default_beacon_key(struct ieee80211_link_data *link,
                                      int idx);
+void ieee80211_remove_link_keys(struct ieee80211_link_data *link,
+                               struct list_head *keys);
+void ieee80211_free_key_list(struct ieee80211_local *local,
+                            struct list_head *keys);
 void ieee80211_free_keys(struct ieee80211_sub_if_data *sdata,
                         bool force_synchronize);
 void ieee80211_free_sta_keys(struct ieee80211_local *local,
index 57df21e..aad6179 100644 (file)
@@ -1854,7 +1854,6 @@ static struct ieee80211_key *
 ieee80211_rx_get_bigtk(struct ieee80211_rx_data *rx, int idx)
 {
        struct ieee80211_key *key = NULL;
-       struct ieee80211_sub_if_data *sdata = rx->sdata;
        int idx2;
 
        /* Make sure key gets set if either BIGTK key index is set so that
@@ -1873,14 +1872,14 @@ ieee80211_rx_get_bigtk(struct ieee80211_rx_data *rx, int idx)
                        idx2 = idx - 1;
        }
 
-       if (rx->sta)
-               key = rcu_dereference(rx->sta->deflink.gtk[idx]);
+       if (rx->link_sta)
+               key = rcu_dereference(rx->link_sta->gtk[idx]);
        if (!key)
-               key = rcu_dereference(sdata->deflink.gtk[idx]);
-       if (!key && rx->sta)
-               key = rcu_dereference(rx->sta->deflink.gtk[idx2]);
+               key = rcu_dereference(rx->link->gtk[idx]);
+       if (!key && rx->link_sta)
+               key = rcu_dereference(rx->link_sta->gtk[idx2]);
        if (!key)
-               key = rcu_dereference(sdata->deflink.gtk[idx2]);
+               key = rcu_dereference(rx->link->gtk[idx2]);
 
        return key;
 }
@@ -1986,15 +1985,15 @@ ieee80211_rx_h_decrypt(struct ieee80211_rx_data *rx)
                if (mmie_keyidx < NUM_DEFAULT_KEYS ||
                    mmie_keyidx >= NUM_DEFAULT_KEYS + NUM_DEFAULT_MGMT_KEYS)
                        return RX_DROP_MONITOR; /* unexpected BIP keyidx */
-               if (rx->sta) {
+               if (rx->link_sta) {
                        if (ieee80211_is_group_privacy_action(skb) &&
                            test_sta_flag(rx->sta, WLAN_STA_MFP))
                                return RX_DROP_MONITOR;
 
-                       rx->key = rcu_dereference(rx->sta->deflink.gtk[mmie_keyidx]);
+                       rx->key = rcu_dereference(rx->link_sta->gtk[mmie_keyidx]);
                }
                if (!rx->key)
-                       rx->key = rcu_dereference(rx->sdata->deflink.gtk[mmie_keyidx]);
+                       rx->key = rcu_dereference(rx->link->gtk[mmie_keyidx]);
        } else if (!ieee80211_has_protected(fc)) {
                /*
                 * The frame was not protected, so skip decryption. However, we
@@ -2003,25 +2002,24 @@ ieee80211_rx_h_decrypt(struct ieee80211_rx_data *rx)
                 * have been expected.
                 */
                struct ieee80211_key *key = NULL;
-               struct ieee80211_sub_if_data *sdata = rx->sdata;
                int i;
 
                if (ieee80211_is_beacon(fc)) {
                        key = ieee80211_rx_get_bigtk(rx, -1);
                } else if (ieee80211_is_mgmt(fc) &&
                           is_multicast_ether_addr(hdr->addr1)) {
-                       key = rcu_dereference(rx->sdata->deflink.default_mgmt_key);
+                       key = rcu_dereference(rx->link->default_mgmt_key);
                } else {
-                       if (rx->sta) {
+                       if (rx->link_sta) {
                                for (i = 0; i < NUM_DEFAULT_KEYS; i++) {
-                                       key = rcu_dereference(rx->sta->deflink.gtk[i]);
+                                       key = rcu_dereference(rx->link_sta->gtk[i]);
                                        if (key)
                                                break;
                                }
                        }
                        if (!key) {
                                for (i = 0; i < NUM_DEFAULT_KEYS; i++) {
-                                       key = rcu_dereference(sdata->deflink.gtk[i]);
+                                       key = rcu_dereference(rx->link->gtk[i]);
                                        if (key)
                                                break;
                                }
@@ -2050,13 +2048,13 @@ ieee80211_rx_h_decrypt(struct ieee80211_rx_data *rx)
                        return RX_DROP_UNUSABLE;
 
                /* check per-station GTK first, if multicast packet */
-               if (is_multicast_ether_addr(hdr->addr1) && rx->sta)
-                       rx->key = rcu_dereference(rx->sta->deflink.gtk[keyidx]);
+               if (is_multicast_ether_addr(hdr->addr1) && rx->link_sta)
+                       rx->key = rcu_dereference(rx->link_sta->gtk[keyidx]);
 
                /* if not found, try default key */
                if (!rx->key) {
                        if (is_multicast_ether_addr(hdr->addr1))
-                               rx->key = rcu_dereference(rx->sdata->deflink.gtk[keyidx]);
+                               rx->key = rcu_dereference(rx->link->gtk[keyidx]);
                        if (!rx->key)
                                rx->key = rcu_dereference(rx->sdata->keys[keyidx]);
 
@@ -4769,7 +4767,17 @@ static bool ieee80211_prepare_and_rx_handle(struct ieee80211_rx_data *rx,
                if (!link)
                        return true;
                rx->link = link;
+
+               if (rx->sta) {
+                       rx->link_sta =
+                               rcu_dereference(rx->sta->link[rx->link_id]);
+                       if (!rx->link_sta)
+                               return true;
+               }
        } else {
+               if (rx->sta)
+                       rx->link_sta = &rx->sta->deflink;
+
                rx->link = &sdata->deflink;
        }
 
index 45df993..8683f24 100644 (file)
@@ -576,6 +576,51 @@ ieee80211_tx_h_check_control_port_protocol(struct ieee80211_tx_data *tx)
        return TX_CONTINUE;
 }
 
+static struct ieee80211_key *
+ieee80211_select_link_key(struct ieee80211_tx_data *tx)
+{
+       struct ieee80211_hdr *hdr = (struct ieee80211_hdr *)tx->skb->data;
+       struct ieee80211_tx_info *info = IEEE80211_SKB_CB(tx->skb);
+       enum {
+               USE_NONE,
+               USE_MGMT_KEY,
+               USE_MCAST_KEY,
+       } which_key = USE_NONE;
+       struct ieee80211_link_data *link;
+       unsigned int link_id;
+
+       if (ieee80211_is_group_privacy_action(tx->skb))
+               which_key = USE_MCAST_KEY;
+       else if (ieee80211_is_mgmt(hdr->frame_control) &&
+                is_multicast_ether_addr(hdr->addr1) &&
+                ieee80211_is_robust_mgmt_frame(tx->skb))
+               which_key = USE_MGMT_KEY;
+       else if (is_multicast_ether_addr(hdr->addr1))
+               which_key = USE_MCAST_KEY;
+       else
+               return NULL;
+
+       link_id = u32_get_bits(info->control.flags, IEEE80211_TX_CTRL_MLO_LINK);
+       if (link_id == IEEE80211_LINK_UNSPECIFIED) {
+               link = &tx->sdata->deflink;
+       } else {
+               link = rcu_dereference(tx->sdata->link[link_id]);
+               if (!link)
+                       return NULL;
+       }
+
+       switch (which_key) {
+       case USE_NONE:
+               break;
+       case USE_MGMT_KEY:
+               return rcu_dereference(link->default_mgmt_key);
+       case USE_MCAST_KEY:
+               return rcu_dereference(link->default_multicast_key);
+       }
+
+       return NULL;
+}
+
 static ieee80211_tx_result debug_noinline
 ieee80211_tx_h_select_key(struct ieee80211_tx_data *tx)
 {
@@ -591,16 +636,7 @@ ieee80211_tx_h_select_key(struct ieee80211_tx_data *tx)
        if (tx->sta &&
            (key = rcu_dereference(tx->sta->ptk[tx->sta->ptk_idx])))
                tx->key = key;
-       else if (ieee80211_is_group_privacy_action(tx->skb) &&
-               (key = rcu_dereference(tx->sdata->deflink.default_multicast_key)))
-               tx->key = key;
-       else if (ieee80211_is_mgmt(hdr->frame_control) &&
-                is_multicast_ether_addr(hdr->addr1) &&
-                ieee80211_is_robust_mgmt_frame(tx->skb) &&
-                (key = rcu_dereference(tx->sdata->deflink.default_mgmt_key)))
-               tx->key = key;
-       else if (is_multicast_ether_addr(hdr->addr1) &&
-                (key = rcu_dereference(tx->sdata->deflink.default_multicast_key)))
+       else if ((key = ieee80211_select_link_key(tx)))
                tx->key = key;
        else if (!is_multicast_ether_addr(hdr->addr1) &&
                 (key = rcu_dereference(tx->sdata->default_unicast_key)))