Merge tag 'for-netdev' of https://git.kernel.org/pub/scm/linux/kernel/git/bpf/bpf...
[platform/kernel/linux-rpi.git] / arch / x86 / net / bpf_jit_comp.c
index b808be7..8db6077 100644 (file)
@@ -1003,6 +1003,7 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, u8 *rw_image
                u8 b2 = 0, b3 = 0;
                u8 *start_of_ldx;
                s64 jmp_offset;
+               s16 insn_off;
                u8 jmp_cond;
                u8 *func;
                int nops;
@@ -1369,57 +1370,52 @@ st:                     if (is_imm8(insn->off))
                case BPF_LDX | BPF_PROBE_MEM | BPF_W:
                case BPF_LDX | BPF_MEM | BPF_DW:
                case BPF_LDX | BPF_PROBE_MEM | BPF_DW:
+                       insn_off = insn->off;
+
                        if (BPF_MODE(insn->code) == BPF_PROBE_MEM) {
-                               /* Though the verifier prevents negative insn->off in BPF_PROBE_MEM
-                                * add abs(insn->off) to the limit to make sure that negative
-                                * offset won't be an issue.
-                                * insn->off is s16, so it won't affect valid pointers.
+                               /* Conservatively check that src_reg + insn->off is a kernel address:
+                                *   src_reg + insn->off >= TASK_SIZE_MAX + PAGE_SIZE
+                                * src_reg is used as scratch for src_reg += insn->off and restored
+                                * after emit_ldx if necessary
                                 */
-                               u64 limit = TASK_SIZE_MAX + PAGE_SIZE + abs(insn->off);
-                               u8 *end_of_jmp1, *end_of_jmp2;
 
-                               /* Conservatively check that src_reg + insn->off is a kernel address:
-                                * 1. src_reg + insn->off >= limit
-                                * 2. src_reg + insn->off doesn't become small positive.
-                                * Cannot do src_reg + insn->off >= limit in one branch,
-                                * since it needs two spare registers, but JIT has only one.
+                               u64 limit = TASK_SIZE_MAX + PAGE_SIZE;
+                               u8 *end_of_jmp;
+
+                               /* At end of these emitted checks, insn->off will have been added
+                                * to src_reg, so no need to do relative load with insn->off offset
                                 */
+                               insn_off = 0;
 
                                /* movabsq r11, limit */
                                EMIT2(add_1mod(0x48, AUX_REG), add_1reg(0xB8, AUX_REG));
                                EMIT((u32)limit, 4);
                                EMIT(limit >> 32, 4);
+
+                               if (insn->off) {
+                                       /* add src_reg, insn->off */
+                                       maybe_emit_1mod(&prog, src_reg, true);
+                                       EMIT2_off32(0x81, add_1reg(0xC0, src_reg), insn->off);
+                               }
+
                                /* cmp src_reg, r11 */
                                maybe_emit_mod(&prog, src_reg, AUX_REG, true);
                                EMIT2(0x39, add_2reg(0xC0, src_reg, AUX_REG));
-                               /* if unsigned '<' goto end_of_jmp2 */
-                               EMIT2(X86_JB, 0);
-                               end_of_jmp1 = prog;
-
-                               /* mov r11, src_reg */
-                               emit_mov_reg(&prog, true, AUX_REG, src_reg);
-                               /* add r11, insn->off */
-                               maybe_emit_1mod(&prog, AUX_REG, true);
-                               EMIT2_off32(0x81, add_1reg(0xC0, AUX_REG), insn->off);
-                               /* jmp if not carry to start_of_ldx
-                                * Otherwise ERR_PTR(-EINVAL) + 128 will be the user addr
-                                * that has to be rejected.
-                                */
-                               EMIT2(0x73 /* JNC */, 0);
-                               end_of_jmp2 = prog;
+
+                               /* if unsigned '>=', goto load */
+                               EMIT2(X86_JAE, 0);
+                               end_of_jmp = prog;
 
                                /* xor dst_reg, dst_reg */
                                emit_mov_imm32(&prog, false, dst_reg, 0);
                                /* jmp byte_after_ldx */
                                EMIT2(0xEB, 0);
 
-                               /* populate jmp_offset for JB above to jump to xor dst_reg */
-                               end_of_jmp1[-1] = end_of_jmp2 - end_of_jmp1;
-                               /* populate jmp_offset for JNC above to jump to start_of_ldx */
+                               /* populate jmp_offset for JAE above to jump to start_of_ldx */
                                start_of_ldx = prog;
-                               end_of_jmp2[-1] = start_of_ldx - end_of_jmp2;
+                               end_of_jmp[-1] = start_of_ldx - end_of_jmp;
                        }
-                       emit_ldx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn->off);
+                       emit_ldx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn_off);
                        if (BPF_MODE(insn->code) == BPF_PROBE_MEM) {
                                struct exception_table_entry *ex;
                                u8 *_insn = image + proglen + (start_of_ldx - temp);
@@ -1428,6 +1424,18 @@ st:                      if (is_imm8(insn->off))
                                /* populate jmp_offset for JMP above */
                                start_of_ldx[-1] = prog - start_of_ldx;
 
+                               if (insn->off && src_reg != dst_reg) {
+                                       /* sub src_reg, insn->off
+                                        * Restore src_reg after "add src_reg, insn->off" in prev
+                                        * if statement. But if src_reg == dst_reg, emit_ldx
+                                        * above already clobbered src_reg, so no need to restore.
+                                        * If add src_reg, insn->off was unnecessary, no need to
+                                        * restore either.
+                                        */
+                                       maybe_emit_1mod(&prog, src_reg, true);
+                                       EMIT2_off32(0x81, add_1reg(0xE8, src_reg), insn->off);
+                               }
+
                                if (!bpf_prog->aux->extable)
                                        break;