LoongArch: BPF: Don't sign extend memory load operand
[platform/kernel/linux-starfive.git] / arch / loongarch / net / bpf_jit.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * BPF JIT compiler for LoongArch
4  *
5  * Copyright (C) 2022 Loongson Technology Corporation Limited
6  */
7 #include "bpf_jit.h"
8
9 #define REG_TCC         LOONGARCH_GPR_A6
10 #define TCC_SAVED       LOONGARCH_GPR_S5
11
12 #define SAVE_RA         BIT(0)
13 #define SAVE_TCC        BIT(1)
14
15 static const int regmap[] = {
16         /* return value from in-kernel function, and exit value for eBPF program */
17         [BPF_REG_0] = LOONGARCH_GPR_A5,
18         /* arguments from eBPF program to in-kernel function */
19         [BPF_REG_1] = LOONGARCH_GPR_A0,
20         [BPF_REG_2] = LOONGARCH_GPR_A1,
21         [BPF_REG_3] = LOONGARCH_GPR_A2,
22         [BPF_REG_4] = LOONGARCH_GPR_A3,
23         [BPF_REG_5] = LOONGARCH_GPR_A4,
24         /* callee saved registers that in-kernel function will preserve */
25         [BPF_REG_6] = LOONGARCH_GPR_S0,
26         [BPF_REG_7] = LOONGARCH_GPR_S1,
27         [BPF_REG_8] = LOONGARCH_GPR_S2,
28         [BPF_REG_9] = LOONGARCH_GPR_S3,
29         /* read-only frame pointer to access stack */
30         [BPF_REG_FP] = LOONGARCH_GPR_S4,
31         /* temporary register for blinding constants */
32         [BPF_REG_AX] = LOONGARCH_GPR_T0,
33 };
34
35 static void mark_call(struct jit_ctx *ctx)
36 {
37         ctx->flags |= SAVE_RA;
38 }
39
40 static void mark_tail_call(struct jit_ctx *ctx)
41 {
42         ctx->flags |= SAVE_TCC;
43 }
44
45 static bool seen_call(struct jit_ctx *ctx)
46 {
47         return (ctx->flags & SAVE_RA);
48 }
49
50 static bool seen_tail_call(struct jit_ctx *ctx)
51 {
52         return (ctx->flags & SAVE_TCC);
53 }
54
55 static u8 tail_call_reg(struct jit_ctx *ctx)
56 {
57         if (seen_call(ctx))
58                 return TCC_SAVED;
59
60         return REG_TCC;
61 }
62
63 /*
64  * eBPF prog stack layout:
65  *
66  *                                        high
67  * original $sp ------------> +-------------------------+ <--LOONGARCH_GPR_FP
68  *                            |           $ra           |
69  *                            +-------------------------+
70  *                            |           $fp           |
71  *                            +-------------------------+
72  *                            |           $s0           |
73  *                            +-------------------------+
74  *                            |           $s1           |
75  *                            +-------------------------+
76  *                            |           $s2           |
77  *                            +-------------------------+
78  *                            |           $s3           |
79  *                            +-------------------------+
80  *                            |           $s4           |
81  *                            +-------------------------+
82  *                            |           $s5           |
83  *                            +-------------------------+ <--BPF_REG_FP
84  *                            |  prog->aux->stack_depth |
85  *                            |        (optional)       |
86  * current $sp -------------> +-------------------------+
87  *                                        low
88  */
89 static void build_prologue(struct jit_ctx *ctx)
90 {
91         int stack_adjust = 0, store_offset, bpf_stack_adjust;
92
93         bpf_stack_adjust = round_up(ctx->prog->aux->stack_depth, 16);
94
95         /* To store ra, fp, s0, s1, s2, s3, s4 and s5. */
96         stack_adjust += sizeof(long) * 8;
97
98         stack_adjust = round_up(stack_adjust, 16);
99         stack_adjust += bpf_stack_adjust;
100
101         /*
102          * First instruction initializes the tail call count (TCC).
103          * On tail call we skip this instruction, and the TCC is
104          * passed in REG_TCC from the caller.
105          */
106         emit_insn(ctx, addid, REG_TCC, LOONGARCH_GPR_ZERO, MAX_TAIL_CALL_CNT);
107
108         emit_insn(ctx, addid, LOONGARCH_GPR_SP, LOONGARCH_GPR_SP, -stack_adjust);
109
110         store_offset = stack_adjust - sizeof(long);
111         emit_insn(ctx, std, LOONGARCH_GPR_RA, LOONGARCH_GPR_SP, store_offset);
112
113         store_offset -= sizeof(long);
114         emit_insn(ctx, std, LOONGARCH_GPR_FP, LOONGARCH_GPR_SP, store_offset);
115
116         store_offset -= sizeof(long);
117         emit_insn(ctx, std, LOONGARCH_GPR_S0, LOONGARCH_GPR_SP, store_offset);
118
119         store_offset -= sizeof(long);
120         emit_insn(ctx, std, LOONGARCH_GPR_S1, LOONGARCH_GPR_SP, store_offset);
121
122         store_offset -= sizeof(long);
123         emit_insn(ctx, std, LOONGARCH_GPR_S2, LOONGARCH_GPR_SP, store_offset);
124
125         store_offset -= sizeof(long);
126         emit_insn(ctx, std, LOONGARCH_GPR_S3, LOONGARCH_GPR_SP, store_offset);
127
128         store_offset -= sizeof(long);
129         emit_insn(ctx, std, LOONGARCH_GPR_S4, LOONGARCH_GPR_SP, store_offset);
130
131         store_offset -= sizeof(long);
132         emit_insn(ctx, std, LOONGARCH_GPR_S5, LOONGARCH_GPR_SP, store_offset);
133
134         emit_insn(ctx, addid, LOONGARCH_GPR_FP, LOONGARCH_GPR_SP, stack_adjust);
135
136         if (bpf_stack_adjust)
137                 emit_insn(ctx, addid, regmap[BPF_REG_FP], LOONGARCH_GPR_SP, bpf_stack_adjust);
138
139         /*
140          * Program contains calls and tail calls, so REG_TCC need
141          * to be saved across calls.
142          */
143         if (seen_tail_call(ctx) && seen_call(ctx))
144                 move_reg(ctx, TCC_SAVED, REG_TCC);
145
146         ctx->stack_size = stack_adjust;
147 }
148
149 static void __build_epilogue(struct jit_ctx *ctx, bool is_tail_call)
150 {
151         int stack_adjust = ctx->stack_size;
152         int load_offset;
153
154         load_offset = stack_adjust - sizeof(long);
155         emit_insn(ctx, ldd, LOONGARCH_GPR_RA, LOONGARCH_GPR_SP, load_offset);
156
157         load_offset -= sizeof(long);
158         emit_insn(ctx, ldd, LOONGARCH_GPR_FP, LOONGARCH_GPR_SP, load_offset);
159
160         load_offset -= sizeof(long);
161         emit_insn(ctx, ldd, LOONGARCH_GPR_S0, LOONGARCH_GPR_SP, load_offset);
162
163         load_offset -= sizeof(long);
164         emit_insn(ctx, ldd, LOONGARCH_GPR_S1, LOONGARCH_GPR_SP, load_offset);
165
166         load_offset -= sizeof(long);
167         emit_insn(ctx, ldd, LOONGARCH_GPR_S2, LOONGARCH_GPR_SP, load_offset);
168
169         load_offset -= sizeof(long);
170         emit_insn(ctx, ldd, LOONGARCH_GPR_S3, LOONGARCH_GPR_SP, load_offset);
171
172         load_offset -= sizeof(long);
173         emit_insn(ctx, ldd, LOONGARCH_GPR_S4, LOONGARCH_GPR_SP, load_offset);
174
175         load_offset -= sizeof(long);
176         emit_insn(ctx, ldd, LOONGARCH_GPR_S5, LOONGARCH_GPR_SP, load_offset);
177
178         emit_insn(ctx, addid, LOONGARCH_GPR_SP, LOONGARCH_GPR_SP, stack_adjust);
179
180         if (!is_tail_call) {
181                 /* Set return value */
182                 move_reg(ctx, LOONGARCH_GPR_A0, regmap[BPF_REG_0]);
183                 /* Return to the caller */
184                 emit_insn(ctx, jirl, LOONGARCH_GPR_RA, LOONGARCH_GPR_ZERO, 0);
185         } else {
186                 /*
187                  * Call the next bpf prog and skip the first instruction
188                  * of TCC initialization.
189                  */
190                 emit_insn(ctx, jirl, LOONGARCH_GPR_T3, LOONGARCH_GPR_ZERO, 1);
191         }
192 }
193
194 static void build_epilogue(struct jit_ctx *ctx)
195 {
196         __build_epilogue(ctx, false);
197 }
198
199 bool bpf_jit_supports_kfunc_call(void)
200 {
201         return true;
202 }
203
204 /* initialized on the first pass of build_body() */
205 static int out_offset = -1;
206 static int emit_bpf_tail_call(struct jit_ctx *ctx)
207 {
208         int off;
209         u8 tcc = tail_call_reg(ctx);
210         u8 a1 = LOONGARCH_GPR_A1;
211         u8 a2 = LOONGARCH_GPR_A2;
212         u8 t1 = LOONGARCH_GPR_T1;
213         u8 t2 = LOONGARCH_GPR_T2;
214         u8 t3 = LOONGARCH_GPR_T3;
215         const int idx0 = ctx->idx;
216
217 #define cur_offset (ctx->idx - idx0)
218 #define jmp_offset (out_offset - (cur_offset))
219
220         /*
221          * a0: &ctx
222          * a1: &array
223          * a2: index
224          *
225          * if (index >= array->map.max_entries)
226          *       goto out;
227          */
228         off = offsetof(struct bpf_array, map.max_entries);
229         emit_insn(ctx, ldwu, t1, a1, off);
230         /* bgeu $a2, $t1, jmp_offset */
231         if (emit_tailcall_jmp(ctx, BPF_JGE, a2, t1, jmp_offset) < 0)
232                 goto toofar;
233
234         /*
235          * if (--TCC < 0)
236          *       goto out;
237          */
238         emit_insn(ctx, addid, REG_TCC, tcc, -1);
239         if (emit_tailcall_jmp(ctx, BPF_JSLT, REG_TCC, LOONGARCH_GPR_ZERO, jmp_offset) < 0)
240                 goto toofar;
241
242         /*
243          * prog = array->ptrs[index];
244          * if (!prog)
245          *       goto out;
246          */
247         emit_insn(ctx, alsld, t2, a2, a1, 2);
248         off = offsetof(struct bpf_array, ptrs);
249         emit_insn(ctx, ldd, t2, t2, off);
250         /* beq $t2, $zero, jmp_offset */
251         if (emit_tailcall_jmp(ctx, BPF_JEQ, t2, LOONGARCH_GPR_ZERO, jmp_offset) < 0)
252                 goto toofar;
253
254         /* goto *(prog->bpf_func + 4); */
255         off = offsetof(struct bpf_prog, bpf_func);
256         emit_insn(ctx, ldd, t3, t2, off);
257         __build_epilogue(ctx, true);
258
259         /* out: */
260         if (out_offset == -1)
261                 out_offset = cur_offset;
262         if (cur_offset != out_offset) {
263                 pr_err_once("tail_call out_offset = %d, expected %d!\n",
264                             cur_offset, out_offset);
265                 return -1;
266         }
267
268         return 0;
269
270 toofar:
271         pr_info_once("tail_call: jump too far\n");
272         return -1;
273 #undef cur_offset
274 #undef jmp_offset
275 }
276
277 static void emit_atomic(const struct bpf_insn *insn, struct jit_ctx *ctx)
278 {
279         const u8 t1 = LOONGARCH_GPR_T1;
280         const u8 t2 = LOONGARCH_GPR_T2;
281         const u8 t3 = LOONGARCH_GPR_T3;
282         const u8 r0 = regmap[BPF_REG_0];
283         const u8 src = regmap[insn->src_reg];
284         const u8 dst = regmap[insn->dst_reg];
285         const s16 off = insn->off;
286         const s32 imm = insn->imm;
287         const bool isdw = BPF_SIZE(insn->code) == BPF_DW;
288
289         move_imm(ctx, t1, off, false);
290         emit_insn(ctx, addd, t1, dst, t1);
291         move_reg(ctx, t3, src);
292
293         switch (imm) {
294         /* lock *(size *)(dst + off) <op>= src */
295         case BPF_ADD:
296                 if (isdw)
297                         emit_insn(ctx, amaddd, t2, t1, src);
298                 else
299                         emit_insn(ctx, amaddw, t2, t1, src);
300                 break;
301         case BPF_AND:
302                 if (isdw)
303                         emit_insn(ctx, amandd, t2, t1, src);
304                 else
305                         emit_insn(ctx, amandw, t2, t1, src);
306                 break;
307         case BPF_OR:
308                 if (isdw)
309                         emit_insn(ctx, amord, t2, t1, src);
310                 else
311                         emit_insn(ctx, amorw, t2, t1, src);
312                 break;
313         case BPF_XOR:
314                 if (isdw)
315                         emit_insn(ctx, amxord, t2, t1, src);
316                 else
317                         emit_insn(ctx, amxorw, t2, t1, src);
318                 break;
319         /* src = atomic_fetch_<op>(dst + off, src) */
320         case BPF_ADD | BPF_FETCH:
321                 if (isdw) {
322                         emit_insn(ctx, amaddd, src, t1, t3);
323                 } else {
324                         emit_insn(ctx, amaddw, src, t1, t3);
325                         emit_zext_32(ctx, src, true);
326                 }
327                 break;
328         case BPF_AND | BPF_FETCH:
329                 if (isdw) {
330                         emit_insn(ctx, amandd, src, t1, t3);
331                 } else {
332                         emit_insn(ctx, amandw, src, t1, t3);
333                         emit_zext_32(ctx, src, true);
334                 }
335                 break;
336         case BPF_OR | BPF_FETCH:
337                 if (isdw) {
338                         emit_insn(ctx, amord, src, t1, t3);
339                 } else {
340                         emit_insn(ctx, amorw, src, t1, t3);
341                         emit_zext_32(ctx, src, true);
342                 }
343                 break;
344         case BPF_XOR | BPF_FETCH:
345                 if (isdw) {
346                         emit_insn(ctx, amxord, src, t1, t3);
347                 } else {
348                         emit_insn(ctx, amxorw, src, t1, t3);
349                         emit_zext_32(ctx, src, true);
350                 }
351                 break;
352         /* src = atomic_xchg(dst + off, src); */
353         case BPF_XCHG:
354                 if (isdw) {
355                         emit_insn(ctx, amswapd, src, t1, t3);
356                 } else {
357                         emit_insn(ctx, amswapw, src, t1, t3);
358                         emit_zext_32(ctx, src, true);
359                 }
360                 break;
361         /* r0 = atomic_cmpxchg(dst + off, r0, src); */
362         case BPF_CMPXCHG:
363                 move_reg(ctx, t2, r0);
364                 if (isdw) {
365                         emit_insn(ctx, lld, r0, t1, 0);
366                         emit_insn(ctx, bne, t2, r0, 4);
367                         move_reg(ctx, t3, src);
368                         emit_insn(ctx, scd, t3, t1, 0);
369                         emit_insn(ctx, beq, t3, LOONGARCH_GPR_ZERO, -4);
370                 } else {
371                         emit_insn(ctx, llw, r0, t1, 0);
372                         emit_zext_32(ctx, t2, true);
373                         emit_zext_32(ctx, r0, true);
374                         emit_insn(ctx, bne, t2, r0, 4);
375                         move_reg(ctx, t3, src);
376                         emit_insn(ctx, scw, t3, t1, 0);
377                         emit_insn(ctx, beq, t3, LOONGARCH_GPR_ZERO, -6);
378                         emit_zext_32(ctx, r0, true);
379                 }
380                 break;
381         }
382 }
383
384 static bool is_signed_bpf_cond(u8 cond)
385 {
386         return cond == BPF_JSGT || cond == BPF_JSLT ||
387                cond == BPF_JSGE || cond == BPF_JSLE;
388 }
389
390 static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx, bool extra_pass)
391 {
392         u8 tm = -1;
393         u64 func_addr;
394         bool func_addr_fixed;
395         int i = insn - ctx->prog->insnsi;
396         int ret, jmp_offset;
397         const u8 code = insn->code;
398         const u8 cond = BPF_OP(code);
399         const u8 t1 = LOONGARCH_GPR_T1;
400         const u8 t2 = LOONGARCH_GPR_T2;
401         const u8 src = regmap[insn->src_reg];
402         const u8 dst = regmap[insn->dst_reg];
403         const s16 off = insn->off;
404         const s32 imm = insn->imm;
405         const u64 imm64 = (u64)(insn + 1)->imm << 32 | (u32)insn->imm;
406         const bool is32 = BPF_CLASS(insn->code) == BPF_ALU || BPF_CLASS(insn->code) == BPF_JMP32;
407
408         switch (code) {
409         /* dst = src */
410         case BPF_ALU | BPF_MOV | BPF_X:
411         case BPF_ALU64 | BPF_MOV | BPF_X:
412                 move_reg(ctx, dst, src);
413                 emit_zext_32(ctx, dst, is32);
414                 break;
415
416         /* dst = imm */
417         case BPF_ALU | BPF_MOV | BPF_K:
418         case BPF_ALU64 | BPF_MOV | BPF_K:
419                 move_imm(ctx, dst, imm, is32);
420                 break;
421
422         /* dst = dst + src */
423         case BPF_ALU | BPF_ADD | BPF_X:
424         case BPF_ALU64 | BPF_ADD | BPF_X:
425                 emit_insn(ctx, addd, dst, dst, src);
426                 emit_zext_32(ctx, dst, is32);
427                 break;
428
429         /* dst = dst + imm */
430         case BPF_ALU | BPF_ADD | BPF_K:
431         case BPF_ALU64 | BPF_ADD | BPF_K:
432                 if (is_signed_imm12(imm)) {
433                         emit_insn(ctx, addid, dst, dst, imm);
434                 } else {
435                         move_imm(ctx, t1, imm, is32);
436                         emit_insn(ctx, addd, dst, dst, t1);
437                 }
438                 emit_zext_32(ctx, dst, is32);
439                 break;
440
441         /* dst = dst - src */
442         case BPF_ALU | BPF_SUB | BPF_X:
443         case BPF_ALU64 | BPF_SUB | BPF_X:
444                 emit_insn(ctx, subd, dst, dst, src);
445                 emit_zext_32(ctx, dst, is32);
446                 break;
447
448         /* dst = dst - imm */
449         case BPF_ALU | BPF_SUB | BPF_K:
450         case BPF_ALU64 | BPF_SUB | BPF_K:
451                 if (is_signed_imm12(-imm)) {
452                         emit_insn(ctx, addid, dst, dst, -imm);
453                 } else {
454                         move_imm(ctx, t1, imm, is32);
455                         emit_insn(ctx, subd, dst, dst, t1);
456                 }
457                 emit_zext_32(ctx, dst, is32);
458                 break;
459
460         /* dst = dst * src */
461         case BPF_ALU | BPF_MUL | BPF_X:
462         case BPF_ALU64 | BPF_MUL | BPF_X:
463                 emit_insn(ctx, muld, dst, dst, src);
464                 emit_zext_32(ctx, dst, is32);
465                 break;
466
467         /* dst = dst * imm */
468         case BPF_ALU | BPF_MUL | BPF_K:
469         case BPF_ALU64 | BPF_MUL | BPF_K:
470                 move_imm(ctx, t1, imm, is32);
471                 emit_insn(ctx, muld, dst, dst, t1);
472                 emit_zext_32(ctx, dst, is32);
473                 break;
474
475         /* dst = dst / src */
476         case BPF_ALU | BPF_DIV | BPF_X:
477         case BPF_ALU64 | BPF_DIV | BPF_X:
478                 emit_zext_32(ctx, dst, is32);
479                 move_reg(ctx, t1, src);
480                 emit_zext_32(ctx, t1, is32);
481                 emit_insn(ctx, divdu, dst, dst, t1);
482                 emit_zext_32(ctx, dst, is32);
483                 break;
484
485         /* dst = dst / imm */
486         case BPF_ALU | BPF_DIV | BPF_K:
487         case BPF_ALU64 | BPF_DIV | BPF_K:
488                 move_imm(ctx, t1, imm, is32);
489                 emit_zext_32(ctx, dst, is32);
490                 emit_insn(ctx, divdu, dst, dst, t1);
491                 emit_zext_32(ctx, dst, is32);
492                 break;
493
494         /* dst = dst % src */
495         case BPF_ALU | BPF_MOD | BPF_X:
496         case BPF_ALU64 | BPF_MOD | BPF_X:
497                 emit_zext_32(ctx, dst, is32);
498                 move_reg(ctx, t1, src);
499                 emit_zext_32(ctx, t1, is32);
500                 emit_insn(ctx, moddu, dst, dst, t1);
501                 emit_zext_32(ctx, dst, is32);
502                 break;
503
504         /* dst = dst % imm */
505         case BPF_ALU | BPF_MOD | BPF_K:
506         case BPF_ALU64 | BPF_MOD | BPF_K:
507                 move_imm(ctx, t1, imm, is32);
508                 emit_zext_32(ctx, dst, is32);
509                 emit_insn(ctx, moddu, dst, dst, t1);
510                 emit_zext_32(ctx, dst, is32);
511                 break;
512
513         /* dst = -dst */
514         case BPF_ALU | BPF_NEG:
515         case BPF_ALU64 | BPF_NEG:
516                 move_imm(ctx, t1, imm, is32);
517                 emit_insn(ctx, subd, dst, LOONGARCH_GPR_ZERO, dst);
518                 emit_zext_32(ctx, dst, is32);
519                 break;
520
521         /* dst = dst & src */
522         case BPF_ALU | BPF_AND | BPF_X:
523         case BPF_ALU64 | BPF_AND | BPF_X:
524                 emit_insn(ctx, and, dst, dst, src);
525                 emit_zext_32(ctx, dst, is32);
526                 break;
527
528         /* dst = dst & imm */
529         case BPF_ALU | BPF_AND | BPF_K:
530         case BPF_ALU64 | BPF_AND | BPF_K:
531                 if (is_unsigned_imm12(imm)) {
532                         emit_insn(ctx, andi, dst, dst, imm);
533                 } else {
534                         move_imm(ctx, t1, imm, is32);
535                         emit_insn(ctx, and, dst, dst, t1);
536                 }
537                 emit_zext_32(ctx, dst, is32);
538                 break;
539
540         /* dst = dst | src */
541         case BPF_ALU | BPF_OR | BPF_X:
542         case BPF_ALU64 | BPF_OR | BPF_X:
543                 emit_insn(ctx, or, dst, dst, src);
544                 emit_zext_32(ctx, dst, is32);
545                 break;
546
547         /* dst = dst | imm */
548         case BPF_ALU | BPF_OR | BPF_K:
549         case BPF_ALU64 | BPF_OR | BPF_K:
550                 if (is_unsigned_imm12(imm)) {
551                         emit_insn(ctx, ori, dst, dst, imm);
552                 } else {
553                         move_imm(ctx, t1, imm, is32);
554                         emit_insn(ctx, or, dst, dst, t1);
555                 }
556                 emit_zext_32(ctx, dst, is32);
557                 break;
558
559         /* dst = dst ^ src */
560         case BPF_ALU | BPF_XOR | BPF_X:
561         case BPF_ALU64 | BPF_XOR | BPF_X:
562                 emit_insn(ctx, xor, dst, dst, src);
563                 emit_zext_32(ctx, dst, is32);
564                 break;
565
566         /* dst = dst ^ imm */
567         case BPF_ALU | BPF_XOR | BPF_K:
568         case BPF_ALU64 | BPF_XOR | BPF_K:
569                 if (is_unsigned_imm12(imm)) {
570                         emit_insn(ctx, xori, dst, dst, imm);
571                 } else {
572                         move_imm(ctx, t1, imm, is32);
573                         emit_insn(ctx, xor, dst, dst, t1);
574                 }
575                 emit_zext_32(ctx, dst, is32);
576                 break;
577
578         /* dst = dst << src (logical) */
579         case BPF_ALU | BPF_LSH | BPF_X:
580                 emit_insn(ctx, sllw, dst, dst, src);
581                 emit_zext_32(ctx, dst, is32);
582                 break;
583
584         case BPF_ALU64 | BPF_LSH | BPF_X:
585                 emit_insn(ctx, slld, dst, dst, src);
586                 break;
587
588         /* dst = dst << imm (logical) */
589         case BPF_ALU | BPF_LSH | BPF_K:
590                 emit_insn(ctx, slliw, dst, dst, imm);
591                 emit_zext_32(ctx, dst, is32);
592                 break;
593
594         case BPF_ALU64 | BPF_LSH | BPF_K:
595                 emit_insn(ctx, sllid, dst, dst, imm);
596                 break;
597
598         /* dst = dst >> src (logical) */
599         case BPF_ALU | BPF_RSH | BPF_X:
600                 emit_insn(ctx, srlw, dst, dst, src);
601                 emit_zext_32(ctx, dst, is32);
602                 break;
603
604         case BPF_ALU64 | BPF_RSH | BPF_X:
605                 emit_insn(ctx, srld, dst, dst, src);
606                 break;
607
608         /* dst = dst >> imm (logical) */
609         case BPF_ALU | BPF_RSH | BPF_K:
610                 emit_insn(ctx, srliw, dst, dst, imm);
611                 emit_zext_32(ctx, dst, is32);
612                 break;
613
614         case BPF_ALU64 | BPF_RSH | BPF_K:
615                 emit_insn(ctx, srlid, dst, dst, imm);
616                 break;
617
618         /* dst = dst >> src (arithmetic) */
619         case BPF_ALU | BPF_ARSH | BPF_X:
620                 emit_insn(ctx, sraw, dst, dst, src);
621                 emit_zext_32(ctx, dst, is32);
622                 break;
623
624         case BPF_ALU64 | BPF_ARSH | BPF_X:
625                 emit_insn(ctx, srad, dst, dst, src);
626                 break;
627
628         /* dst = dst >> imm (arithmetic) */
629         case BPF_ALU | BPF_ARSH | BPF_K:
630                 emit_insn(ctx, sraiw, dst, dst, imm);
631                 emit_zext_32(ctx, dst, is32);
632                 break;
633
634         case BPF_ALU64 | BPF_ARSH | BPF_K:
635                 emit_insn(ctx, sraid, dst, dst, imm);
636                 break;
637
638         /* dst = BSWAP##imm(dst) */
639         case BPF_ALU | BPF_END | BPF_FROM_LE:
640                 switch (imm) {
641                 case 16:
642                         /* zero-extend 16 bits into 64 bits */
643                         emit_insn(ctx, bstrpickd, dst, dst, 15, 0);
644                         break;
645                 case 32:
646                         /* zero-extend 32 bits into 64 bits */
647                         emit_zext_32(ctx, dst, is32);
648                         break;
649                 case 64:
650                         /* do nothing */
651                         break;
652                 }
653                 break;
654
655         case BPF_ALU | BPF_END | BPF_FROM_BE:
656                 switch (imm) {
657                 case 16:
658                         emit_insn(ctx, revb2h, dst, dst);
659                         /* zero-extend 16 bits into 64 bits */
660                         emit_insn(ctx, bstrpickd, dst, dst, 15, 0);
661                         break;
662                 case 32:
663                         emit_insn(ctx, revb2w, dst, dst);
664                         /* zero-extend 32 bits into 64 bits */
665                         emit_zext_32(ctx, dst, is32);
666                         break;
667                 case 64:
668                         emit_insn(ctx, revbd, dst, dst);
669                         break;
670                 }
671                 break;
672
673         /* PC += off if dst cond src */
674         case BPF_JMP | BPF_JEQ | BPF_X:
675         case BPF_JMP | BPF_JNE | BPF_X:
676         case BPF_JMP | BPF_JGT | BPF_X:
677         case BPF_JMP | BPF_JGE | BPF_X:
678         case BPF_JMP | BPF_JLT | BPF_X:
679         case BPF_JMP | BPF_JLE | BPF_X:
680         case BPF_JMP | BPF_JSGT | BPF_X:
681         case BPF_JMP | BPF_JSGE | BPF_X:
682         case BPF_JMP | BPF_JSLT | BPF_X:
683         case BPF_JMP | BPF_JSLE | BPF_X:
684         case BPF_JMP32 | BPF_JEQ | BPF_X:
685         case BPF_JMP32 | BPF_JNE | BPF_X:
686         case BPF_JMP32 | BPF_JGT | BPF_X:
687         case BPF_JMP32 | BPF_JGE | BPF_X:
688         case BPF_JMP32 | BPF_JLT | BPF_X:
689         case BPF_JMP32 | BPF_JLE | BPF_X:
690         case BPF_JMP32 | BPF_JSGT | BPF_X:
691         case BPF_JMP32 | BPF_JSGE | BPF_X:
692         case BPF_JMP32 | BPF_JSLT | BPF_X:
693         case BPF_JMP32 | BPF_JSLE | BPF_X:
694                 jmp_offset = bpf2la_offset(i, off, ctx);
695                 move_reg(ctx, t1, dst);
696                 move_reg(ctx, t2, src);
697                 if (is_signed_bpf_cond(BPF_OP(code))) {
698                         emit_sext_32(ctx, t1, is32);
699                         emit_sext_32(ctx, t2, is32);
700                 } else {
701                         emit_zext_32(ctx, t1, is32);
702                         emit_zext_32(ctx, t2, is32);
703                 }
704                 if (emit_cond_jmp(ctx, cond, t1, t2, jmp_offset) < 0)
705                         goto toofar;
706                 break;
707
708         /* PC += off if dst cond imm */
709         case BPF_JMP | BPF_JEQ | BPF_K:
710         case BPF_JMP | BPF_JNE | BPF_K:
711         case BPF_JMP | BPF_JGT | BPF_K:
712         case BPF_JMP | BPF_JGE | BPF_K:
713         case BPF_JMP | BPF_JLT | BPF_K:
714         case BPF_JMP | BPF_JLE | BPF_K:
715         case BPF_JMP | BPF_JSGT | BPF_K:
716         case BPF_JMP | BPF_JSGE | BPF_K:
717         case BPF_JMP | BPF_JSLT | BPF_K:
718         case BPF_JMP | BPF_JSLE | BPF_K:
719         case BPF_JMP32 | BPF_JEQ | BPF_K:
720         case BPF_JMP32 | BPF_JNE | BPF_K:
721         case BPF_JMP32 | BPF_JGT | BPF_K:
722         case BPF_JMP32 | BPF_JGE | BPF_K:
723         case BPF_JMP32 | BPF_JLT | BPF_K:
724         case BPF_JMP32 | BPF_JLE | BPF_K:
725         case BPF_JMP32 | BPF_JSGT | BPF_K:
726         case BPF_JMP32 | BPF_JSGE | BPF_K:
727         case BPF_JMP32 | BPF_JSLT | BPF_K:
728         case BPF_JMP32 | BPF_JSLE | BPF_K:
729                 jmp_offset = bpf2la_offset(i, off, ctx);
730                 if (imm) {
731                         move_imm(ctx, t1, imm, false);
732                         tm = t1;
733                 } else {
734                         /* If imm is 0, simply use zero register. */
735                         tm = LOONGARCH_GPR_ZERO;
736                 }
737                 move_reg(ctx, t2, dst);
738                 if (is_signed_bpf_cond(BPF_OP(code))) {
739                         emit_sext_32(ctx, tm, is32);
740                         emit_sext_32(ctx, t2, is32);
741                 } else {
742                         emit_zext_32(ctx, tm, is32);
743                         emit_zext_32(ctx, t2, is32);
744                 }
745                 if (emit_cond_jmp(ctx, cond, t2, tm, jmp_offset) < 0)
746                         goto toofar;
747                 break;
748
749         /* PC += off if dst & src */
750         case BPF_JMP | BPF_JSET | BPF_X:
751         case BPF_JMP32 | BPF_JSET | BPF_X:
752                 jmp_offset = bpf2la_offset(i, off, ctx);
753                 emit_insn(ctx, and, t1, dst, src);
754                 emit_zext_32(ctx, t1, is32);
755                 if (emit_cond_jmp(ctx, cond, t1, LOONGARCH_GPR_ZERO, jmp_offset) < 0)
756                         goto toofar;
757                 break;
758
759         /* PC += off if dst & imm */
760         case BPF_JMP | BPF_JSET | BPF_K:
761         case BPF_JMP32 | BPF_JSET | BPF_K:
762                 jmp_offset = bpf2la_offset(i, off, ctx);
763                 move_imm(ctx, t1, imm, is32);
764                 emit_insn(ctx, and, t1, dst, t1);
765                 emit_zext_32(ctx, t1, is32);
766                 if (emit_cond_jmp(ctx, cond, t1, LOONGARCH_GPR_ZERO, jmp_offset) < 0)
767                         goto toofar;
768                 break;
769
770         /* PC += off */
771         case BPF_JMP | BPF_JA:
772                 jmp_offset = bpf2la_offset(i, off, ctx);
773                 if (emit_uncond_jmp(ctx, jmp_offset) < 0)
774                         goto toofar;
775                 break;
776
777         /* function call */
778         case BPF_JMP | BPF_CALL:
779                 mark_call(ctx);
780                 ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass,
781                                             &func_addr, &func_addr_fixed);
782                 if (ret < 0)
783                         return ret;
784
785                 move_addr(ctx, t1, func_addr);
786                 emit_insn(ctx, jirl, t1, LOONGARCH_GPR_RA, 0);
787                 move_reg(ctx, regmap[BPF_REG_0], LOONGARCH_GPR_A0);
788                 break;
789
790         /* tail call */
791         case BPF_JMP | BPF_TAIL_CALL:
792                 mark_tail_call(ctx);
793                 if (emit_bpf_tail_call(ctx) < 0)
794                         return -EINVAL;
795                 break;
796
797         /* function return */
798         case BPF_JMP | BPF_EXIT:
799                 emit_sext_32(ctx, regmap[BPF_REG_0], true);
800
801                 if (i == ctx->prog->len - 1)
802                         break;
803
804                 jmp_offset = epilogue_offset(ctx);
805                 if (emit_uncond_jmp(ctx, jmp_offset) < 0)
806                         goto toofar;
807                 break;
808
809         /* dst = imm64 */
810         case BPF_LD | BPF_IMM | BPF_DW:
811                 move_imm(ctx, dst, imm64, is32);
812                 return 1;
813
814         /* dst = *(size *)(src + off) */
815         case BPF_LDX | BPF_MEM | BPF_B:
816         case BPF_LDX | BPF_MEM | BPF_H:
817         case BPF_LDX | BPF_MEM | BPF_W:
818         case BPF_LDX | BPF_MEM | BPF_DW:
819                 switch (BPF_SIZE(code)) {
820                 case BPF_B:
821                         if (is_signed_imm12(off)) {
822                                 emit_insn(ctx, ldbu, dst, src, off);
823                         } else {
824                                 move_imm(ctx, t1, off, is32);
825                                 emit_insn(ctx, ldxbu, dst, src, t1);
826                         }
827                         break;
828                 case BPF_H:
829                         if (is_signed_imm12(off)) {
830                                 emit_insn(ctx, ldhu, dst, src, off);
831                         } else {
832                                 move_imm(ctx, t1, off, is32);
833                                 emit_insn(ctx, ldxhu, dst, src, t1);
834                         }
835                         break;
836                 case BPF_W:
837                         if (is_signed_imm12(off)) {
838                                 emit_insn(ctx, ldwu, dst, src, off);
839                         } else if (is_signed_imm14(off)) {
840                                 emit_insn(ctx, ldptrw, dst, src, off);
841                         } else {
842                                 move_imm(ctx, t1, off, is32);
843                                 emit_insn(ctx, ldxwu, dst, src, t1);
844                         }
845                         break;
846                 case BPF_DW:
847                         move_imm(ctx, t1, off, is32);
848                         emit_insn(ctx, ldxd, dst, src, t1);
849                         break;
850                 }
851                 break;
852
853         /* *(size *)(dst + off) = imm */
854         case BPF_ST | BPF_MEM | BPF_B:
855         case BPF_ST | BPF_MEM | BPF_H:
856         case BPF_ST | BPF_MEM | BPF_W:
857         case BPF_ST | BPF_MEM | BPF_DW:
858                 switch (BPF_SIZE(code)) {
859                 case BPF_B:
860                         move_imm(ctx, t1, imm, is32);
861                         if (is_signed_imm12(off)) {
862                                 emit_insn(ctx, stb, t1, dst, off);
863                         } else {
864                                 move_imm(ctx, t2, off, is32);
865                                 emit_insn(ctx, stxb, t1, dst, t2);
866                         }
867                         break;
868                 case BPF_H:
869                         move_imm(ctx, t1, imm, is32);
870                         if (is_signed_imm12(off)) {
871                                 emit_insn(ctx, sth, t1, dst, off);
872                         } else {
873                                 move_imm(ctx, t2, off, is32);
874                                 emit_insn(ctx, stxh, t1, dst, t2);
875                         }
876                         break;
877                 case BPF_W:
878                         move_imm(ctx, t1, imm, is32);
879                         if (is_signed_imm12(off)) {
880                                 emit_insn(ctx, stw, t1, dst, off);
881                         } else if (is_signed_imm14(off)) {
882                                 emit_insn(ctx, stptrw, t1, dst, off);
883                         } else {
884                                 move_imm(ctx, t2, off, is32);
885                                 emit_insn(ctx, stxw, t1, dst, t2);
886                         }
887                         break;
888                 case BPF_DW:
889                         move_imm(ctx, t1, imm, is32);
890                         if (is_signed_imm12(off)) {
891                                 emit_insn(ctx, std, t1, dst, off);
892                         } else if (is_signed_imm14(off)) {
893                                 emit_insn(ctx, stptrd, t1, dst, off);
894                         } else {
895                                 move_imm(ctx, t2, off, is32);
896                                 emit_insn(ctx, stxd, t1, dst, t2);
897                         }
898                         break;
899                 }
900                 break;
901
902         /* *(size *)(dst + off) = src */
903         case BPF_STX | BPF_MEM | BPF_B:
904         case BPF_STX | BPF_MEM | BPF_H:
905         case BPF_STX | BPF_MEM | BPF_W:
906         case BPF_STX | BPF_MEM | BPF_DW:
907                 switch (BPF_SIZE(code)) {
908                 case BPF_B:
909                         if (is_signed_imm12(off)) {
910                                 emit_insn(ctx, stb, src, dst, off);
911                         } else {
912                                 move_imm(ctx, t1, off, is32);
913                                 emit_insn(ctx, stxb, src, dst, t1);
914                         }
915                         break;
916                 case BPF_H:
917                         if (is_signed_imm12(off)) {
918                                 emit_insn(ctx, sth, src, dst, off);
919                         } else {
920                                 move_imm(ctx, t1, off, is32);
921                                 emit_insn(ctx, stxh, src, dst, t1);
922                         }
923                         break;
924                 case BPF_W:
925                         if (is_signed_imm12(off)) {
926                                 emit_insn(ctx, stw, src, dst, off);
927                         } else if (is_signed_imm14(off)) {
928                                 emit_insn(ctx, stptrw, src, dst, off);
929                         } else {
930                                 move_imm(ctx, t1, off, is32);
931                                 emit_insn(ctx, stxw, src, dst, t1);
932                         }
933                         break;
934                 case BPF_DW:
935                         if (is_signed_imm12(off)) {
936                                 emit_insn(ctx, std, src, dst, off);
937                         } else if (is_signed_imm14(off)) {
938                                 emit_insn(ctx, stptrd, src, dst, off);
939                         } else {
940                                 move_imm(ctx, t1, off, is32);
941                                 emit_insn(ctx, stxd, src, dst, t1);
942                         }
943                         break;
944                 }
945                 break;
946
947         case BPF_STX | BPF_ATOMIC | BPF_W:
948         case BPF_STX | BPF_ATOMIC | BPF_DW:
949                 emit_atomic(insn, ctx);
950                 break;
951
952         /* Speculation barrier */
953         case BPF_ST | BPF_NOSPEC:
954                 break;
955
956         default:
957                 pr_err("bpf_jit: unknown opcode %02x\n", code);
958                 return -EINVAL;
959         }
960
961         return 0;
962
963 toofar:
964         pr_info_once("bpf_jit: opcode %02x, jump too far\n", code);
965         return -E2BIG;
966 }
967
968 static int build_body(struct jit_ctx *ctx, bool extra_pass)
969 {
970         int i;
971         const struct bpf_prog *prog = ctx->prog;
972
973         for (i = 0; i < prog->len; i++) {
974                 const struct bpf_insn *insn = &prog->insnsi[i];
975                 int ret;
976
977                 if (ctx->image == NULL)
978                         ctx->offset[i] = ctx->idx;
979
980                 ret = build_insn(insn, ctx, extra_pass);
981                 if (ret > 0) {
982                         i++;
983                         if (ctx->image == NULL)
984                                 ctx->offset[i] = ctx->idx;
985                         continue;
986                 }
987                 if (ret)
988                         return ret;
989         }
990
991         if (ctx->image == NULL)
992                 ctx->offset[i] = ctx->idx;
993
994         return 0;
995 }
996
997 /* Fill space with break instructions */
998 static void jit_fill_hole(void *area, unsigned int size)
999 {
1000         u32 *ptr;
1001
1002         /* We are guaranteed to have aligned memory */
1003         for (ptr = area; size >= sizeof(u32); size -= sizeof(u32))
1004                 *ptr++ = INSN_BREAK;
1005 }
1006
1007 static int validate_code(struct jit_ctx *ctx)
1008 {
1009         int i;
1010         union loongarch_instruction insn;
1011
1012         for (i = 0; i < ctx->idx; i++) {
1013                 insn = ctx->image[i];
1014                 /* Check INSN_BREAK */
1015                 if (insn.word == INSN_BREAK)
1016                         return -1;
1017         }
1018
1019         return 0;
1020 }
1021
1022 struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
1023 {
1024         bool tmp_blinded = false, extra_pass = false;
1025         u8 *image_ptr;
1026         int image_size;
1027         struct jit_ctx ctx;
1028         struct jit_data *jit_data;
1029         struct bpf_binary_header *header;
1030         struct bpf_prog *tmp, *orig_prog = prog;
1031
1032         /*
1033          * If BPF JIT was not enabled then we must fall back to
1034          * the interpreter.
1035          */
1036         if (!prog->jit_requested)
1037                 return orig_prog;
1038
1039         tmp = bpf_jit_blind_constants(prog);
1040         /*
1041          * If blinding was requested and we failed during blinding,
1042          * we must fall back to the interpreter. Otherwise, we save
1043          * the new JITed code.
1044          */
1045         if (IS_ERR(tmp))
1046                 return orig_prog;
1047
1048         if (tmp != prog) {
1049                 tmp_blinded = true;
1050                 prog = tmp;
1051         }
1052
1053         jit_data = prog->aux->jit_data;
1054         if (!jit_data) {
1055                 jit_data = kzalloc(sizeof(*jit_data), GFP_KERNEL);
1056                 if (!jit_data) {
1057                         prog = orig_prog;
1058                         goto out;
1059                 }
1060                 prog->aux->jit_data = jit_data;
1061         }
1062         if (jit_data->ctx.offset) {
1063                 ctx = jit_data->ctx;
1064                 image_ptr = jit_data->image;
1065                 header = jit_data->header;
1066                 extra_pass = true;
1067                 image_size = sizeof(u32) * ctx.idx;
1068                 goto skip_init_ctx;
1069         }
1070
1071         memset(&ctx, 0, sizeof(ctx));
1072         ctx.prog = prog;
1073
1074         ctx.offset = kvcalloc(prog->len + 1, sizeof(u32), GFP_KERNEL);
1075         if (ctx.offset == NULL) {
1076                 prog = orig_prog;
1077                 goto out_offset;
1078         }
1079
1080         /* 1. Initial fake pass to compute ctx->idx and set ctx->flags */
1081         build_prologue(&ctx);
1082         if (build_body(&ctx, extra_pass)) {
1083                 prog = orig_prog;
1084                 goto out_offset;
1085         }
1086         ctx.epilogue_offset = ctx.idx;
1087         build_epilogue(&ctx);
1088
1089         /* Now we know the actual image size.
1090          * As each LoongArch instruction is of length 32bit,
1091          * we are translating number of JITed intructions into
1092          * the size required to store these JITed code.
1093          */
1094         image_size = sizeof(u32) * ctx.idx;
1095         /* Now we know the size of the structure to make */
1096         header = bpf_jit_binary_alloc(image_size, &image_ptr,
1097                                       sizeof(u32), jit_fill_hole);
1098         if (header == NULL) {
1099                 prog = orig_prog;
1100                 goto out_offset;
1101         }
1102
1103         /* 2. Now, the actual pass to generate final JIT code */
1104         ctx.image = (union loongarch_instruction *)image_ptr;
1105
1106 skip_init_ctx:
1107         ctx.idx = 0;
1108
1109         build_prologue(&ctx);
1110         if (build_body(&ctx, extra_pass)) {
1111                 bpf_jit_binary_free(header);
1112                 prog = orig_prog;
1113                 goto out_offset;
1114         }
1115         build_epilogue(&ctx);
1116
1117         /* 3. Extra pass to validate JITed code */
1118         if (validate_code(&ctx)) {
1119                 bpf_jit_binary_free(header);
1120                 prog = orig_prog;
1121                 goto out_offset;
1122         }
1123
1124         /* And we're done */
1125         if (bpf_jit_enable > 1)
1126                 bpf_jit_dump(prog->len, image_size, 2, ctx.image);
1127
1128         /* Update the icache */
1129         flush_icache_range((unsigned long)header, (unsigned long)(ctx.image + ctx.idx));
1130
1131         if (!prog->is_func || extra_pass) {
1132                 if (extra_pass && ctx.idx != jit_data->ctx.idx) {
1133                         pr_err_once("multi-func JIT bug %d != %d\n",
1134                                     ctx.idx, jit_data->ctx.idx);
1135                         bpf_jit_binary_free(header);
1136                         prog->bpf_func = NULL;
1137                         prog->jited = 0;
1138                         prog->jited_len = 0;
1139                         goto out_offset;
1140                 }
1141                 bpf_jit_binary_lock_ro(header);
1142         } else {
1143                 jit_data->ctx = ctx;
1144                 jit_data->image = image_ptr;
1145                 jit_data->header = header;
1146         }
1147         prog->jited = 1;
1148         prog->jited_len = image_size;
1149         prog->bpf_func = (void *)ctx.image;
1150
1151         if (!prog->is_func || extra_pass) {
1152                 int i;
1153
1154                 /* offset[prog->len] is the size of program */
1155                 for (i = 0; i <= prog->len; i++)
1156                         ctx.offset[i] *= LOONGARCH_INSN_SIZE;
1157                 bpf_prog_fill_jited_linfo(prog, ctx.offset + 1);
1158
1159 out_offset:
1160                 kvfree(ctx.offset);
1161                 kfree(jit_data);
1162                 prog->aux->jit_data = NULL;
1163         }
1164
1165 out:
1166         if (tmp_blinded)
1167                 bpf_jit_prog_release_other(prog, prog == orig_prog ? tmp : orig_prog);
1168
1169         out_offset = -1;
1170
1171         return prog;
1172 }