bpf, x86: Simplify the parsing logic of structure parameters
authorPu Lehui <pulehui@huawei.com>
Thu, 5 Jan 2023 03:50:26 +0000 (11:50 +0800)
committerMartin KaFai Lau <martin.lau@kernel.org>
Tue, 10 Jan 2023 23:53:22 +0000 (15:53 -0800)
Extra_nregs of structure parameters and nr_args can be
added directly at the beginning, and using a flip flag
to identifiy structure parameters. Meantime, renaming
some variables to make them more sense.

Signed-off-by: Pu Lehui <pulehui@huawei.com>
Acked-by: Yonghong Song <yhs@fb.com>
Link: https://lore.kernel.org/r/20230105035026.3091988-1-pulehui@huaweicloud.com
Signed-off-by: Martin KaFai Lau <martin.lau@kernel.org>
arch/x86/net/bpf_jit_comp.c

index 8db6077..1056bbf 100644 (file)
@@ -1857,62 +1857,59 @@ emit_jmp:
        return proglen;
 }
 
-static void save_regs(const struct btf_func_model *m, u8 **prog, int nr_args,
+static void save_regs(const struct btf_func_model *m, u8 **prog, int nr_regs,
                      int stack_size)
 {
-       int i, j, arg_size, nr_regs;
+       int i, j, arg_size;
+       bool next_same_struct = false;
+
        /* Store function arguments to stack.
         * For a function that accepts two pointers the sequence will be:
         * mov QWORD PTR [rbp-0x10],rdi
         * mov QWORD PTR [rbp-0x8],rsi
         */
-       for (i = 0, j = 0; i < min(nr_args, 6); i++) {
-               if (m->arg_flags[i] & BTF_FMODEL_STRUCT_ARG) {
-                       nr_regs = (m->arg_size[i] + 7) / 8;
+       for (i = 0, j = 0; i < min(nr_regs, 6); i++) {
+               /* The arg_size is at most 16 bytes, enforced by the verifier. */
+               arg_size = m->arg_size[j];
+               if (arg_size > 8) {
                        arg_size = 8;
-               } else {
-                       nr_regs = 1;
-                       arg_size = m->arg_size[i];
+                       next_same_struct = !next_same_struct;
                }
 
-               while (nr_regs) {
-                       emit_stx(prog, bytes_to_bpf_size(arg_size),
-                                BPF_REG_FP,
-                                j == 5 ? X86_REG_R9 : BPF_REG_1 + j,
-                                -(stack_size - j * 8));
-                       nr_regs--;
-                       j++;
-               }
+               emit_stx(prog, bytes_to_bpf_size(arg_size),
+                        BPF_REG_FP,
+                        i == 5 ? X86_REG_R9 : BPF_REG_1 + i,
+                        -(stack_size - i * 8));
+
+               j = next_same_struct ? j : j + 1;
        }
 }
 
-static void restore_regs(const struct btf_func_model *m, u8 **prog, int nr_args,
+static void restore_regs(const struct btf_func_model *m, u8 **prog, int nr_regs,
                         int stack_size)
 {
-       int i, j, arg_size, nr_regs;
+       int i, j, arg_size;
+       bool next_same_struct = false;
 
        /* Restore function arguments from stack.
         * For a function that accepts two pointers the sequence will be:
         * EMIT4(0x48, 0x8B, 0x7D, 0xF0); mov rdi,QWORD PTR [rbp-0x10]
         * EMIT4(0x48, 0x8B, 0x75, 0xF8); mov rsi,QWORD PTR [rbp-0x8]
         */
-       for (i = 0, j = 0; i < min(nr_args, 6); i++) {
-               if (m->arg_flags[i] & BTF_FMODEL_STRUCT_ARG) {
-                       nr_regs = (m->arg_size[i] + 7) / 8;
+       for (i = 0, j = 0; i < min(nr_regs, 6); i++) {
+               /* The arg_size is at most 16 bytes, enforced by the verifier. */
+               arg_size = m->arg_size[j];
+               if (arg_size > 8) {
                        arg_size = 8;
-               } else {
-                       nr_regs = 1;
-                       arg_size = m->arg_size[i];
+                       next_same_struct = !next_same_struct;
                }
 
-               while (nr_regs) {
-                       emit_ldx(prog, bytes_to_bpf_size(arg_size),
-                                j == 5 ? X86_REG_R9 : BPF_REG_1 + j,
-                                BPF_REG_FP,
-                                -(stack_size - j * 8));
-                       nr_regs--;
-                       j++;
-               }
+               emit_ldx(prog, bytes_to_bpf_size(arg_size),
+                        i == 5 ? X86_REG_R9 : BPF_REG_1 + i,
+                        BPF_REG_FP,
+                        -(stack_size - i * 8));
+
+               j = next_same_struct ? j : j + 1;
        }
 }
 
@@ -2138,8 +2135,8 @@ int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *i
                                struct bpf_tramp_links *tlinks,
                                void *func_addr)
 {
-       int ret, i, nr_args = m->nr_args, extra_nregs = 0;
-       int regs_off, ip_off, args_off, stack_size = nr_args * 8, run_ctx_off;
+       int i, ret, nr_regs = m->nr_args, stack_size = 0;
+       int regs_off, nregs_off, ip_off, run_ctx_off;
        struct bpf_tramp_links *fentry = &tlinks[BPF_TRAMP_FENTRY];
        struct bpf_tramp_links *fexit = &tlinks[BPF_TRAMP_FEXIT];
        struct bpf_tramp_links *fmod_ret = &tlinks[BPF_TRAMP_MODIFY_RETURN];
@@ -2148,17 +2145,14 @@ int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *i
        u8 *prog;
        bool save_ret;
 
-       /* x86-64 supports up to 6 arguments. 7+ can be added in the future */
-       if (nr_args > 6)
-               return -ENOTSUPP;
-
-       for (i = 0; i < MAX_BPF_FUNC_ARGS; i++) {
+       /* extra registers for struct arguments */
+       for (i = 0; i < m->nr_args; i++)
                if (m->arg_flags[i] & BTF_FMODEL_STRUCT_ARG)
-                       extra_nregs += (m->arg_size[i] + 7) / 8 - 1;
-       }
-       if (nr_args + extra_nregs > 6)
+                       nr_regs += (m->arg_size[i] + 7) / 8 - 1;
+
+       /* x86-64 supports up to 6 arguments. 7+ can be added in the future */
+       if (nr_regs > 6)
                return -ENOTSUPP;
-       stack_size += extra_nregs * 8;
 
        /* Generated trampoline stack layout:
         *
@@ -2172,7 +2166,7 @@ int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *i
         *                 [ ...             ]
         * RBP - regs_off  [ reg_arg1        ]  program's ctx pointer
         *
-        * RBP - args_off  [ arg regs count  ]  always
+        * RBP - nregs_off [ regs count      ]  always
         *
         * RBP - ip_off    [ traced function ]  BPF_TRAMP_F_IP_ARG flag
         *
@@ -2184,11 +2178,12 @@ int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *i
        if (save_ret)
                stack_size += 8;
 
+       stack_size += nr_regs * 8;
        regs_off = stack_size;
 
-       /* args count  */
+       /* regs count  */
        stack_size += 8;
-       args_off = stack_size;
+       nregs_off = stack_size;
 
        if (flags & BPF_TRAMP_F_IP_ARG)
                stack_size += 8; /* room for IP address argument */
@@ -2221,11 +2216,11 @@ int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *i
        EMIT1(0x53);             /* push rbx */
 
        /* Store number of argument registers of the traced function:
-        *   mov rax, nr_args + extra_nregs
-        *   mov QWORD PTR [rbp - args_off], rax
+        *   mov rax, nr_regs
+        *   mov QWORD PTR [rbp - nregs_off], rax
         */
-       emit_mov_imm64(&prog, BPF_REG_0, 0, (u32) nr_args + extra_nregs);
-       emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_0, -args_off);
+       emit_mov_imm64(&prog, BPF_REG_0, 0, (u32) nr_regs);
+       emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_0, -nregs_off);
 
        if (flags & BPF_TRAMP_F_IP_ARG) {
                /* Store IP address of the traced function:
@@ -2236,7 +2231,7 @@ int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *i
                emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_0, -ip_off);
        }
 
-       save_regs(m, &prog, nr_args, regs_off);
+       save_regs(m, &prog, nr_regs, regs_off);
 
        if (flags & BPF_TRAMP_F_CALL_ORIG) {
                /* arg1: mov rdi, im */
@@ -2266,7 +2261,7 @@ int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *i
        }
 
        if (flags & BPF_TRAMP_F_CALL_ORIG) {
-               restore_regs(m, &prog, nr_args, regs_off);
+               restore_regs(m, &prog, nr_regs, regs_off);
 
                if (flags & BPF_TRAMP_F_ORIG_STACK) {
                        emit_ldx(&prog, BPF_DW, BPF_REG_0, BPF_REG_FP, 8);
@@ -2307,7 +2302,7 @@ int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *i
                }
 
        if (flags & BPF_TRAMP_F_RESTORE_REGS)
-               restore_regs(m, &prog, nr_args, regs_off);
+               restore_regs(m, &prog, nr_regs, regs_off);
 
        /* This needs to be done regardless. If there were fmod_ret programs,
         * the return value is only updated on the stack and still needs to be