Merge tag 'kvm-4.20-1' of git://git.kernel.org/pub/scm/virt/kvm/kvm
[platform/kernel/linux-starfive.git] / arch / x86 / kvm / mmu.c
index e843ec4..cf5f572 100644 (file)
@@ -932,7 +932,7 @@ static int mmu_topup_memory_cache(struct kvm_mmu_memory_cache *cache,
        while (cache->nobjs < ARRAY_SIZE(cache->objects)) {
                obj = kmem_cache_zalloc(base_cache, GFP_KERNEL);
                if (!obj)
-                       return -ENOMEM;
+                       return cache->nobjs >= min ? 0 : -ENOMEM;
                cache->objects[cache->nobjs++] = obj;
        }
        return 0;
@@ -960,7 +960,7 @@ static int mmu_topup_memory_cache_page(struct kvm_mmu_memory_cache *cache,
        while (cache->nobjs < ARRAY_SIZE(cache->objects)) {
                page = (void *)__get_free_page(GFP_KERNEL_ACCOUNT);
                if (!page)
-                       return -ENOMEM;
+                       return cache->nobjs >= min ? 0 : -ENOMEM;
                cache->objects[cache->nobjs++] = page;
        }
        return 0;
@@ -1265,24 +1265,24 @@ pte_list_desc_remove_entry(struct kvm_rmap_head *rmap_head,
        mmu_free_pte_list_desc(desc);
 }
 
-static void pte_list_remove(u64 *spte, struct kvm_rmap_head *rmap_head)
+static void __pte_list_remove(u64 *spte, struct kvm_rmap_head *rmap_head)
 {
        struct pte_list_desc *desc;
        struct pte_list_desc *prev_desc;
        int i;
 
        if (!rmap_head->val) {
-               printk(KERN_ERR "pte_list_remove: %p 0->BUG\n", spte);
+               pr_err("%s: %p 0->BUG\n", __func__, spte);
                BUG();
        } else if (!(rmap_head->val & 1)) {
-               rmap_printk("pte_list_remove:  %p 1->0\n", spte);
+               rmap_printk("%s:  %p 1->0\n", __func__, spte);
                if ((u64 *)rmap_head->val != spte) {
-                       printk(KERN_ERR "pte_list_remove:  %p 1->BUG\n", spte);
+                       pr_err("%s:  %p 1->BUG\n", __func__, spte);
                        BUG();
                }
                rmap_head->val = 0;
        } else {
-               rmap_printk("pte_list_remove:  %p many->many\n", spte);
+               rmap_printk("%s:  %p many->many\n", __func__, spte);
                desc = (struct pte_list_desc *)(rmap_head->val & ~1ul);
                prev_desc = NULL;
                while (desc) {
@@ -1296,11 +1296,17 @@ static void pte_list_remove(u64 *spte, struct kvm_rmap_head *rmap_head)
                        prev_desc = desc;
                        desc = desc->more;
                }
-               pr_err("pte_list_remove: %p many->many\n", spte);
+               pr_err("%s: %p many->many\n", __func__, spte);
                BUG();
        }
 }
 
+static void pte_list_remove(struct kvm_rmap_head *rmap_head, u64 *sptep)
+{
+       mmu_spte_clear_track_bits(sptep);
+       __pte_list_remove(sptep, rmap_head);
+}
+
 static struct kvm_rmap_head *__gfn_to_rmap(gfn_t gfn, int level,
                                           struct kvm_memory_slot *slot)
 {
@@ -1349,7 +1355,7 @@ static void rmap_remove(struct kvm *kvm, u64 *spte)
        sp = page_header(__pa(spte));
        gfn = kvm_mmu_page_get_gfn(sp, spte - sp->spt);
        rmap_head = gfn_to_rmap(kvm, gfn, sp);
-       pte_list_remove(spte, rmap_head);
+       __pte_list_remove(spte, rmap_head);
 }
 
 /*
@@ -1685,7 +1691,7 @@ static bool kvm_zap_rmapp(struct kvm *kvm, struct kvm_rmap_head *rmap_head)
        while ((sptep = rmap_get_first(rmap_head, &iter))) {
                rmap_printk("%s: spte %p %llx.\n", __func__, sptep, *sptep);
 
-               drop_spte(kvm, sptep);
+               pte_list_remove(rmap_head, sptep);
                flush = true;
        }
 
@@ -1721,7 +1727,7 @@ restart:
                need_flush = 1;
 
                if (pte_write(*ptep)) {
-                       drop_spte(kvm, sptep);
+                       pte_list_remove(rmap_head, sptep);
                        goto restart;
                } else {
                        new_spte = *sptep & ~PT64_BASE_ADDR_MASK;
@@ -1988,7 +1994,7 @@ static void mmu_page_add_parent_pte(struct kvm_vcpu *vcpu,
 static void mmu_page_remove_parent_pte(struct kvm_mmu_page *sp,
                                       u64 *parent_pte)
 {
-       pte_list_remove(parent_pte, &sp->parent_ptes);
+       __pte_list_remove(parent_pte, &sp->parent_ptes);
 }
 
 static void drop_parent_pte(struct kvm_mmu_page *sp,
@@ -2181,7 +2187,7 @@ static bool __kvm_sync_page(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp,
                            struct list_head *invalid_list)
 {
        if (sp->role.cr4_pae != !!is_pae(vcpu)
-           || vcpu->arch.mmu.sync_page(vcpu, sp) == 0) {
+           || vcpu->arch.mmu->sync_page(vcpu, sp) == 0) {
                kvm_mmu_prepare_zap_page(vcpu->kvm, sp, invalid_list);
                return false;
        }
@@ -2375,14 +2381,14 @@ static struct kvm_mmu_page *kvm_mmu_get_page(struct kvm_vcpu *vcpu,
        int collisions = 0;
        LIST_HEAD(invalid_list);
 
-       role = vcpu->arch.mmu.base_role;
+       role = vcpu->arch.mmu->mmu_role.base;
        role.level = level;
        role.direct = direct;
        if (role.direct)
                role.cr4_pae = 0;
        role.access = access;
-       if (!vcpu->arch.mmu.direct_map
-           && vcpu->arch.mmu.root_level <= PT32_ROOT_LEVEL) {
+       if (!vcpu->arch.mmu->direct_map
+           && vcpu->arch.mmu->root_level <= PT32_ROOT_LEVEL) {
                quadrant = gaddr >> (PAGE_SHIFT + (PT64_PT_BITS * level));
                quadrant &= (1 << ((PT32_PT_BITS - PT64_PT_BITS) * level)) - 1;
                role.quadrant = quadrant;
@@ -2457,11 +2463,11 @@ static void shadow_walk_init_using_root(struct kvm_shadow_walk_iterator *iterato
 {
        iterator->addr = addr;
        iterator->shadow_addr = root;
-       iterator->level = vcpu->arch.mmu.shadow_root_level;
+       iterator->level = vcpu->arch.mmu->shadow_root_level;
 
        if (iterator->level == PT64_ROOT_4LEVEL &&
-           vcpu->arch.mmu.root_level < PT64_ROOT_4LEVEL &&
-           !vcpu->arch.mmu.direct_map)
+           vcpu->arch.mmu->root_level < PT64_ROOT_4LEVEL &&
+           !vcpu->arch.mmu->direct_map)
                --iterator->level;
 
        if (iterator->level == PT32E_ROOT_LEVEL) {
@@ -2469,10 +2475,10 @@ static void shadow_walk_init_using_root(struct kvm_shadow_walk_iterator *iterato
                 * prev_root is currently only used for 64-bit hosts. So only
                 * the active root_hpa is valid here.
                 */
-               BUG_ON(root != vcpu->arch.mmu.root_hpa);
+               BUG_ON(root != vcpu->arch.mmu->root_hpa);
 
                iterator->shadow_addr
-                       = vcpu->arch.mmu.pae_root[(addr >> 30) & 3];
+                       = vcpu->arch.mmu->pae_root[(addr >> 30) & 3];
                iterator->shadow_addr &= PT64_BASE_ADDR_MASK;
                --iterator->level;
                if (!iterator->shadow_addr)
@@ -2483,7 +2489,7 @@ static void shadow_walk_init_using_root(struct kvm_shadow_walk_iterator *iterato
 static void shadow_walk_init(struct kvm_shadow_walk_iterator *iterator,
                             struct kvm_vcpu *vcpu, u64 addr)
 {
-       shadow_walk_init_using_root(iterator, vcpu, vcpu->arch.mmu.root_hpa,
+       shadow_walk_init_using_root(iterator, vcpu, vcpu->arch.mmu->root_hpa,
                                    addr);
 }
 
@@ -3095,7 +3101,7 @@ static int __direct_map(struct kvm_vcpu *vcpu, int write, int map_writable,
        int emulate = 0;
        gfn_t pseudo_gfn;
 
-       if (!VALID_PAGE(vcpu->arch.mmu.root_hpa))
+       if (!VALID_PAGE(vcpu->arch.mmu->root_hpa))
                return 0;
 
        for_each_shadow_entry(vcpu, (u64)gfn << PAGE_SHIFT, iterator) {
@@ -3301,7 +3307,7 @@ static bool fast_page_fault(struct kvm_vcpu *vcpu, gva_t gva, int level,
        u64 spte = 0ull;
        uint retry_count = 0;
 
-       if (!VALID_PAGE(vcpu->arch.mmu.root_hpa))
+       if (!VALID_PAGE(vcpu->arch.mmu->root_hpa))
                return false;
 
        if (!page_fault_can_be_fast(error_code))
@@ -3471,11 +3477,11 @@ static void mmu_free_root_page(struct kvm *kvm, hpa_t *root_hpa,
 }
 
 /* roots_to_free must be some combination of the KVM_MMU_ROOT_* flags */
-void kvm_mmu_free_roots(struct kvm_vcpu *vcpu, ulong roots_to_free)
+void kvm_mmu_free_roots(struct kvm_vcpu *vcpu, struct kvm_mmu *mmu,
+                       ulong roots_to_free)
 {
        int i;
        LIST_HEAD(invalid_list);
-       struct kvm_mmu *mmu = &vcpu->arch.mmu;
        bool free_active_root = roots_to_free & KVM_MMU_ROOT_CURRENT;
 
        BUILD_BUG_ON(KVM_MMU_NUM_PREV_ROOTS >= BITS_PER_LONG);
@@ -3535,20 +3541,20 @@ static int mmu_alloc_direct_roots(struct kvm_vcpu *vcpu)
        struct kvm_mmu_page *sp;
        unsigned i;
 
-       if (vcpu->arch.mmu.shadow_root_level >= PT64_ROOT_4LEVEL) {
+       if (vcpu->arch.mmu->shadow_root_level >= PT64_ROOT_4LEVEL) {
                spin_lock(&vcpu->kvm->mmu_lock);
                if(make_mmu_pages_available(vcpu) < 0) {
                        spin_unlock(&vcpu->kvm->mmu_lock);
                        return -ENOSPC;
                }
                sp = kvm_mmu_get_page(vcpu, 0, 0,
-                               vcpu->arch.mmu.shadow_root_level, 1, ACC_ALL);
+                               vcpu->arch.mmu->shadow_root_level, 1, ACC_ALL);
                ++sp->root_count;
                spin_unlock(&vcpu->kvm->mmu_lock);
-               vcpu->arch.mmu.root_hpa = __pa(sp->spt);
-       } else if (vcpu->arch.mmu.shadow_root_level == PT32E_ROOT_LEVEL) {
+               vcpu->arch.mmu->root_hpa = __pa(sp->spt);
+       } else if (vcpu->arch.mmu->shadow_root_level == PT32E_ROOT_LEVEL) {
                for (i = 0; i < 4; ++i) {
-                       hpa_t root = vcpu->arch.mmu.pae_root[i];
+                       hpa_t root = vcpu->arch.mmu->pae_root[i];
 
                        MMU_WARN_ON(VALID_PAGE(root));
                        spin_lock(&vcpu->kvm->mmu_lock);
@@ -3561,9 +3567,9 @@ static int mmu_alloc_direct_roots(struct kvm_vcpu *vcpu)
                        root = __pa(sp->spt);
                        ++sp->root_count;
                        spin_unlock(&vcpu->kvm->mmu_lock);
-                       vcpu->arch.mmu.pae_root[i] = root | PT_PRESENT_MASK;
+                       vcpu->arch.mmu->pae_root[i] = root | PT_PRESENT_MASK;
                }
-               vcpu->arch.mmu.root_hpa = __pa(vcpu->arch.mmu.pae_root);
+               vcpu->arch.mmu->root_hpa = __pa(vcpu->arch.mmu->pae_root);
        } else
                BUG();
 
@@ -3577,7 +3583,7 @@ static int mmu_alloc_shadow_roots(struct kvm_vcpu *vcpu)
        gfn_t root_gfn;
        int i;
 
-       root_gfn = vcpu->arch.mmu.get_cr3(vcpu) >> PAGE_SHIFT;
+       root_gfn = vcpu->arch.mmu->get_cr3(vcpu) >> PAGE_SHIFT;
 
        if (mmu_check_root(vcpu, root_gfn))
                return 1;
@@ -3586,8 +3592,8 @@ static int mmu_alloc_shadow_roots(struct kvm_vcpu *vcpu)
         * Do we shadow a long mode page table? If so we need to
         * write-protect the guests page table root.
         */
-       if (vcpu->arch.mmu.root_level >= PT64_ROOT_4LEVEL) {
-               hpa_t root = vcpu->arch.mmu.root_hpa;
+       if (vcpu->arch.mmu->root_level >= PT64_ROOT_4LEVEL) {
+               hpa_t root = vcpu->arch.mmu->root_hpa;
 
                MMU_WARN_ON(VALID_PAGE(root));
 
@@ -3597,11 +3603,11 @@ static int mmu_alloc_shadow_roots(struct kvm_vcpu *vcpu)
                        return -ENOSPC;
                }
                sp = kvm_mmu_get_page(vcpu, root_gfn, 0,
-                               vcpu->arch.mmu.shadow_root_level, 0, ACC_ALL);
+                               vcpu->arch.mmu->shadow_root_level, 0, ACC_ALL);
                root = __pa(sp->spt);
                ++sp->root_count;
                spin_unlock(&vcpu->kvm->mmu_lock);
-               vcpu->arch.mmu.root_hpa = root;
+               vcpu->arch.mmu->root_hpa = root;
                return 0;
        }
 
@@ -3611,17 +3617,17 @@ static int mmu_alloc_shadow_roots(struct kvm_vcpu *vcpu)
         * the shadow page table may be a PAE or a long mode page table.
         */
        pm_mask = PT_PRESENT_MASK;
-       if (vcpu->arch.mmu.shadow_root_level == PT64_ROOT_4LEVEL)
+       if (vcpu->arch.mmu->shadow_root_level == PT64_ROOT_4LEVEL)
                pm_mask |= PT_ACCESSED_MASK | PT_WRITABLE_MASK | PT_USER_MASK;
 
        for (i = 0; i < 4; ++i) {
-               hpa_t root = vcpu->arch.mmu.pae_root[i];
+               hpa_t root = vcpu->arch.mmu->pae_root[i];
 
                MMU_WARN_ON(VALID_PAGE(root));
-               if (vcpu->arch.mmu.root_level == PT32E_ROOT_LEVEL) {
-                       pdptr = vcpu->arch.mmu.get_pdptr(vcpu, i);
+               if (vcpu->arch.mmu->root_level == PT32E_ROOT_LEVEL) {
+                       pdptr = vcpu->arch.mmu->get_pdptr(vcpu, i);
                        if (!(pdptr & PT_PRESENT_MASK)) {
-                               vcpu->arch.mmu.pae_root[i] = 0;
+                               vcpu->arch.mmu->pae_root[i] = 0;
                                continue;
                        }
                        root_gfn = pdptr >> PAGE_SHIFT;
@@ -3639,16 +3645,16 @@ static int mmu_alloc_shadow_roots(struct kvm_vcpu *vcpu)
                ++sp->root_count;
                spin_unlock(&vcpu->kvm->mmu_lock);
 
-               vcpu->arch.mmu.pae_root[i] = root | pm_mask;
+               vcpu->arch.mmu->pae_root[i] = root | pm_mask;
        }
-       vcpu->arch.mmu.root_hpa = __pa(vcpu->arch.mmu.pae_root);
+       vcpu->arch.mmu->root_hpa = __pa(vcpu->arch.mmu->pae_root);
 
        /*
         * If we shadow a 32 bit page table with a long mode page
         * table we enter this path.
         */
-       if (vcpu->arch.mmu.shadow_root_level == PT64_ROOT_4LEVEL) {
-               if (vcpu->arch.mmu.lm_root == NULL) {
+       if (vcpu->arch.mmu->shadow_root_level == PT64_ROOT_4LEVEL) {
+               if (vcpu->arch.mmu->lm_root == NULL) {
                        /*
                         * The additional page necessary for this is only
                         * allocated on demand.
@@ -3660,12 +3666,12 @@ static int mmu_alloc_shadow_roots(struct kvm_vcpu *vcpu)
                        if (lm_root == NULL)
                                return 1;
 
-                       lm_root[0] = __pa(vcpu->arch.mmu.pae_root) | pm_mask;
+                       lm_root[0] = __pa(vcpu->arch.mmu->pae_root) | pm_mask;
 
-                       vcpu->arch.mmu.lm_root = lm_root;
+                       vcpu->arch.mmu->lm_root = lm_root;
                }
 
-               vcpu->arch.mmu.root_hpa = __pa(vcpu->arch.mmu.lm_root);
+               vcpu->arch.mmu->root_hpa = __pa(vcpu->arch.mmu->lm_root);
        }
 
        return 0;
@@ -3673,7 +3679,7 @@ static int mmu_alloc_shadow_roots(struct kvm_vcpu *vcpu)
 
 static int mmu_alloc_roots(struct kvm_vcpu *vcpu)
 {
-       if (vcpu->arch.mmu.direct_map)
+       if (vcpu->arch.mmu->direct_map)
                return mmu_alloc_direct_roots(vcpu);
        else
                return mmu_alloc_shadow_roots(vcpu);
@@ -3684,17 +3690,16 @@ void kvm_mmu_sync_roots(struct kvm_vcpu *vcpu)
        int i;
        struct kvm_mmu_page *sp;
 
-       if (vcpu->arch.mmu.direct_map)
+       if (vcpu->arch.mmu->direct_map)
                return;
 
-       if (!VALID_PAGE(vcpu->arch.mmu.root_hpa))
+       if (!VALID_PAGE(vcpu->arch.mmu->root_hpa))
                return;
 
        vcpu_clear_mmio_info(vcpu, MMIO_GVA_ANY);
 
-       if (vcpu->arch.mmu.root_level >= PT64_ROOT_4LEVEL) {
-               hpa_t root = vcpu->arch.mmu.root_hpa;
-
+       if (vcpu->arch.mmu->root_level >= PT64_ROOT_4LEVEL) {
+               hpa_t root = vcpu->arch.mmu->root_hpa;
                sp = page_header(root);
 
                /*
@@ -3725,7 +3730,7 @@ void kvm_mmu_sync_roots(struct kvm_vcpu *vcpu)
        kvm_mmu_audit(vcpu, AUDIT_PRE_SYNC);
 
        for (i = 0; i < 4; ++i) {
-               hpa_t root = vcpu->arch.mmu.pae_root[i];
+               hpa_t root = vcpu->arch.mmu->pae_root[i];
 
                if (root && VALID_PAGE(root)) {
                        root &= PT64_BASE_ADDR_MASK;
@@ -3799,7 +3804,7 @@ walk_shadow_page_get_mmio_spte(struct kvm_vcpu *vcpu, u64 addr, u64 *sptep)
        int root, leaf;
        bool reserved = false;
 
-       if (!VALID_PAGE(vcpu->arch.mmu.root_hpa))
+       if (!VALID_PAGE(vcpu->arch.mmu->root_hpa))
                goto exit;
 
        walk_shadow_page_lockless_begin(vcpu);
@@ -3816,7 +3821,7 @@ walk_shadow_page_get_mmio_spte(struct kvm_vcpu *vcpu, u64 addr, u64 *sptep)
                if (!is_shadow_present_pte(spte))
                        break;
 
-               reserved |= is_shadow_zero_bits_set(&vcpu->arch.mmu, spte,
+               reserved |= is_shadow_zero_bits_set(vcpu->arch.mmu, spte,
                                                    iterator.level);
        }
 
@@ -3895,7 +3900,7 @@ static void shadow_page_table_clear_flood(struct kvm_vcpu *vcpu, gva_t addr)
        struct kvm_shadow_walk_iterator iterator;
        u64 spte;
 
-       if (!VALID_PAGE(vcpu->arch.mmu.root_hpa))
+       if (!VALID_PAGE(vcpu->arch.mmu->root_hpa))
                return;
 
        walk_shadow_page_lockless_begin(vcpu);
@@ -3922,7 +3927,7 @@ static int nonpaging_page_fault(struct kvm_vcpu *vcpu, gva_t gva,
        if (r)
                return r;
 
-       MMU_WARN_ON(!VALID_PAGE(vcpu->arch.mmu.root_hpa));
+       MMU_WARN_ON(!VALID_PAGE(vcpu->arch.mmu->root_hpa));
 
 
        return nonpaging_map(vcpu, gva & PAGE_MASK,
@@ -3935,8 +3940,8 @@ static int kvm_arch_setup_async_pf(struct kvm_vcpu *vcpu, gva_t gva, gfn_t gfn)
 
        arch.token = (vcpu->arch.apf.id++ << 12) | vcpu->vcpu_id;
        arch.gfn = gfn;
-       arch.direct_map = vcpu->arch.mmu.direct_map;
-       arch.cr3 = vcpu->arch.mmu.get_cr3(vcpu);
+       arch.direct_map = vcpu->arch.mmu->direct_map;
+       arch.cr3 = vcpu->arch.mmu->get_cr3(vcpu);
 
        return kvm_setup_async_pf(vcpu, gva, kvm_vcpu_gfn_to_hva(vcpu, gfn), &arch);
 }
@@ -4042,7 +4047,7 @@ static int tdp_page_fault(struct kvm_vcpu *vcpu, gva_t gpa, u32 error_code,
        int write = error_code & PFERR_WRITE_MASK;
        bool map_writable;
 
-       MMU_WARN_ON(!VALID_PAGE(vcpu->arch.mmu.root_hpa));
+       MMU_WARN_ON(!VALID_PAGE(vcpu->arch.mmu->root_hpa));
 
        if (page_fault_handle_page_track(vcpu, error_code, gfn))
                return RET_PF_EMULATE;
@@ -4118,7 +4123,7 @@ static bool cached_root_available(struct kvm_vcpu *vcpu, gpa_t new_cr3,
 {
        uint i;
        struct kvm_mmu_root_info root;
-       struct kvm_mmu *mmu = &vcpu->arch.mmu;
+       struct kvm_mmu *mmu = vcpu->arch.mmu;
 
        root.cr3 = mmu->get_cr3(vcpu);
        root.hpa = mmu->root_hpa;
@@ -4141,7 +4146,7 @@ static bool fast_cr3_switch(struct kvm_vcpu *vcpu, gpa_t new_cr3,
                            union kvm_mmu_page_role new_role,
                            bool skip_tlb_flush)
 {
-       struct kvm_mmu *mmu = &vcpu->arch.mmu;
+       struct kvm_mmu *mmu = vcpu->arch.mmu;
 
        /*
         * For now, limit the fast switch to 64-bit hosts+VMs in order to avoid
@@ -4192,7 +4197,8 @@ static void __kvm_mmu_new_cr3(struct kvm_vcpu *vcpu, gpa_t new_cr3,
                              bool skip_tlb_flush)
 {
        if (!fast_cr3_switch(vcpu, new_cr3, new_role, skip_tlb_flush))
-               kvm_mmu_free_roots(vcpu, KVM_MMU_ROOT_CURRENT);
+               kvm_mmu_free_roots(vcpu, vcpu->arch.mmu,
+                                  KVM_MMU_ROOT_CURRENT);
 }
 
 void kvm_mmu_new_cr3(struct kvm_vcpu *vcpu, gpa_t new_cr3, bool skip_tlb_flush)
@@ -4210,7 +4216,7 @@ static unsigned long get_cr3(struct kvm_vcpu *vcpu)
 static void inject_page_fault(struct kvm_vcpu *vcpu,
                              struct x86_exception *fault)
 {
-       vcpu->arch.mmu.inject_page_fault(vcpu, fault);
+       vcpu->arch.mmu->inject_page_fault(vcpu, fault);
 }
 
 static bool sync_mmio_spte(struct kvm_vcpu *vcpu, u64 *sptep, gfn_t gfn,
@@ -4414,7 +4420,8 @@ static void reset_rsvds_bits_mask_ept(struct kvm_vcpu *vcpu,
 void
 reset_shadow_zero_bits_mask(struct kvm_vcpu *vcpu, struct kvm_mmu *context)
 {
-       bool uses_nx = context->nx || context->base_role.smep_andnot_wp;
+       bool uses_nx = context->nx ||
+               context->mmu_role.base.smep_andnot_wp;
        struct rsvd_bits_validate *shadow_zero_check;
        int i;
 
@@ -4553,7 +4560,7 @@ static void update_permission_bitmask(struct kvm_vcpu *vcpu,
                         * SMAP:kernel-mode data accesses from user-mode
                         * mappings should fault. A fault is considered
                         * as a SMAP violation if all of the following
-                        * conditions are ture:
+                        * conditions are true:
                         *   - X86_CR4_SMAP is set in CR4
                         *   - A user page is accessed
                         *   - The access is not a fetch
@@ -4714,27 +4721,65 @@ static void paging32E_init_context(struct kvm_vcpu *vcpu,
        paging64_init_context_common(vcpu, context, PT32E_ROOT_LEVEL);
 }
 
-static union kvm_mmu_page_role
-kvm_calc_tdp_mmu_root_page_role(struct kvm_vcpu *vcpu)
+static union kvm_mmu_extended_role kvm_calc_mmu_role_ext(struct kvm_vcpu *vcpu)
+{
+       union kvm_mmu_extended_role ext = {0};
+
+       ext.cr0_pg = !!is_paging(vcpu);
+       ext.cr4_smep = !!kvm_read_cr4_bits(vcpu, X86_CR4_SMEP);
+       ext.cr4_smap = !!kvm_read_cr4_bits(vcpu, X86_CR4_SMAP);
+       ext.cr4_pse = !!is_pse(vcpu);
+       ext.cr4_pke = !!kvm_read_cr4_bits(vcpu, X86_CR4_PKE);
+       ext.cr4_la57 = !!kvm_read_cr4_bits(vcpu, X86_CR4_LA57);
+
+       ext.valid = 1;
+
+       return ext;
+}
+
+static union kvm_mmu_role kvm_calc_mmu_role_common(struct kvm_vcpu *vcpu,
+                                                  bool base_only)
+{
+       union kvm_mmu_role role = {0};
+
+       role.base.access = ACC_ALL;
+       role.base.nxe = !!is_nx(vcpu);
+       role.base.cr4_pae = !!is_pae(vcpu);
+       role.base.cr0_wp = is_write_protection(vcpu);
+       role.base.smm = is_smm(vcpu);
+       role.base.guest_mode = is_guest_mode(vcpu);
+
+       if (base_only)
+               return role;
+
+       role.ext = kvm_calc_mmu_role_ext(vcpu);
+
+       return role;
+}
+
+static union kvm_mmu_role
+kvm_calc_tdp_mmu_root_page_role(struct kvm_vcpu *vcpu, bool base_only)
 {
-       union kvm_mmu_page_role role = {0};
+       union kvm_mmu_role role = kvm_calc_mmu_role_common(vcpu, base_only);
 
-       role.guest_mode = is_guest_mode(vcpu);
-       role.smm = is_smm(vcpu);
-       role.ad_disabled = (shadow_accessed_mask == 0);
-       role.level = kvm_x86_ops->get_tdp_level(vcpu);
-       role.direct = true;
-       role.access = ACC_ALL;
+       role.base.ad_disabled = (shadow_accessed_mask == 0);
+       role.base.level = kvm_x86_ops->get_tdp_level(vcpu);
+       role.base.direct = true;
 
        return role;
 }
 
 static void init_kvm_tdp_mmu(struct kvm_vcpu *vcpu)
 {
-       struct kvm_mmu *context = &vcpu->arch.mmu;
+       struct kvm_mmu *context = vcpu->arch.mmu;
+       union kvm_mmu_role new_role =
+               kvm_calc_tdp_mmu_root_page_role(vcpu, false);
 
-       context->base_role.word = mmu_base_role_mask.word &
-                                 kvm_calc_tdp_mmu_root_page_role(vcpu).word;
+       new_role.base.word &= mmu_base_role_mask.word;
+       if (new_role.as_u64 == context->mmu_role.as_u64)
+               return;
+
+       context->mmu_role.as_u64 = new_role.as_u64;
        context->page_fault = tdp_page_fault;
        context->sync_page = nonpaging_sync_page;
        context->invlpg = nonpaging_invlpg;
@@ -4774,36 +4819,36 @@ static void init_kvm_tdp_mmu(struct kvm_vcpu *vcpu)
        reset_tdp_shadow_zero_bits_mask(vcpu, context);
 }
 
-static union kvm_mmu_page_role
-kvm_calc_shadow_mmu_root_page_role(struct kvm_vcpu *vcpu)
-{
-       union kvm_mmu_page_role role = {0};
-       bool smep = kvm_read_cr4_bits(vcpu, X86_CR4_SMEP);
-       bool smap = kvm_read_cr4_bits(vcpu, X86_CR4_SMAP);
-
-       role.nxe = is_nx(vcpu);
-       role.cr4_pae = !!is_pae(vcpu);
-       role.cr0_wp  = is_write_protection(vcpu);
-       role.smep_andnot_wp = smep && !is_write_protection(vcpu);
-       role.smap_andnot_wp = smap && !is_write_protection(vcpu);
-       role.guest_mode = is_guest_mode(vcpu);
-       role.smm = is_smm(vcpu);
-       role.direct = !is_paging(vcpu);
-       role.access = ACC_ALL;
+static union kvm_mmu_role
+kvm_calc_shadow_mmu_root_page_role(struct kvm_vcpu *vcpu, bool base_only)
+{
+       union kvm_mmu_role role = kvm_calc_mmu_role_common(vcpu, base_only);
+
+       role.base.smep_andnot_wp = role.ext.cr4_smep &&
+               !is_write_protection(vcpu);
+       role.base.smap_andnot_wp = role.ext.cr4_smap &&
+               !is_write_protection(vcpu);
+       role.base.direct = !is_paging(vcpu);
 
        if (!is_long_mode(vcpu))
-               role.level = PT32E_ROOT_LEVEL;
+               role.base.level = PT32E_ROOT_LEVEL;
        else if (is_la57_mode(vcpu))
-               role.level = PT64_ROOT_5LEVEL;
+               role.base.level = PT64_ROOT_5LEVEL;
        else
-               role.level = PT64_ROOT_4LEVEL;
+               role.base.level = PT64_ROOT_4LEVEL;
 
        return role;
 }
 
 void kvm_init_shadow_mmu(struct kvm_vcpu *vcpu)
 {
-       struct kvm_mmu *context = &vcpu->arch.mmu;
+       struct kvm_mmu *context = vcpu->arch.mmu;
+       union kvm_mmu_role new_role =
+               kvm_calc_shadow_mmu_root_page_role(vcpu, false);
+
+       new_role.base.word &= mmu_base_role_mask.word;
+       if (new_role.as_u64 == context->mmu_role.as_u64)
+               return;
 
        if (!is_paging(vcpu))
                nonpaging_init_context(vcpu, context);
@@ -4814,22 +4859,28 @@ void kvm_init_shadow_mmu(struct kvm_vcpu *vcpu)
        else
                paging32_init_context(vcpu, context);
 
-       context->base_role.word = mmu_base_role_mask.word &
-                                 kvm_calc_shadow_mmu_root_page_role(vcpu).word;
+       context->mmu_role.as_u64 = new_role.as_u64;
        reset_shadow_zero_bits_mask(vcpu, context);
 }
 EXPORT_SYMBOL_GPL(kvm_init_shadow_mmu);
 
-static union kvm_mmu_page_role
-kvm_calc_shadow_ept_root_page_role(struct kvm_vcpu *vcpu, bool accessed_dirty)
+static union kvm_mmu_role
+kvm_calc_shadow_ept_root_page_role(struct kvm_vcpu *vcpu, bool accessed_dirty,
+                                  bool execonly)
 {
-       union kvm_mmu_page_role role = vcpu->arch.mmu.base_role;
+       union kvm_mmu_role role;
+
+       /* Base role is inherited from root_mmu */
+       role.base.word = vcpu->arch.root_mmu.mmu_role.base.word;
+       role.ext = kvm_calc_mmu_role_ext(vcpu);
+
+       role.base.level = PT64_ROOT_4LEVEL;
+       role.base.direct = false;
+       role.base.ad_disabled = !accessed_dirty;
+       role.base.guest_mode = true;
+       role.base.access = ACC_ALL;
 
-       role.level = PT64_ROOT_4LEVEL;
-       role.direct = false;
-       role.ad_disabled = !accessed_dirty;
-       role.guest_mode = true;
-       role.access = ACC_ALL;
+       role.ext.execonly = execonly;
 
        return role;
 }
@@ -4837,11 +4888,17 @@ kvm_calc_shadow_ept_root_page_role(struct kvm_vcpu *vcpu, bool accessed_dirty)
 void kvm_init_shadow_ept_mmu(struct kvm_vcpu *vcpu, bool execonly,
                             bool accessed_dirty, gpa_t new_eptp)
 {
-       struct kvm_mmu *context = &vcpu->arch.mmu;
-       union kvm_mmu_page_role root_page_role =
-               kvm_calc_shadow_ept_root_page_role(vcpu, accessed_dirty);
+       struct kvm_mmu *context = vcpu->arch.mmu;
+       union kvm_mmu_role new_role =
+               kvm_calc_shadow_ept_root_page_role(vcpu, accessed_dirty,
+                                                  execonly);
+
+       __kvm_mmu_new_cr3(vcpu, new_eptp, new_role.base, false);
+
+       new_role.base.word &= mmu_base_role_mask.word;
+       if (new_role.as_u64 == context->mmu_role.as_u64)
+               return;
 
-       __kvm_mmu_new_cr3(vcpu, new_eptp, root_page_role, false);
        context->shadow_root_level = PT64_ROOT_4LEVEL;
 
        context->nx = true;
@@ -4853,7 +4910,8 @@ void kvm_init_shadow_ept_mmu(struct kvm_vcpu *vcpu, bool execonly,
        context->update_pte = ept_update_pte;
        context->root_level = PT64_ROOT_4LEVEL;
        context->direct_map = false;
-       context->base_role.word = root_page_role.word & mmu_base_role_mask.word;
+       context->mmu_role.as_u64 = new_role.as_u64;
+
        update_permission_bitmask(vcpu, context, true);
        update_pkru_bitmask(vcpu, context, true);
        update_last_nonleaf_level(vcpu, context);
@@ -4864,7 +4922,7 @@ EXPORT_SYMBOL_GPL(kvm_init_shadow_ept_mmu);
 
 static void init_kvm_softmmu(struct kvm_vcpu *vcpu)
 {
-       struct kvm_mmu *context = &vcpu->arch.mmu;
+       struct kvm_mmu *context = vcpu->arch.mmu;
 
        kvm_init_shadow_mmu(vcpu);
        context->set_cr3           = kvm_x86_ops->set_cr3;
@@ -4875,14 +4933,20 @@ static void init_kvm_softmmu(struct kvm_vcpu *vcpu)
 
 static void init_kvm_nested_mmu(struct kvm_vcpu *vcpu)
 {
+       union kvm_mmu_role new_role = kvm_calc_mmu_role_common(vcpu, false);
        struct kvm_mmu *g_context = &vcpu->arch.nested_mmu;
 
+       new_role.base.word &= mmu_base_role_mask.word;
+       if (new_role.as_u64 == g_context->mmu_role.as_u64)
+               return;
+
+       g_context->mmu_role.as_u64 = new_role.as_u64;
        g_context->get_cr3           = get_cr3;
        g_context->get_pdptr         = kvm_pdptr_read;
        g_context->inject_page_fault = kvm_inject_page_fault;
 
        /*
-        * Note that arch.mmu.gva_to_gpa translates l2_gpa to l1_gpa using
+        * Note that arch.mmu->gva_to_gpa translates l2_gpa to l1_gpa using
         * L1's nested page tables (e.g. EPT12). The nested translation
         * of l2_gva to l1_gpa is done by arch.nested_mmu.gva_to_gpa using
         * L2's page tables as the first level of translation and L1's
@@ -4921,10 +4985,10 @@ void kvm_init_mmu(struct kvm_vcpu *vcpu, bool reset_roots)
        if (reset_roots) {
                uint i;
 
-               vcpu->arch.mmu.root_hpa = INVALID_PAGE;
+               vcpu->arch.mmu->root_hpa = INVALID_PAGE;
 
                for (i = 0; i < KVM_MMU_NUM_PREV_ROOTS; i++)
-                       vcpu->arch.mmu.prev_roots[i] = KVM_MMU_ROOT_INFO_INVALID;
+                       vcpu->arch.mmu->prev_roots[i] = KVM_MMU_ROOT_INFO_INVALID;
        }
 
        if (mmu_is_nested(vcpu))
@@ -4939,10 +5003,14 @@ EXPORT_SYMBOL_GPL(kvm_init_mmu);
 static union kvm_mmu_page_role
 kvm_mmu_calc_root_page_role(struct kvm_vcpu *vcpu)
 {
+       union kvm_mmu_role role;
+
        if (tdp_enabled)
-               return kvm_calc_tdp_mmu_root_page_role(vcpu);
+               role = kvm_calc_tdp_mmu_root_page_role(vcpu, true);
        else
-               return kvm_calc_shadow_mmu_root_page_role(vcpu);
+               role = kvm_calc_shadow_mmu_root_page_role(vcpu, true);
+
+       return role.base;
 }
 
 void kvm_mmu_reset_context(struct kvm_vcpu *vcpu)
@@ -4972,8 +5040,10 @@ EXPORT_SYMBOL_GPL(kvm_mmu_load);
 
 void kvm_mmu_unload(struct kvm_vcpu *vcpu)
 {
-       kvm_mmu_free_roots(vcpu, KVM_MMU_ROOTS_ALL);
-       WARN_ON(VALID_PAGE(vcpu->arch.mmu.root_hpa));
+       kvm_mmu_free_roots(vcpu, &vcpu->arch.root_mmu, KVM_MMU_ROOTS_ALL);
+       WARN_ON(VALID_PAGE(vcpu->arch.root_mmu.root_hpa));
+       kvm_mmu_free_roots(vcpu, &vcpu->arch.guest_mmu, KVM_MMU_ROOTS_ALL);
+       WARN_ON(VALID_PAGE(vcpu->arch.guest_mmu.root_hpa));
 }
 EXPORT_SYMBOL_GPL(kvm_mmu_unload);
 
@@ -4987,7 +5057,7 @@ static void mmu_pte_write_new_pte(struct kvm_vcpu *vcpu,
         }
 
        ++vcpu->kvm->stat.mmu_pte_updated;
-       vcpu->arch.mmu.update_pte(vcpu, sp, spte, new);
+       vcpu->arch.mmu->update_pte(vcpu, sp, spte, new);
 }
 
 static bool need_remote_flush(u64 old, u64 new)
@@ -5164,10 +5234,12 @@ static void kvm_mmu_pte_write(struct kvm_vcpu *vcpu, gpa_t gpa,
 
                local_flush = true;
                while (npte--) {
+                       u32 base_role = vcpu->arch.mmu->mmu_role.base.word;
+
                        entry = *spte;
                        mmu_page_zap_pte(vcpu->kvm, sp, spte);
                        if (gentry &&
-                             !((sp->role.word ^ vcpu->arch.mmu.base_role.word)
+                             !((sp->role.word ^ base_role)
                              & mmu_base_role_mask.word) && rmap_can_add(vcpu))
                                mmu_pte_write_new_pte(vcpu, sp, spte, &gentry);
                        if (need_remote_flush(entry, *spte))
@@ -5185,7 +5257,7 @@ int kvm_mmu_unprotect_page_virt(struct kvm_vcpu *vcpu, gva_t gva)
        gpa_t gpa;
        int r;
 
-       if (vcpu->arch.mmu.direct_map)
+       if (vcpu->arch.mmu->direct_map)
                return 0;
 
        gpa = kvm_mmu_gva_to_gpa_read(vcpu, gva, NULL);
@@ -5221,10 +5293,10 @@ int kvm_mmu_page_fault(struct kvm_vcpu *vcpu, gva_t cr2, u64 error_code,
 {
        int r, emulation_type = 0;
        enum emulation_result er;
-       bool direct = vcpu->arch.mmu.direct_map;
+       bool direct = vcpu->arch.mmu->direct_map;
 
        /* With shadow page tables, fault_address contains a GVA or nGPA.  */
-       if (vcpu->arch.mmu.direct_map) {
+       if (vcpu->arch.mmu->direct_map) {
                vcpu->arch.gpa_available = true;
                vcpu->arch.gpa_val = cr2;
        }
@@ -5237,8 +5309,9 @@ int kvm_mmu_page_fault(struct kvm_vcpu *vcpu, gva_t cr2, u64 error_code,
        }
 
        if (r == RET_PF_INVALID) {
-               r = vcpu->arch.mmu.page_fault(vcpu, cr2, lower_32_bits(error_code),
-                                             false);
+               r = vcpu->arch.mmu->page_fault(vcpu, cr2,
+                                              lower_32_bits(error_code),
+                                              false);
                WARN_ON(r == RET_PF_INVALID);
        }
 
@@ -5254,7 +5327,7 @@ int kvm_mmu_page_fault(struct kvm_vcpu *vcpu, gva_t cr2, u64 error_code,
         * paging in both guests. If true, we simply unprotect the page
         * and resume the guest.
         */
-       if (vcpu->arch.mmu.direct_map &&
+       if (vcpu->arch.mmu->direct_map &&
            (error_code & PFERR_NESTED_GUEST_PAGE) == PFERR_NESTED_GUEST_PAGE) {
                kvm_mmu_unprotect_page(vcpu->kvm, gpa_to_gfn(cr2));
                return 1;
@@ -5302,7 +5375,7 @@ EXPORT_SYMBOL_GPL(kvm_mmu_page_fault);
 
 void kvm_mmu_invlpg(struct kvm_vcpu *vcpu, gva_t gva)
 {
-       struct kvm_mmu *mmu = &vcpu->arch.mmu;
+       struct kvm_mmu *mmu = vcpu->arch.mmu;
        int i;
 
        /* INVLPG on a * non-canonical address is a NOP according to the SDM.  */
@@ -5333,7 +5406,7 @@ EXPORT_SYMBOL_GPL(kvm_mmu_invlpg);
 
 void kvm_mmu_invpcid_gva(struct kvm_vcpu *vcpu, gva_t gva, unsigned long pcid)
 {
-       struct kvm_mmu *mmu = &vcpu->arch.mmu;
+       struct kvm_mmu *mmu = vcpu->arch.mmu;
        bool tlb_flush = false;
        uint i;
 
@@ -5377,8 +5450,8 @@ EXPORT_SYMBOL_GPL(kvm_disable_tdp);
 
 static void free_mmu_pages(struct kvm_vcpu *vcpu)
 {
-       free_page((unsigned long)vcpu->arch.mmu.pae_root);
-       free_page((unsigned long)vcpu->arch.mmu.lm_root);
+       free_page((unsigned long)vcpu->arch.mmu->pae_root);
+       free_page((unsigned long)vcpu->arch.mmu->lm_root);
 }
 
 static int alloc_mmu_pages(struct kvm_vcpu *vcpu)
@@ -5398,9 +5471,9 @@ static int alloc_mmu_pages(struct kvm_vcpu *vcpu)
        if (!page)
                return -ENOMEM;
 
-       vcpu->arch.mmu.pae_root = page_address(page);
+       vcpu->arch.mmu->pae_root = page_address(page);
        for (i = 0; i < 4; ++i)
-               vcpu->arch.mmu.pae_root[i] = INVALID_PAGE;
+               vcpu->arch.mmu->pae_root[i] = INVALID_PAGE;
 
        return 0;
 }
@@ -5409,27 +5482,21 @@ int kvm_mmu_create(struct kvm_vcpu *vcpu)
 {
        uint i;
 
-       vcpu->arch.walk_mmu = &vcpu->arch.mmu;
-       vcpu->arch.mmu.root_hpa = INVALID_PAGE;
-       vcpu->arch.mmu.translate_gpa = translate_gpa;
-       vcpu->arch.nested_mmu.translate_gpa = translate_nested_gpa;
+       vcpu->arch.mmu = &vcpu->arch.root_mmu;
+       vcpu->arch.walk_mmu = &vcpu->arch.root_mmu;
 
+       vcpu->arch.root_mmu.root_hpa = INVALID_PAGE;
+       vcpu->arch.root_mmu.translate_gpa = translate_gpa;
        for (i = 0; i < KVM_MMU_NUM_PREV_ROOTS; i++)
-               vcpu->arch.mmu.prev_roots[i] = KVM_MMU_ROOT_INFO_INVALID;
-
-       return alloc_mmu_pages(vcpu);
-}
+               vcpu->arch.root_mmu.prev_roots[i] = KVM_MMU_ROOT_INFO_INVALID;
 
-void kvm_mmu_setup(struct kvm_vcpu *vcpu)
-{
-       MMU_WARN_ON(VALID_PAGE(vcpu->arch.mmu.root_hpa));
+       vcpu->arch.guest_mmu.root_hpa = INVALID_PAGE;
+       vcpu->arch.guest_mmu.translate_gpa = translate_gpa;
+       for (i = 0; i < KVM_MMU_NUM_PREV_ROOTS; i++)
+               vcpu->arch.guest_mmu.prev_roots[i] = KVM_MMU_ROOT_INFO_INVALID;
 
-       /*
-        * kvm_mmu_setup() is called only on vCPU initialization.  
-        * Therefore, no need to reset mmu roots as they are not yet
-        * initialized.
-        */
-       kvm_init_mmu(vcpu, false);
+       vcpu->arch.nested_mmu.translate_gpa = translate_nested_gpa;
+       return alloc_mmu_pages(vcpu);
 }
 
 static void kvm_mmu_invalidate_zap_pages_in_memslot(struct kvm *kvm,
@@ -5612,7 +5679,7 @@ restart:
                if (sp->role.direct &&
                        !kvm_is_reserved_pfn(pfn) &&
                        PageTransCompoundMap(pfn_to_page(pfn))) {
-                       drop_spte(kvm, sptep);
+                       pte_list_remove(rmap_head, sptep);
                        need_tlb_flush = 1;
                        goto restart;
                }
@@ -5869,6 +5936,16 @@ int kvm_mmu_module_init(void)
 {
        int ret = -ENOMEM;
 
+       /*
+        * MMU roles use union aliasing which is, generally speaking, an
+        * undefined behavior. However, we supposedly know how compilers behave
+        * and the current status quo is unlikely to change. Guardians below are
+        * supposed to let us know if the assumption becomes false.
+        */
+       BUILD_BUG_ON(sizeof(union kvm_mmu_page_role) != sizeof(u32));
+       BUILD_BUG_ON(sizeof(union kvm_mmu_extended_role) != sizeof(u32));
+       BUILD_BUG_ON(sizeof(union kvm_mmu_role) != sizeof(u64));
+
        kvm_mmu_reset_all_pte_masks();
 
        pte_list_desc_cache = kmem_cache_create("pte_list_desc",
@@ -5898,7 +5975,7 @@ out:
 }
 
 /*
- * Caculate mmu pages needed for kvm.
+ * Calculate mmu pages needed for kvm.
  */
 unsigned int kvm_mmu_calculate_mmu_pages(struct kvm *kvm)
 {