Merge tag 'kvm-x86-mmu-6.4' of https://github.com/kvm-x86/linux into HEAD
[platform/kernel/linux-rpi.git] / arch / x86 / kvm / mmu / mmu.c
index 4544605..e3d02f0 100644 (file)
@@ -125,17 +125,31 @@ module_param(dbg, bool, 0644);
 #define PTE_LIST_EXT 14
 
 /*
- * Slight optimization of cacheline layout, by putting `more' and `spte_count'
- * at the start; then accessing it will only use one single cacheline for
- * either full (entries==PTE_LIST_EXT) case or entries<=6.
+ * struct pte_list_desc is the core data structure used to implement a custom
+ * list for tracking a set of related SPTEs, e.g. all the SPTEs that map a
+ * given GFN when used in the context of rmaps.  Using a custom list allows KVM
+ * to optimize for the common case where many GFNs will have at most a handful
+ * of SPTEs pointing at them, i.e. allows packing multiple SPTEs into a small
+ * memory footprint, which in turn improves runtime performance by exploiting
+ * cache locality.
+ *
+ * A list is comprised of one or more pte_list_desc objects (descriptors).
+ * Each individual descriptor stores up to PTE_LIST_EXT SPTEs.  If a descriptor
+ * is full and a new SPTEs needs to be added, a new descriptor is allocated and
+ * becomes the head of the list.  This means that by definitions, all tail
+ * descriptors are full.
+ *
+ * Note, the meta data fields are deliberately placed at the start of the
+ * structure to optimize the cacheline layout; accessing the descriptor will
+ * touch only a single cacheline so long as @spte_count<=6 (or if only the
+ * descriptors metadata is accessed).
  */
 struct pte_list_desc {
        struct pte_list_desc *more;
-       /*
-        * Stores number of entries stored in the pte_list_desc.  No need to be
-        * u64 but just for easier alignment.  When PTE_LIST_EXT, means full.
-        */
-       u64 spte_count;
+       /* The number of PTEs stored in _this_ descriptor. */
+       u32 spte_count;
+       /* The number of PTEs stored in all tails of this descriptor. */
+       u32 tail_count;
        u64 *sptes[PTE_LIST_EXT];
 };
 
@@ -242,32 +256,35 @@ static struct kvm_mmu_role_regs vcpu_to_role_regs(struct kvm_vcpu *vcpu)
        return regs;
 }
 
-static inline bool kvm_available_flush_tlb_with_range(void)
+static unsigned long get_guest_cr3(struct kvm_vcpu *vcpu)
 {
-       return kvm_x86_ops.tlb_remote_flush_with_range;
+       return kvm_read_cr3(vcpu);
 }
 
-static void kvm_flush_remote_tlbs_with_range(struct kvm *kvm,
-               struct kvm_tlb_range *range)
+static inline unsigned long kvm_mmu_get_guest_pgd(struct kvm_vcpu *vcpu,
+                                                 struct kvm_mmu *mmu)
 {
-       int ret = -ENOTSUPP;
-
-       if (range && kvm_x86_ops.tlb_remote_flush_with_range)
-               ret = static_call(kvm_x86_tlb_remote_flush_with_range)(kvm, range);
+       if (IS_ENABLED(CONFIG_RETPOLINE) && mmu->get_guest_pgd == get_guest_cr3)
+               return kvm_read_cr3(vcpu);
 
-       if (ret)
-               kvm_flush_remote_tlbs(kvm);
+       return mmu->get_guest_pgd(vcpu);
 }
 
-void kvm_flush_remote_tlbs_with_address(struct kvm *kvm,
-               u64 start_gfn, u64 pages)
+static inline bool kvm_available_flush_remote_tlbs_range(void)
 {
-       struct kvm_tlb_range range;
+       return kvm_x86_ops.flush_remote_tlbs_range;
+}
 
-       range.start_gfn = start_gfn;
-       range.pages = pages;
+void kvm_flush_remote_tlbs_range(struct kvm *kvm, gfn_t start_gfn,
+                                gfn_t nr_pages)
+{
+       int ret = -EOPNOTSUPP;
 
-       kvm_flush_remote_tlbs_with_range(kvm, &range);
+       if (kvm_x86_ops.flush_remote_tlbs_range)
+               ret = static_call(kvm_x86_flush_remote_tlbs_range)(kvm, start_gfn,
+                                                                  nr_pages);
+       if (ret)
+               kvm_flush_remote_tlbs(kvm);
 }
 
 static gfn_t kvm_mmu_page_get_gfn(struct kvm_mmu_page *sp, int index);
@@ -888,9 +905,9 @@ static void unaccount_nx_huge_page(struct kvm *kvm, struct kvm_mmu_page *sp)
        untrack_possible_nx_huge_page(kvm, sp);
 }
 
