bpf: Support new sign-extension load insns
authorYonghong Song <yonghong.song@linux.dev>
Fri, 28 Jul 2023 01:11:56 +0000 (18:11 -0700)
committerAlexei Starovoitov <ast@kernel.org>
Fri, 28 Jul 2023 01:52:33 +0000 (18:52 -0700)
Add interpreter/jit support for new sign-extension load insns
which adds a new mode (BPF_MEMSX).
Also add verifier support to recognize these insns and to
do proper verification with new insns. In verifier, besides
to deduce proper bounds for the dst_reg, probed memory access
is also properly handled.

Acked-by: Eduard Zingerman <eddyz87@gmail.com>
Signed-off-by: Yonghong Song <yonghong.song@linux.dev>
Link: https://lore.kernel.org/r/20230728011156.3711870-1-yonghong.song@linux.dev
Signed-off-by: Alexei Starovoitov <ast@kernel.org>
arch/x86/net/bpf_jit_comp.c
include/linux/filter.h
include/uapi/linux/bpf.h
kernel/bpf/core.c
kernel/bpf/verifier.c
tools/include/uapi/linux/bpf.h

index 83c4b45..54478a9 100644 (file)
@@ -779,6 +779,29 @@ static void emit_ldx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off)
        *pprog = prog;
 }
 
