mm/mmu_notifier: add an interval tree notifier
authorJason Gunthorpe <jgg@mellanox.com>
Tue, 12 Nov 2019 20:22:19 +0000 (16:22 -0400)
committerJason Gunthorpe <jgg@mellanox.com>
Sat, 23 Nov 2019 23:56:44 +0000 (19:56 -0400)
Of the 13 users of mmu_notifiers, 8 of them use only
invalidate_range_start/end() and immediately intersect the
mmu_notifier_range with some kind of internal list of VAs.  4 use an
interval tree (i915_gem, radeon_mn, umem_odp, hfi1). 4 use a linked list
of some kind (scif_dma, vhost, gntdev, hmm)

And the remaining 5 either don't use invalidate_range_start() or do some
special thing with it.

It turns out that building a correct scheme with an interval tree is
pretty complicated, particularly if the use case is synchronizing against
another thread doing get_user_pages().  Many of these implementations have
various subtle and difficult to fix races.

This approach puts the interval tree as common code at the top of the mmu
notifier call tree and implements a shareable locking scheme.

It includes:
 - An interval tree tracking VA ranges, with per-range callbacks
 - A read/write locking scheme for the interval tree that avoids
   sleeping in the notifier path (for OOM killer)
 - A sequence counter based collision-retry locking scheme to tell
   device page fault that a VA range is being concurrently invalidated.

This is based on various ideas:
- hmm accumulates invalidated VA ranges and releases them when all
  invalidates are done, via active_invalidate_ranges count.
  This approach avoids having to intersect the interval tree twice (as
  umem_odp does) at the potential cost of a longer device page fault.

- kvm/umem_odp use a sequence counter to drive the collision retry,
  via invalidate_seq

- a deferred work todo list on unlock scheme like RTNL, via deferred_list.
  This makes adding/removing interval tree members more deterministic

- seqlock, except this version makes the seqlock idea multi-holder on the
  write side by protecting it with active_invalidate_ranges and a spinlock

To minimize MM overhead when only the interval tree is being used, the
entire SRCU and hlist overheads are dropped using some simple
branches. Similarly the interval tree overhead is dropped when in hlist
mode.

The overhead from the mandatory spinlock is broadly the same as most of
existing users which already had a lock (or two) of some sort on the
invalidation path.

Link: https://lore.kernel.org/r/20191112202231.3856-3-jgg@ziepe.ca
Acked-by: Christian König <christian.koenig@amd.com>
Tested-by: Philip Yang <Philip.Yang@amd.com>
Tested-by: Ralph Campbell <rcampbell@nvidia.com>
Reviewed-by: John Hubbard <jhubbard@nvidia.com>
Reviewed-by: Christoph Hellwig <hch@lst.de>
Signed-off-by: Jason Gunthorpe <jgg@mellanox.com>
include/linux/mmu_notifier.h
mm/Kconfig
mm/mmu_notifier.c

index 12bd603..9e6caa8 100644 (file)
@@ -6,10 +6,12 @@
 #include <linux/spinlock.h>
 #include <linux/mm_types.h>
 #include <linux/srcu.h>
+#include <linux/interval_tree.h>
 
 struct mmu_notifier_mm;
 struct mmu_notifier;
 struct mmu_notifier_range;