-static struct kvm_memory_slot *
-gfn_to_memslot_dirty_bitmap(struct kvm_vcpu *vcpu, gfn_t gfn,
-                           bool no_dirty_log)
+static struct kvm_memory_slot *gfn_to_memslot_dirty_bitmap(struct kvm_vcpu *vcpu,
+                                                          gfn_t gfn,
+                                                          bool no_dirty_log)
 {
        struct kvm_memory_slot *slot;
 
@@ -929,53 +946,69 @@ static int pte_list_add(struct kvm_mmu_memory_cache *cache, u64 *spte,
                desc->sptes[0] = (u64 *)rmap_head->val;
                desc->sptes[1] = spte;
                desc->spte_count = 2;
+               desc->tail_count = 0;
                rmap_head->val = (unsigned long)desc | 1;
                ++count;
        } else {
                rmap_printk("%p %llx many->many\n", spte, *spte);
                desc = (struct pte_list_desc *)(rmap_head->val & ~1ul);
-               while (desc->spte_count == PTE_LIST_EXT) {
-                       count += PTE_LIST_EXT;
-                       if (!desc->more) {
-                               desc->more = kvm_mmu_memory_cache_alloc(cache);
-                               desc = desc->more;
-                               desc->spte_count = 0;
-                               break;
-                       }
-                       desc = desc->more;
+               count = desc->tail_count + desc->spte_count;
+
+               /*
+                * If the previous head is full, allocate a new head descriptor
+                * as tail descriptors are always kept full.
+                */
+               if (desc->spte_count == PTE_LIST_EXT) {
+                       desc = kvm_mmu_memory_cache_alloc(cache);
+                       desc->more = (struct pte_list_desc *)(rmap_head->val & ~1ul);
+                       desc->spte_count = 0;
+                       desc->tail_count = count;
+                       rmap_head->val = (unsigned long)desc | 1;
                }
-               count += desc->spte_count;
                desc->sptes[desc->spte_count++] = spte;
        }
        return count;
 }
 
-static void
-pte_list_desc_remove_entry(struct kvm_rmap_head *rmap_head,
-                          struct pte_list_desc *desc, int i,
-                          struct pte_list_desc *prev_desc)
+static void pte_list_desc_remove_entry(struct kvm_rmap_head *rmap_head,
+                                      struct pte_list_desc *desc, int i)
 {
-       int j = desc->spte_count - 1;
+       struct pte_list_desc *head_desc = (struct pte_list_desc *)(rmap_head->val & ~1ul);
+       int j = head_desc->spte_count - 1;
 
-       desc->sptes[i] = desc->sptes[j];
-       desc->sptes[j] = NULL;
-       desc->spte_count--;
-       if (desc->spte_count)
+       /*
+        * The head descriptor should never be empty.  A new head is added only
+        * when adding an entry and the previous head is full, and heads are
+        * removed (this flow) when they become empty.
+        */
+       BUG_ON(j < 0);
+
+       /*
+        * Replace the to-be-freed SPTE with the last valid entry from the head
+        * descriptor to ensure that tail descriptors are full at all times.
+        * Note, this also means that tail_count is stable for each descriptor.
+        */
+       desc->sptes[i] = head_desc->sptes[j];
+       head_desc->sptes[j] = NULL;
+       head_desc->spte_count--;
+       if (head_desc->spte_count)
                return;
-       if (!prev_desc && !desc->more)
+
+       /*
+        * The head descriptor is empty.  If there are no tail descriptors,
+        * nullify the rmap head to mark the list as emtpy, else point the rmap
+        * head at the next descriptor, i.e. the new head.
+        */
+       if (!head_desc->more)
                rmap_head->val = 0;
        else
-               if (prev_desc)
-                       prev_desc->more = desc->more;
-               else
-                       rmap_head->val = (unsigned long)desc->more | 1;
-       mmu_free_pte_list_desc(desc);
+               rmap_head->val = (unsigned long)head_desc->more | 1;
+       mmu_free_pte_list_desc(head_desc);
 }
 
 static void pte_list_remove(u64 *spte, struct kvm_rmap_head *rmap_head)
 {
        struct pte_list_desc *desc;
-       struct pte_list_desc *prev_desc;
        int i;
 
        if (!rmap_head->val) {
@@ -991,16 +1024,13 @@ static void pte_list_remove(u64 *spte, struct kvm_rmap_head *rmap_head)
        } else {
                rmap_printk("%p many->many\n", spte);
                desc = (struct pte_list_desc *)(rmap_head->val & ~1ul);
-               prev_desc = NULL;
                while (desc) {
                        for (i = 0; i < desc->spte_count; ++i) {
                                if (desc->sptes[i] == spte) {
-                                       pte_list_desc_remove_entry(rmap_head,
-                                                       desc, i, prev_desc);
+                                       pte_list_desc_remove_entry(rmap_head, desc, i);
                                        return;
                                }
                        }
-                       prev_desc = desc;
                        desc = desc->more;
                }
                pr_err("%s: %p many->many\n", __func__, spte);
@@ -1047,7 +1077,6 @@ out:
 unsigned int pte_list_count(struct kvm_rmap_head *rmap_head)
 {
        struct pte_list_desc *desc;
-       unsigned int count = 0;
 
        if (!rmap_head->val)
                return 0;
@@ -1055,13 +1084,7 @@ unsigned int pte_list_count(struct kvm_rmap_head *rmap_head)
                return 1;
 
        desc = (struct pte_list_desc *)(rmap_head->val & ~1ul);
-
-       while (desc) {
-               count += desc->spte_count;
-               desc = desc->more;
-       }
-
-       return count;
+       return desc->tail_count + desc->spte_count;
 }
 
 static struct kvm_rmap_head *gfn_to_rmap(gfn_t gfn, int level,
@@ -1073,14 +1096,6 @@ static struct kvm_rmap_head *gfn_to_rmap(gfn_t gfn, int level,
        return &slot->arch.rmap[level - PG_LEVEL_4K][idx];
 }
 
-static bool rmap_can_add(struct kvm_vcpu *vcpu)
-{
-       struct kvm_mmu_memory_cache *mc;
-
-       mc = &vcpu->arch.mmu_pte_list_desc_cache;
-       return kvm_mmu_memory_cache_nr_free_objects(mc);
-}
-
 static void rmap_remove(struct kvm *kvm, u64 *spte)
 {
        struct kvm_memslots *slots;
@@ -1479,7 +1494,7 @@ restart:
                }
        }
 
-       if (need_flush && kvm_available_flush_tlb_with_range()) {
+       if (need_flush && kvm_available_flush_remote_tlbs_range()) {
                kvm_flush_remote_tlbs_gfn(kvm, gfn, level);
                return false;
        }
@@ -1504,8 +1519,8 @@ struct slot_rmap_walk_iterator {
        struct kvm_rmap_head *end_rmap;
 };
 
-static void
-rmap_walk_init_level(struct slot_rmap_walk_iterator *iterator, int level)
+static void rmap_walk_init_level(struct slot_rmap_walk_iterator *iterator,
+                                int level)
 {
        iterator->level = level;
        iterator->gfn = iterator->start_gfn;
@@ -1513,10 +1528,10 @@ rmap_walk_init_level(struct slot_rmap_walk_iterator *iterator, int level)
        iterator->end_rmap = gfn_to_rmap(iterator->end_gfn, level, iterator->slot);
 }
 
-static void
-slot_rmap_walk_init(struct slot_rmap_walk_iterator *iterator,
-                   const struct kvm_memory_slot *slot, int start_level,
-                   int end_level, gfn_t start_gfn, gfn_t end_gfn)
+static void slot_rmap_walk_init(struct slot_rmap_walk_iterator *iterator,
+                               const struct kvm_memory_slot *slot,
+                               int start_level, int end_level,
+                               gfn_t start_gfn, gfn_t end_gfn)
 {
        iterator->slot = slot;
        iterator->start_level = start_level;
@@ -1789,12 +1804,6 @@ static void mark_unsync(u64 *spte)
        kvm_mmu_mark_parents_unsync(sp);
 }
 
-static int nonpaging_sync_page(struct kvm_vcpu *vcpu,
-                              struct kvm_mmu_page *sp)
-{
-       return -1;
-}
-
 #define KVM_PAGE_ARRAY_NR 16
 
 struct kvm_mmu_pages {
@@ -1914,10 +1923,79 @@ static bool sp_has_gptes(struct kvm_mmu_page *sp)
          &(_kvm)->arch.mmu_page_hash[kvm_page_table_hashfn(_gfn)])     \
                if ((_sp)->gfn != (_gfn) || !sp_has_gptes(_sp)) {} else
 
+static bool kvm_sync_page_check(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp)
+{
+       union kvm_mmu_page_role root_role = vcpu->arch.mmu->root_role;
+
+       /*
+        * Ignore various flags when verifying that it's safe to sync a shadow
+        * page using the current MMU context.
+        *
+        *  - level: not part of the overall MMU role and will never match as the MMU's
+        *           level tracks the root level
+        *  - access: updated based on the new guest PTE
+        *  - quadrant: not part of the overall MMU role (similar to level)
+        */
+       const union kvm_mmu_page_role sync_role_ign = {
+               .level = 0xf,
+               .access = 0x7,
+               .quadrant = 0x3,
+               .passthrough = 0x1,
+       };
+
+       /*
+        * Direct pages can never be unsync, and KVM should never attempt to
+        * sync a shadow page for a different MMU context, e.g. if the role
+        * differs then the memslot lookup (SMM vs. non-SMM) will be bogus, the
+        * reserved bits checks will be wrong, etc...
+        */
+       if (WARN_ON_ONCE(sp->role.direct || !vcpu->arch.mmu->sync_spte ||
+                        (sp->role.word ^ root_role.word) & ~sync_role_ign.word))
+               return false;
+
+       return true;
+}
+
+static int kvm_sync_spte(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp, int i)
+{
+       if (!sp->spt[i])
+               return 0;
+
+       return vcpu->arch.mmu->sync_spte(vcpu, sp, i);
+}
+
+static int __kvm_sync_page(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp)
+{
+       int flush = 0;
+       int i;
+
+       if (!kvm_sync_page_check(vcpu, sp))
+               return -1;
+
+       for (i = 0; i < SPTE_ENT_PER_PAGE; i++) {
+               int ret = kvm_sync_spte(vcpu, sp, i);
+
+               if (ret < -1)
+                       return -1;
+               flush |= ret;
+       }
+
+       /*
+        * Note, any flush is purely for KVM's correctness, e.g. when dropping
+        * an existing SPTE or clearing W/A/D bits to ensure an mmu_notifier
+        * unmap or dirty logging event doesn't fail to flush.  The guest is
+        * responsible for flushing the TLB to ensure any changes in protection
+        * bits are recognized, i.e. until the guest flushes or page faults on
+        * a relevant address, KVM is architecturally allowed to let vCPUs use
+        * cached translations with the old protection bits.
+        */
+       return flush;
+}
+
 static int kvm_sync_page(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp,
                         struct list_head *invalid_list)
 {
-       int ret = vcpu->arch.mmu->sync_page(vcpu, sp);
+       int ret = __kvm_sync_page(vcpu, sp);
 
        if (ret < 0)
                kvm_mmu_prepare_zap_page(vcpu->kvm, sp, invalid_list);
@@ -3304,9 +3382,9 @@ static bool page_fault_can_be_fast(struct kvm_page_fault *fault)
  * Returns true if the SPTE was fixed successfully. Otherwise,
  * someone else modified the SPTE from its original value.
  */
-static bool
-fast_pf_fix_direct_spte(struct kvm_vcpu *vcpu, struct kvm_page_fault *fault,
-                       u64 *sptep, u64 old_spte, u64 new_spte)
+static bool fast_pf_fix_direct_spte(struct kvm_vcpu *vcpu,
+                                   struct kvm_page_fault *fault,
+                                   u64 *sptep, u64 old_spte, u64 new_spte)
 {
        /*
         * Theoretically we could also set dirty bit (and flush TLB) here in
@@ -3513,6 +3591,8 @@ void kvm_mmu_free_roots(struct kvm *kvm, struct kvm_mmu *mmu,
        LIST_HEAD(invalid_list);
        bool free_active_root;
 
+       WARN_ON_ONCE(roots_to_free & ~KVM_MMU_ROOTS_ALL);
+
        BUILD_BUG_ON(KVM_MMU_NUM_PREV_ROOTS >= BITS_PER_LONG);
 
        /* Before acquiring the MMU lock, see if we need to do any real work. */
@@ -3731,7 +3811,7 @@ static int mmu_alloc_shadow_roots(struct kvm_vcpu *vcpu)
        int quadrant, i, r;
        hpa_t root;
 
-       root_pgd = mmu->get_guest_pgd(vcpu);
+       root_pgd = kvm_mmu_get_guest_pgd(vcpu, mmu);
        root_gfn = root_pgd >> PAGE_SHIFT;
 
        if (mmu_check_root(vcpu, root_gfn))
@@ -4181,7 +4261,7 @@ static bool kvm_arch_setup_async_pf(struct kvm_vcpu *vcpu, gpa_t cr2_or_gpa,
        arch.token = alloc_apf_token(vcpu);
        arch.gfn = gfn;
        arch.direct_map = vcpu->arch.mmu->root_role.direct;
-       arch.cr3 = vcpu->arch.mmu->get_guest_pgd(vcpu);
+       arch.cr3 = kvm_mmu_get_guest_pgd(vcpu, vcpu->arch.mmu);
 
        return kvm_setup_async_pf(vcpu, cr2_or_gpa,
                                  kvm_vcpu_gfn_to_hva(vcpu, gfn), &arch);
@@ -4200,7 +4280,7 @@ void kvm_arch_async_page_ready(struct kvm_vcpu *vcpu, struct kvm_async_pf *work)
                return;
 
        if (!vcpu->arch.mmu->root_role.direct &&
-             work->arch.cr3 != vcpu->arch.mmu->get_guest_pgd(vcpu))
+             work->arch.cr3 != kvm_mmu_get_guest_pgd(vcpu, vcpu->arch.mmu))
                return;
 
        kvm_mmu_do_page_fault(vcpu, work->cr2_or_gpa, 0, true, NULL);
@@ -4469,8 +4549,7 @@ static void nonpaging_init_context(struct kvm_mmu *context)
 {
        context->page_fault = nonpaging_page_fault;
        context->gva_to_gpa = nonpaging_gva_to_gpa;
-       context->sync_page = nonpaging_sync_page;
-       context->invlpg = NULL;
+       context->sync_spte = NULL;
 }
 
 static inline bool is_root_usable(struct kvm_mmu_root_info *root, gpa_t pgd,
@@ -4604,11 +4683,6 @@ void kvm_mmu_new_pgd(struct kvm_vcpu *vcpu, gpa_t new_pgd)
 }
 EXPORT_SYMBOL_GPL(kvm_mmu_new_pgd);
 
-static unsigned long get_cr3(struct kvm_vcpu *vcpu)
-{
-       return kvm_read_cr3(vcpu);
-}
-
 static bool sync_mmio_spte(struct kvm_vcpu *vcpu, u64 *sptep, gfn_t gfn,
                           unsigned int access)
 {
@@ -4638,10 +4712,9 @@ static bool sync_mmio_spte(struct kvm_vcpu *vcpu, u64 *sptep, gfn_t gfn,
 #include "paging_tmpl.h"
 #undef PTTYPE
 
-static void
-__reset_rsvds_bits_mask(struct rsvd_bits_validate *rsvd_check,
-                       u64 pa_bits_rsvd, int level, bool nx, bool gbpages,
-                       bool pse, bool amd)
+static void __reset_rsvds_bits_mask(struct rsvd_bits_validate *rsvd_check,
+                                   u64 pa_bits_rsvd, int level, bool nx,
+                                   bool gbpages, bool pse, bool amd)
 {
        u64 gbpages_bit_rsvd = 0;
        u64 nonleaf_bit8_rsvd = 0;
@@ -4754,9 +4827,9 @@ static void reset_guest_rsvds_bits_mask(struct kvm_vcpu *vcpu,
                                guest_cpuid_is_amd_or_hygon(vcpu));
 }
 
-static void
-__reset_rsvds_bits_mask_ept(struct rsvd_bits_validate *rsvd_check,
-                           u64 pa_bits_rsvd, bool execonly, int huge_page_level)
+static void __reset_rsvds_bits_mask_ept(struct rsvd_bits_validate *rsvd_check,
+                                       u64 pa_bits_rsvd, bool execonly,
+                                       int huge_page_level)
 {
        u64 high_bits_rsvd = pa_bits_rsvd & rsvd_bits(0, 51);
        u64 large_1g_rsvd = 0, large_2m_rsvd = 0;
@@ -4856,8 +4929,7 @@ static inline bool boot_cpu_is_amd(void)
  * the direct page table on host, use as much mmu features as
  * possible, however, kvm currently does not do execution-protection.
  */
-static void
-reset_tdp_shadow_zero_bits_mask(struct kvm_mmu *context)
+static void reset_tdp_shadow_zero_bits_mask(struct kvm_mmu *context)
 {
        struct rsvd_bits_validate *shadow_zero_check;
        int i;
@@ -5060,20 +5132,18 @@ static void paging64_init_context(struct kvm_mmu *context)
 {
        context->page_fault = paging64_page_fault;
        context->gva_to_gpa = paging64_gva_to_gpa;
-       context->sync_page = paging64_sync_page;
-       context->invlpg = paging64_invlpg;
+       context->sync_spte = paging64_sync_spte;
 }
 
 static void paging32_init_context(struct kvm_mmu *context)
 {
        context->page_fault = paging32_page_fault;
        context->gva_to_gpa = paging32_gva_to_gpa;
-       context->sync_page = paging32_sync_page;
-       context->invlpg = paging32_invlpg;
+       context->sync_spte = paging32_sync_spte;
 }
 
-static union kvm_cpu_role
-kvm_calc_cpu_role(struct kvm_vcpu *vcpu, const struct kvm_mmu_role_regs *regs)
+static union kvm_cpu_role kvm_calc_cpu_role(struct kvm_vcpu *vcpu,
+                                           const struct kvm_mmu_role_regs *regs)
 {
        union kvm_cpu_role role = {0};
 
@@ -5172,9 +5242,8 @@ static void init_kvm_tdp_mmu(struct kvm_vcpu *vcpu,
        context->cpu_role.as_u64 = cpu_role.as_u64;
        context->root_role.word = root_role.word;
        context->page_fault = kvm_tdp_page_fault;
-       context->sync_page = nonpaging_sync_page;
-       context->invlpg = NULL;
-       context->get_guest_pgd = get_cr3;
+       context->sync_spte = NULL;
+       context->get_guest_pgd = get_guest_cr3;
        context->get_pdptr = kvm_pdptr_read;
        context->inject_page_fault = kvm_inject_page_fault;
 
@@ -5304,8 +5373,7 @@ void kvm_init_shadow_ept_mmu(struct kvm_vcpu *vcpu, bool execonly,
 
                context->page_fault = ept_page_fault;
                context->gva_to_gpa = ept_gva_to_gpa;
-               context->sync_page = ept_sync_page;
-               context->invlpg = ept_invlpg;
+               context->sync_spte = ept_sync_spte;
 
                update_permission_bitmask(context, true);
                context->pkru_mask = 0;
@@ -5324,7 +5392,7 @@ static void init_kvm_softmmu(struct kvm_vcpu *vcpu,
 
        kvm_init_shadow_mmu(vcpu, cpu_role);
 
-       context->get_guest_pgd     = get_cr3;
+       context->get_guest_pgd     = get_guest_cr3;
        context->get_pdptr         = kvm_pdptr_read;
        context->inject_page_fault = kvm_inject_page_fault;
 }
@@ -5338,7 +5406,7 @@ static void init_kvm_nested_mmu(struct kvm_vcpu *vcpu,
                return;
 
        g_context->cpu_role.as_u64   = new_mode.as_u64;
-       g_context->get_guest_pgd     = get_cr3;
+       g_context->get_guest_pgd     = get_guest_cr3;
        g_context->get_pdptr         = kvm_pdptr_read;
        g_context->inject_page_fault = kvm_inject_page_fault;
 
@@ -5346,7 +5414,7 @@ static void init_kvm_nested_mmu(struct kvm_vcpu *vcpu,
         * L2 page tables are never shadowed, so there is no need to sync
         * SPTEs.
         */
-       g_context->invlpg            = NULL;
+       g_context->sync_spte         = NULL;
 
        /*
         * Note that arch.mmu->gva_to_gpa translates l2_gpa to l1_gpa using
@@ -5722,48 +5790,77 @@ emulate:
 }
 EXPORT_SYMBOL_GPL(kvm_mmu_page_fault);
 
-void kvm_mmu_invalidate_gva(struct kvm_vcpu *vcpu, struct kvm_mmu *mmu,
-                           gva_t gva, hpa_t root_hpa)
+static void __kvm_mmu_invalidate_addr(struct kvm_vcpu *vcpu, struct kvm_mmu *mmu,
+                                     u64 addr, hpa_t root_hpa)
+{
+       struct kvm_shadow_walk_iterator iterator;
+
+       vcpu_clear_mmio_info(vcpu, addr);
+
+       if (!VALID_PAGE(root_hpa))
+               return;
+
+       write_lock(&vcpu->kvm->mmu_lock);
+       for_each_shadow_entry_using_root(vcpu, root_hpa, addr, iterator) {
+               struct kvm_mmu_page *sp = sptep_to_sp(iterator.sptep);
+
+               if (sp->unsync) {
+                       int ret = kvm_sync_spte(vcpu, sp, iterator.index);
+
+                       if (ret < 0)
+                               mmu_page_zap_pte(vcpu->kvm, sp, iterator.sptep, NULL);
+                       if (ret)
+                               kvm_flush_remote_tlbs_sptep(vcpu->kvm, iterator.sptep);
+               }
+
+               if (!sp->unsync_children)
+                       break;
+       }
+       write_unlock(&vcpu->kvm->mmu_lock);
+}
+
+void kvm_mmu_invalidate_addr(struct kvm_vcpu *vcpu, struct kvm_mmu *mmu,
+                            u64 addr, unsigned long roots)
 {
        int i;
 
+       WARN_ON_ONCE(roots & ~KVM_MMU_ROOTS_ALL);
+
        /* It's actually a GPA for vcpu->arch.guest_mmu.  */
        if (mmu != &vcpu->arch.guest_mmu) {
                /* INVLPG on a non-canonical address is a NOP according to the SDM.  */
-               if (is_noncanonical_address(gva, vcpu))
+               if (is_noncanonical_address(addr, vcpu))
                        return;
 
-               static_call(kvm_x86_flush_tlb_gva)(vcpu, gva);
+               static_call(kvm_x86_flush_tlb_gva)(vcpu, addr);
        }
 
-       if (!mmu->invlpg)
+       if (!mmu->sync_spte)
                return;
 
-       if (root_hpa == INVALID_PAGE) {
-               mmu->invlpg(vcpu, gva, mmu->root.hpa);
+       if (roots & KVM_MMU_ROOT_CURRENT)
+               __kvm_mmu_invalidate_addr(vcpu, mmu, addr, mmu->root.hpa);
 
-               /*
-                * INVLPG is required to invalidate any global mappings for the VA,
-                * irrespective of PCID. Since it would take us roughly similar amount
-                * of work to determine whether any of the prev_root mappings of the VA
-                * is marked global, or to just sync it blindly, so we might as well
-                * just always sync it.
-                *
-                * Mappings not reachable via the current cr3 or the prev_roots will be
-                * synced when switching to that cr3, so nothing needs to be done here
-                * for them.
-                */
-               for (i = 0; i < KVM_MMU_NUM_PREV_ROOTS; i++)
-                       if (VALID_PAGE(mmu->prev_roots[i].hpa))
-                               mmu->invlpg(vcpu, gva, mmu->prev_roots[i].hpa);
-       } else {
-               mmu->invlpg(vcpu, gva, root_hpa);
+       for (i = 0; i < KVM_MMU_NUM_PREV_ROOTS; i++) {
+               if (roots & KVM_MMU_ROOT_PREVIOUS(i))
+                       __kvm_mmu_invalidate_addr(vcpu, mmu, addr, mmu->prev_roots[i].hpa);
        }
 }
+EXPORT_SYMBOL_GPL(kvm_mmu_invalidate_addr);
 
 void kvm_mmu_invlpg(struct kvm_vcpu *vcpu, gva_t gva)
 {
-       kvm_mmu_invalidate_gva(vcpu, vcpu->arch.walk_mmu, gva, INVALID_PAGE);
+       /*
+        * INVLPG is required to invalidate any global mappings for the VA,
+        * irrespective of PCID.  Blindly sync all roots as it would take
+        * roughly the same amount of work/time to determine whether any of the
+        * previous roots have a global mapping.
+        *
+        * Mappings not reachable via the current or previous cached roots will
+        * be synced when switching to that new cr3, so nothing needs to be
+        * done here for them.
+        */
+       kvm_mmu_invalidate_addr(vcpu, vcpu->arch.walk_mmu, gva, KVM_MMU_ROOTS_ALL);
        ++vcpu->stat.invlpg;
 }
 EXPORT_SYMBOL_GPL(kvm_mmu_invlpg);
@@ -5772,27 +5869,20 @@ EXPORT_SYMBOL_GPL(kvm_mmu_invlpg);
 void kvm_mmu_invpcid_gva(struct kvm_vcpu *vcpu, gva_t gva, unsigned long pcid)
 {
        struct kvm_mmu *mmu = vcpu->arch.mmu;
-       bool tlb_flush = false;
+       unsigned long roots = 0;
        uint i;
 
-       if (pcid == kvm_get_active_pcid(vcpu)) {
-               if (mmu->invlpg)
-                       mmu->invlpg(vcpu, gva, mmu->root.hpa);
-               tlb_flush = true;
-       }
+       if (pcid == kvm_get_active_pcid(vcpu))
+               roots |= KVM_MMU_ROOT_CURRENT;
 
        for (i = 0; i < KVM_MMU_NUM_PREV_ROOTS; i++) {
                if (VALID_PAGE(mmu->prev_roots[i].hpa) &&
-                   pcid == kvm_get_pcid(vcpu, mmu->prev_roots[i].pgd)) {
-                       if (mmu->invlpg)
-                               mmu->invlpg(vcpu, gva, mmu->prev_roots[i].hpa);
-                       tlb_flush = true;
-               }
+                   pcid == kvm_get_pcid(vcpu, mmu->prev_roots[i].pgd))
+                       roots |= KVM_MMU_ROOT_PREVIOUS(i);
        }
 
-       if (tlb_flush)
-               static_call(kvm_x86_flush_tlb_gva)(vcpu, gva);
-
+       if (roots)
+               kvm_mmu_invalidate_addr(vcpu, mmu, gva, roots);
        ++vcpu->stat.invlpg;
 
        /*
@@ -5829,29 +5919,30 @@ void kvm_configure_mmu(bool enable_tdp, int tdp_forced_root_level,
 EXPORT_SYMBOL_GPL(kvm_configure_mmu);
 
 /* The return value indicates if tlb flush on all vcpus is needed. */
-typedef bool (*slot_level_handler) (struct kvm *kvm,
+typedef bool (*slot_rmaps_handler) (struct kvm *kvm,
                                    struct kvm_rmap_head *rmap_head,
                                    const struct kvm_memory_slot *slot);
 
-/* The caller should hold mmu-lock before calling this function. */
-static __always_inline bool
-slot_handle_level_range(struct kvm *kvm, const struct kvm_memory_slot *memslot,
-                       slot_level_handler fn, int start_level, int end_level,
-                       gfn_t start_gfn, gfn_t end_gfn, bool flush_on_yield,
-                       bool flush)
+static __always_inline bool __walk_slot_rmaps(struct kvm *kvm,
+                                             const struct kvm_memory_slot *slot,
+                                             slot_rmaps_handler fn,
+                                             int start_level, int end_level,
+                                             gfn_t start_gfn, gfn_t end_gfn,
+                                             bool flush_on_yield, bool flush)
 {
        struct slot_rmap_walk_iterator iterator;
 
-       for_each_slot_rmap_range(memslot, start_level, end_level, start_gfn,
+       lockdep_assert_held_write(&kvm->mmu_lock);
+
+       for_each_slot_rmap_range(slot, start_level, end_level, start_gfn,
                        end_gfn, &iterator) {
                if (iterator.rmap)
-                       flush |= fn(kvm, iterator.rmap, memslot);
+                       flush |= fn(kvm, iterator.rmap, slot);
 
                if (need_resched() || rwlock_needbreak(&kvm->mmu_lock)) {
                        if (flush && flush_on_yield) {
-                               kvm_flush_remote_tlbs_with_address(kvm,
-                                               start_gfn,
-                                               iterator.gfn - start_gfn + 1);
+                               kvm_flush_remote_tlbs_range(kvm, start_gfn,
+                                                           iterator.gfn - start_gfn + 1);
                                flush = false;
                        }
                        cond_resched_rwlock_write(&kvm->mmu_lock);
@@ -5861,23 +5952,23 @@ slot_handle_level_range(struct kvm *kvm, const struct kvm_memory_slot *memslot,
        return flush;
 }
 
-static __always_inline bool
-slot_handle_level(struct kvm *kvm, const struct kvm_memory_slot *memslot,
-                 slot_level_handler fn, int start_level, int end_level,
-                 bool flush_on_yield)
+static __always_inline bool walk_slot_rmaps(struct kvm *kvm,
+                                           const struct kvm_memory_slot *slot,
+                                           slot_rmaps_handler fn,
+                                           int start_level, int end_level,
+                                           bool flush_on_yield)
 {
-       return slot_handle_level_range(kvm, memslot, fn, start_level,
-                       end_level, memslot->base_gfn,
-                       memslot->base_gfn + memslot->npages - 1,
-                       flush_on_yield, false);
+       return __walk_slot_rmaps(kvm, slot, fn, start_level, end_level,
+                                slot->base_gfn, slot->base_gfn + slot->npages - 1,
+                                flush_on_yield, false);
 }
 
-static __always_inline bool
-slot_handle_level_4k(struct kvm *kvm, const struct kvm_memory_slot *memslot,
-                    slot_level_handler fn, bool flush_on_yield)
+static __always_inline bool walk_slot_rmaps_4k(struct kvm *kvm,
+                                              const struct kvm_memory_slot *slot,
+                                              slot_rmaps_handler fn,
+                                              bool flush_on_yield)
 {
-       return slot_handle_level(kvm, memslot, fn, PG_LEVEL_4K,
-                                PG_LEVEL_4K, flush_on_yield);
+       return walk_slot_rmaps(kvm, slot, fn, PG_LEVEL_4K, PG_LEVEL_4K, flush_on_yield);
 }
 
 static void free_mmu_pages(struct kvm_mmu *mmu)
@@ -6172,9 +6263,9 @@ static bool kvm_rmap_zap_gfn_range(struct kvm *kvm, gfn_t gfn_start, gfn_t gfn_e
                        if (WARN_ON_ONCE(start >= end))
                                continue;
 
-                       flush = slot_handle_level_range(kvm, memslot, __kvm_zap_rmap,
-                                                       PG_LEVEL_4K, KVM_MAX_HUGEPAGE_LEVEL,
-                                                       start, end - 1, true, flush);
+                       flush = __walk_slot_rmaps(kvm, memslot, __kvm_zap_rmap,
+                                                 PG_LEVEL_4K, KVM_MAX_HUGEPAGE_LEVEL,
+                                                 start, end - 1, true, flush);
                }
        }
 
@@ -6206,8 +6297,7 @@ void kvm_zap_gfn_range(struct kvm *kvm, gfn_t gfn_start, gfn_t gfn_end)
        }
 
        if (flush)
-               kvm_flush_remote_tlbs_with_address(kvm, gfn_start,
-                                                  gfn_end - gfn_start);
+               kvm_flush_remote_tlbs_range(kvm, gfn_start, gfn_end - gfn_start);
 
        kvm_mmu_invalidate_end(kvm, 0, -1ul);
 
@@ -6227,8 +6317,8 @@ void kvm_mmu_slot_remove_write_access(struct kvm *kvm,
 {
        if (kvm_memslots_have_rmaps(kvm)) {
                write_lock(&kvm->mmu_lock);
-               slot_handle_level(kvm, memslot, slot_rmap_write_protect,
-                                 start_level, KVM_MAX_HUGEPAGE_LEVEL, false);
+               walk_slot_rmaps(kvm, memslot, slot_rmap_write_protect,
+                               start_level, KVM_MAX_HUGEPAGE_LEVEL, false);
                write_unlock(&kvm->mmu_lock);
        }
 
@@ -6463,10 +6553,9 @@ static void kvm_shadow_mmu_try_split_huge_pages(struct kvm *kvm,
         * all the way to the target level. There's no need to split pages
         * already at the target level.
         */
-       for (level = KVM_MAX_HUGEPAGE_LEVEL; level > target_level; level--) {
-               slot_handle_level_range(kvm, slot, shadow_mmu_try_split_huge_pages,
-                                       level, level, start, end - 1, true, false);
-       }
+       for (level = KVM_MAX_HUGEPAGE_LEVEL; level > target_level; level--)
+               __walk_slot_rmaps(kvm, slot, shadow_mmu_try_split_huge_pages,
+                                 level, level, start, end - 1, true, false);
 }
 
 /* Must be called with the mmu_lock held in write-mode. */
@@ -6545,7 +6634,7 @@ restart:
                                                               PG_LEVEL_NUM)) {
                        kvm_zap_one_rmap_spte(kvm, rmap_head, sptep);
 
-                       if (kvm_available_flush_tlb_with_range())
+                       if (kvm_available_flush_remote_tlbs_range())
                                kvm_flush_remote_tlbs_sptep(kvm, sptep);
                        else
                                need_tlb_flush = 1;
@@ -6564,8 +6653,8 @@ static void kvm_rmap_zap_collapsible_sptes(struct kvm *kvm,
         * Note, use KVM_MAX_HUGEPAGE_LEVEL - 1 since there's no need to zap
         * pages that are already mapped at the maximum hugepage level.
         */
-       if (slot_handle_level(kvm, slot, kvm_mmu_zap_collapsible_spte,
-                             PG_LEVEL_4K, KVM_MAX_HUGEPAGE_LEVEL - 1, true))
+       if (walk_slot_rmaps(kvm, slot, kvm_mmu_zap_collapsible_spte,
+                           PG_LEVEL_4K, KVM_MAX_HUGEPAGE_LEVEL - 1, true))
                kvm_arch_flush_remote_tlbs_memslot(kvm, slot);
 }
 
@@ -6596,8 +6685,7 @@ void kvm_arch_flush_remote_tlbs_memslot(struct kvm *kvm,
         * is observed by any other operation on the same memslot.
         */
        lockdep_assert_held(&kvm->slots_lock);
-       kvm_flush_remote_tlbs_with_address(kvm, memslot->base_gfn,
-                                          memslot->npages);
+       kvm_flush_remote_tlbs_range(kvm, memslot->base_gfn, memslot->npages);
 }
 
 void kvm_mmu_slot_leaf_clear_dirty(struct kvm *kvm,
@@ -6609,7 +6697,7 @@ void kvm_mmu_slot_leaf_clear_dirty(struct kvm *kvm,
                 * Clear dirty bits only on 4k SPTEs since the legacy MMU only
                 * support dirty logging at a 4k granularity.
                 */
-               slot_handle_level_4k(kvm, memslot, __rmap_clear_dirty, false);
+               walk_slot_rmaps_4k(kvm, memslot, __rmap_clear_dirty, false);
                write_unlock(&kvm->mmu_lock);
        }
 
@@ -6679,8 +6767,8 @@ void kvm_mmu_invalidate_mmio_sptes(struct kvm *kvm, u64 gen)
        }
 }
 
-static unsigned long
-mmu_shrink_scan(struct shrinker *shrink, struct shrink_control *sc)
+static unsigned long mmu_shrink_scan(struct shrinker *shrink,
+                                    struct shrink_control *sc)
 {
        struct kvm *kvm;
        int nr_to_scan = sc->nr_to_scan;
@@ -6738,8 +6826,8 @@ unlock:
        return freed;
 }
 
-static unsigned long
-mmu_shrink_count(struct shrinker *shrink, struct shrink_control *sc)
+static unsigned long mmu_shrink_count(struct shrinker *shrink,
+                                     struct shrink_control *sc)
 {
        return percpu_counter_read_positive(&kvm_total_used_mmu_pages);
 }