+/* LDSX: dst_reg = *(s8*)(src_reg + off) */
+static void emit_ldsx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off)
+{
+       u8 *prog = *pprog;
+
+       switch (size) {
+       case BPF_B:
+               /* Emit 'movsx rax, byte ptr [rax + off]' */
+               EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xBE);
+               break;
+       case BPF_H:
+               /* Emit 'movsx rax, word ptr [rax + off]' */
+               EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xBF);
+               break;
+       case BPF_W:
+               /* Emit 'movsx rax, dword ptr [rax+0x14]' */
+               EMIT2(add_2mod(0x48, src_reg, dst_reg), 0x63);
+               break;
+       }
+       emit_insn_suffix(&prog, src_reg, dst_reg, off);
+       *pprog = prog;
+}
+
 /* STX: *(u8*)(dst_reg + off) = src_reg */
 static void emit_stx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off)
 {
@@ -1370,9 +1393,17 @@ st:                      if (is_imm8(insn->off))
                case BPF_LDX | BPF_PROBE_MEM | BPF_W:
                case BPF_LDX | BPF_MEM | BPF_DW:
                case BPF_LDX | BPF_PROBE_MEM | BPF_DW:
+                       /* LDXS: dst_reg = *(s8*)(src_reg + off) */
+               case BPF_LDX | BPF_MEMSX | BPF_B:
+               case BPF_LDX | BPF_MEMSX | BPF_H:
+               case BPF_LDX | BPF_MEMSX | BPF_W:
+               case BPF_LDX | BPF_PROBE_MEMSX | BPF_B:
+               case BPF_LDX | BPF_PROBE_MEMSX | BPF_H:
+               case BPF_LDX | BPF_PROBE_MEMSX | BPF_W:
                        insn_off = insn->off;
 
-                       if (BPF_MODE(insn->code) == BPF_PROBE_MEM) {
+                       if (BPF_MODE(insn->code) == BPF_PROBE_MEM ||
+                           BPF_MODE(insn->code) == BPF_PROBE_MEMSX) {
                                /* Conservatively check that src_reg + insn->off is a kernel address:
                                 *   src_reg + insn->off >= TASK_SIZE_MAX + PAGE_SIZE
                                 * src_reg is used as scratch for src_reg += insn->off and restored
@@ -1415,8 +1446,13 @@ st:                      if (is_imm8(insn->off))
                                start_of_ldx = prog;
                                end_of_jmp[-1] = start_of_ldx - end_of_jmp;
                        }
-                       emit_ldx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn_off);
-                       if (BPF_MODE(insn->code) == BPF_PROBE_MEM) {
+                       if (BPF_MODE(insn->code) == BPF_PROBE_MEMSX ||
+                           BPF_MODE(insn->code) == BPF_MEMSX)
+                               emit_ldsx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn_off);
+                       else
+                               emit_ldx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn_off);
+                       if (BPF_MODE(insn->code) == BPF_PROBE_MEM ||
+                           BPF_MODE(insn->code) == BPF_PROBE_MEMSX) {
                                struct exception_table_entry *ex;
                                u8 *_insn = image + proglen + (start_of_ldx - temp);
                                s64 delta;
index f691140..a93242b 100644 (file)
@@ -69,6 +69,9 @@ struct ctl_table_header;
 /* unused opcode to mark special load instruction. Same as BPF_ABS */
 #define BPF_PROBE_MEM  0x20
 
+/* unused opcode to mark special ldsx instruction. Same as BPF_IND */
+#define BPF_PROBE_MEMSX        0x40
+
 /* unused opcode to mark call to interpreter with arguments */
 #define BPF_CALL_ARGS  0xe0
 
index 7fc98f4..14fd26b 100644 (file)
@@ -19,6 +19,7 @@
 
 /* ld/ldx fields */
 #define BPF_DW         0x18    /* double word (64-bit) */
+#define BPF_MEMSX      0x80    /* load with sign extension */
 #define BPF_ATOMIC     0xc0    /* atomic memory ops - op type in immediate */
 #define BPF_XADD       0xc0    /* exclusive add - legacy name */
 
index dc85240..01b72fc 100644 (file)
@@ -1610,6 +1610,9 @@ EXPORT_SYMBOL_GPL(__bpf_call_base);
        INSN_3(LDX, MEM, H),                    \
        INSN_3(LDX, MEM, W),                    \
        INSN_3(LDX, MEM, DW),                   \
+       INSN_3(LDX, MEMSX, B),                  \
+       INSN_3(LDX, MEMSX, H),                  \
+       INSN_3(LDX, MEMSX, W),                  \
        /*   Immediate based. */                \
        INSN_3(LD, IMM, DW)
 
@@ -1666,6 +1669,9 @@ static u64 ___bpf_prog_run(u64 *regs, const struct bpf_insn *insn)
                [BPF_LDX | BPF_PROBE_MEM | BPF_H] = &&LDX_PROBE_MEM_H,
                [BPF_LDX | BPF_PROBE_MEM | BPF_W] = &&LDX_PROBE_MEM_W,
                [BPF_LDX | BPF_PROBE_MEM | BPF_DW] = &&LDX_PROBE_MEM_DW,
+               [BPF_LDX | BPF_PROBE_MEMSX | BPF_B] = &&LDX_PROBE_MEMSX_B,
+               [BPF_LDX | BPF_PROBE_MEMSX | BPF_H] = &&LDX_PROBE_MEMSX_H,
+               [BPF_LDX | BPF_PROBE_MEMSX | BPF_W] = &&LDX_PROBE_MEMSX_W,
        };
 #undef BPF_INSN_3_LBL
 #undef BPF_INSN_2_LBL
@@ -1942,6 +1948,21 @@ out:
        LDST(DW, u64)
 #undef LDST
 
+#define LDSX(SIZEOP, SIZE)                                             \
+       LDX_MEMSX_##SIZEOP:                                             \
+               DST = *(SIZE *)(unsigned long) (SRC + insn->off);       \
+               CONT;                                                   \
+       LDX_PROBE_MEMSX_##SIZEOP:                                       \
+               bpf_probe_read_kernel(&DST, sizeof(SIZE),               \
+                                     (const void *)(long) (SRC + insn->off));  \
+               DST = *((SIZE *)&DST);                                  \
+               CONT;
+
+       LDSX(B,   s8)
+       LDSX(H,  s16)
+       LDSX(W,  s32)
+#undef LDSX
+
 #define ATOMIC_ALU_OP(BOP, KOP)                                                \
                case BOP:                                               \
                        if (BPF_SIZE(insn->code) == BPF_W)              \
index 71473c1..b154854 100644 (file)
@@ -5827,6 +5827,84 @@ static void coerce_reg_to_size(struct bpf_reg_state *reg, int size)
        __reg_combine_64_into_32(reg);
 }
 
+static void set_sext64_default_val(struct bpf_reg_state *reg, int size)
+{
+       if (size == 1) {
+               reg->smin_value = reg->s32_min_value = S8_MIN;
+               reg->smax_value = reg->s32_max_value = S8_MAX;
+       } else if (size == 2) {
+               reg->smin_value = reg->s32_min_value = S16_MIN;
+               reg->smax_value = reg->s32_max_value = S16_MAX;
+       } else {
+               /* size == 4 */
+               reg->smin_value = reg->s32_min_value = S32_MIN;
+               reg->smax_value = reg->s32_max_value = S32_MAX;
+       }
+       reg->umin_value = reg->u32_min_value = 0;
+       reg->umax_value = U64_MAX;
+       reg->u32_max_value = U32_MAX;
+       reg->var_off = tnum_unknown;
+}
+
+static void coerce_reg_to_size_sx(struct bpf_reg_state *reg, int size)
+{
+       s64 init_s64_max, init_s64_min, s64_max, s64_min, u64_cval;
+       u64 top_smax_value, top_smin_value;
+       u64 num_bits = size * 8;
+
+       if (tnum_is_const(reg->var_off)) {
+               u64_cval = reg->var_off.value;
+               if (size == 1)
+                       reg->var_off = tnum_const((s8)u64_cval);
+               else if (size == 2)
+                       reg->var_off = tnum_const((s16)u64_cval);
+               else
+                       /* size == 4 */
+                       reg->var_off = tnum_const((s32)u64_cval);
+
+               u64_cval = reg->var_off.value;
+               reg->smax_value = reg->smin_value = u64_cval;
+               reg->umax_value = reg->umin_value = u64_cval;
+               reg->s32_max_value = reg->s32_min_value = u64_cval;
+               reg->u32_max_value = reg->u32_min_value = u64_cval;
+               return;
+       }
+
+       top_smax_value = ((u64)reg->smax_value >> num_bits) << num_bits;
+       top_smin_value = ((u64)reg->smin_value >> num_bits) << num_bits;
+
+       if (top_smax_value != top_smin_value)
+               goto out;
+
+       /* find the s64_min and s64_min after sign extension */
+       if (size == 1) {
+               init_s64_max = (s8)reg->smax_value;
+               init_s64_min = (s8)reg->smin_value;
+       } else if (size == 2) {
+               init_s64_max = (s16)reg->smax_value;
+               init_s64_min = (s16)reg->smin_value;
+       } else {
+               init_s64_max = (s32)reg->smax_value;
+               init_s64_min = (s32)reg->smin_value;
+       }
+
+       s64_max = max(init_s64_max, init_s64_min);
+       s64_min = min(init_s64_max, init_s64_min);
+
+       /* both of s64_max/s64_min positive or negative */
+       if (s64_max >= 0 == s64_min >= 0) {
+               reg->smin_value = reg->s32_min_value = s64_min;
+               reg->smax_value = reg->s32_max_value = s64_max;
+               reg->umin_value = reg->u32_min_value = s64_min;
+               reg->umax_value = reg->u32_max_value = s64_max;
+               reg->var_off = tnum_range(s64_min, s64_max);
+               return;
+       }
+
+out:
+       set_sext64_default_val(reg, size);
+}
+
 static bool bpf_map_is_rdonly(const struct bpf_map *map)
 {
        /* A map is considered read-only if the following condition are true:
@@ -5847,7 +5925,8 @@ static bool bpf_map_is_rdonly(const struct bpf_map *map)
               !bpf_map_write_active(map);
 }
 
-static int bpf_map_direct_read(struct bpf_map *map, int off, int size, u64 *val)
+static int bpf_map_direct_read(struct bpf_map *map, int off, int size, u64 *val,
+                              bool is_ldsx)
 {
        void *ptr;
        u64 addr;
@@ -5860,13 +5939,13 @@ static int bpf_map_direct_read(struct bpf_map *map, int off, int size, u64 *val)
 
        switch (size) {
        case sizeof(u8):
-               *val = (u64)*(u8 *)ptr;
+               *val = is_ldsx ? (s64)*(s8 *)ptr : (u64)*(u8 *)ptr;
                break;
        case sizeof(u16):
-               *val = (u64)*(u16 *)ptr;
+               *val = is_ldsx ? (s64)*(s16 *)ptr : (u64)*(u16 *)ptr;
                break;
        case sizeof(u32):
-               *val = (u64)*(u32 *)ptr;
+               *val = is_ldsx ? (s64)*(s32 *)ptr : (u64)*(u32 *)ptr;
                break;
        case sizeof(u64):
                *val = *(u64 *)ptr;
@@ -6285,7 +6364,7 @@ static int check_stack_access_within_bounds(
  */
 static int check_mem_access(struct bpf_verifier_env *env, int insn_idx, u32 regno,
                            int off, int bpf_size, enum bpf_access_type t,
-                           int value_regno, bool strict_alignment_once)
+                           int value_regno, bool strict_alignment_once, bool is_ldsx)
 {
        struct bpf_reg_state *regs = cur_regs(env);
        struct bpf_reg_state *reg = regs + regno;
@@ -6346,7 +6425,7 @@ static int check_mem_access(struct bpf_verifier_env *env, int insn_idx, u32 regn
                                u64 val = 0;
 
                                err = bpf_map_direct_read(map, map_off, size,
-                                                         &val);
+                                                         &val, is_ldsx);
                                if (err)
                                        return err;
 
@@ -6516,8 +6595,11 @@ static int check_mem_access(struct bpf_verifier_env *env, int insn_idx, u32 regn
 
        if (!err && size < BPF_REG_SIZE && value_regno >= 0 && t == BPF_READ &&
            regs[value_regno].type == SCALAR_VALUE) {
-               /* b/h/w load zero-extends, mark upper bits as known 0 */
-               coerce_reg_to_size(&regs[value_regno], size);
+               if (!is_ldsx)
+                       /* b/h/w load zero-extends, mark upper bits as known 0 */
+                       coerce_reg_to_size(&regs[value_regno], size);
+               else
+                       coerce_reg_to_size_sx(&regs[value_regno], size);
        }
        return err;
 }
@@ -6609,17 +6691,17 @@ static int check_atomic(struct bpf_verifier_env *env, int insn_idx, struct bpf_i
         * case to simulate the register fill.
         */
        err = check_mem_access(env, insn_idx, insn->dst_reg, insn->off,
-                              BPF_SIZE(insn->code), BPF_READ, -1, true);
+                              BPF_SIZE(insn->code), BPF_READ, -1, true, false);
        if (!err && load_reg >= 0)
                err = check_mem_access(env, insn_idx, insn->dst_reg, insn->off,
                                       BPF_SIZE(insn->code), BPF_READ, load_reg,
-                                      true);
+                                      true, false);
        if (err)
                return err;
 
        /* Check whether we can write into the same memory. */
        err = check_mem_access(env, insn_idx, insn->dst_reg, insn->off,
-                              BPF_SIZE(insn->code), BPF_WRITE, -1, true);
+                              BPF_SIZE(insn->code), BPF_WRITE, -1, true, false);
        if (err)
                return err;
 
