bpf: Use scalar ids in mark_chain_precision()
authorEduard Zingerman <eddyz87@gmail.com>
Tue, 13 Jun 2023 15:38:21 +0000 (18:38 +0300)
committerAndrii Nakryiko <andrii@kernel.org>
Tue, 13 Jun 2023 22:14:27 +0000 (15:14 -0700)
Change mark_chain_precision() to track precision in situations
like below:

    r2 = unknown value
    ...
  --- state #0 ---
    ...
    r1 = r2                 // r1 and r2 now share the same ID
    ...
  --- state #1 {r1.id = A, r2.id = A} ---
    ...
    if (r2 > 10) goto exit; // find_equal_scalars() assigns range to r1
    ...
  --- state #2 {r1.id = A, r2.id = A} ---
    r3 = r10
    r3 += r1                // need to mark both r1 and r2

At the beginning of the processing of each state, ensure that if a
register with a scalar ID is marked as precise, all registers sharing
this ID are also marked as precise.

This property would be used by a follow-up change in regsafe().

Signed-off-by: Eduard Zingerman <eddyz87@gmail.com>
Signed-off-by: Andrii Nakryiko <andrii@kernel.org>
Acked-by: Andrii Nakryiko <andrii@kernel.org>
Link: https://lore.kernel.org/bpf/20230613153824.3324830-2-eddyz87@gmail.com
include/linux/bpf_verifier.h
kernel/bpf/verifier.c
tools/testing/selftests/bpf/verifier/precise.c

index 5b11a3b..22fb13c 100644 (file)
@@ -557,6 +557,11 @@ struct backtrack_state {
        u64 stack_masks[MAX_CALL_FRAMES];
 };
 
+struct bpf_idset {
+       u32 count;
+       u32 ids[BPF_ID_MAP_SIZE];
+};
+
 /* single container for all structs
  * one verifier_env per bpf_check() call
  */
