bpf: Support kernel function call in x86-32
authorMartin KaFai Lau <kafai@fb.com>
Thu, 25 Mar 2021 01:51:49 +0000 (18:51 -0700)
committerAlexei Starovoitov <ast@kernel.org>
Sat, 27 Mar 2021 03:41:51 +0000 (20:41 -0700)
This patch adds kernel function call support to the x86-32 bpf jit.

Signed-off-by: Martin KaFai Lau <kafai@fb.com>
Signed-off-by: Alexei Starovoitov <ast@kernel.org>
Link: https://lore.kernel.org/bpf/20210325015149.1545267-1-kafai@fb.com
arch/x86/net/bpf_jit_comp32.c

index d17b67c..0a7a287 100644 (file)
@@ -1390,6 +1390,19 @@ static inline void emit_push_r64(const u8 src[], u8 **pprog)
        *pprog = prog;
 }
 
+static void emit_push_r32(const u8 src[], u8 **pprog)
+{
+       u8 *prog = *pprog;
+       int cnt = 0;
+
+       /* mov ecx,dword ptr [ebp+off] */
+       EMIT3(0x8B, add_2reg(0x40, IA32_EBP, IA32_ECX), STACK_VAR(src_lo));
+       /* push ecx */
+       EMIT1(0x51);
+
+       *pprog = prog;
+}
+
 static u8 get_cond_jmp_opcode(const u8 op, bool is_cmp_lo)
 {
        u8 jmp_cond;
@@ -1459,6 +1472,174 @@ static u8 get_cond_jmp_opcode(const u8 op, bool is_cmp_lo)
        return jmp_cond;
 }
 
