mac80211: Fix MLO address translation for multiple bss case
[platform/kernel/linux-starfive.git] / net / mac80211 / rx.c
index 45d7e71..8f0d7c6 100644 (file)
@@ -49,7 +49,7 @@ static struct sk_buff *ieee80211_clean_skb(struct sk_buff *skb,
 
        if (present_fcs_len)
                __pskb_trim(skb, skb->len - present_fcs_len);
-       __pskb_pull(skb, rtap_space);
+       pskb_pull(skb, rtap_space);
 
        hdr = (void *)skb->data;
        fc = hdr->frame_control;
@@ -74,7 +74,7 @@ static struct sk_buff *ieee80211_clean_skb(struct sk_buff *skb,
 
        memmove(skb->data + IEEE80211_HT_CTL_LEN, skb->data,
                hdrlen - IEEE80211_HT_CTL_LEN);
-       __pskb_pull(skb, IEEE80211_HT_CTL_LEN);
+       pskb_pull(skb, IEEE80211_HT_CTL_LEN);
 
        return skb;
 }
@@ -215,9 +215,19 @@ ieee80211_rx_radiotap_hdrlen(struct ieee80211_local *local,
 }
 
 static void __ieee80211_queue_skb_to_iface(struct ieee80211_sub_if_data *sdata,
+                                          int link_id,
                                           struct sta_info *sta,
                                           struct sk_buff *skb)
 {
+       struct ieee80211_rx_status *status = IEEE80211_SKB_RXCB(skb);
+
+       if (link_id >= 0) {
+               status->link_valid = 1;
+               status->link_id = link_id;
+       } else {
+               status->link_valid = 0;
+       }
+
        skb_queue_tail(&sdata->skb_queue, skb);
        ieee80211_queue_work(&sdata->local->hw, &sdata->work);
        if (sta)
@@ -225,11 +235,12 @@ static void __ieee80211_queue_skb_to_iface(struct ieee80211_sub_if_data *sdata,
 }
 
 static void ieee80211_queue_skb_to_iface(struct ieee80211_sub_if_data *sdata,
+                                        int link_id,
                                         struct sta_info *sta,
                                         struct sk_buff *skb)
 {
        skb->protocol = 0;
-       __ieee80211_queue_skb_to_iface(sdata, sta, skb);
+       __ieee80211_queue_skb_to_iface(sdata, link_id, sta, skb);
 }
 
 static void ieee80211_handle_mu_mimo_mon(struct ieee80211_sub_if_data *sdata,
@@ -272,7 +283,7 @@ static void ieee80211_handle_mu_mimo_mon(struct ieee80211_sub_if_data *sdata,
        if (!skb)
                return;
 
-       ieee80211_queue_skb_to_iface(sdata, NULL, skb);
+       ieee80211_queue_skb_to_iface(sdata, -1, NULL, skb);
 }
 
 /*
@@ -1394,7 +1405,7 @@ static void ieee80211_rx_reorder_ampdu(struct ieee80211_rx_data *rx,
        /* if this mpdu is fragmented - terminate rx aggregation session */
        sc = le16_to_cpu(hdr->seq_ctrl);
        if (sc & IEEE80211_SCTL_FRAG) {
-               ieee80211_queue_skb_to_iface(rx->sdata, NULL, skb);
+               ieee80211_queue_skb_to_iface(rx->sdata, rx->link_id, NULL, skb);
                return;
        }
 
@@ -1441,7 +1452,7 @@ ieee80211_rx_h_check_dup(struct ieee80211_rx_data *rx)
        if (unlikely(ieee80211_has_retry(hdr->frame_control) &&
                     rx->sta->last_seq_ctrl[rx->seqno_idx] == hdr->seq_ctrl)) {
                I802_DEBUG_INC(rx->local->dot11FrameDuplicateCount);
-               rx->sta->deflink.rx_stats.num_duplicates++;
+               rx->link_sta->rx_stats.num_duplicates++;
                return RX_DROP_UNUSABLE;
        } else if (!(status->flag & RX_FLAG_AMSDU_MORE)) {
                rx->sta->last_seq_ctrl[rx->seqno_idx] = hdr->seq_ctrl;
@@ -1720,12 +1731,13 @@ static ieee80211_rx_result debug_noinline
 ieee80211_rx_h_sta_process(struct ieee80211_rx_data *rx)
 {
        struct sta_info *sta = rx->sta;
+       struct link_sta_info *link_sta = rx->link_sta;
        struct sk_buff *skb = rx->skb;
        struct ieee80211_rx_status *status = IEEE80211_SKB_RXCB(skb);
        struct ieee80211_hdr *hdr = (struct ieee80211_hdr *)skb->data;
        int i;
 
-       if (!sta)
+       if (!sta || !link_sta)
                return RX_CONTINUE;
 
        /*
@@ -1741,47 +1753,47 @@ ieee80211_rx_h_sta_process(struct ieee80211_rx_data *rx)
                                                NL80211_IFTYPE_ADHOC);
                if (ether_addr_equal(bssid, rx->sdata->u.ibss.bssid) &&
                    test_sta_flag(sta, WLAN_STA_AUTHORIZED)) {
-                       sta->deflink.rx_stats.last_rx = jiffies;
+                       link_sta->rx_stats.last_rx = jiffies;
                        if (ieee80211_is_data(hdr->frame_control) &&
                            !is_multicast_ether_addr(hdr->addr1))
-                               sta->deflink.rx_stats.last_rate =
+                               link_sta->rx_stats.last_rate =
                                        sta_stats_encode_rate(status);
                }
        } else if (rx->sdata->vif.type == NL80211_IFTYPE_OCB) {
-               sta->deflink.rx_stats.last_rx = jiffies;
+               link_sta->rx_stats.last_rx = jiffies;
        } else if (!ieee80211_is_s1g_beacon(hdr->frame_control) &&
                   !is_multicast_ether_addr(hdr->addr1)) {
                /*
                 * Mesh beacons will update last_rx when if they are found to
                 * match the current local configuration when processed.
                 */
-               sta->deflink.rx_stats.last_rx = jiffies;
+               link_sta->rx_stats.last_rx = jiffies;
                if (ieee80211_is_data(hdr->frame_control))
-                       sta->deflink.rx_stats.last_rate = sta_stats_encode_rate(status);
+                       link_sta->rx_stats.last_rate = sta_stats_encode_rate(status);
        }
 
-       sta->deflink.rx_stats.fragments++;
+       link_sta->rx_stats.fragments++;
 
-       u64_stats_update_begin(&rx->sta->deflink.rx_stats.syncp);
-       sta->deflink.rx_stats.bytes += rx->skb->len;
-       u64_stats_update_end(&rx->sta->deflink.rx_stats.syncp);
+       u64_stats_update_begin(&link_sta->rx_stats.syncp);
+       link_sta->rx_stats.bytes += rx->skb->len;
+       u64_stats_update_end(&link_sta->rx_stats.syncp);
 
        if (!(status->flag & RX_FLAG_NO_SIGNAL_VAL)) {
-               sta->deflink.rx_stats.last_signal = status->signal;
-               ewma_signal_add(&sta->deflink.rx_stats_avg.signal,
+               link_sta->rx_stats.last_signal = status->signal;
+               ewma_signal_add(&link_sta->rx_stats_avg.signal,
                                -status->signal);
        }
 
        if (status->chains) {
-               sta->deflink.rx_stats.chains = status->chains;
+               link_sta->rx_stats.chains = status->chains;
                for (i = 0; i < ARRAY_SIZE(status->chain_signal); i++) {
                        int signal = status->chain_signal[i];
 
                        if (!(status->chains & BIT(i)))
                                continue;
 
-                       sta->deflink.rx_stats.chain_signal_last[i] = signal;
-                       ewma_signal_add(&sta->deflink.rx_stats_avg.chain_signal[i],
+                       link_sta->rx_stats.chain_signal_last[i] = signal;
+                       ewma_signal_add(&link_sta->rx_stats_avg.chain_signal[i],
                                        -signal);
                }
        }
@@ -1842,7 +1854,7 @@ ieee80211_rx_h_sta_process(struct ieee80211_rx_data *rx)
                 * Update counter and free packet here to avoid
                 * counting this as a dropped packed.
                 */
-               sta->deflink.rx_stats.packets++;
+               link_sta->rx_stats.packets++;
                dev_kfree_skb(rx->skb);
                return RX_QUEUED;
        }
@@ -1854,7 +1866,6 @@ static struct ieee80211_key *
 ieee80211_rx_get_bigtk(struct ieee80211_rx_data *rx, int idx)
 {
        struct ieee80211_key *key = NULL;
-       struct ieee80211_sub_if_data *sdata = rx->sdata;
        int idx2;
 
        /* Make sure key gets set if either BIGTK key index is set so that
@@ -1873,14 +1884,14 @@ ieee80211_rx_get_bigtk(struct ieee80211_rx_data *rx, int idx)
                        idx2 = idx - 1;
        }
 
-       if (rx->sta)
-               key = rcu_dereference(rx->sta->deflink.gtk[idx]);
+       if (rx->link_sta)
+               key = rcu_dereference(rx->link_sta->gtk[idx]);
        if (!key)
-               key = rcu_dereference(sdata->deflink.gtk[idx]);
-       if (!key && rx->sta)
-               key = rcu_dereference(rx->sta->deflink.gtk[idx2]);
+               key = rcu_dereference(rx->link->gtk[idx]);
+       if (!key && rx->link_sta)
+               key = rcu_dereference(rx->link_sta->gtk[idx2]);
        if (!key)
-               key = rcu_dereference(sdata->deflink.gtk[idx2]);
+               key = rcu_dereference(rx->link->gtk[idx2]);
 
        return key;
 }
@@ -1967,10 +1978,11 @@ ieee80211_rx_h_decrypt(struct ieee80211_rx_data *rx)
 
                if (mmie_keyidx < NUM_DEFAULT_KEYS + NUM_DEFAULT_MGMT_KEYS ||
                    mmie_keyidx >= NUM_DEFAULT_KEYS + NUM_DEFAULT_MGMT_KEYS +
-                   NUM_DEFAULT_BEACON_KEYS) {
-                       cfg80211_rx_unprot_mlme_mgmt(rx->sdata->dev,
-                                                    skb->data,
-                                                    skb->len);
+                                  NUM_DEFAULT_BEACON_KEYS) {
+                       if (rx->sdata->dev)
+                               cfg80211_rx_unprot_mlme_mgmt(rx->sdata->dev,
+                                                            skb->data,
+                                                            skb->len);
                        return RX_DROP_MONITOR; /* unexpected BIP keyidx */
                }
 
@@ -1986,15 +1998,15 @@ ieee80211_rx_h_decrypt(struct ieee80211_rx_data *rx)
                if (mmie_keyidx < NUM_DEFAULT_KEYS ||
                    mmie_keyidx >= NUM_DEFAULT_KEYS + NUM_DEFAULT_MGMT_KEYS)
                        return RX_DROP_MONITOR; /* unexpected BIP keyidx */
-               if (rx->sta) {
+               if (rx->link_sta) {
                        if (ieee80211_is_group_privacy_action(skb) &&
                            test_sta_flag(rx->sta, WLAN_STA_MFP))
                                return RX_DROP_MONITOR;
 
-                       rx->key = rcu_dereference(rx->sta->deflink.gtk[mmie_keyidx]);
+                       rx->key = rcu_dereference(rx->link_sta->gtk[mmie_keyidx]);
                }
                if (!rx->key)
-                       rx->key = rcu_dereference(rx->sdata->deflink.gtk[mmie_keyidx]);
+                       rx->key = rcu_dereference(rx->link->gtk[mmie_keyidx]);
        } else if (!ieee80211_has_protected(fc)) {
                /*
                 * The frame was not protected, so skip decryption. However, we
@@ -2003,25 +2015,24 @@ ieee80211_rx_h_decrypt(struct ieee80211_rx_data *rx)
                 * have been expected.
                 */
                struct ieee80211_key *key = NULL;
-               struct ieee80211_sub_if_data *sdata = rx->sdata;
                int i;
 
                if (ieee80211_is_beacon(fc)) {
                        key = ieee80211_rx_get_bigtk(rx, -1);
                } else if (ieee80211_is_mgmt(fc) &&
                           is_multicast_ether_addr(hdr->addr1)) {
-                       key = rcu_dereference(rx->sdata->deflink.default_mgmt_key);
+                       key = rcu_dereference(rx->link->default_mgmt_key);
                } else {
-                       if (rx->sta) {
+                       if (rx->link_sta) {
                                for (i = 0; i < NUM_DEFAULT_KEYS; i++) {
-                                       key = rcu_dereference(rx->sta->deflink.gtk[i]);
+                                       key = rcu_dereference(rx->link_sta->gtk[i]);
                                        if (key)
                                                break;
                                }
                        }
                        if (!key) {
                                for (i = 0; i < NUM_DEFAULT_KEYS; i++) {
-                                       key = rcu_dereference(sdata->deflink.gtk[i]);
+                                       key = rcu_dereference(rx->link->gtk[i]);
                                        if (key)
                                                break;
                                }
@@ -2050,13 +2061,13 @@ ieee80211_rx_h_decrypt(struct ieee80211_rx_data *rx)
                        return RX_DROP_UNUSABLE;
 
                /* check per-station GTK first, if multicast packet */
-               if (is_multicast_ether_addr(hdr->addr1) && rx->sta)
-                       rx->key = rcu_dereference(rx->sta->deflink.gtk[keyidx]);
+               if (is_multicast_ether_addr(hdr->addr1) && rx->link_sta)
+                       rx->key = rcu_dereference(rx->link_sta->gtk[keyidx]);
 
                /* if not found, try default key */
                if (!rx->key) {
                        if (is_multicast_ether_addr(hdr->addr1))
-                               rx->key = rcu_dereference(rx->sdata->deflink.gtk[keyidx]);
+                               rx->key = rcu_dereference(rx->link->gtk[keyidx]);
                        if (!rx->key)
                                rx->key = rcu_dereference(rx->sdata->keys[keyidx]);
 
@@ -2121,7 +2132,8 @@ ieee80211_rx_h_decrypt(struct ieee80211_rx_data *rx)
        /* either the frame has been decrypted or will be dropped */
        status->flag |= RX_FLAG_DECRYPTED;
 
-       if (unlikely(ieee80211_is_beacon(fc) && result == RX_DROP_UNUSABLE))
+       if (unlikely(ieee80211_is_beacon(fc) && result == RX_DROP_UNUSABLE &&
+                    rx->sdata->dev))
                cfg80211_rx_unprot_mlme_mgmt(rx->sdata->dev,
                                             skb->data, skb->len);
 
@@ -2380,7 +2392,7 @@ ieee80211_rx_h_defragment(struct ieee80211_rx_data *rx)
  out:
        ieee80211_led_rx(rx->local);
        if (rx->sta)
-               rx->sta->deflink.rx_stats.packets++;
+               rx->link_sta->rx_stats.packets++;
        return RX_CONTINUE;
 }
 
@@ -2656,9 +2668,9 @@ ieee80211_deliver_skb(struct ieee80211_rx_data *rx)
                 * for non-QoS-data frames. Here we know it's a data
                 * frame, so count MSDUs.
                 */
-               u64_stats_update_begin(&rx->sta->deflink.rx_stats.syncp);
-               rx->sta->deflink.rx_stats.msdu[rx->seqno_idx]++;
-               u64_stats_update_end(&rx->sta->deflink.rx_stats.syncp);
+               u64_stats_update_begin(&rx->link_sta->rx_stats.syncp);
+               rx->link_sta->rx_stats.msdu[rx->seqno_idx]++;
+               u64_stats_update_end(&rx->link_sta->rx_stats.syncp);
        }
 
        if ((sdata->vif.type == NL80211_IFTYPE_AP ||
@@ -3046,7 +3058,8 @@ ieee80211_rx_h_data(struct ieee80211_rx_data *rx)
                    (tf->action_code == WLAN_TDLS_CHANNEL_SWITCH_REQUEST ||
                     tf->action_code == WLAN_TDLS_CHANNEL_SWITCH_RESPONSE)) {
                        rx->skb->protocol = cpu_to_be16(ETH_P_TDLS);
-                       __ieee80211_queue_skb_to_iface(sdata, rx->sta, rx->skb);
+                       __ieee80211_queue_skb_to_iface(sdata, rx->link_id,
+                                                      rx->sta, rx->skb);
                        return RX_QUEUED;
                }
        }
@@ -3354,7 +3367,7 @@ ieee80211_rx_h_action(struct ieee80211_rx_data *rx)
        switch (mgmt->u.action.category) {
        case WLAN_CATEGORY_HT:
                /* reject HT action frames from stations not supporting HT */
-               if (!rx->sta->sta.deflink.ht_cap.ht_supported)
+               if (!rx->link_sta->pub->ht_cap.ht_supported)
                        goto invalid;
 
                if (sdata->vif.type != NL80211_IFTYPE_STATION &&
@@ -3394,9 +3407,9 @@ ieee80211_rx_h_action(struct ieee80211_rx_data *rx)
                        }
 
                        /* if no change do nothing */
-                       if (rx->sta->sta.smps_mode == smps_mode)
+                       if (rx->link_sta->pub->smps_mode == smps_mode)
                                goto handled;
-                       rx->sta->sta.smps_mode = smps_mode;
+                       rx->link_sta->pub->smps_mode = smps_mode;
                        sta_opmode.smps_mode =
                                ieee80211_smps_mode_to_smps_mode(smps_mode);
                        sta_opmode.changed = STA_OPMODE_SMPS_MODE_CHANGED;
@@ -3418,26 +3431,26 @@ ieee80211_rx_h_action(struct ieee80211_rx_data *rx)
                        struct sta_opmode_info sta_opmode = {};
 
                        /* If it doesn't support 40 MHz it can't change ... */
-                       if (!(rx->sta->sta.deflink.ht_cap.cap &
+                       if (!(rx->link_sta->pub->ht_cap.cap &
                                        IEEE80211_HT_CAP_SUP_WIDTH_20_40))
                                goto handled;
 
                        if (chanwidth == IEEE80211_HT_CHANWIDTH_20MHZ)
                                max_bw = IEEE80211_STA_RX_BW_20;
                        else
-                               max_bw = ieee80211_sta_cap_rx_bw(&rx->sta->deflink);
+                               max_bw = ieee80211_sta_cap_rx_bw(rx->link_sta);
 
                        /* set cur_max_bandwidth and recalc sta bw */
-                       rx->sta->deflink.cur_max_bandwidth = max_bw;
-                       new_bw = ieee80211_sta_cur_vht_bw(&rx->sta->deflink);
+                       rx->link_sta->cur_max_bandwidth = max_bw;
+                       new_bw = ieee80211_sta_cur_vht_bw(rx->link_sta);
 
-                       if (rx->sta->sta.deflink.bandwidth == new_bw)
+                       if (rx->link_sta->pub->bandwidth == new_bw)
                                goto handled;
 
-                       rx->sta->sta.deflink.bandwidth = new_bw;
+                       rx->link_sta->pub->bandwidth = new_bw;
                        sband = rx->local->hw.wiphy->bands[status->band];
                        sta_opmode.bw =
-                               ieee80211_sta_rx_bw_to_chan_width(&rx->sta->deflink);
+                               ieee80211_sta_rx_bw_to_chan_width(rx->link_sta);
                        sta_opmode.changed = STA_OPMODE_MAX_BW_CHANGED;
 
                        rate_control_rate_update(local, sband, rx->sta, 0,
@@ -3631,12 +3644,12 @@ ieee80211_rx_h_action(struct ieee80211_rx_data *rx)
 
  handled:
        if (rx->sta)
-               rx->sta->deflink.rx_stats.packets++;
+               rx->link_sta->rx_stats.packets++;
        dev_kfree_skb(rx->skb);
        return RX_QUEUED;
 
  queue:
-       ieee80211_queue_skb_to_iface(sdata, rx->sta, rx->skb);
+       ieee80211_queue_skb_to_iface(sdata, rx->link_id, rx->sta, rx->skb);
        return RX_QUEUED;
 }
 
@@ -3675,7 +3688,7 @@ ieee80211_rx_h_userspace_mgmt(struct ieee80211_rx_data *rx)
 
        if (cfg80211_rx_mgmt_ext(&rx->sdata->wdev, &info)) {
                if (rx->sta)
-                       rx->sta->deflink.rx_stats.packets++;
+                       rx->link_sta->rx_stats.packets++;
                dev_kfree_skb(rx->skb);
                return RX_QUEUED;
        }
@@ -3713,7 +3726,7 @@ ieee80211_rx_h_action_post_userspace(struct ieee80211_rx_data *rx)
 
  handled:
        if (rx->sta)
-               rx->sta->deflink.rx_stats.packets++;
+               rx->link_sta->rx_stats.packets++;
        dev_kfree_skb(rx->skb);
        return RX_QUEUED;
 }
@@ -3794,7 +3807,7 @@ ieee80211_rx_h_ext(struct ieee80211_rx_data *rx)
                return RX_DROP_MONITOR;
 
        /* for now only beacons are ext, so queue them */
-       ieee80211_queue_skb_to_iface(sdata, rx->sta, rx->skb);
+       ieee80211_queue_skb_to_iface(sdata, rx->link_id, rx->sta, rx->skb);
 
        return RX_QUEUED;
 }
@@ -3851,7 +3864,7 @@ ieee80211_rx_h_mgmt(struct ieee80211_rx_data *rx)
                return RX_DROP_MONITOR;
        }
 
-       ieee80211_queue_skb_to_iface(sdata, rx->sta, rx->skb);
+       ieee80211_queue_skb_to_iface(sdata, rx->link_id, rx->sta, rx->skb);
 
        return RX_QUEUED;
 }
@@ -3933,7 +3946,7 @@ static void ieee80211_rx_handlers_result(struct ieee80211_rx_data *rx,
        case RX_DROP_MONITOR:
                I802_DEBUG_INC(rx->sdata->local->rx_handlers_drop);
                if (rx->sta)
-                       rx->sta->deflink.rx_stats.dropped++;
+                       rx->link_sta->rx_stats.dropped++;
                fallthrough;
        case RX_CONTINUE: {
                struct ieee80211_rate *rate = NULL;
@@ -3952,7 +3965,7 @@ static void ieee80211_rx_handlers_result(struct ieee80211_rx_data *rx,
        case RX_DROP_UNUSABLE:
                I802_DEBUG_INC(rx->sdata->local->rx_handlers_drop);
                if (rx->sta)
-                       rx->sta->deflink.rx_stats.dropped++;
+                       rx->link_sta->rx_stats.dropped++;
                dev_kfree_skb(rx->skb);
                break;
        case RX_QUEUED:
@@ -4057,6 +4070,58 @@ static void ieee80211_invoke_rx_handlers(struct ieee80211_rx_data *rx)
 #undef CALL_RXH
 }
 
+static bool
+ieee80211_rx_is_valid_sta_link_id(struct ieee80211_sta *sta, u8 link_id)
+{
+       if (!sta->mlo)
+               return false;
+
+       return !!(sta->valid_links & BIT(link_id));
+}
+
+static bool ieee80211_rx_data_set_link(struct ieee80211_rx_data *rx,
+                                      u8 link_id)
+{
+       rx->link_id = link_id;
+       rx->link = rcu_dereference(rx->sdata->link[link_id]);
+
+       if (!rx->sta)
+               return rx->link;
+
+       if (!ieee80211_rx_is_valid_sta_link_id(&rx->sta->sta, link_id))
+               return false;
+
+       rx->link_sta = rcu_dereference(rx->sta->link[link_id]);
+
+       return rx->link && rx->link_sta;
+}
+
+static bool ieee80211_rx_data_set_sta(struct ieee80211_rx_data *rx,
+                                     struct ieee80211_sta *pubsta,
+                                     int link_id)
+{
+       struct sta_info *sta;
+
+       sta = container_of(pubsta, struct sta_info, sta);
+
+       rx->link_id = link_id;
+       rx->sta = sta;
+
+       if (sta) {
+               rx->local = sta->sdata->local;
+               if (!rx->sdata)
+                       rx->sdata = sta->sdata;
+               rx->link_sta = &sta->deflink;
+       }
+
+       if (link_id < 0)
+               rx->link = &rx->sdata->deflink;
+       else if (!ieee80211_rx_data_set_link(rx, link_id))
+               return false;
+
+       return true;
+}
+
 /*
  * This function makes calls into the RX path, therefore
  * it has to be invoked under RCU read lock.
@@ -4065,16 +4130,19 @@ void ieee80211_release_reorder_timeout(struct sta_info *sta, int tid)
 {
        struct sk_buff_head frames;
        struct ieee80211_rx_data rx = {
-               .sta = sta,
-               .sdata = sta->sdata,
-               .local = sta->local,
                /* This is OK -- must be QoS data frame */
                .security_idx = tid,
                .seqno_idx = tid,
-               .link_id = -1,
        };
        struct tid_ampdu_rx *tid_agg_rx;
-       u8 link_id;
+       int link_id = -1;
+
+       /* FIXME: statistics won't be right with this */
+       if (sta->sta.valid_links)
+               link_id = ffs(sta->sta.valid_links) - 1;
+
+       if (!ieee80211_rx_data_set_sta(&rx, &sta->sta, link_id))
+               return;
 
        tid_agg_rx = rcu_dereference(sta->ampdu_mlme.tid_rx[tid]);
        if (!tid_agg_rx)
@@ -4094,9 +4162,6 @@ void ieee80211_release_reorder_timeout(struct sta_info *sta, int tid)
                };
                drv_event_callback(rx.local, rx.sdata, &event);
        }
-       /* FIXME: statistics won't be right with this */
-       link_id = sta->sta.valid_links ? ffs(sta->sta.valid_links) - 1 : 0;
-       rx.link = rcu_dereference(sta->sdata->link[link_id]);
 
        ieee80211_rx_handlers(&rx, &frames);
 }
@@ -4112,7 +4177,6 @@ void ieee80211_mark_rx_ba_filtered_frames(struct ieee80211_sta *pubsta, u8 tid,
                /* This is OK -- must be QoS data frame */
                .security_idx = tid,
                .seqno_idx = tid,
-               .link_id = -1,
        };
        int i, diff;
 
@@ -4123,10 +4187,8 @@ void ieee80211_mark_rx_ba_filtered_frames(struct ieee80211_sta *pubsta, u8 tid,
 
        sta = container_of(pubsta, struct sta_info, sta);
 
-       rx.sta = sta;
-       rx.sdata = sta->sdata;
-       rx.link = &rx.sdata->deflink;
-       rx.local = sta->local;
+       if (!ieee80211_rx_data_set_sta(&rx, pubsta, -1))
+               return;
 
        rcu_read_lock();
        tid_agg_rx = rcu_dereference(sta->ampdu_mlme.tid_rx[tid]);
@@ -4340,6 +4402,7 @@ void ieee80211_check_fast_rx(struct sta_info *sta)
                .vif_type = sdata->vif.type,
                .control_port_protocol = sdata->control_port_protocol,
        }, *old, *new = NULL;
+       u32 offload_flags;
        bool set_offload = false;
        bool assign = false;
        bool offload;
@@ -4455,10 +4518,10 @@ void ieee80211_check_fast_rx(struct sta_info *sta)
        if (assign)
                new = kmemdup(&fastrx, sizeof(fastrx), GFP_KERNEL);
 
-       offload = assign &&
-                 (sdata->vif.offload_flags & IEEE80211_OFFLOAD_DECAP_ENABLED);
+       offload_flags = get_bss_sdata(sdata)->vif.offload_flags;
+       offload = offload_flags & IEEE80211_OFFLOAD_DECAP_ENABLED;
 
-       if (offload)
+       if (assign && offload)
                set_offload = !test_and_set_sta_flag(sta, WLAN_STA_DECAP_OFFLOAD);
        else
                set_offload = test_and_clear_sta_flag(sta, WLAN_STA_DECAP_OFFLOAD);
@@ -4519,19 +4582,30 @@ static void ieee80211_rx_8023(struct ieee80211_rx_data *rx,
        struct ieee80211_sta_rx_stats *stats;
        struct ieee80211_rx_status *status = IEEE80211_SKB_RXCB(rx->skb);
        struct sta_info *sta = rx->sta;
+       struct link_sta_info *link_sta;
        struct sk_buff *skb = rx->skb;
        void *sa = skb->data + ETH_ALEN;
        void *da = skb->data;
 
-       stats = &sta->deflink.rx_stats;
+       if (rx->link_id >= 0) {
+               link_sta = rcu_dereference(sta->link[rx->link_id]);
+               if (WARN_ON_ONCE(!link_sta)) {
+                       dev_kfree_skb(rx->skb);
+                       return;
+               }
+       } else {
+               link_sta = &sta->deflink;
+       }
+
+       stats = &link_sta->rx_stats;
        if (fast_rx->uses_rss)
-               stats = this_cpu_ptr(sta->deflink.pcpu_rx_stats);
+               stats = this_cpu_ptr(link_sta->pcpu_rx_stats);
 
        /* statistics part of ieee80211_rx_h_sta_process() */
        if (!(status->flag & RX_FLAG_NO_SIGNAL_VAL)) {
                stats->last_signal = status->signal;
                if (!fast_rx->uses_rss)
-                       ewma_signal_add(&sta->deflink.rx_stats_avg.signal,
+                       ewma_signal_add(&link_sta->rx_stats_avg.signal,
                                        -status->signal);
        }
 
@@ -4547,7 +4621,7 @@ static void ieee80211_rx_8023(struct ieee80211_rx_data *rx,
 
                        stats->chain_signal_last[i] = signal;
                        if (!fast_rx->uses_rss)
-                               ewma_signal_add(&sta->deflink.rx_stats_avg.chain_signal[i],
+                               ewma_signal_add(&link_sta->rx_stats_avg.chain_signal[i],
                                                -signal);
                }
        }
@@ -4611,7 +4685,6 @@ static bool ieee80211_invoke_fast_rx(struct ieee80211_rx_data *rx,
        struct sk_buff *skb = rx->skb;
        struct ieee80211_hdr *hdr = (void *)skb->data;
        struct ieee80211_rx_status *status = IEEE80211_SKB_RXCB(skb);
-       struct sta_info *sta = rx->sta;
        int orig_len = skb->len;
        int hdrlen = ieee80211_hdrlen(hdr->frame_control);
        int snap_offs = hdrlen;
@@ -4623,7 +4696,7 @@ static bool ieee80211_invoke_fast_rx(struct ieee80211_rx_data *rx,
                u8 da[ETH_ALEN];
                u8 sa[ETH_ALEN];
        } addrs __aligned(2);
-       struct ieee80211_sta_rx_stats *stats = &sta->deflink.rx_stats;
+       struct ieee80211_sta_rx_stats *stats;
 
        /* for parallel-rx, we need to have DUP_VALIDATED, otherwise we write
         * to a common data structure; drivers can implement that per queue
@@ -4675,7 +4748,7 @@ static bool ieee80211_invoke_fast_rx(struct ieee80211_rx_data *rx,
 
        if (!(status->rx_flags & IEEE80211_RX_AMSDU)) {
                if (!pskb_may_pull(skb, snap_offs + sizeof(*payload)))
-                       goto drop;
+                       return false;
 
                payload = (void *)(skb->data + snap_offs);
 
@@ -4724,8 +4797,11 @@ static bool ieee80211_invoke_fast_rx(struct ieee80211_rx_data *rx,
        return true;
  drop:
        dev_kfree_skb(skb);
+
        if (fast_rx->uses_rss)
-               stats = this_cpu_ptr(sta->deflink.pcpu_rx_stats);
+               stats = this_cpu_ptr(rx->link_sta->pcpu_rx_stats);
+       else
+               stats = &rx->link_sta->rx_stats;
 
        stats->dropped++;
        return true;
@@ -4743,8 +4819,8 @@ static bool ieee80211_prepare_and_rx_handle(struct ieee80211_rx_data *rx,
        struct ieee80211_local *local = rx->local;
        struct ieee80211_sub_if_data *sdata = rx->sdata;
        struct ieee80211_hdr *hdr = (void *)skb->data;
-       struct link_sta_info *link_sta = NULL;
-       struct ieee80211_link_data *link;
+       struct link_sta_info *link_sta = rx->link_sta;
+       struct ieee80211_link_data *link = rx->link;
 
        rx->skb = skb;
 
@@ -4766,25 +4842,6 @@ static bool ieee80211_prepare_and_rx_handle(struct ieee80211_rx_data *rx,
        if (!ieee80211_accept_frame(rx))
                return false;
 
-       if (rx->link_id >= 0) {
-               link = rcu_dereference(rx->sdata->link[rx->link_id]);
-
-               /* we might race link removal */
-               if (!link)
-                       return true;
-               rx->link = link;
-       } else {
-               rx->link = &sdata->deflink;
-       }
-
-       if (unlikely(!is_multicast_ether_addr(hdr->addr1) &&
-                    rx->link_id >= 0 && rx->sta && rx->sta->sta.mlo)) {
-               link_sta = rcu_dereference(rx->sta->link[rx->link_id]);
-
-               if (WARN_ON_ONCE(!link_sta))
-                       return true;
-       }
-
        if (!consume) {
                struct skb_shared_hwtstamps *shwt;
 
@@ -4802,9 +4859,12 @@ static bool ieee80211_prepare_and_rx_handle(struct ieee80211_rx_data *rx,
                 */
                shwt = skb_hwtstamps(rx->skb);
                shwt->hwtstamp = skb_hwtstamps(skb)->hwtstamp;
+
+               /* Update the hdr pointer to the new skb for translation below */
+               hdr = (struct ieee80211_hdr *)rx->skb->data;
        }
 
-       if (unlikely(link_sta)) {
+       if (unlikely(rx->sta && rx->sta->sta.mlo)) {
                /* translate to MLD addresses */
                if (ether_addr_equal(link->conf->addr, hdr->addr1))
                        ether_addr_copy(hdr->addr1, rx->sdata->vif.addr);
@@ -4831,8 +4891,10 @@ static void __ieee80211_rx_handle_8023(struct ieee80211_hw *hw,
                                       struct list_head *list)
 {
        struct ieee80211_local *local = hw_to_local(hw);
+       struct ieee80211_rx_status *status = IEEE80211_SKB_RXCB(skb);
        struct ieee80211_fast_rx *fast_rx;
        struct ieee80211_rx_data rx;
+       int link_id = -1;
 
        memset(&rx, 0, sizeof(rx));
        rx.skb = skb;
@@ -4849,9 +4911,18 @@ static void __ieee80211_rx_handle_8023(struct ieee80211_hw *hw,
        if (!pubsta)
                goto drop;
 
-       rx.sta = container_of(pubsta, struct sta_info, sta);
-       rx.sdata = rx.sta->sdata;
-       rx.link = &rx.sdata->deflink;
+       if (status->link_valid)
+               link_id = status->link_id;
+
+       /*
+        * TODO: Should the frame be dropped if the right link_id is not
+        * available? Or may be it is fine in the current form to proceed with
+        * the frame processing because with frame being in 802.3 format,
+        * link_id is used only for stats purpose and updating the stats on
+        * the deflink is fine?
+        */
+       if (!ieee80211_rx_data_set_sta(&rx, pubsta, link_id))
+               goto drop;
 
        fast_rx = rcu_dereference(rx.sta->fast_rx);
        if (!fast_rx)
@@ -4869,6 +4940,8 @@ static bool ieee80211_rx_for_interface(struct ieee80211_rx_data *rx,
 {
        struct link_sta_info *link_sta;
        struct ieee80211_hdr *hdr = (void *)skb->data;
+       struct sta_info *sta;
+       int link_id = -1;
 
        /*
         * Look up link station first, in case there's a
@@ -4878,12 +4951,19 @@ static bool ieee80211_rx_for_interface(struct ieee80211_rx_data *rx,
         */
        link_sta = link_sta_info_get_bss(rx->sdata, hdr->addr2);
        if (link_sta) {
-               rx->sta = link_sta->sta;
-               rx->link_id = link_sta->link_id;
+               sta = link_sta->sta;
+               link_id = link_sta->link_id;
        } else {
-               rx->sta = sta_info_get_bss(rx->sdata, hdr->addr2);
+               struct ieee80211_rx_status *status = IEEE80211_SKB_RXCB(skb);
+
+               sta = sta_info_get_bss(rx->sdata, hdr->addr2);
+               if (status->link_valid)
+                       link_id = status->link_id;
        }
 
+       if (!ieee80211_rx_data_set_sta(rx, &sta->sta, link_id))
+               return false;
+
        return ieee80211_prepare_and_rx_handle(rx, skb, consume);
 }
 
@@ -4897,6 +4977,7 @@ static void __ieee80211_rx_handle_packet(struct ieee80211_hw *hw,
                                         struct list_head *list)
 {
        struct ieee80211_local *local = hw_to_local(hw);
+       struct ieee80211_rx_status *status = IEEE80211_SKB_RXCB(skb);
        struct ieee80211_sub_if_data *sdata;
        struct ieee80211_hdr *hdr;
        __le16 fc;
@@ -4941,10 +5022,35 @@ static void __ieee80211_rx_handle_packet(struct ieee80211_hw *hw,
 
        if (ieee80211_is_data(fc)) {
                struct sta_info *sta, *prev_sta;
+               int link_id = -1;
+
+               if (status->link_valid)
+                       link_id = status->link_id;
 
                if (pubsta) {
-                       rx.sta = container_of(pubsta, struct sta_info, sta);
-                       rx.sdata = rx.sta->sdata;
+                       if (!ieee80211_rx_data_set_sta(&rx, pubsta, link_id))
+                               goto out;
+
+                       /*
+                        * In MLO connection, fetch the link_id using addr2
+                        * when the driver does not pass link_id in status.
+                        * When the address translation is already performed by
+                        * driver/hw, the valid link_id must be passed in
+                        * status.
+                        */
+
+                       if (!status->link_valid && pubsta->mlo) {
+                               struct ieee80211_hdr *hdr = (void *)skb->data;
+                               struct link_sta_info *link_sta;
+
+                               link_sta = link_sta_info_get_bss(rx.sdata,
+                                                                hdr->addr2);
+                               if (!link_sta)
+                                       goto out;
+
+                               ieee80211_rx_data_set_link(&rx, link_sta->link_id);
+                       }
+
                        if (ieee80211_prepare_and_rx_handle(&rx, skb, true))
                                return;
                        goto out;
@@ -4958,16 +5064,27 @@ static void __ieee80211_rx_handle_packet(struct ieee80211_hw *hw,
                                continue;
                        }
 
-                       rx.sta = prev_sta;
                        rx.sdata = prev_sta->sdata;
+                       if (!ieee80211_rx_data_set_sta(&rx, &prev_sta->sta,
+                                                      link_id))
+                               goto out;
+
+                       if (!status->link_valid && prev_sta->sta.mlo)
+                               continue;
+
                        ieee80211_prepare_and_rx_handle(&rx, skb, false);
 
                        prev_sta = sta;
                }
 
                if (prev_sta) {
-                       rx.sta = prev_sta;
                        rx.sdata = prev_sta->sdata;
+                       if (!ieee80211_rx_data_set_sta(&rx, &prev_sta->sta,
+                                                      link_id))
+                               goto out;
+
+                       if (!status->link_valid && prev_sta->sta.mlo)
+                               goto out;
 
                        if (ieee80211_prepare_and_rx_handle(&rx, skb, true))
                                return;
@@ -5108,6 +5225,9 @@ void ieee80211_rx_list(struct ieee80211_hw *hw, struct ieee80211_sta *pubsta,
                }
        }
 
+       if (WARN_ON_ONCE(status->link_id >= IEEE80211_LINK_UNSPECIFIED))
+               goto drop;
+
        status->rx_flags = 0;
 
        kcov_remote_start_common(skb_get_kcov_handle(skb));