@@ -588,7 +593,10 @@ struct bpf_verifier_env {
        const struct bpf_line_info *prev_linfo;
        struct bpf_verifier_log log;
        struct bpf_subprog_info subprog_info[BPF_MAX_SUBPROGS + 1];
-       struct bpf_id_pair idmap_scratch[BPF_ID_MAP_SIZE];
+       union {
+               struct bpf_id_pair idmap_scratch[BPF_ID_MAP_SIZE];
+               struct bpf_idset idset_scratch;
+       };
        struct {
                int *insn_state;
                int *insn_stack;
index 1e38584..064aef5 100644 (file)
@@ -3779,6 +3779,96 @@ static void mark_all_scalars_imprecise(struct bpf_verifier_env *env, struct bpf_
        }
 }
 
+static bool idset_contains(struct bpf_idset *s, u32 id)
+{
+       u32 i;
+
+       for (i = 0; i < s->count; ++i)
+               if (s->ids[i] == id)
+                       return true;
+
+       return false;
+}
+
+static int idset_push(struct bpf_idset *s, u32 id)
+{
+       if (WARN_ON_ONCE(s->count >= ARRAY_SIZE(s->ids)))
+               return -EFAULT;
+       s->ids[s->count++] = id;
+       return 0;
+}
+
+static void idset_reset(struct bpf_idset *s)
+{
+       s->count = 0;
+}
+
+/* Collect a set of IDs for all registers currently marked as precise in env->bt.
+ * Mark all registers with these IDs as precise.
+ */
+static int mark_precise_scalar_ids(struct bpf_verifier_env *env, struct bpf_verifier_state *st)
+{
+       struct bpf_idset *precise_ids = &env->idset_scratch;
+       struct backtrack_state *bt = &env->bt;
+       struct bpf_func_state *func;
+       struct bpf_reg_state *reg;
+       DECLARE_BITMAP(mask, 64);
+       int i, fr;
+
+       idset_reset(precise_ids);
+
+       for (fr = bt->frame; fr >= 0; fr--) {
+               func = st->frame[fr];
+
+               bitmap_from_u64(mask, bt_frame_reg_mask(bt, fr));
+               for_each_set_bit(i, mask, 32) {
+                       reg = &func->regs[i];
+                       if (!reg->id || reg->type != SCALAR_VALUE)
+                               continue;
+                       if (idset_push(precise_ids, reg->id))
+                               return -EFAULT;
+               }
+
+               bitmap_from_u64(mask, bt_frame_stack_mask(bt, fr));
+               for_each_set_bit(i, mask, 64) {
+                       if (i >= func->allocated_stack / BPF_REG_SIZE)
+                               break;
+                       if (!is_spilled_scalar_reg(&func->stack[i]))
+                               continue;
+                       reg = &func->stack[i].spilled_ptr;
+                       if (!reg->id)
+                               continue;
+                       if (idset_push(precise_ids, reg->id))
+                               return -EFAULT;
+               }
+       }
+
+       for (fr = 0; fr <= st->curframe; ++fr) {
+               func = st->frame[fr];
+
+               for (i = BPF_REG_0; i < BPF_REG_10; ++i) {
+                       reg = &func->regs[i];
+                       if (!reg->id)
+                               continue;
+                       if (!idset_contains(precise_ids, reg->id))
+                               continue;
+                       bt_set_frame_reg(bt, fr, i);
+               }
+               for (i = 0; i < func->allocated_stack / BPF_REG_SIZE; ++i) {
+                       if (!is_spilled_scalar_reg(&func->stack[i]))
+                               continue;
+                       reg = &func->stack[i].spilled_ptr;
+                       if (!reg->id)
+                               continue;
+                       if (!idset_contains(precise_ids, reg->id))
+                               continue;
+                       bt_set_frame_slot(bt, fr, i);
+               }
+       }
+
+       return 0;
+}
+
 /*
  * __mark_chain_precision() backtracks BPF program instruction sequence and
  * chain of verifier states making sure that register *regno* (if regno >= 0)
@@ -3910,6 +4000,31 @@ static int __mark_chain_precision(struct bpf_verifier_env *env, int regno)
                                bt->frame, last_idx, first_idx, subseq_idx);
                }
 
+               /* If some register with scalar ID is marked as precise,
+                * make sure that all registers sharing this ID are also precise.
+                * This is needed to estimate effect of find_equal_scalars().
+                * Do this at the last instruction of each state,
+                * bpf_reg_state::id fields are valid for these instructions.
+                *
+                * Allows to track precision in situation like below:
+                *
+                *     r2 = unknown value
+                *     ...
+                *   --- state #0 ---
+                *     ...
+                *     r1 = r2                 // r1 and r2 now share the same ID
+                *     ...
+                *   --- state #1 {r1.id = A, r2.id = A} ---
+                *     ...
+                *     if (r2 > 10) goto exit; // find_equal_scalars() assigns range to r1
+                *     ...
+                *   --- state #2 {r1.id = A, r2.id = A} ---
+                *     r3 = r10
+                *     r3 += r1                // need to mark both r1 and r2
+                */
+               if (mark_precise_scalar_ids(env, st))
+                       return -EFAULT;
+
                if (last_idx < 0) {
                        /* we are at the entry into subprog, which
                         * is expected for global funcs, but only if
index b8c0aae..99272bb 100644 (file)
@@ -46,7 +46,7 @@
        mark_precise: frame0: regs=r2 stack= before 20\
        mark_precise: frame0: parent state regs=r2 stack=:\
        mark_precise: frame0: last_idx 19 first_idx 10\
-       mark_precise: frame0: regs=r2 stack= before 19\
+       mark_precise: frame0: regs=r2,r9 stack= before 19\
        mark_precise: frame0: regs=r9 stack= before 18\
        mark_precise: frame0: regs=r8,r9 stack= before 17\
        mark_precise: frame0: regs=r0,r9 stack= before 15\
        mark_precise: frame0: regs=r2 stack= before 22\
        mark_precise: frame0: parent state regs=r2 stack=:\
        mark_precise: frame0: last_idx 20 first_idx 20\
-       mark_precise: frame0: regs=r2 stack= before 20\
-       mark_precise: frame0: parent state regs=r2 stack=:\
+       mark_precise: frame0: regs=r2,r9 stack= before 20\
+       mark_precise: frame0: parent state regs=r2,r9 stack=:\
        mark_precise: frame0: last_idx 19 first_idx 17\
-       mark_precise: frame0: regs=r2 stack= before 19\
+       mark_precise: frame0: regs=r2,r9 stack= before 19\
        mark_precise: frame0: regs=r9 stack= before 18\
        mark_precise: frame0: regs=r8,r9 stack= before 17\
        mark_precise: frame0: parent state regs= stack=:",