@@ -6865,7 +6947,7 @@ static int check_helper_mem_access(struct bpf_verifier_env *env, int regno,
                                return zero_size_allowed ? 0 : -EACCES;
 
                        return check_mem_access(env, env->insn_idx, regno, offset, BPF_B,
-                                               atype, -1, false);
+                                               atype, -1, false, false);
                }
 
                fallthrough;
@@ -7237,7 +7319,7 @@ static int process_dynptr_func(struct bpf_verifier_env *env, int regno, int insn
                /* we write BPF_DW bits (8 bytes) at a time */
                for (i = 0; i < BPF_DYNPTR_SIZE; i += 8) {
                        err = check_mem_access(env, insn_idx, regno,
-                                              i, BPF_DW, BPF_WRITE, -1, false);
+                                              i, BPF_DW, BPF_WRITE, -1, false, false);
                        if (err)
                                return err;
                }
@@ -7330,7 +7412,7 @@ static int process_iter_arg(struct bpf_verifier_env *env, int regno, int insn_id
 
                for (i = 0; i < nr_slots * 8; i += BPF_REG_SIZE) {
                        err = check_mem_access(env, insn_idx, regno,
-                                              i, BPF_DW, BPF_WRITE, -1, false);
+                                              i, BPF_DW, BPF_WRITE, -1, false, false);
                        if (err)
                                return err;
                }
