bpf, x32: Fix bug for BPF_JMP | {BPF_JSGT, BPF_JSLE, BPF_JSLT, BPF_JSGE}
authorWang YanQing <udknight@gmail.com>
Sat, 27 Apr 2019 08:28:26 +0000 (16:28 +0800)
committerDaniel Borkmann <daniel@iogearbox.net>
Wed, 1 May 2019 21:32:16 +0000 (23:32 +0200)
The current method to compare 64-bit numbers for conditional jump is:

1) Compare the high 32-bit first.

2) If the high 32-bit isn't the same, then goto step 4.

3) Compare the low 32-bit.

4) Check the desired condition.

This method is right for unsigned comparison, but it is buggy for signed
comparison, because it does signed comparison for low 32-bit too.

There is only one sign bit in 64-bit number, that is the MSB in the 64-bit
number, it is wrong to treat low 32-bit as signed number and do the signed
comparison for it.

This patch fixes the bug and adds a testcase in selftests/bpf for such bug.

Signed-off-by: Wang YanQing <udknight@gmail.com>
Signed-off-by: Daniel Borkmann <daniel@iogearbox.net>
arch/x86/net/bpf_jit_comp32.c
tools/testing/selftests/bpf/verifier/jit.c

index 0d9cdff..8097b88 100644 (file)
@@ -117,6 +117,8 @@ static bool is_simm32(s64 value)
 #define IA32_JLE 0x7E
 #define IA32_JG  0x7F
 
+#define COND_JMP_OPCODE_INVALID        (0xFF)
+
 /*
  * Map eBPF registers to IA32 32bit registers or stack scratch space.
  *
@@ -1613,6 +1615,75 @@ static inline void emit_push_r64(const u8 src[], u8 **pprog)
        *pprog = prog;
 }
 
+static u8 get_cond_jmp_opcode(const u8 op, bool is_cmp_lo)
+{
+       u8 jmp_cond;
+
+       /* Convert BPF opcode to x86 */
+       switch (op) {
+       case BPF_JEQ:
+               jmp_cond = IA32_JE;
+               break;
+       case BPF_JSET:
+       case BPF_JNE:
+               jmp_cond = IA32_JNE;
+               break;
+       case BPF_JGT:
+               /* GT is unsigned '>', JA in x86 */
+               jmp_cond = IA32_JA;
+               break;
+       case BPF_JLT:
+               /* LT is unsigned '<', JB in x86 */
+               jmp_cond = IA32_JB;
+               break;
+       case BPF_JGE:
+               /* GE is unsigned '>=', JAE in x86 */
+               jmp_cond = IA32_JAE;
+               break;
+       case BPF_JLE:
+               /* LE is unsigned '<=', JBE in x86 */
+               jmp_cond = IA32_JBE;
+               break;
+       case BPF_JSGT:
+               if (!is_cmp_lo)
+                       /* Signed '>', GT in x86 */
+                       jmp_cond = IA32_JG;
+               else
+                       /* GT is unsigned '>', JA in x86 */
+                       jmp_cond = IA32_JA;
+               break;
+       case BPF_JSLT:
+               if (!is_cmp_lo)
+                       /* Signed '<', LT in x86 */
+                       jmp_cond = IA32_JL;
+               else
+                       /* LT is unsigned '<', JB in x86 */
+                       jmp_cond = IA32_JB;
+               break;
+       case BPF_JSGE:
+               if (!is_cmp_lo)
+                       /* Signed '>=', GE in x86 */
+                       jmp_cond = IA32_JGE;
+               else
+                       /* GE is unsigned '>=', JAE in x86 */
+                       jmp_cond = IA32_JAE;
+               break;
+       case BPF_JSLE:
+               if (!is_cmp_lo)
+                       /* Signed '<=', LE in x86 */
+                       jmp_cond = IA32_JLE;
+               else
+                       /* LE is unsigned '<=', JBE in x86 */
+                       jmp_cond = IA32_JBE;
+               break;
+       default: /* to silence GCC warning */
+               jmp_cond = COND_JMP_OPCODE_INVALID;
+               break;
+       }
+
+       return jmp_cond;
+}
+
 static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
                  int oldproglen, struct jit_context *ctx)
 {
@@ -2069,10 +2140,6 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
                case BPF_JMP | BPF_JLT | BPF_X:
                case BPF_JMP | BPF_JGE | BPF_X:
                case BPF_JMP | BPF_JLE | BPF_X:
-               case BPF_JMP | BPF_JSGT | BPF_X:
-               case BPF_JMP | BPF_JSLE | BPF_X:
-               case BPF_JMP | BPF_JSLT | BPF_X:
-               case BPF_JMP | BPF_JSGE | BPF_X:
                case BPF_JMP32 | BPF_JEQ | BPF_X:
                case BPF_JMP32 | BPF_JNE | BPF_X:
                case BPF_JMP32 | BPF_JGT | BPF_X:
@@ -2118,6 +2185,40 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
                        EMIT2(0x39, add_2reg(0xC0, dreg_lo, sreg_lo));
                        goto emit_cond_jmp;
                }
