bpf: propagate nullness information for reg to reg comparisons
authorEduard Zingerman <eddyz87@gmail.com>
Tue, 15 Nov 2022 22:48:58 +0000 (00:48 +0200)
committerAlexei Starovoitov <ast@kernel.org>
Wed, 16 Nov 2022 01:38:36 +0000 (17:38 -0800)
Propagate nullness information for branches of register to register
equality compare instructions. The following rules are used:
- suppose register A maybe null
- suppose register B is not null
- for JNE A, B, ... - A is not null in the false branch
- for JEQ A, B, ... - A is not null in the true branch

E.g. for program like below:

  r6 = skb->sk;
  r7 = sk_fullsock(r6);
  r0 = sk_fullsock(r6);
  if (r0 == 0) return 0;    (a)
  if (r0 != r7) return 0;   (b)
  *r7->type;                (c)
  return 0;

It is safe to dereference r7 at point (c), because of (a) and (b).

Signed-off-by: Eduard Zingerman <eddyz87@gmail.com>
Acked-by: Yonghong Song <yhs@fb.com>
Link: https://lore.kernel.org/r/20221115224859.2452988-2-eddyz87@gmail.com
Signed-off-by: Alexei Starovoitov <ast@kernel.org>
kernel/bpf/verifier.c

index be24774..0312d9c 100644 (file)
@@ -10267,6 +10267,7 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env,
        struct bpf_verifier_state *other_branch;
        struct bpf_reg_state *regs = this_branch->frame[this_branch->curframe]->regs;
        struct bpf_reg_state *dst_reg, *other_branch_regs, *src_reg = NULL;
+       struct bpf_reg_state *eq_branch_regs;
        u8 opcode = BPF_OP(insn->code);
        bool is_jmp32;
        int pred = -1;
@@ -10376,8 +10377,8 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env,
        /* detect if we are comparing against a constant value so we can adjust
         * our min/max values for our dst register.
         * this is only legit if both are scalars (or pointers to the same
-        * object, I suppose, but we don't support that right now), because
-        * otherwise the different base pointers mean the offsets aren't
+        * object, I suppose, see the PTR_MAYBE_NULL related if block below),
+        * because otherwise the different base pointers mean the offsets aren't
         * comparable.
         */
        if (BPF_SRC(insn->code) == BPF_X) {
@@ -10426,6 +10427,36 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env,
                find_equal_scalars(other_branch, &other_branch_regs[insn->dst_reg]);
        }
 
+       /* if one pointer register is compared to another pointer
+        * register check if PTR_MAYBE_NULL could be lifted.
+        * E.g. register A - maybe null
+        *      register B - not null
+        * for JNE A, B, ... - A is not null in the false branch;
+        * for JEQ A, B, ... - A is not null in the true branch.
+        */
+       if (!is_jmp32 && BPF_SRC(insn->code) == BPF_X &&
+           __is_pointer_value(false, src_reg) && __is_pointer_value(false, dst_reg) &&
+           type_may_be_null(src_reg->type) != type_may_be_null(dst_reg->type)) {
+               eq_branch_regs = NULL;
+               switch (opcode) {
+               case BPF_JEQ:
+                       eq_branch_regs = other_branch_regs;
+                       break;
+               case BPF_JNE:
+                       eq_branch_regs = regs;
+                       break;
+               default:
+                       /* do nothing */
+                       break;
+               }
+               if (eq_branch_regs) {
+                       if (type_may_be_null(src_reg->type))
+                               mark_ptr_not_null_reg(&eq_branch_regs[insn->src_reg]);
+                       else
+                               mark_ptr_not_null_reg(&eq_branch_regs[insn->dst_reg]);
+               }
+       }
+
        /* detect if R == 0 where R is returned from bpf_map_lookup_elem().
         * NOTE: these optimizations below are related with pointer comparison
         *       which will never be JMP32.