+/* i386 kernel compiles with "-mregparm=3".  From gcc document:
+ *
+ * ==== snippet ====
+ * regparm (number)
+ *     On x86-32 targets, the regparm attribute causes the compiler
+ *     to pass arguments number one to (number) if they are of integral
+ *     type in registers EAX, EDX, and ECX instead of on the stack.
+ *     Functions that take a variable number of arguments continue
+ *     to be passed all of their arguments on the stack.
+ * ==== snippet ====
+ *
+ * The first three args of a function will be considered for
+ * putting into the 32bit register EAX, EDX, and ECX.
+ *
+ * Two 32bit registers are used to pass a 64bit arg.
+ *
+ * For example,
+ * void foo(u32 a, u32 b, u32 c, u32 d):
+ *     u32 a: EAX
+ *     u32 b: EDX
+ *     u32 c: ECX
+ *     u32 d: stack
+ *
+ * void foo(u64 a, u32 b, u32 c):
+ *     u64 a: EAX (lo32) EDX (hi32)
+ *     u32 b: ECX
+ *     u32 c: stack
+ *
+ * void foo(u32 a, u64 b, u32 c):
+ *     u32 a: EAX
+ *     u64 b: EDX (lo32) ECX (hi32)
+ *     u32 c: stack
+ *
+ * void foo(u32 a, u32 b, u64 c):
+ *     u32 a: EAX
+ *     u32 b: EDX
+ *     u64 c: stack
+ *
+ * The return value will be stored in the EAX (and EDX for 64bit value).
+ *
+ * For example,
+ * u32 foo(u32 a, u32 b, u32 c):
+ *     return value: EAX
+ *
+ * u64 foo(u32 a, u32 b, u32 c):
+ *     return value: EAX (lo32) EDX (hi32)
+ *
+ * Notes:
+ *     The verifier only accepts function having integer and pointers
+ *     as its args and return value, so it does not have
+ *     struct-by-value.
+ *
+ * emit_kfunc_call() finds out the btf_func_model by calling
+ * bpf_jit_find_kfunc_model().  A btf_func_model
+ * has the details about the number of args, size of each arg,
+ * and the size of the return value.
+ *
+ * It first decides how many args can be passed by EAX, EDX, and ECX.
+ * That will decide what args should be pushed to the stack:
+ * [first_stack_regno, last_stack_regno] are the bpf regnos
+ * that should be pushed to the stack.
+ *
+ * It will first push all args to the stack because the push
+ * will need to use ECX.  Then, it moves
+ * [BPF_REG_1, first_stack_regno) to EAX, EDX, and ECX.
+ *
+ * When emitting a call (0xE8), it needs to figure out
+ * the jmp_offset relative to the jit-insn address immediately
+ * following the call (0xE8) instruction.  At this point, it knows
+ * the end of the jit-insn address after completely translated the
+ * current (BPF_JMP | BPF_CALL) bpf-insn.  It is passed as "end_addr"
+ * to the emit_kfunc_call().  Thus, it can learn the "immediate-follow-call"
+ * address by figuring out how many jit-insn is generated between
+ * the call (0xE8) and the end_addr:
+ *     - 0-1 jit-insn (3 bytes each) to restore the esp pointer if there
+ *       is arg pushed to the stack.
+ *     - 0-2 jit-insns (3 bytes each) to handle the return value.
+ */
+static int emit_kfunc_call(const struct bpf_prog *bpf_prog, u8 *end_addr,
+                          const struct bpf_insn *insn, u8 **pprog)
+{
+       const u8 arg_regs[] = { IA32_EAX, IA32_EDX, IA32_ECX };
+       int i, cnt = 0, first_stack_regno, last_stack_regno;
+       int free_arg_regs = ARRAY_SIZE(arg_regs);
+       const struct btf_func_model *fm;
+       int bytes_in_stack = 0;
+       const u8 *cur_arg_reg;
+       u8 *prog = *pprog;
+       s64 jmp_offset;
+
+       fm = bpf_jit_find_kfunc_model(bpf_prog, insn);
+       if (!fm)
+               return -EINVAL;
+
+       first_stack_regno = BPF_REG_1;
+       for (i = 0; i < fm->nr_args; i++) {
+               int regs_needed = fm->arg_size[i] > sizeof(u32) ? 2 : 1;
+
+               if (regs_needed > free_arg_regs)
+                       break;
+
+               free_arg_regs -= regs_needed;
+               first_stack_regno++;
+       }
+
+       /* Push the args to the stack */
+       last_stack_regno = BPF_REG_0 + fm->nr_args;
+       for (i = last_stack_regno; i >= first_stack_regno; i--) {
+               if (fm->arg_size[i - 1] > sizeof(u32)) {
+                       emit_push_r64(bpf2ia32[i], &prog);
+                       bytes_in_stack += 8;
+               } else {
+                       emit_push_r32(bpf2ia32[i], &prog);
+                       bytes_in_stack += 4;
+               }
+       }
+
+       cur_arg_reg = &arg_regs[0];
+       for (i = BPF_REG_1; i < first_stack_regno; i++) {
+               /* mov e[adc]x,dword ptr [ebp+off] */
+               EMIT3(0x8B, add_2reg(0x40, IA32_EBP, *cur_arg_reg++),
+                     STACK_VAR(bpf2ia32[i][0]));
+               if (fm->arg_size[i - 1] > sizeof(u32))
+                       /* mov e[adc]x,dword ptr [ebp+off] */
+                       EMIT3(0x8B, add_2reg(0x40, IA32_EBP, *cur_arg_reg++),
+                             STACK_VAR(bpf2ia32[i][1]));
+       }
+
+       if (bytes_in_stack)
+               /* add esp,"bytes_in_stack" */
+               end_addr -= 3;
+
+       /* mov dword ptr [ebp+off],edx */
+       if (fm->ret_size > sizeof(u32))
+               end_addr -= 3;
+
+       /* mov dword ptr [ebp+off],eax */
+       if (fm->ret_size)
+               end_addr -= 3;
+
+       jmp_offset = (u8 *)__bpf_call_base + insn->imm - end_addr;
+       if (!is_simm32(jmp_offset)) {
+               pr_err("unsupported BPF kernel function jmp_offset:%lld\n",
+                      jmp_offset);
+               return -EINVAL;
+       }
+
+       EMIT1_off32(0xE8, jmp_offset);
+
+       if (fm->ret_size)
+               /* mov dword ptr [ebp+off],eax */
+               EMIT3(0x89, add_2reg(0x40, IA32_EBP, IA32_EAX),
+                     STACK_VAR(bpf2ia32[BPF_REG_0][0]));
+
+       if (fm->ret_size > sizeof(u32))
+               /* mov dword ptr [ebp+off],edx */
+               EMIT3(0x89, add_2reg(0x40, IA32_EBP, IA32_EDX),
+                     STACK_VAR(bpf2ia32[BPF_REG_0][1]));
+
+       if (bytes_in_stack)
+               /* add esp,"bytes_in_stack" */
+               EMIT3(0x83, add_1reg(0xC0, IA32_ESP), bytes_in_stack);
+
+       *pprog = prog;
+
+       return 0;
+}
+
 static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
                  int oldproglen, struct jit_context *ctx)
 {
@@ -1888,6 +2069,18 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
                        if (insn->src_reg == BPF_PSEUDO_CALL)
                                goto notyet;
 
+                       if (insn->src_reg == BPF_PSEUDO_KFUNC_CALL) {
+                               int err;
+
+                               err = emit_kfunc_call(bpf_prog,
+                                                     image + addrs[i],
+                                                     insn, &prog);
+
+                               if (err)
+                                       return err;
+                               break;
+                       }
+
                        func = (u8 *) __bpf_call_base + imm32;
                        jmp_offset = func - (image + addrs[i]);
 
@@ -2393,3 +2586,8 @@ out:
                                           tmp : orig_prog);
        return prog;
 }
+
+bool bpf_jit_supports_kfunc_call(void)
+{
+       return true;
+}