@@ -9474,7 +9556,7 @@ static int check_helper_call(struct bpf_verifier_env *env, struct bpf_insn *insn
         */
        for (i = 0; i < meta.access_size; i++) {
                err = check_mem_access(env, insn_idx, meta.regno, i, BPF_B,
-                                      BPF_WRITE, -1, false);
+                                      BPF_WRITE, -1, false, false);
                if (err)
                        return err;
        }
@@ -16202,7 +16284,7 @@ static int save_aux_ptr_type(struct bpf_verifier_env *env, enum bpf_reg_type typ
                         * Have to support a use case when one path through
                         * the program yields TRUSTED pointer while another
                         * is UNTRUSTED. Fallback to UNTRUSTED to generate
-                        * BPF_PROBE_MEM.
+                        * BPF_PROBE_MEM/BPF_PROBE_MEMSX.
                         */
                        *prev_type = PTR_TO_BTF_ID | PTR_UNTRUSTED;
                } else {
@@ -16343,7 +16425,8 @@ static int do_check(struct bpf_verifier_env *env)
                         */
                        err = check_mem_access(env, env->insn_idx, insn->src_reg,
                                               insn->off, BPF_SIZE(insn->code),
-                                              BPF_READ, insn->dst_reg, false);
+                                              BPF_READ, insn->dst_reg, false,
+                                              BPF_MODE(insn->code) == BPF_MEMSX);
                        if (err)
                                return err;
 