+               case BPF_JMP | BPF_JSGT | BPF_X:
+               case BPF_JMP | BPF_JSLE | BPF_X:
+               case BPF_JMP | BPF_JSLT | BPF_X:
+               case BPF_JMP | BPF_JSGE | BPF_X: {
+                       u8 dreg_lo = dstk ? IA32_EAX : dst_lo;
+                       u8 dreg_hi = dstk ? IA32_EDX : dst_hi;
+                       u8 sreg_lo = sstk ? IA32_ECX : src_lo;
+                       u8 sreg_hi = sstk ? IA32_EBX : src_hi;
+
+                       if (dstk) {
+                               EMIT3(0x8B, add_2reg(0x40, IA32_EBP, IA32_EAX),
+                                     STACK_VAR(dst_lo));
+                               EMIT3(0x8B,
+                                     add_2reg(0x40, IA32_EBP,
+                                              IA32_EDX),
+                                     STACK_VAR(dst_hi));
+                       }
+
+                       if (sstk) {
+                               EMIT3(0x8B, add_2reg(0x40, IA32_EBP, IA32_ECX),
+                                     STACK_VAR(src_lo));
+                               EMIT3(0x8B,
+                                     add_2reg(0x40, IA32_EBP,
+                                              IA32_EBX),
+                                     STACK_VAR(src_hi));
+                       }
+
+                       /* cmp dreg_hi,sreg_hi */
+                       EMIT2(0x39, add_2reg(0xC0, dreg_hi, sreg_hi));
+                       EMIT2(IA32_JNE, 10);
+                       /* cmp dreg_lo,sreg_lo */
+                       EMIT2(0x39, add_2reg(0xC0, dreg_lo, sreg_lo));
+                       goto emit_cond_jmp_signed;
+               }
                case BPF_JMP | BPF_JSET | BPF_X:
                case BPF_JMP32 | BPF_JSET | BPF_X: {
                        bool is_jmp64 = BPF_CLASS(insn->code) == BPF_JMP;
@@ -2194,10 +2295,6 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
                case BPF_JMP | BPF_JLT | BPF_K:
                case BPF_JMP | BPF_JGE | BPF_K:
                case BPF_JMP | BPF_JLE | BPF_K:
-               case BPF_JMP | BPF_JSGT | BPF_K:
-               case BPF_JMP | BPF_JSLE | BPF_K:
-               case BPF_JMP | BPF_JSLT | BPF_K:
-               case BPF_JMP | BPF_JSGE | BPF_K:
                case BPF_JMP32 | BPF_JEQ | BPF_K:
                case BPF_JMP32 | BPF_JNE | BPF_K:
                case BPF_JMP32 | BPF_JGT | BPF_K:
@@ -2238,50 +2335,9 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
                        /* cmp dreg_lo,sreg_lo */
                        EMIT2(0x39, add_2reg(0xC0, dreg_lo, sreg_lo));
 
-emit_cond_jmp:         /* Convert BPF opcode to x86 */
-                       switch (BPF_OP(code)) {
-                       case BPF_JEQ:
-                               jmp_cond = IA32_JE;
-                               break;
-                       case BPF_JSET:
-                       case BPF_JNE:
-                               jmp_cond = IA32_JNE;
-                               break;
-                       case BPF_JGT:
-                               /* GT is unsigned '>', JA in x86 */
-                               jmp_cond = IA32_JA;
-                               break;
-                       case BPF_JLT:
-                               /* LT is unsigned '<', JB in x86 */
-                               jmp_cond = IA32_JB;
-                               break;
-                       case BPF_JGE:
-                               /* GE is unsigned '>=', JAE in x86 */
-                               jmp_cond = IA32_JAE;
-                               break;
-                       case BPF_JLE:
-                               /* LE is unsigned '<=', JBE in x86 */
-                               jmp_cond = IA32_JBE;
-                               break;
-                       case BPF_JSGT:
-                               /* Signed '>', GT in x86 */
-                               jmp_cond = IA32_JG;
-                               break;
-                       case BPF_JSLT:
-                               /* Signed '<', LT in x86 */
-                               jmp_cond = IA32_JL;
-                               break;
-                       case BPF_JSGE:
-                               /* Signed '>=', GE in x86 */
-                               jmp_cond = IA32_JGE;
-                               break;
-                       case BPF_JSLE:
-                               /* Signed '<=', LE in x86 */
-                               jmp_cond = IA32_JLE;
-                               break;
-                       default: /* to silence GCC warning */
+emit_cond_jmp:         jmp_cond = get_cond_jmp_opcode(BPF_OP(code), false);
+                       if (jmp_cond == COND_JMP_OPCODE_INVALID)
                                return -EFAULT;
-                       }
                        jmp_offset = addrs[i + insn->off] - addrs[i];
                        if (is_imm8(jmp_offset)) {
                                EMIT2(jmp_cond, jmp_offset);
@@ -2291,7 +2347,66 @@ emit_cond_jmp:           /* Convert BPF opcode to x86 */
                                pr_err("cond_jmp gen bug %llx\n", jmp_offset);
                                return -EFAULT;
                        }
+                       break;
+               }
+               case BPF_JMP | BPF_JSGT | BPF_K:
+               case BPF_JMP | BPF_JSLE | BPF_K:
+               case BPF_JMP | BPF_JSLT | BPF_K:
+               case BPF_JMP | BPF_JSGE | BPF_K: {
+                       u8 dreg_lo = dstk ? IA32_EAX : dst_lo;
+                       u8 dreg_hi = dstk ? IA32_EDX : dst_hi;
+                       u8 sreg_lo = IA32_ECX;
+                       u8 sreg_hi = IA32_EBX;
+                       u32 hi;
+
+                       if (dstk) {
+                               EMIT3(0x8B, add_2reg(0x40, IA32_EBP, IA32_EAX),
+                                     STACK_VAR(dst_lo));
+                               EMIT3(0x8B,
+                                     add_2reg(0x40, IA32_EBP,
+                                              IA32_EDX),
+                                     STACK_VAR(dst_hi));
+                       }
+
+                       /* mov ecx,imm32 */
+                       EMIT2_off32(0xC7, add_1reg(0xC0, IA32_ECX), imm32);
+                       hi = imm32 & (1 << 31) ? (u32)~0 : 0;
+                       /* mov ebx,imm32 */
+                       EMIT2_off32(0xC7, add_1reg(0xC0, IA32_EBX), hi);
+                       /* cmp dreg_hi,sreg_hi */
+                       EMIT2(0x39, add_2reg(0xC0, dreg_hi, sreg_hi));
+                       EMIT2(IA32_JNE, 10);
+                       /* cmp dreg_lo,sreg_lo */
+                       EMIT2(0x39, add_2reg(0xC0, dreg_lo, sreg_lo));
+
+                       /*
+                        * For simplicity of branch offset computation,
+                        * let's use fixed jump coding here.
+                        */
+emit_cond_jmp_signed:  /* Check the condition for low 32-bit comparison */
+                       jmp_cond = get_cond_jmp_opcode(BPF_OP(code), true);
+                       if (jmp_cond == COND_JMP_OPCODE_INVALID)
+                               return -EFAULT;
+                       jmp_offset = addrs[i + insn->off] - addrs[i] + 8;
+                       if (is_simm32(jmp_offset)) {
+                               EMIT2_off32(0x0F, jmp_cond + 0x10, jmp_offset);
+                       } else {
+                               pr_err("cond_jmp gen bug %llx\n", jmp_offset);
+                               return -EFAULT;
+                       }
+                       EMIT2(0xEB, 6);
 
+                       /* Check the condition for high 32-bit comparison */
+                       jmp_cond = get_cond_jmp_opcode(BPF_OP(code), false);
+                       if (jmp_cond == COND_JMP_OPCODE_INVALID)
+                               return -EFAULT;
+                       jmp_offset = addrs[i + insn->off] - addrs[i];
+                       if (is_simm32(jmp_offset)) {
+                               EMIT2_off32(0x0F, jmp_cond + 0x10, jmp_offset);
+                       } else {
+                               pr_err("cond_jmp gen bug %llx\n", jmp_offset);
+                               return -EFAULT;
+                       }
                        break;
                }
                case BPF_JMP | BPF_JA:
index be488b4..c33adf3 100644 (file)
        .result = ACCEPT,
        .retval = 2,
 },
+{
+       "jit: jsgt, jslt",
+       .insns = {
+       BPF_LD_IMM64(BPF_REG_1, 0x80000000ULL),
+       BPF_LD_IMM64(BPF_REG_2, 0x0ULL),
+       BPF_JMP_REG(BPF_JSGT, BPF_REG_1, BPF_REG_2, 2),
+       BPF_MOV64_IMM(BPF_REG_0, 1),
+       BPF_EXIT_INSN(),
+
+       BPF_JMP_REG(BPF_JSLT, BPF_REG_2, BPF_REG_1, 2),
+       BPF_MOV64_IMM(BPF_REG_0, 1),
+       BPF_EXIT_INSN(),
+
+       BPF_MOV64_IMM(BPF_REG_0, 2),
+       BPF_EXIT_INSN(),
+       },
+       .result = ACCEPT,
+       .retval = 2,
+},