bpf: Add helper macro bpf_for_each_reg_in_vstate
authorKumar Kartikeya Dwivedi <memxor@gmail.com>
Sun, 4 Sep 2022 20:41:28 +0000 (22:41 +0200)
committerAlexei Starovoitov <ast@kernel.org>
Wed, 7 Sep 2022 23:42:23 +0000 (16:42 -0700)
For a lot of use cases in future patches, we will want to modify the
state of registers part of some same 'group' (e.g. same ref_obj_id). It
won't just be limited to releasing reference state, but setting a type
flag dynamically based on certain actions, etc.

Hence, we need a way to easily pass a callback to the function that
iterates over all registers in current bpf_verifier_state in all frames
upto (and including) the curframe.

While in C++ we would be able to easily use a lambda to pass state and
the callback together, sadly we aren't using C++ in the kernel. The next
best thing to avoid defining a function for each case seems like
statement expressions in GNU C. The kernel already uses them heavily,
hence they can passed to the macro in the style of a lambda. The
statement expression will then be substituted in the for loop bodies.

Variables __state and __reg are set to current bpf_func_state and reg
for each invocation of the expression inside the passed in verifier
state.

Then, convert mark_ptr_or_null_regs, clear_all_pkt_pointers,
release_reference, find_good_pkt_pointers, find_equal_scalars to
use bpf_for_each_reg_in_vstate.

Signed-off-by: Kumar Kartikeya Dwivedi <memxor@gmail.com>
Link: https://lore.kernel.org/r/20220904204145.3089-16-memxor@gmail.com
Signed-off-by: Alexei Starovoitov <ast@kernel.org>
include/linux/bpf_verifier.h
kernel/bpf/verifier.c

index 8fbc1d05281e8fcab548b88fa6a95d707ceb07eb..b49a349cc6aed6c8b4d49954c8f826e9cab49045 100644 (file)
@@ -348,6 +348,27 @@ struct bpf_verifier_state {
             iter < frame->allocated_stack / BPF_REG_SIZE;              \
             iter++, reg = bpf_get_spilled_reg(iter, frame))
 