@@ -16380,7 +16463,7 @@ static int do_check(struct bpf_verifier_env *env)
                        /* check that memory (dst_reg + off) is writeable */
                        err = check_mem_access(env, env->insn_idx, insn->dst_reg,
                                               insn->off, BPF_SIZE(insn->code),
-                                              BPF_WRITE, insn->src_reg, false);
+                                              BPF_WRITE, insn->src_reg, false, false);
                        if (err)
                                return err;
 
@@ -16405,7 +16488,7 @@ static int do_check(struct bpf_verifier_env *env)
                        /* check that memory (dst_reg + off) is writeable */
                        err = check_mem_access(env, env->insn_idx, insn->dst_reg,
                                               insn->off, BPF_SIZE(insn->code),
-                                              BPF_WRITE, -1, false);
+                                              BPF_WRITE, -1, false, false);
                        if (err)
                                return err;
 
@@ -16833,7 +16916,8 @@ static int resolve_pseudo_ldimm64(struct bpf_verifier_env *env)
 
        for (i = 0; i < insn_cnt; i++, insn++) {
                if (BPF_CLASS(insn->code) == BPF_LDX &&
-                   (BPF_MODE(insn->code) != BPF_MEM || insn->imm != 0)) {
+                   ((BPF_MODE(insn->code) != BPF_MEM && BPF_MODE(insn->code) != BPF_MEMSX) ||
+                   insn->imm != 0)) {
                        verbose(env, "BPF_LDX uses reserved fields\n");
                        return -EINVAL;
                }
@@ -17531,7 +17615,10 @@ static int convert_ctx_accesses(struct bpf_verifier_env *env)
                if (insn->code == (BPF_LDX | BPF_MEM | BPF_B) ||
                    insn->code == (BPF_LDX | BPF_MEM | BPF_H) ||
                    insn->code == (BPF_LDX | BPF_MEM | BPF_W) ||
-                   insn->code == (BPF_LDX | BPF_MEM | BPF_DW)) {
+                   insn->code == (BPF_LDX | BPF_MEM | BPF_DW) ||
+                   insn->code == (BPF_LDX | BPF_MEMSX | BPF_B) ||
+                   insn->code == (BPF_LDX | BPF_MEMSX | BPF_H) ||
+                   insn->code == (BPF_LDX | BPF_MEMSX | BPF_W)) {
                        type = BPF_READ;
                } else if (insn->code == (BPF_STX | BPF_MEM | BPF_B) ||
                           insn->code == (BPF_STX | BPF_MEM | BPF_H) ||
@@ -17590,8 +17677,12 @@ static int convert_ctx_accesses(struct bpf_verifier_env *env)
                 */
                case PTR_TO_BTF_ID | MEM_ALLOC | PTR_UNTRUSTED:
                        if (type == BPF_READ) {
-                               insn->code = BPF_LDX | BPF_PROBE_MEM |
-                                       BPF_SIZE((insn)->code);
+                               if (BPF_MODE(insn->code) == BPF_MEM)
+                                       insn->code = BPF_LDX | BPF_PROBE_MEM |
+                                                    BPF_SIZE((insn)->code);
+                               else
+                                       insn->code = BPF_LDX | BPF_PROBE_MEMSX |
+                                                    BPF_SIZE((insn)->code);
                                env->prog->aux->num_exentries++;
                        }
                        continue;
@@ -17779,7 +17870,8 @@ static int jit_subprogs(struct bpf_verifier_env *env)
                insn = func[i]->insnsi;
                for (j = 0; j < func[i]->len; j++, insn++) {
                        if (BPF_CLASS(insn->code) == BPF_LDX &&
-                           BPF_MODE(insn->code) == BPF_PROBE_MEM)
+                           (BPF_MODE(insn->code) == BPF_PROBE_MEM ||
+                            BPF_MODE(insn->code) == BPF_PROBE_MEMSX))
                                num_exentries++;
                }
                func[i]->aux->num_exentries = num_exentries;
index 7fc98f4..14fd26b 100644 (file)
@@ -19,6 +19,7 @@
 
 /* ld/ldx fields */
 #define BPF_DW         0x18    /* double word (64-bit) */
+#define BPF_MEMSX      0x80    /* load with sign extension */
 #define BPF_ATOMIC     0xc0    /* atomic memory ops - op type in immediate */
 #define BPF_XADD       0xc0    /* exclusive add - legacy name */