+struct mmu_interval_notifier;
 
 /**
  * enum mmu_notifier_event - reason for the mmu notifier callback
@@ -32,6 +34,9 @@ struct mmu_notifier_range;
  * access flags). User should soft dirty the page in the end callback to make
  * sure that anyone relying on soft dirtyness catch pages that might be written
  * through non CPU mappings.
+ *
+ * @MMU_NOTIFY_RELEASE: used during mmu_interval_notifier invalidate to signal
+ * that the mm refcount is zero and the range is no longer accessible.
  */
 enum mmu_notifier_event {
        MMU_NOTIFY_UNMAP = 0,
@@ -39,6 +44,7 @@ enum mmu_notifier_event {
        MMU_NOTIFY_PROTECTION_VMA,
        MMU_NOTIFY_PROTECTION_PAGE,
        MMU_NOTIFY_SOFT_DIRTY,
+       MMU_NOTIFY_RELEASE,
 };
 
 #define MMU_NOTIFIER_RANGE_BLOCKABLE (1 << 0)
@@ -222,6 +228,26 @@ struct mmu_notifier {
        unsigned int users;
 };
 
+/**
+ * struct mmu_interval_notifier_ops
+ * @invalidate: Upon return the caller must stop using any SPTEs within this
+ *              range. This function can sleep. Return false only if sleeping
+ *              was required but mmu_notifier_range_blockable(range) is false.
+ */
+struct mmu_interval_notifier_ops {
+       bool (*invalidate)(struct mmu_interval_notifier *mni,
+                          const struct mmu_notifier_range *range,
+                          unsigned long cur_seq);
+};
+
+struct mmu_interval_notifier {
+       struct interval_tree_node interval_tree;
+       const struct mmu_interval_notifier_ops *ops;
+       struct mm_struct *mm;
+       struct hlist_node deferred_item;
+       unsigned long invalidate_seq;
+};
+
 #ifdef CONFIG_MMU_NOTIFIER
 
 #ifdef CONFIG_LOCKDEP
@@ -263,6 +289,81 @@ extern int __mmu_notifier_register(struct mmu_notifier *mn,
                                   struct mm_struct *mm);
 extern void mmu_notifier_unregister(struct mmu_notifier *mn,
                                    struct mm_struct *mm);
+
+unsigned long mmu_interval_read_begin(struct mmu_interval_notifier *mni);
+int mmu_interval_notifier_insert(struct mmu_interval_notifier *mni,
+                                struct mm_struct *mm, unsigned long start,
+                                unsigned long length,
+                                const struct mmu_interval_notifier_ops *ops);
+int mmu_interval_notifier_insert_locked(
+       struct mmu_interval_notifier *mni, struct mm_struct *mm,
+       unsigned long start, unsigned long length,
+       const struct mmu_interval_notifier_ops *ops);
+void mmu_interval_notifier_remove(struct mmu_interval_notifier *mni);
+
+/**
+ * mmu_interval_set_seq - Save the invalidation sequence
+ * @mni - The mni passed to invalidate
+ * @cur_seq - The cur_seq passed to the invalidate() callback
+ *
+ * This must be called unconditionally from the invalidate callback of a
+ * struct mmu_interval_notifier_ops under the same lock that is used to call
+ * mmu_interval_read_retry(). It updates the sequence number for later use by
+ * mmu_interval_read_retry(). The provided cur_seq will always be odd.
+ *
+ * If the caller does not call mmu_interval_read_begin() or
+ * mmu_interval_read_retry() then this call is not required.
+ */
+static inline void mmu_interval_set_seq(struct mmu_interval_notifier *mni,
+                                       unsigned long cur_seq)
+{
+       WRITE_ONCE(mni->invalidate_seq, cur_seq);
+}
+
+/**
+ * mmu_interval_read_retry - End a read side critical section against a VA range
+ * mni: The range
+ * seq: The return of the paired mmu_interval_read_begin()
+ *
+ * This MUST be called under a user provided lock that is also held
+ * unconditionally by op->invalidate() when it calls mmu_interval_set_seq().
+ *
+ * Each call should be paired with a single mmu_interval_read_begin() and
+ * should be used to conclude the read side.
+ *
+ * Returns true if an invalidation collided with this critical section, and
+ * the caller should retry.
+ */
+static inline bool mmu_interval_read_retry(struct mmu_interval_notifier *mni,
+                                          unsigned long seq)
+{
+       return mni->invalidate_seq != seq;
+}
+
+/**
+ * mmu_interval_check_retry - Test if a collision has occurred
+ * mni: The range
+ * seq: The return of the matching mmu_interval_read_begin()
+ *
+ * This can be used in the critical section between mmu_interval_read_begin()
+ * and mmu_interval_read_retry().  A return of true indicates an invalidation
+ * has collided with this critical region and a future
+ * mmu_interval_read_retry() will return true.
+ *
+ * False is not reliable and only suggests a collision may not have
+ * occured. It can be called many times and does not have to hold the user
+ * provided lock.
+ *
+ * This call can be used as part of loops and other expensive operations to
+ * expedite a retry.
+ */
+static inline bool mmu_interval_check_retry(struct mmu_interval_notifier *mni,
+                                           unsigned long seq)
+{
+       /* Pairs with the WRITE_ONCE in mmu_interval_set_seq() */
+       return READ_ONCE(mni->invalidate_seq) != seq;
+}
+
 extern void __mmu_notifier_mm_destroy(struct mm_struct *mm);
 extern void __mmu_notifier_release(struct mm_struct *mm);
 extern int __mmu_notifier_clear_flush_young(struct mm_struct *mm,
index a5dae9a..d0b5046 100644 (file)
@@ -284,6 +284,7 @@ config VIRT_TO_BUS
 config MMU_NOTIFIER
        bool
        select SRCU
+       select INTERVAL_TREE
 
 config KSM
        bool "Enable KSM for page merging"
index 367670c..30abbfd 100644 (file)
@@ -12,6 +12,7 @@
 #include <linux/export.h>
 #include <linux/mm.h>
 #include <linux/err.h>
+#include <linux/interval_tree.h>
 #include <linux/srcu.h>
 #include <linux/rcupdate.h>
 #include <linux/sched.h>
@@ -36,11 +37,246 @@ struct lockdep_map __mmu_notifier_invalidate_range_start_map = {
 struct mmu_notifier_mm {
        /* all mmu notifiers registered in this mm are queued in this list */
        struct hlist_head list;
+       bool has_itree;
        /* to serialize the list modifications and hlist_unhashed */
        spinlock_t lock;
+       unsigned long invalidate_seq;
+       unsigned long active_invalidate_ranges;
+       struct rb_root_cached itree;
+       wait_queue_head_t wq;
+       struct hlist_head deferred_list;
 };
 
 /*
+ * This is a collision-retry read-side/write-side 'lock', a lot like a
+ * seqcount, however this allows multiple write-sides to hold it at
+ * once. Conceptually the write side is protecting the values of the PTEs in
+ * this mm, such that PTES cannot be read into SPTEs (shadow PTEs) while any
+ * writer exists.
+ *
+ * Note that the core mm creates nested invalidate_range_start()/end() regions
+ * within the same thread, and runs invalidate_range_start()/end() in parallel
+ * on multiple CPUs. This is designed to not reduce concurrency or block
+ * progress on the mm side.
+ *
+ * As a secondary function, holding the full write side also serves to prevent
+ * writers for the itree, this is an optimization to avoid extra locking
+ * during invalidate_range_start/end notifiers.
+ *
+ * The write side has two states, fully excluded:
+ *  - mm->active_invalidate_ranges != 0
+ *  - mnn->invalidate_seq & 1 == True (odd)
+ *  - some range on the mm_struct is being invalidated
+ *  - the itree is not allowed to change
+ *
+ * And partially excluded:
+ *  - mm->active_invalidate_ranges != 0
+ *  - mnn->invalidate_seq & 1 == False (even)
+ *  - some range on the mm_struct is being invalidated
+ *  - the itree is allowed to change
+ *
+ * Operations on mmu_notifier_mm->invalidate_seq (under spinlock):
+ *    seq |= 1  # Begin writing
+ *    seq++     # Release the writing state
+ *    seq & 1   # True if a writer exists
+ *
+ * The later state avoids some expensive work on inv_end in the common case of
+ * no mni monitoring the VA.
+ */
+static bool mn_itree_is_invalidating(struct mmu_notifier_mm *mmn_mm)
+{
+       lockdep_assert_held(&mmn_mm->lock);
+       return mmn_mm->invalidate_seq & 1;
+}
+
+static struct mmu_interval_notifier *
+mn_itree_inv_start_range(struct mmu_notifier_mm *mmn_mm,
+                        const struct mmu_notifier_range *range,
+                        unsigned long *seq)
+{
+       struct interval_tree_node *node;
+       struct mmu_interval_notifier *res = NULL;
+
+       spin_lock(&mmn_mm->lock);
+       mmn_mm->active_invalidate_ranges++;
+       node = interval_tree_iter_first(&mmn_mm->itree, range->start,
+                                       range->end - 1);
+       if (node) {
+               mmn_mm->invalidate_seq |= 1;
+               res = container_of(node, struct mmu_interval_notifier,
+                                  interval_tree);
+       }
+
+       *seq = mmn_mm->invalidate_seq;
+       spin_unlock(&mmn_mm->lock);
+       return res;
+}
+
+static struct mmu_interval_notifier *
+mn_itree_inv_next(struct mmu_interval_notifier *mni,
+                 const struct mmu_notifier_range *range)
+{
+       struct interval_tree_node *node;
+
+       node = interval_tree_iter_next(&mni->interval_tree, range->start,
+                                      range->end - 1);
+       if (!node)
+               return NULL;
+       return container_of(node, struct mmu_interval_notifier, interval_tree);
+}
+
+static void mn_itree_inv_end(struct mmu_notifier_mm *mmn_mm)
+{
+       struct mmu_interval_notifier *mni;
+       struct hlist_node *next;
+
+       spin_lock(&mmn_mm->lock);
+       if (--mmn_mm->active_invalidate_ranges ||
+           !mn_itree_is_invalidating(mmn_mm)) {
+               spin_unlock(&mmn_mm->lock);
+               return;
+       }
+
+       /* Make invalidate_seq even */
+       mmn_mm->invalidate_seq++;
+
+       /*
+        * The inv_end incorporates a deferred mechanism like rtnl_unlock().
+        * Adds and removes are queued until the final inv_end happens then
+        * they are progressed. This arrangement for tree updates is used to
+        * avoid using a blocking lock during invalidate_range_start.
+        */
+       hlist_for_each_entry_safe(mni, next, &mmn_mm->deferred_list,
+                                 deferred_item) {
+               if (RB_EMPTY_NODE(&mni->interval_tree.rb))
+                       interval_tree_insert(&mni->interval_tree,
+                                            &mmn_mm->itree);
+               else
+                       interval_tree_remove(&mni->interval_tree,
+                                            &mmn_mm->itree);
+               hlist_del(&mni->deferred_item);
+       }
+       spin_unlock(&mmn_mm->lock);
+
+       wake_up_all(&mmn_mm->wq);
+}
+
+/**
+ * mmu_interval_read_begin - Begin a read side critical section against a VA
+ *                           range
+ * mni: The range to use
+ *
+ * mmu_iterval_read_begin()/mmu_iterval_read_retry() implement a
+ * collision-retry scheme similar to seqcount for the VA range under mni. If
+ * the mm invokes invalidation during the critical section then
+ * mmu_interval_read_retry() will return true.
+ *
+ * This is useful to obtain shadow PTEs where teardown or setup of the SPTEs
+ * require a blocking context.  The critical region formed by this can sleep,
+ * and the required 'user_lock' can also be a sleeping lock.
+ *
+ * The caller is required to provide a 'user_lock' to serialize both teardown
+ * and setup.
+ *
+ * The return value should be passed to mmu_interval_read_retry().
+ */
+unsigned long mmu_interval_read_begin(struct mmu_interval_notifier *mni)
+{
+       struct mmu_notifier_mm *mmn_mm = mni->mm->mmu_notifier_mm;
+       unsigned long seq;
+       bool is_invalidating;
+
+       /*
+        * If the mni has a different seq value under the user_lock than we
+        * started with then it has collided.
+        *
+        * If the mni currently has the same seq value as the mmn_mm seq, then
+        * it is currently between invalidate_start/end and is colliding.
+        *
+        * The locking looks broadly like this:
+        *   mn_tree_invalidate_start():          mmu_interval_read_begin():
+        *                                         spin_lock
+        *                                          seq = READ_ONCE(mni->invalidate_seq);
+        *                                          seq == mmn_mm->invalidate_seq
+        *                                         spin_unlock
+        *    spin_lock
+        *     seq = ++mmn_mm->invalidate_seq
+        *    spin_unlock
+        *     op->invalidate_range():
+        *       user_lock
+        *        mmu_interval_set_seq()
+        *         mni->invalidate_seq = seq
+        *       user_unlock
+        *
+        *                          [Required: mmu_interval_read_retry() == true]
+        *
+        *   mn_itree_inv_end():
+        *    spin_lock
+        *     seq = ++mmn_mm->invalidate_seq
+        *    spin_unlock
+        *
+        *                                        user_lock
+        *                                         mmu_interval_read_retry():
+        *                                          mni->invalidate_seq != seq
+        *                                        user_unlock
+        *
+        * Barriers are not needed here as any races here are closed by an
+        * eventual mmu_interval_read_retry(), which provides a barrier via the
+        * user_lock.
+        */
+       spin_lock(&mmn_mm->lock);
+       /* Pairs with the WRITE_ONCE in mmu_interval_set_seq() */
+       seq = READ_ONCE(mni->invalidate_seq);
+       is_invalidating = seq == mmn_mm->invalidate_seq;
+       spin_unlock(&mmn_mm->lock);
+
+       /*
+        * mni->invalidate_seq must always be set to an odd value via
+        * mmu_interval_set_seq() using the provided cur_seq from
+        * mn_itree_inv_start_range(). This ensures that if seq does wrap we
+        * will always clear the below sleep in some reasonable time as
+        * mmn_mm->invalidate_seq is even in the idle state.
+        */
+       lock_map_acquire(&__mmu_notifier_invalidate_range_start_map);
+       lock_map_release(&__mmu_notifier_invalidate_range_start_map);
+       if (is_invalidating)
+               wait_event(mmn_mm->wq,
+                          READ_ONCE(mmn_mm->invalidate_seq) != seq);
+
+       /*
+        * Notice that mmu_interval_read_retry() can already be true at this
+        * point, avoiding loops here allows the caller to provide a global
+        * time bound.
+        */
+
+       return seq;
+}
+EXPORT_SYMBOL_GPL(mmu_interval_read_begin);
+
+static void mn_itree_release(struct mmu_notifier_mm *mmn_mm,
+                            struct mm_struct *mm)
+{
+       struct mmu_notifier_range range = {
+               .flags = MMU_NOTIFIER_RANGE_BLOCKABLE,
+               .event = MMU_NOTIFY_RELEASE,
+               .mm = mm,
+               .start = 0,
+               .end = ULONG_MAX,
+       };
+       struct mmu_interval_notifier *mni;
+       unsigned long cur_seq;
+       bool ret;
+
+       for (mni = mn_itree_inv_start_range(mmn_mm, &range, &cur_seq); mni;
+            mni = mn_itree_inv_next(mni, &range)) {
+               ret = mni->ops->invalidate(mni, &range, cur_seq);
+               WARN_ON(!ret);
+       }
+
+       mn_itree_inv_end(mmn_mm);
+}
+
+/*
  * This function can't run concurrently against mmu_notifier_register
  * because mm->mm_users > 0 during mmu_notifier_register and exit_mmap
  * runs with mm_users == 0. Other tasks may still invoke mmu notifiers
@@ -52,7 +288,8 @@ struct mmu_notifier_mm {
  * can't go away from under us as exit_mmap holds an mm_count pin
  * itself.
  */
-void __mmu_notifier_release(struct mm_struct *mm)
+static void mn_hlist_release(struct mmu_notifier_mm *mmn_mm,
+                            struct mm_struct *mm)
 {
        struct mmu_notifier *mn;
        int id;
@@ -62,7 +299,7 @@ void __mmu_notifier_release(struct mm_struct *mm)
         * ->release returns.
         */
        id = srcu_read_lock(&srcu);
-       hlist_for_each_entry_rcu(mn, &mm->mmu_notifier_mm->list, hlist)
+       hlist_for_each_entry_rcu(mn, &mmn_mm->list, hlist)
                /*
                 * If ->release runs before mmu_notifier_unregister it must be
                 * handled, as it's the only way for the driver to flush all
@@ -72,10 +309,9 @@ void __mmu_notifier_release(struct mm_struct *mm)
                if (mn->ops->release)
                        mn->ops->release(mn, mm);
 
-       spin_lock(&mm->mmu_notifier_mm->lock);
-       while (unlikely(!hlist_empty(&mm->mmu_notifier_mm->list))) {
-               mn = hlist_entry(mm->mmu_notifier_mm->list.first,
-                                struct mmu_notifier,
+       spin_lock(&mmn_mm->lock);
+       while (unlikely(!hlist_empty(&mmn_mm->list))) {
+               mn = hlist_entry(mmn_mm->list.first, struct mmu_notifier,
                                 hlist);
                /*
                 * We arrived before mmu_notifier_unregister so
@@ -85,7 +321,7 @@ void __mmu_notifier_release(struct mm_struct *mm)
                 */
                hlist_del_init_rcu(&mn->hlist);
        }
-       spin_unlock(&mm->mmu_notifier_mm->lock);
+       spin_unlock(&mmn_mm->lock);
        srcu_read_unlock(&srcu, id);
 
        /*
@@ -100,6 +336,17 @@ void __mmu_notifier_release(struct mm_struct *mm)
        synchronize_srcu(&srcu);
 }
 
+void __mmu_notifier_release(struct mm_struct *mm)
+{
+       struct mmu_notifier_mm *mmn_mm = mm->mmu_notifier_mm;
+
+       if (mmn_mm->has_itree)
+               mn_itree_release(mmn_mm, mm);
+
+       if (!hlist_empty(&mmn_mm->list))
+               mn_hlist_release(mmn_mm, mm);
+}
+
 /*
  * If no young bitflag is supported by the hardware, ->clear_flush_young can
  * unmap the address and return 1 or 0 depending if the mapping previously
@@ -172,14 +419,43 @@ void __mmu_notifier_change_pte(struct mm_struct *mm, unsigned long address,
        srcu_read_unlock(&srcu, id);
 }
 
-int __mmu_notifier_invalidate_range_start(struct mmu_notifier_range *range)
+static int mn_itree_invalidate(struct mmu_notifier_mm *mmn_mm,
+                              const struct mmu_notifier_range *range)
+{
+       struct mmu_interval_notifier *mni;
+       unsigned long cur_seq;
+
+       for (mni = mn_itree_inv_start_range(mmn_mm, range, &cur_seq); mni;
+            mni = mn_itree_inv_next(mni, range)) {
+               bool ret;
+
+               ret = mni->ops->invalidate(mni, range, cur_seq);
+               if (!ret) {
+                       if (WARN_ON(mmu_notifier_range_blockable(range)))
+                               continue;
+                       goto out_would_block;
+               }
+       }
+       return 0;
+
+out_would_block:
+       /*
+        * On -EAGAIN the non-blocking caller is not allowed to call
+        * invalidate_range_end()
+        */
+       mn_itree_inv_end(mmn_mm);
+       return -EAGAIN;
+}
+
+static int mn_hlist_invalidate_range_start(struct mmu_notifier_mm *mmn_mm,
+                                          struct mmu_notifier_range *range)
 {
        struct mmu_notifier *mn;
        int ret = 0;
        int id;
 
        id = srcu_read_lock(&srcu);
-       hlist_for_each_entry_rcu(mn, &range->mm->mmu_notifier_mm->list, hlist) {
+       hlist_for_each_entry_rcu(mn, &mmn_mm->list, hlist) {
                if (mn->ops->invalidate_range_start) {
                        int _ret;
 
@@ -203,15 +479,30 @@ int __mmu_notifier_invalidate_range_start(struct mmu_notifier_range *range)
        return ret;
 }
 
-void __mmu_notifier_invalidate_range_end(struct mmu_notifier_range *range,
-                                        bool only_end)
+int __mmu_notifier_invalidate_range_start(struct mmu_notifier_range *range)
+{
+       struct mmu_notifier_mm *mmn_mm = range->mm->mmu_notifier_mm;
+       int ret;
+
+       if (mmn_mm->has_itree) {
+               ret = mn_itree_invalidate(mmn_mm, range);
+               if (ret)
+                       return ret;
+       }
+       if (!hlist_empty(&mmn_mm->list))
+               return mn_hlist_invalidate_range_start(mmn_mm, range);
+       return 0;
+}
+
+static void mn_hlist_invalidate_end(struct mmu_notifier_mm *mmn_mm,
+                                   struct mmu_notifier_range *range,
+                                   bool only_end)
 {
        struct mmu_notifier *mn;
        int id;
 
-       lock_map_acquire(&__mmu_notifier_invalidate_range_start_map);
        id = srcu_read_lock(&srcu);
-       hlist_for_each_entry_rcu(mn, &range->mm->mmu_notifier_mm->list, hlist) {
+       hlist_for_each_entry_rcu(mn, &mmn_mm->list, hlist) {
                /*
                 * Call invalidate_range here too to avoid the need for the
                 * subsystem of having to register an invalidate_range_end
@@ -238,6 +529,19 @@ void __mmu_notifier_invalidate_range_end(struct mmu_notifier_range *range,
                }
        }
        srcu_read_unlock(&srcu, id);
+}
+
+void __mmu_notifier_invalidate_range_end(struct mmu_notifier_range *range,
+                                        bool only_end)
+{
+       struct mmu_notifier_mm *mmn_mm = range->mm->mmu_notifier_mm;
+
+       lock_map_acquire(&__mmu_notifier_invalidate_range_start_map);
+       if (mmn_mm->has_itree)
+               mn_itree_inv_end(mmn_mm);
+
+       if (!hlist_empty(&mmn_mm->list))
+               mn_hlist_invalidate_end(mmn_mm, range, only_end);
        lock_map_release(&__mmu_notifier_invalidate_range_start_map);
 }
 
@@ -256,8 +560,9 @@ void __mmu_notifier_invalidate_range(struct mm_struct *mm,
 }
 
 /*
- * Same as mmu_notifier_register but here the caller must hold the
- * mmap_sem in write mode.
+ * Same as mmu_notifier_register but here the caller must hold the mmap_sem in
+ * write mode. A NULL mn signals the notifier is being registered for itree
+ * mode.
  */
 int __mmu_notifier_register(struct mmu_notifier *mn, struct mm_struct *mm)
 {
@@ -274,9 +579,6 @@ int __mmu_notifier_register(struct mmu_notifier *mn, struct mm_struct *mm)
                fs_reclaim_release(GFP_KERNEL);
        }
 
-       mn->mm = mm;
-       mn->users = 1;
-
        if (!mm->mmu_notifier_mm) {
                /*
                 * kmalloc cannot be called under mm_take_all_locks(), but we
@@ -284,21 +586,22 @@ int __mmu_notifier_register(struct mmu_notifier *mn, struct mm_struct *mm)
                 * the write side of the mmap_sem.
                 */
                mmu_notifier_mm =
-                       kmalloc(sizeof(struct mmu_notifier_mm), GFP_KERNEL);
+                       kzalloc(sizeof(struct mmu_notifier_mm), GFP_KERNEL);
                if (!mmu_notifier_mm)
                        return -ENOMEM;
 
                INIT_HLIST_HEAD(&mmu_notifier_mm->list);
                spin_lock_init(&mmu_notifier_mm->lock);
+               mmu_notifier_mm->invalidate_seq = 2;
+               mmu_notifier_mm->itree = RB_ROOT_CACHED;
+               init_waitqueue_head(&mmu_notifier_mm->wq);
+               INIT_HLIST_HEAD(&mmu_notifier_mm->deferred_list);
        }
 
        ret = mm_take_all_locks(mm);
        if (unlikely(ret))
                goto out_clean;
 
-       /* Pairs with the mmdrop in mmu_notifier_unregister_* */
-       mmgrab(mm);
-
        /*
         * Serialize the update against mmu_notifier_unregister. A
         * side note: mmu_notifier_release can't run concurrently with
@@ -306,13 +609,28 @@ int __mmu_notifier_register(struct mmu_notifier *mn, struct mm_struct *mm)
         * current->mm or explicitly with get_task_mm() or similar).
         * We can't race against any other mmu notifier method either
         * thanks to mm_take_all_locks().
+        *
+        * release semantics on the initialization of the mmu_notifier_mm's
+        * contents are provided for unlocked readers.  acquire can only be
+        * used while holding the mmgrab or mmget, and is safe because once
+        * created the mmu_notififer_mm is not freed until the mm is
+        * destroyed.  As above, users holding the mmap_sem or one of the
+        * mm_take_all_locks() do not need to use acquire semantics.
         */
        if (mmu_notifier_mm)
-               mm->mmu_notifier_mm = mmu_notifier_mm;
+               smp_store_release(&mm->mmu_notifier_mm, mmu_notifier_mm);
 
-       spin_lock(&mm->mmu_notifier_mm->lock);
-       hlist_add_head_rcu(&mn->hlist, &mm->mmu_notifier_mm->list);
-       spin_unlock(&mm->mmu_notifier_mm->lock);
+       if (mn) {
+               /* Pairs with the mmdrop in mmu_notifier_unregister_* */
+               mmgrab(mm);
+               mn->mm = mm;
+               mn->users = 1;
+
+               spin_lock(&mm->mmu_notifier_mm->lock);
+               hlist_add_head_rcu(&mn->hlist, &mm->mmu_notifier_mm->list);
+               spin_unlock(&mm->mmu_notifier_mm->lock);
+       } else
+               mm->mmu_notifier_mm->has_itree = true;
 
        mm_drop_all_locks(mm);
        BUG_ON(atomic_read(&mm->mm_users) <= 0);
@@ -529,6 +847,180 @@ out_unlock:
 }
 EXPORT_SYMBOL_GPL(mmu_notifier_put);
 
+static int __mmu_interval_notifier_insert(
+       struct mmu_interval_notifier *mni, struct mm_struct *mm,
+       struct mmu_notifier_mm *mmn_mm, unsigned long start,
+       unsigned long length, const struct mmu_interval_notifier_ops *ops)
+{
+       mni->mm = mm;
+       mni->ops = ops;
+       RB_CLEAR_NODE(&mni->interval_tree.rb);
+       mni->interval_tree.start = start;
+       /*
+        * Note that the representation of the intervals in the interval tree
+        * considers the ending point as contained in the interval.
+        */
+       if (length == 0 ||
+           check_add_overflow(start, length - 1, &mni->interval_tree.last))
+               return -EOVERFLOW;
+
+       /* Must call with a mmget() held */
+       if (WARN_ON(atomic_read(&mm->mm_count) <= 0))
+               return -EINVAL;
+
+       /* pairs with mmdrop in mmu_interval_notifier_remove() */
+       mmgrab(mm);
+
+       /*
+        * If some invalidate_range_start/end region is going on in parallel
+        * we don't know what VA ranges are affected, so we must assume this
+        * new range is included.
+        *
+        * If the itree is invalidating then we are not allowed to change
+        * it. Retrying until invalidation is done is tricky due to the
+        * possibility for live lock, instead defer the add to
+        * mn_itree_inv_end() so this algorithm is deterministic.
+        *
+        * In all cases the value for the mni->invalidate_seq should be
+        * odd, see mmu_interval_read_begin()
+        */
+       spin_lock(&mmn_mm->lock);
+       if (mmn_mm->active_invalidate_ranges) {
+               if (mn_itree_is_invalidating(mmn_mm))
+                       hlist_add_head(&mni->deferred_item,
+                                      &mmn_mm->deferred_list);
+               else {
+                       mmn_mm->invalidate_seq |= 1;
+                       interval_tree_insert(&mni->interval_tree,
+                                            &mmn_mm->itree);
+               }
+               mni->invalidate_seq = mmn_mm->invalidate_seq;
+       } else {
+               WARN_ON(mn_itree_is_invalidating(mmn_mm));
+               /*
+                * The starting seq for a mni not under invalidation should be
+                * odd, not equal to the current invalidate_seq and
+                * invalidate_seq should not 'wrap' to the new seq any time
+                * soon.
+                */
+               mni->invalidate_seq = mmn_mm->invalidate_seq - 1;
+               interval_tree_insert(&mni->interval_tree, &mmn_mm->itree);
+       }
+       spin_unlock(&mmn_mm->lock);
+       return 0;
+}
+
+/**
+ * mmu_interval_notifier_insert - Insert an interval notifier
+ * @mni: Interval notifier to register
+ * @start: Starting virtual address to monitor
+ * @length: Length of the range to monitor
+ * @mm : mm_struct to attach to
+ *
+ * This function subscribes the interval notifier for notifications from the
+ * mm.  Upon return the ops related to mmu_interval_notifier will be called
+ * whenever an event that intersects with the given range occurs.
+ *
+ * Upon return the range_notifier may not be present in the interval tree yet.
+ * The caller must use the normal interval notifier read flow via
+ * mmu_interval_read_begin() to establish SPTEs for this range.
+ */
+int mmu_interval_notifier_insert(struct mmu_interval_notifier *mni,
+                                struct mm_struct *mm, unsigned long start,
+                                unsigned long length,
+                                const struct mmu_interval_notifier_ops *ops)
+{
+       struct mmu_notifier_mm *mmn_mm;
+       int ret;
+
+       might_lock(&mm->mmap_sem);
+
+       mmn_mm = smp_load_acquire(&mm->mmu_notifier_mm);
+       if (!mmn_mm || !mmn_mm->has_itree) {
+               ret = mmu_notifier_register(NULL, mm);
+               if (ret)
+                       return ret;
+               mmn_mm = mm->mmu_notifier_mm;
+       }
+       return __mmu_interval_notifier_insert(mni, mm, mmn_mm, start, length,
+                                             ops);
+}
+EXPORT_SYMBOL_GPL(mmu_interval_notifier_insert);
+
+int mmu_interval_notifier_insert_locked(
+       struct mmu_interval_notifier *mni, struct mm_struct *mm,
+       unsigned long start, unsigned long length,
+       const struct mmu_interval_notifier_ops *ops)
+{
+       struct mmu_notifier_mm *mmn_mm;
+       int ret;
+
+       lockdep_assert_held_write(&mm->mmap_sem);
+
+       mmn_mm = mm->mmu_notifier_mm;
+       if (!mmn_mm || !mmn_mm->has_itree) {
+               ret = __mmu_notifier_register(NULL, mm);
+               if (ret)
+                       return ret;
+               mmn_mm = mm->mmu_notifier_mm;
+       }
+       return __mmu_interval_notifier_insert(mni, mm, mmn_mm, start, length,
+                                             ops);
+}
+EXPORT_SYMBOL_GPL(mmu_interval_notifier_insert_locked);
+
+/**
+ * mmu_interval_notifier_remove - Remove a interval notifier
+ * @mni: Interval notifier to unregister
+ *
+ * This function must be paired with mmu_interval_notifier_insert(). It cannot
+ * be called from any ops callback.
+ *
+ * Once this returns ops callbacks are no longer running on other CPUs and
+ * will not be called in future.
+ */
+void mmu_interval_notifier_remove(struct mmu_interval_notifier *mni)
+{
+       struct mm_struct *mm = mni->mm;
+       struct mmu_notifier_mm *mmn_mm = mm->mmu_notifier_mm;
+       unsigned long seq = 0;
+
+       might_sleep();
+
+       spin_lock(&mmn_mm->lock);
+       if (mn_itree_is_invalidating(mmn_mm)) {
+               /*
+                * remove is being called after insert put this on the
+                * deferred list, but before the deferred list was processed.
+                */
+               if (RB_EMPTY_NODE(&mni->interval_tree.rb)) {
+                       hlist_del(&mni->deferred_item);
+               } else {
+                       hlist_add_head(&mni->deferred_item,
+                                      &mmn_mm->deferred_list);
+                       seq = mmn_mm->invalidate_seq;
+               }
+       } else {
+               WARN_ON(RB_EMPTY_NODE(&mni->interval_tree.rb));
+               interval_tree_remove(&mni->interval_tree, &mmn_mm->itree);
+       }
+       spin_unlock(&mmn_mm->lock);
+
+       /*
+        * The possible sleep on progress in the invalidation requires the
+        * caller not hold any locks held by invalidation callbacks.
+        */
+       lock_map_acquire(&__mmu_notifier_invalidate_range_start_map);
+       lock_map_release(&__mmu_notifier_invalidate_range_start_map);
+       if (seq)
+               wait_event(mmn_mm->wq,
+                          READ_ONCE(mmn_mm->invalidate_seq) != seq);
+
+       /* pairs with mmgrab in mmu_interval_notifier_insert() */
+       mmdrop(mm);
+}
+EXPORT_SYMBOL_GPL(mmu_interval_notifier_remove);
+
 /**
  * mmu_notifier_synchronize - Ensure all mmu_notifiers are freed
  *