+/* Invoke __expr over regsiters in __vst, setting __state and __reg */
+#define bpf_for_each_reg_in_vstate(__vst, __state, __reg, __expr)   \
+       ({                                                               \
+               struct bpf_verifier_state *___vstate = __vst;            \
+               int ___i, ___j;                                          \
+               for (___i = 0; ___i <= ___vstate->curframe; ___i++) {    \
+                       struct bpf_reg_state *___regs;                   \
+                       __state = ___vstate->frame[___i];                \
+                       ___regs = __state->regs;                         \
+                       for (___j = 0; ___j < MAX_BPF_REG; ___j++) {     \
+                               __reg = &___regs[___j];                  \
+                               (void)(__expr);                          \
+                       }                                                \
+                       bpf_for_each_spilled_reg(___j, __state, __reg) { \
+                               if (!__reg)                              \
+                                       continue;                        \
+                               (void)(__expr);                          \
+                       }                                                \
+               }                                                        \
+       })
+
 /* linked list of verifier states used to prune search */
 struct bpf_verifier_state_list {
        struct bpf_verifier_state state;
index f3344a86d88d0f577e6875fc88fde6d1fca4d451..c0f175ac187af05b0d3f5fa7f19e205c62e2f54f 100644 (file)
@@ -6513,31 +6513,15 @@ static int check_func_proto(const struct bpf_func_proto *fn, int func_id)
 /* Packet data might have moved, any old PTR_TO_PACKET[_META,_END]
  * are now invalid, so turn them into unknown SCALAR_VALUE.
  */
-static void __clear_all_pkt_pointers(struct bpf_verifier_env *env,
-                                    struct bpf_func_state *state)
+static void clear_all_pkt_pointers(struct bpf_verifier_env *env)
 {
-       struct bpf_reg_state *regs = state->regs, *reg;
-       int i;
-
-       for (i = 0; i < MAX_BPF_REG; i++)
-               if (reg_is_pkt_pointer_any(&regs[i]))
-                       mark_reg_unknown(env, regs, i);
+       struct bpf_func_state *state;
+       struct bpf_reg_state *reg;
 
-       bpf_for_each_spilled_reg(i, state, reg) {
-               if (!reg)
-                       continue;
+       bpf_for_each_reg_in_vstate(env->cur_state, state, reg, ({
                if (reg_is_pkt_pointer_any(reg))
                        __mark_reg_unknown(env, reg);
-       }
-}
-
-static void clear_all_pkt_pointers(struct bpf_verifier_env *env)
-{
-       struct bpf_verifier_state *vstate = env->cur_state;
-       int i;
-
-       for (i = 0; i <= vstate->curframe; i++)
-               __clear_all_pkt_pointers(env, vstate->frame[i]);
+       }));
 }
 
 enum {
@@ -6566,41 +6550,24 @@ static void mark_pkt_end(struct bpf_verifier_state *vstate, int regn, bool range
                reg->range = AT_PKT_END;
 }
 
-static void release_reg_references(struct bpf_verifier_env *env,
-                                  struct bpf_func_state *state,
-                                  int ref_obj_id)
-{
-       struct bpf_reg_state *regs = state->regs, *reg;
-       int i;
-
-       for (i = 0; i < MAX_BPF_REG; i++)
-               if (regs[i].ref_obj_id == ref_obj_id)
-                       mark_reg_unknown(env, regs, i);
-
-       bpf_for_each_spilled_reg(i, state, reg) {
-               if (!reg)
-                       continue;
-               if (reg->ref_obj_id == ref_obj_id)
-                       __mark_reg_unknown(env, reg);
-       }
-}
-
 /* The pointer with the specified id has released its reference to kernel
  * resources. Identify all copies of the same pointer and clear the reference.
  */
 static int release_reference(struct bpf_verifier_env *env,
                             int ref_obj_id)
 {
-       struct bpf_verifier_state *vstate = env->cur_state;
+       struct bpf_func_state *state;
+       struct bpf_reg_state *reg;
        int err;
-       int i;
 
        err = release_reference_state(cur_func(env), ref_obj_id);
        if (err)
                return err;
 
-       for (i = 0; i <= vstate->curframe; i++)
-               release_reg_references(env, vstate->frame[i], ref_obj_id);
+       bpf_for_each_reg_in_vstate(env->cur_state, state, reg, ({
+               if (reg->ref_obj_id == ref_obj_id)
+                       __mark_reg_unknown(env, reg);
+       }));
 
        return 0;
 }
@@ -9335,34 +9302,14 @@ static int check_alu_op(struct bpf_verifier_env *env, struct bpf_insn *insn)
        return 0;
 }
 
-static void __find_good_pkt_pointers(struct bpf_func_state *state,
-                                    struct bpf_reg_state *dst_reg,
-                                    enum bpf_reg_type type, int new_range)
-{
-       struct bpf_reg_state *reg;
-       int i;
-
-       for (i = 0; i < MAX_BPF_REG; i++) {
-               reg = &state->regs[i];
-               if (reg->type == type && reg->id == dst_reg->id)
-                       /* keep the maximum range already checked */
-                       reg->range = max(reg->range, new_range);
-       }
-
-       bpf_for_each_spilled_reg(i, state, reg) {
-               if (!reg)
-                       continue;
-               if (reg->type == type && reg->id == dst_reg->id)
-                       reg->range = max(reg->range, new_range);
-       }
-}
-
 static void find_good_pkt_pointers(struct bpf_verifier_state *vstate,
                                   struct bpf_reg_state *dst_reg,
                                   enum bpf_reg_type type,
                                   bool range_right_open)
 {
-       int new_range, i;
+       struct bpf_func_state *state;
+       struct bpf_reg_state *reg;
+       int new_range;
 
        if (dst_reg->off < 0 ||
            (dst_reg->off == 0 && range_right_open))
@@ -9427,9 +9374,11 @@ static void find_good_pkt_pointers(struct bpf_verifier_state *vstate,
         * the range won't allow anything.
         * dst_reg->off is known < MAX_PACKET_OFF, therefore it fits in a u16.
         */
-       for (i = 0; i <= vstate->curframe; i++)
-               __find_good_pkt_pointers(vstate->frame[i], dst_reg, type,
-                                        new_range);
+       bpf_for_each_reg_in_vstate(vstate, state, reg, ({
+               if (reg->type == type && reg->id == dst_reg->id)
+                       /* keep the maximum range already checked */
+                       reg->range = max(reg->range, new_range);
+       }));
 }
 
 static int is_branch32_taken(struct bpf_reg_state *reg, u32 val, u8 opcode)
@@ -9918,7 +9867,7 @@ static void mark_ptr_or_null_reg(struct bpf_func_state *state,
 
                if (!reg_may_point_to_spin_lock(reg)) {
                        /* For not-NULL ptr, reg->ref_obj_id will be reset
-                        * in release_reg_references().
+                        * in release_reference().
                         *
                         * reg->id is still used by spin_lock ptr. Other
                         * than spin_lock ptr type, reg->id can be reset.
@@ -9928,22 +9877,6 @@ static void mark_ptr_or_null_reg(struct bpf_func_state *state,
        }
 }
 
-static void __mark_ptr_or_null_regs(struct bpf_func_state *state, u32 id,
-                                   bool is_null)
-{
-       struct bpf_reg_state *reg;
-       int i;
-
-       for (i = 0; i < MAX_BPF_REG; i++)
-               mark_ptr_or_null_reg(state, &state->regs[i], id, is_null);
-
-       bpf_for_each_spilled_reg(i, state, reg) {
-               if (!reg)
-                       continue;
-               mark_ptr_or_null_reg(state, reg, id, is_null);
-       }
-}
-
 /* The logic is similar to find_good_pkt_pointers(), both could eventually
  * be folded together at some point.
  */
@@ -9951,10 +9884,9 @@ static void mark_ptr_or_null_regs(struct bpf_verifier_state *vstate, u32 regno,
                                  bool is_null)
 {
        struct bpf_func_state *state = vstate->frame[vstate->curframe];
-       struct bpf_reg_state *regs = state->regs;
+       struct bpf_reg_state *regs = state->regs, *reg;
        u32 ref_obj_id = regs[regno].ref_obj_id;
        u32 id = regs[regno].id;
-       int i;
 
        if (ref_obj_id && ref_obj_id == id && is_null)
                /* regs[regno] is in the " == NULL" branch.
@@ -9963,8 +9895,9 @@ static void mark_ptr_or_null_regs(struct bpf_verifier_state *vstate, u32 regno,
                 */
                WARN_ON_ONCE(release_reference_state(state, id));
 
-       for (i = 0; i <= vstate->curframe; i++)
-               __mark_ptr_or_null_regs(vstate->frame[i], id, is_null);
+       bpf_for_each_reg_in_vstate(vstate, state, reg, ({
+               mark_ptr_or_null_reg(state, reg, id, is_null);
+       }));
 }
 
 static bool try_match_pkt_pointers(const struct bpf_insn *insn,
@@ -10077,23 +10010,11 @@ static void find_equal_scalars(struct bpf_verifier_state *vstate,
 {
        struct bpf_func_state *state;
        struct bpf_reg_state *reg;
-       int i, j;
 
-       for (i = 0; i <= vstate->curframe; i++) {
-               state = vstate->frame[i];
-               for (j = 0; j < MAX_BPF_REG; j++) {
-                       reg = &state->regs[j];
-                       if (reg->type == SCALAR_VALUE && reg->id == known_reg->id)
-                               *reg = *known_reg;
-               }
-
-               bpf_for_each_spilled_reg(j, state, reg) {
-                       if (!reg)
-                               continue;
-                       if (reg->type == SCALAR_VALUE && reg->id == known_reg->id)
-                               *reg = *known_reg;
-               }
-       }
+       bpf_for_each_reg_in_vstate(vstate, state, reg, ({
+               if (reg->type == SCALAR_VALUE && reg->id == known_reg->id)
+                       *reg = *known_reg;
+       }));
 }
 
 static int check_cond_jmp_op(struct bpf_verifier_env *env,