eb0e46905bf9d0b183d8752cd3046117f3222278
[platform/kernel/linux-rpi.git] / tools / testing / selftests / x86 / lam.c
1 // SPDX-License-Identifier: GPL-2.0
2 #define _GNU_SOURCE
3 #include <stdio.h>
4 #include <stdlib.h>
5 #include <string.h>
6 #include <sys/syscall.h>
7 #include <time.h>
8 #include <signal.h>
9 #include <setjmp.h>
10 #include <sys/mman.h>
11 #include <sys/utsname.h>
12 #include <sys/wait.h>
13 #include <sys/stat.h>
14 #include <fcntl.h>
15 #include <inttypes.h>
16 #include <sched.h>
17
18 #include <sys/uio.h>
19 #include <linux/io_uring.h>
20 #include "../kselftest.h"
21
22 #ifndef __x86_64__
23 # error This test is 64-bit only
24 #endif
25
26 /* LAM modes, these definitions were copied from kernel code */
27 #define LAM_NONE                0
28 #define LAM_U57_BITS            6
29
30 #define LAM_U57_MASK            (0x3fULL << 57)
31 /* arch prctl for LAM */
32 #define ARCH_GET_UNTAG_MASK     0x4001
33 #define ARCH_ENABLE_TAGGED_ADDR 0x4002
34 #define ARCH_GET_MAX_TAG_BITS   0x4003
35 #define ARCH_FORCE_TAGGED_SVA   0x4004
36
37 /* Specified test function bits */
38 #define FUNC_MALLOC             0x1
39 #define FUNC_BITS               0x2
40 #define FUNC_MMAP               0x4
41 #define FUNC_SYSCALL            0x8
42 #define FUNC_URING              0x10
43 #define FUNC_INHERITE           0x20
44 #define FUNC_PASID              0x40
45
46 #define TEST_MASK               0x7f
47
48 #define LOW_ADDR                (0x1UL << 30)
49 #define HIGH_ADDR               (0x3UL << 48)
50
51 #define MALLOC_LEN              32
52
53 #define PAGE_SIZE               (4 << 10)
54
55 #define STACK_SIZE              65536
56
57 #define barrier() ({                                            \
58                    __asm__ __volatile__("" : : : "memory");     \
59 })
60
61 #define URING_QUEUE_SZ 1
62 #define URING_BLOCK_SZ 2048
63
64 /* Pasid test define */
65 #define LAM_CMD_BIT 0x1
66 #define PAS_CMD_BIT 0x2
67 #define SVA_CMD_BIT 0x4
68
69 #define PAS_CMD(cmd1, cmd2, cmd3) (((cmd3) << 8) | ((cmd2) << 4) | ((cmd1) << 0))
70
71 struct testcases {
72         unsigned int later;
73         int expected; /* 2: SIGSEGV Error; 1: other errors */
74         unsigned long lam;
75         uint64_t addr;
76         uint64_t cmd;
77         int (*test_func)(struct testcases *test);
78         const char *msg;
79 };
80
81 /* Used by CQ of uring, source file handler and file's size */
82 struct file_io {
83         int file_fd;
84         off_t file_sz;
85         struct iovec iovecs[];
86 };
87
88 struct io_uring_queue {
89         unsigned int *head;
90         unsigned int *tail;
91         unsigned int *ring_mask;
92         unsigned int *ring_entries;
93         unsigned int *flags;
94         unsigned int *array;
95         union {
96                 struct io_uring_cqe *cqes;
97                 struct io_uring_sqe *sqes;
98         } queue;
99         size_t ring_sz;
100 };
101
102 struct io_ring {
103         int ring_fd;
104         struct io_uring_queue sq_ring;
105         struct io_uring_queue cq_ring;
106 };
107
108 int tests_cnt;
109 jmp_buf segv_env;
110
111 static void segv_handler(int sig)
112 {
113         ksft_print_msg("Get segmentation fault(%d).", sig);
114
115         siglongjmp(segv_env, 1);
116 }
117
118 static inline int cpu_has_lam(void)
119 {
120         unsigned int cpuinfo[4];
121
122         __cpuid_count(0x7, 1, cpuinfo[0], cpuinfo[1], cpuinfo[2], cpuinfo[3]);
123
124         return (cpuinfo[0] & (1 << 26));
125 }
126
127 /* Check 5-level page table feature in CPUID.(EAX=07H, ECX=00H):ECX.[bit 16] */
128 static inline int cpu_has_la57(void)
129 {
130         unsigned int cpuinfo[4];
131
132         __cpuid_count(0x7, 0, cpuinfo[0], cpuinfo[1], cpuinfo[2], cpuinfo[3]);
133
134         return (cpuinfo[2] & (1 << 16));
135 }
136
137 /*
138  * Set tagged address and read back untag mask.
139  * check if the untagged mask is expected.
140  *
141  * @return:
142  * 0: Set LAM mode successfully
143  * others: failed to set LAM
144  */
145 static int set_lam(unsigned long lam)
146 {
147         int ret = 0;
148         uint64_t ptr = 0;
149
150         if (lam != LAM_U57_BITS && lam != LAM_NONE)
151                 return -1;
152
153         /* Skip check return */
154         syscall(SYS_arch_prctl, ARCH_ENABLE_TAGGED_ADDR, lam);
155
156         /* Get untagged mask */
157         syscall(SYS_arch_prctl, ARCH_GET_UNTAG_MASK, &ptr);
158
159         /* Check mask returned is expected */
160         if (lam == LAM_U57_BITS)
161                 ret = (ptr != ~(LAM_U57_MASK));
162         else if (lam == LAM_NONE)
163                 ret = (ptr != -1ULL);
164
165         return ret;
166 }
167
168 static unsigned long get_default_tag_bits(void)
169 {
170         pid_t pid;
171         int lam = LAM_NONE;
172         int ret = 0;
173
174         pid = fork();
175         if (pid < 0) {
176                 perror("Fork failed.");
177         } else if (pid == 0) {
178                 /* Set LAM mode in child process */
179                 if (set_lam(LAM_U57_BITS) == 0)
180                         lam = LAM_U57_BITS;
181                 else
182                         lam = LAM_NONE;
183                 exit(lam);
184         } else {
185                 wait(&ret);
186                 lam = WEXITSTATUS(ret);
187         }
188
189         return lam;
190 }
191
192 /*
193  * Set tagged address and read back untag mask.
194  * check if the untag mask is expected.
195  */
196 static int get_lam(void)
197 {
198         uint64_t ptr = 0;
199         int ret = -1;
200         /* Get untagged mask */
201         if (syscall(SYS_arch_prctl, ARCH_GET_UNTAG_MASK, &ptr) == -1)
202                 return -1;
203
204         /* Check mask returned is expected */
205         if (ptr == ~(LAM_U57_MASK))
206                 ret = LAM_U57_BITS;
207         else if (ptr == -1ULL)
208                 ret = LAM_NONE;
209
210
211         return ret;
212 }
213
214 /* According to LAM mode, set metadata in high bits */
215 static uint64_t set_metadata(uint64_t src, unsigned long lam)
216 {
217         uint64_t metadata;
218
219         srand(time(NULL));
220
221         switch (lam) {
222         case LAM_U57_BITS: /* Set metadata in bits 62:57 */
223                 /* Get a random non-zero value as metadata */
224                 metadata = (rand() % ((1UL << LAM_U57_BITS) - 1) + 1) << 57;
225                 metadata |= (src & ~(LAM_U57_MASK));
226                 break;
227         default:
228                 metadata = src;
229                 break;
230         }
231
232         return metadata;
233 }
234
235 /*
236  * Set metadata in user pointer, compare new pointer with original pointer.
237  * both pointers should point to the same address.
238  *
239  * @return:
240  * 0: value on the pointer with metadate and value on original are same
241  * 1: not same.
242  */
243 static int handle_lam_test(void *src, unsigned int lam)
244 {
245         char *ptr;
246
247         strcpy((char *)src, "USER POINTER");
248
249         ptr = (char *)set_metadata((uint64_t)src, lam);
250         if (src == ptr)
251                 return 0;
252
253         /* Copy a string into the pointer with metadata */
254         strcpy((char *)ptr, "METADATA POINTER");
255
256         return (!!strcmp((char *)src, (char *)ptr));
257 }
258
259
260 int handle_max_bits(struct testcases *test)
261 {
262         unsigned long exp_bits = get_default_tag_bits();
263         unsigned long bits = 0;
264
265         if (exp_bits != LAM_NONE)
266                 exp_bits = LAM_U57_BITS;
267
268         /* Get LAM max tag bits */
269         if (syscall(SYS_arch_prctl, ARCH_GET_MAX_TAG_BITS, &bits) == -1)
270                 return 1;
271
272         return (exp_bits != bits);
273 }
274
275 /*
276  * Test lam feature through dereference pointer get from malloc.
277  * @return 0: Pass test. 1: Get failure during test 2: Get SIGSEGV
278  */
279 static int handle_malloc(struct testcases *test)
280 {
281         char *ptr = NULL;
282         int ret = 0;
283
284         if (test->later == 0 && test->lam != 0)
285                 if (set_lam(test->lam) == -1)
286                         return 1;
287
288         ptr = (char *)malloc(MALLOC_LEN);
289         if (ptr == NULL) {
290                 perror("malloc() failure\n");
291                 return 1;
292         }
293
294         /* Set signal handler */
295         if (sigsetjmp(segv_env, 1) == 0) {
296                 signal(SIGSEGV, segv_handler);
297                 ret = handle_lam_test(ptr, test->lam);
298         } else {
299                 ret = 2;
300         }
301
302         if (test->later != 0 && test->lam != 0)
303                 if (set_lam(test->lam) == -1 && ret == 0)
304                         ret = 1;
305
306         free(ptr);
307
308         return ret;
309 }
310
311 static int handle_mmap(struct testcases *test)
312 {
313         void *ptr;
314         unsigned int flags = MAP_PRIVATE | MAP_ANONYMOUS | MAP_FIXED;
315         int ret = 0;
316
317         if (test->later == 0 && test->lam != 0)
318                 if (set_lam(test->lam) != 0)
319                         return 1;
320
321         ptr = mmap((void *)test->addr, PAGE_SIZE, PROT_READ | PROT_WRITE,
322                    flags, -1, 0);
323         if (ptr == MAP_FAILED) {
324                 if (test->addr == HIGH_ADDR)
325                         if (!cpu_has_la57())
326                                 return 3; /* unsupport LA57 */
327                 return 1;
328         }
329
330         if (test->later != 0 && test->lam != 0)
331                 if (set_lam(test->lam) != 0)
332                         ret = 1;
333
334         if (ret == 0) {
335                 if (sigsetjmp(segv_env, 1) == 0) {
336                         signal(SIGSEGV, segv_handler);
337                         ret = handle_lam_test(ptr, test->lam);
338                 } else {
339                         ret = 2;
340                 }
341         }
342
343         munmap(ptr, PAGE_SIZE);
344         return ret;
345 }
346
347 static int handle_syscall(struct testcases *test)
348 {
349         struct utsname unme, *pu;
350         int ret = 0;
351
352         if (test->later == 0 && test->lam != 0)
353                 if (set_lam(test->lam) != 0)
354                         return 1;
355
356         if (sigsetjmp(segv_env, 1) == 0) {
357                 signal(SIGSEGV, segv_handler);
358                 pu = (struct utsname *)set_metadata((uint64_t)&unme, test->lam);
359                 ret = uname(pu);
360                 if (ret < 0)
361                         ret = 1;
362         } else {
363                 ret = 2;
364         }
365
366         if (test->later != 0 && test->lam != 0)
367                 if (set_lam(test->lam) != -1 && ret == 0)
368                         ret = 1;
369
370         return ret;
371 }
372
373 int sys_uring_setup(unsigned int entries, struct io_uring_params *p)
374 {
375         return (int)syscall(__NR_io_uring_setup, entries, p);
376 }
377
378 int sys_uring_enter(int fd, unsigned int to, unsigned int min, unsigned int flags)
379 {
380         return (int)syscall(__NR_io_uring_enter, fd, to, min, flags, NULL, 0);
381 }
382
383 /* Init submission queue and completion queue */
384 int mmap_io_uring(struct io_uring_params p, struct io_ring *s)
385 {
386         struct io_uring_queue *sring = &s->sq_ring;
387         struct io_uring_queue *cring = &s->cq_ring;
388
389         sring->ring_sz = p.sq_off.array + p.sq_entries * sizeof(unsigned int);
390         cring->ring_sz = p.cq_off.cqes + p.cq_entries * sizeof(struct io_uring_cqe);
391
392         if (p.features & IORING_FEAT_SINGLE_MMAP) {
393                 if (cring->ring_sz > sring->ring_sz)
394                         sring->ring_sz = cring->ring_sz;
395
396                 cring->ring_sz = sring->ring_sz;
397         }
398
399         void *sq_ptr = mmap(0, sring->ring_sz, PROT_READ | PROT_WRITE,
400                             MAP_SHARED | MAP_POPULATE, s->ring_fd,
401                             IORING_OFF_SQ_RING);
402
403         if (sq_ptr == MAP_FAILED) {
404                 perror("sub-queue!");
405                 return 1;
406         }
407
408         void *cq_ptr = sq_ptr;
409
410         if (!(p.features & IORING_FEAT_SINGLE_MMAP)) {
411                 cq_ptr = mmap(0, cring->ring_sz, PROT_READ | PROT_WRITE,
412                               MAP_SHARED | MAP_POPULATE, s->ring_fd,
413                               IORING_OFF_CQ_RING);
414                 if (cq_ptr == MAP_FAILED) {
415                         perror("cpl-queue!");
416                         munmap(sq_ptr, sring->ring_sz);
417                         return 1;
418                 }
419         }
420
421         sring->head = sq_ptr + p.sq_off.head;
422         sring->tail = sq_ptr + p.sq_off.tail;
423         sring->ring_mask = sq_ptr + p.sq_off.ring_mask;
424         sring->ring_entries = sq_ptr + p.sq_off.ring_entries;
425         sring->flags = sq_ptr + p.sq_off.flags;
426         sring->array = sq_ptr + p.sq_off.array;
427
428         /* Map a queue as mem map */
429         s->sq_ring.queue.sqes = mmap(0, p.sq_entries * sizeof(struct io_uring_sqe),
430                                      PROT_READ | PROT_WRITE, MAP_SHARED | MAP_POPULATE,
431                                      s->ring_fd, IORING_OFF_SQES);
432         if (s->sq_ring.queue.sqes == MAP_FAILED) {
433                 munmap(sq_ptr, sring->ring_sz);
434                 if (sq_ptr != cq_ptr) {
435                         ksft_print_msg("failed to mmap uring queue!");
436                         munmap(cq_ptr, cring->ring_sz);
437                         return 1;
438                 }
439         }
440
441         cring->head = cq_ptr + p.cq_off.head;
442         cring->tail = cq_ptr + p.cq_off.tail;
443         cring->ring_mask = cq_ptr + p.cq_off.ring_mask;
444         cring->ring_entries = cq_ptr + p.cq_off.ring_entries;
445         cring->queue.cqes = cq_ptr + p.cq_off.cqes;
446
447         return 0;
448 }
449
450 /* Init io_uring queues */
451 int setup_io_uring(struct io_ring *s)
452 {
453         struct io_uring_params para;
454
455         memset(&para, 0, sizeof(para));
456         s->ring_fd = sys_uring_setup(URING_QUEUE_SZ, &para);
457         if (s->ring_fd < 0)
458                 return 1;
459
460         return mmap_io_uring(para, s);
461 }
462
463 /*
464  * Get data from completion queue. the data buffer saved the file data
465  * return 0: success; others: error;
466  */
467 int handle_uring_cq(struct io_ring *s)
468 {
469         struct file_io *fi = NULL;
470         struct io_uring_queue *cring = &s->cq_ring;
471         struct io_uring_cqe *cqe;
472         unsigned int head;
473         off_t len = 0;
474
475         head = *cring->head;
476
477         do {
478                 barrier();
479                 if (head == *cring->tail)
480                         break;
481                 /* Get the entry */
482                 cqe = &cring->queue.cqes[head & *s->cq_ring.ring_mask];
483                 fi = (struct file_io *)cqe->user_data;
484                 if (cqe->res < 0)
485                         break;
486
487                 int blocks = (int)(fi->file_sz + URING_BLOCK_SZ - 1) / URING_BLOCK_SZ;
488
489                 for (int i = 0; i < blocks; i++)
490                         len += fi->iovecs[i].iov_len;
491
492                 head++;
493         } while (1);
494
495         *cring->head = head;
496         barrier();
497
498         return (len != fi->file_sz);
499 }
500
501 /*
502  * Submit squeue. specify via IORING_OP_READV.
503  * the buffer need to be set metadata according to LAM mode
504  */
505 int handle_uring_sq(struct io_ring *ring, struct file_io *fi, unsigned long lam)
506 {
507         int file_fd = fi->file_fd;
508         struct io_uring_queue *sring = &ring->sq_ring;
509         unsigned int index = 0, cur_block = 0, tail = 0, next_tail = 0;
510         struct io_uring_sqe *sqe;
511
512         off_t remain = fi->file_sz;
513         int blocks = (int)(remain + URING_BLOCK_SZ - 1) / URING_BLOCK_SZ;
514
515         while (remain) {
516                 off_t bytes = remain;
517                 void *buf;
518
519                 if (bytes > URING_BLOCK_SZ)
520                         bytes = URING_BLOCK_SZ;
521
522                 fi->iovecs[cur_block].iov_len = bytes;
523
524                 if (posix_memalign(&buf, URING_BLOCK_SZ, URING_BLOCK_SZ))
525                         return 1;
526
527                 fi->iovecs[cur_block].iov_base = (void *)set_metadata((uint64_t)buf, lam);
528                 remain -= bytes;
529                 cur_block++;
530         }
531
532         next_tail = *sring->tail;
533         tail = next_tail;
534         next_tail++;
535
536         barrier();
537
538         index = tail & *ring->sq_ring.ring_mask;
539
540         sqe = &ring->sq_ring.queue.sqes[index];
541         sqe->fd = file_fd;
542         sqe->flags = 0;
543         sqe->opcode = IORING_OP_READV;
544         sqe->addr = (unsigned long)fi->iovecs;
545         sqe->len = blocks;
546         sqe->off = 0;
547         sqe->user_data = (uint64_t)fi;
548
549         sring->array[index] = index;
550         tail = next_tail;
551
552         if (*sring->tail != tail) {
553                 *sring->tail = tail;
554                 barrier();
555         }
556
557         if (sys_uring_enter(ring->ring_fd, 1, 1, IORING_ENTER_GETEVENTS) < 0)
558                 return 1;
559
560         return 0;
561 }
562
563 /*
564  * Test LAM in async I/O and io_uring, read current binery through io_uring
565  * Set metadata in pointers to iovecs buffer.
566  */
567 int do_uring(unsigned long lam)
568 {
569         struct io_ring *ring;
570         struct file_io *fi;
571         struct stat st;
572         int ret = 1;
573         char path[PATH_MAX] = {0};
574
575         /* get current process path */
576         if (readlink("/proc/self/exe", path, PATH_MAX) <= 0)
577                 return 1;
578
579         int file_fd = open(path, O_RDONLY);
580
581         if (file_fd < 0)
582                 return 1;
583
584         if (fstat(file_fd, &st) < 0)
585                 return 1;
586
587         off_t file_sz = st.st_size;
588
589         int blocks = (int)(file_sz + URING_BLOCK_SZ - 1) / URING_BLOCK_SZ;
590
591         fi = malloc(sizeof(*fi) + sizeof(struct iovec) * blocks);
592         if (!fi)
593                 return 1;
594
595         fi->file_sz = file_sz;
596         fi->file_fd = file_fd;
597
598         ring = malloc(sizeof(*ring));
599         if (!ring)
600                 return 1;
601
602         memset(ring, 0, sizeof(struct io_ring));
603
604         if (setup_io_uring(ring))
605                 goto out;
606
607         if (handle_uring_sq(ring, fi, lam))
608                 goto out;
609
610         ret = handle_uring_cq(ring);
611
612 out:
613         free(ring);
614
615         for (int i = 0; i < blocks; i++) {
616                 if (fi->iovecs[i].iov_base) {
617                         uint64_t addr = ((uint64_t)fi->iovecs[i].iov_base);
618
619                         switch (lam) {
620                         case LAM_U57_BITS: /* Clear bits 62:57 */
621                                 addr = (addr & ~(LAM_U57_MASK));
622                                 break;
623                         }
624                         free((void *)addr);
625                         fi->iovecs[i].iov_base = NULL;
626                 }
627         }
628
629         free(fi);
630
631         return ret;
632 }
633
634 int handle_uring(struct testcases *test)
635 {
636         int ret = 0;
637
638         if (test->later == 0 && test->lam != 0)
639                 if (set_lam(test->lam) != 0)
640                         return 1;
641
642         if (sigsetjmp(segv_env, 1) == 0) {
643                 signal(SIGSEGV, segv_handler);
644                 ret = do_uring(test->lam);
645         } else {
646                 ret = 2;
647         }
648
649         return ret;
650 }
651
652 static int fork_test(struct testcases *test)
653 {
654         int ret, child_ret;
655         pid_t pid;
656
657         pid = fork();
658         if (pid < 0) {
659                 perror("Fork failed.");
660                 ret = 1;
661         } else if (pid == 0) {
662                 ret = test->test_func(test);
663                 exit(ret);
664         } else {
665                 wait(&child_ret);
666                 ret = WEXITSTATUS(child_ret);
667         }
668
669         return ret;
670 }
671
672 static int handle_execve(struct testcases *test)
673 {
674         int ret, child_ret;
675         int lam = test->lam;
676         pid_t pid;
677
678         pid = fork();
679         if (pid < 0) {
680                 perror("Fork failed.");
681                 ret = 1;
682         } else if (pid == 0) {
683                 char path[PATH_MAX];
684
685                 /* Set LAM mode in parent process */
686                 if (set_lam(lam) != 0)
687                         return 1;
688
689                 /* Get current binary's path and the binary was run by execve */
690                 if (readlink("/proc/self/exe", path, PATH_MAX) <= 0)
691                         exit(-1);
692
693                 /* run binary to get LAM mode and return to parent process */
694                 if (execlp(path, path, "-t 0x0", NULL) < 0) {
695                         perror("error on exec");
696                         exit(-1);
697                 }
698         } else {
699                 wait(&child_ret);
700                 ret = WEXITSTATUS(child_ret);
701                 if (ret != LAM_NONE)
702                         return 1;
703         }
704
705         return 0;
706 }
707
708 static int handle_inheritance(struct testcases *test)
709 {
710         int ret, child_ret;
711         int lam = test->lam;
712         pid_t pid;
713
714         /* Set LAM mode in parent process */
715         if (set_lam(lam) != 0)
716                 return 1;
717
718         pid = fork();
719         if (pid < 0) {
720                 perror("Fork failed.");
721                 return 1;
722         } else if (pid == 0) {
723                 /* Set LAM mode in parent process */
724                 int child_lam = get_lam();
725
726                 exit(child_lam);
727         } else {
728                 wait(&child_ret);
729                 ret = WEXITSTATUS(child_ret);
730
731                 if (lam != ret)
732                         return 1;
733         }
734
735         return 0;
736 }
737
738 static int thread_fn_get_lam(void *arg)
739 {
740         return get_lam();
741 }
742
743 static int thread_fn_set_lam(void *arg)
744 {
745         struct testcases *test = arg;
746
747         return set_lam(test->lam);
748 }
749
750 static int handle_thread(struct testcases *test)
751 {
752         char stack[STACK_SIZE];
753         int ret, child_ret;
754         int lam = 0;
755         pid_t pid;
756
757         /* Set LAM mode in parent process */
758         if (!test->later) {
759                 lam = test->lam;
760                 if (set_lam(lam) != 0)
761                         return 1;
762         }
763
764         pid = clone(thread_fn_get_lam, stack + STACK_SIZE,
765                     SIGCHLD | CLONE_FILES | CLONE_FS | CLONE_VM, NULL);
766         if (pid < 0) {
767                 perror("Clone failed.");
768                 return 1;
769         }
770
771         waitpid(pid, &child_ret, 0);
772         ret = WEXITSTATUS(child_ret);
773
774         if (lam != ret)
775                 return 1;
776
777         if (test->later) {
778                 if (set_lam(test->lam) != 0)
779                         return 1;
780         }
781
782         return 0;
783 }
784
785 static int handle_thread_enable(struct testcases *test)
786 {
787         char stack[STACK_SIZE];
788         int ret, child_ret;
789         int lam = test->lam;
790         pid_t pid;
791
792         pid = clone(thread_fn_set_lam, stack + STACK_SIZE,
793                     SIGCHLD | CLONE_FILES | CLONE_FS | CLONE_VM, test);
794         if (pid < 0) {
795                 perror("Clone failed.");
796                 return 1;
797         }
798
799         waitpid(pid, &child_ret, 0);
800         ret = WEXITSTATUS(child_ret);
801
802         if (lam != ret)
803                 return 1;
804
805         return 0;
806 }
807 static void run_test(struct testcases *test, int count)
808 {
809         int i, ret = 0;
810
811         for (i = 0; i < count; i++) {
812                 struct testcases *t = test + i;
813
814                 /* fork a process to run test case */
815                 tests_cnt++;
816                 ret = fork_test(t);
817
818                 /* return 3 is not support LA57, the case should be skipped */
819                 if (ret == 3) {
820                         ksft_test_result_skip(t->msg);
821                         continue;
822                 }
823
824                 if (ret != 0)
825                         ret = (t->expected == ret);
826                 else
827                         ret = !(t->expected);
828
829                 ksft_test_result(ret, t->msg);
830         }
831 }
832
833 static struct testcases uring_cases[] = {
834         {
835                 .later = 0,
836                 .lam = LAM_U57_BITS,
837                 .test_func = handle_uring,
838                 .msg = "URING: LAM_U57. Dereferencing pointer with metadata\n",
839         },
840         {
841                 .later = 1,
842                 .expected = 1,
843                 .lam = LAM_U57_BITS,
844                 .test_func = handle_uring,
845                 .msg = "URING:[Negative] Disable LAM. Dereferencing pointer with metadata.\n",
846         },
847 };
848
849 static struct testcases malloc_cases[] = {
850         {
851                 .later = 0,
852                 .lam = LAM_U57_BITS,
853                 .test_func = handle_malloc,
854                 .msg = "MALLOC: LAM_U57. Dereferencing pointer with metadata\n",
855         },
856         {
857                 .later = 1,
858                 .expected = 2,
859                 .lam = LAM_U57_BITS,
860                 .test_func = handle_malloc,
861                 .msg = "MALLOC:[Negative] Disable LAM. Dereferencing pointer with metadata.\n",
862         },
863 };
864
865 static struct testcases bits_cases[] = {
866         {
867                 .test_func = handle_max_bits,
868                 .msg = "BITS: Check default tag bits\n",
869         },
870 };
871
872 static struct testcases syscall_cases[] = {
873         {
874                 .later = 0,
875                 .lam = LAM_U57_BITS,
876                 .test_func = handle_syscall,
877                 .msg = "SYSCALL: LAM_U57. syscall with metadata\n",
878         },
879         {
880                 .later = 1,
881                 .expected = 1,
882                 .lam = LAM_U57_BITS,
883                 .test_func = handle_syscall,
884                 .msg = "SYSCALL:[Negative] Disable LAM. Dereferencing pointer with metadata.\n",
885         },
886 };
887
888 static struct testcases mmap_cases[] = {
889         {
890                 .later = 1,
891                 .expected = 0,
892                 .lam = LAM_U57_BITS,
893                 .addr = HIGH_ADDR,
894                 .test_func = handle_mmap,
895                 .msg = "MMAP: First mmap high address, then set LAM_U57.\n",
896         },
897         {
898                 .later = 0,
899                 .expected = 0,
900                 .lam = LAM_U57_BITS,
901                 .addr = HIGH_ADDR,
902                 .test_func = handle_mmap,
903                 .msg = "MMAP: First LAM_U57, then High address.\n",
904         },
905         {
906                 .later = 0,
907                 .expected = 0,
908                 .lam = LAM_U57_BITS,
909                 .addr = LOW_ADDR,
910                 .test_func = handle_mmap,
911                 .msg = "MMAP: First LAM_U57, then Low address.\n",
912         },
913 };
914
915 static struct testcases inheritance_cases[] = {
916         {
917                 .expected = 0,
918                 .lam = LAM_U57_BITS,
919                 .test_func = handle_inheritance,
920                 .msg = "FORK: LAM_U57, child process should get LAM mode same as parent\n",
921         },
922         {
923                 .expected = 0,
924                 .lam = LAM_U57_BITS,
925                 .test_func = handle_thread,
926                 .msg = "THREAD: LAM_U57, child thread should get LAM mode same as parent\n",
927         },
928         {
929                 .expected = 1,
930                 .lam = LAM_U57_BITS,
931                 .test_func = handle_thread_enable,
932                 .msg = "THREAD: [NEGATIVE] Enable LAM in child.\n",
933         },
934         {
935                 .expected = 1,
936                 .later = 1,
937                 .lam = LAM_U57_BITS,
938                 .test_func = handle_thread,
939                 .msg = "THREAD: [NEGATIVE] Enable LAM in parent after thread created.\n",
940         },
941         {
942                 .expected = 0,
943                 .lam = LAM_U57_BITS,
944                 .test_func = handle_execve,
945                 .msg = "EXECVE: LAM_U57, child process should get disabled LAM mode\n",
946         },
947 };
948
949 static void cmd_help(void)
950 {
951         printf("usage: lam [-h] [-t test list]\n");
952         printf("\t-t test list: run tests specified in the test list, default:0x%x\n", TEST_MASK);
953         printf("\t\t0x1:malloc; 0x2:max_bits; 0x4:mmap; 0x8:syscall; 0x10:io_uring; 0x20:inherit;\n");
954         printf("\t-h: help\n");
955 }
956
957 /* Check for file existence */
958 uint8_t file_Exists(const char *fileName)
959 {
960         struct stat buffer;
961
962         uint8_t ret = (stat(fileName, &buffer) == 0);
963
964         return ret;
965 }
966
967 /* Sysfs idxd files */
968 const char *dsa_configs[] = {
969         "echo 1 > /sys/bus/dsa/devices/dsa0/wq0.1/group_id",
970         "echo shared > /sys/bus/dsa/devices/dsa0/wq0.1/mode",
971         "echo 10 > /sys/bus/dsa/devices/dsa0/wq0.1/priority",
972         "echo 16 > /sys/bus/dsa/devices/dsa0/wq0.1/size",
973         "echo 15 > /sys/bus/dsa/devices/dsa0/wq0.1/threshold",
974         "echo user > /sys/bus/dsa/devices/dsa0/wq0.1/type",
975         "echo MyApp1 > /sys/bus/dsa/devices/dsa0/wq0.1/name",
976         "echo 1 > /sys/bus/dsa/devices/dsa0/engine0.1/group_id",
977         "echo dsa0 > /sys/bus/dsa/drivers/idxd/bind",
978         /* bind files and devices, generated a device file in /dev */
979         "echo wq0.1 > /sys/bus/dsa/drivers/user/bind",
980 };
981
982 /* DSA device file */
983 const char *dsaDeviceFile = "/dev/dsa/wq0.1";
984 /* file for io*/
985 const char *dsaPasidEnable = "/sys/bus/dsa/devices/dsa0/pasid_enabled";
986
987 /*
988  * DSA depends on kernel cmdline "intel_iommu=on,sm_on"
989  * return pasid_enabled (0: disable 1:enable)
990  */
991 int Check_DSA_Kernel_Setting(void)
992 {
993         char command[256] = "";
994         char buf[256] = "";
995         char *ptr;
996         int rv = -1;
997
998         snprintf(command, sizeof(command) - 1, "cat %s", dsaPasidEnable);
999
1000         FILE *cmd = popen(command, "r");
1001
1002         if (cmd) {
1003                 while (fgets(buf, sizeof(buf) - 1, cmd) != NULL);
1004
1005                 pclose(cmd);
1006                 rv = strtol(buf, &ptr, 16);
1007         }
1008
1009         return rv;
1010 }
1011
1012 /*
1013  * Config DSA's sysfs files as shared DSA's WQ.
1014  * Generated a device file /dev/dsa/wq0.1
1015  * Return:  0 OK; 1 Failed; 3 Skip(SVA disabled).
1016  */
1017 int Dsa_Init_Sysfs(void)
1018 {
1019         uint len = ARRAY_SIZE(dsa_configs);
1020         const char **p = dsa_configs;
1021
1022         if (file_Exists(dsaDeviceFile) == 1)
1023                 return 0;
1024
1025         /* check the idxd driver */
1026         if (file_Exists(dsaPasidEnable) != 1) {
1027                 printf("Please make sure idxd driver was loaded\n");
1028                 return 3;
1029         }
1030
1031         /* Check SVA feature */
1032         if (Check_DSA_Kernel_Setting() != 1) {
1033                 printf("Please enable SVA.(Add intel_iommu=on,sm_on in kernel cmdline)\n");
1034                 return 3;
1035         }
1036
1037         /* Check the idxd device file on /dev/dsa/ */
1038         for (int i = 0; i < len; i++) {
1039                 if (system(p[i]))
1040                         return 1;
1041         }
1042
1043         /* After config, /dev/dsa/wq0.1 should be generated */
1044         return (file_Exists(dsaDeviceFile) != 1);
1045 }
1046
1047 /*
1048  * Open DSA device file, triger API: iommu_sva_alloc_pasid
1049  */
1050 void *allocate_dsa_pasid(void)
1051 {
1052         int fd;
1053         void *wq;
1054
1055         fd = open(dsaDeviceFile, O_RDWR);
1056         if (fd < 0) {
1057                 perror("open");
1058                 return MAP_FAILED;
1059         }
1060
1061         wq = mmap(NULL, 0x1000, PROT_WRITE,
1062                            MAP_SHARED | MAP_POPULATE, fd, 0);
1063         if (wq == MAP_FAILED)
1064                 perror("mmap");
1065
1066         return wq;
1067 }
1068
1069 int set_force_svm(void)
1070 {
1071         int ret = 0;
1072
1073         ret = syscall(SYS_arch_prctl, ARCH_FORCE_TAGGED_SVA);
1074
1075         return ret;
1076 }
1077
1078 int handle_pasid(struct testcases *test)
1079 {
1080         uint tmp = test->cmd;
1081         uint runed = 0x0;
1082         int ret = 0;
1083         void *wq = NULL;
1084
1085         ret = Dsa_Init_Sysfs();
1086         if (ret != 0)
1087                 return ret;
1088
1089         for (int i = 0; i < 3; i++) {
1090                 int err = 0;
1091
1092                 if (tmp & 0x1) {
1093                         /* run set lam mode*/
1094                         if ((runed & 0x1) == 0) {
1095                                 err = set_lam(LAM_U57_BITS);
1096                                 runed = runed | 0x1;
1097                         } else
1098                                 err = 1;
1099                 } else if (tmp & 0x4) {
1100                         /* run force svm */
1101                         if ((runed & 0x4) == 0) {
1102                                 err = set_force_svm();
1103                                 runed = runed | 0x4;
1104                         } else
1105                                 err = 1;
1106                 } else if (tmp & 0x2) {
1107                         /* run allocate pasid */
1108                         if ((runed & 0x2) == 0) {
1109                                 runed = runed | 0x2;
1110                                 wq = allocate_dsa_pasid();
1111                                 if (wq == MAP_FAILED)
1112                                         err = 1;
1113                         } else
1114                                 err = 1;
1115                 }
1116
1117                 ret = ret + err;
1118                 if (ret > 0)
1119                         break;
1120
1121                 tmp = tmp >> 4;
1122         }
1123
1124         if (wq != MAP_FAILED && wq != NULL)
1125                 if (munmap(wq, 0x1000))
1126                         printf("munmap failed %d\n", errno);
1127
1128         if (runed != 0x7)
1129                 ret = 1;
1130
1131         return (ret != 0);
1132 }
1133
1134 /*
1135  * Pasid test depends on idxd and SVA, kernel should enable iommu and sm.
1136  * command line(intel_iommu=on,sm_on)
1137  */
1138 static struct testcases pasid_cases[] = {
1139         {
1140                 .expected = 1,
1141                 .cmd = PAS_CMD(LAM_CMD_BIT, PAS_CMD_BIT, SVA_CMD_BIT),
1142                 .test_func = handle_pasid,
1143                 .msg = "PASID: [Negative] Execute LAM, PASID, SVA in sequence\n",
1144         },
1145         {
1146                 .expected = 0,
1147                 .cmd = PAS_CMD(LAM_CMD_BIT, SVA_CMD_BIT, PAS_CMD_BIT),
1148                 .test_func = handle_pasid,
1149                 .msg = "PASID: Execute LAM, SVA, PASID in sequence\n",
1150         },
1151         {
1152                 .expected = 1,
1153                 .cmd = PAS_CMD(PAS_CMD_BIT, LAM_CMD_BIT, SVA_CMD_BIT),
1154                 .test_func = handle_pasid,
1155                 .msg = "PASID: [Negative] Execute PASID, LAM, SVA in sequence\n",
1156         },
1157         {
1158                 .expected = 0,
1159                 .cmd = PAS_CMD(PAS_CMD_BIT, SVA_CMD_BIT, LAM_CMD_BIT),
1160                 .test_func = handle_pasid,
1161                 .msg = "PASID: Execute PASID, SVA, LAM in sequence\n",
1162         },
1163         {
1164                 .expected = 0,
1165                 .cmd = PAS_CMD(SVA_CMD_BIT, LAM_CMD_BIT, PAS_CMD_BIT),
1166                 .test_func = handle_pasid,
1167                 .msg = "PASID: Execute SVA, LAM, PASID in sequence\n",
1168         },
1169         {
1170                 .expected = 0,
1171                 .cmd = PAS_CMD(SVA_CMD_BIT, PAS_CMD_BIT, LAM_CMD_BIT),
1172                 .test_func = handle_pasid,
1173                 .msg = "PASID: Execute SVA, PASID, LAM in sequence\n",
1174         },
1175 };
1176
1177 int main(int argc, char **argv)
1178 {
1179         int c = 0;
1180         unsigned int tests = TEST_MASK;
1181
1182         tests_cnt = 0;
1183
1184         if (!cpu_has_lam()) {
1185                 ksft_print_msg("Unsupported LAM feature!\n");
1186                 return -1;
1187         }
1188
1189         while ((c = getopt(argc, argv, "ht:")) != -1) {
1190                 switch (c) {
1191                 case 't':
1192                         tests = strtoul(optarg, NULL, 16);
1193                         if (tests && !(tests & TEST_MASK)) {
1194                                 ksft_print_msg("Invalid argument!\n");
1195                                 return -1;
1196                         }
1197                         break;
1198                 case 'h':
1199                         cmd_help();
1200                         return 0;
1201                 default:
1202                         ksft_print_msg("Invalid argument\n");
1203                         return -1;
1204                 }
1205         }
1206
1207         /*
1208          * When tests is 0, it is not a real test case;
1209          * the option used by test case(execve) to check the lam mode in
1210          * process generated by execve, the process read back lam mode and
1211          * check with lam mode in parent process.
1212          */
1213         if (!tests)
1214                 return (get_lam());
1215
1216         /* Run test cases */
1217         if (tests & FUNC_MALLOC)
1218                 run_test(malloc_cases, ARRAY_SIZE(malloc_cases));
1219
1220         if (tests & FUNC_BITS)
1221                 run_test(bits_cases, ARRAY_SIZE(bits_cases));
1222
1223         if (tests & FUNC_MMAP)
1224                 run_test(mmap_cases, ARRAY_SIZE(mmap_cases));
1225
1226         if (tests & FUNC_SYSCALL)
1227                 run_test(syscall_cases, ARRAY_SIZE(syscall_cases));
1228
1229         if (tests & FUNC_URING)
1230                 run_test(uring_cases, ARRAY_SIZE(uring_cases));
1231
1232         if (tests & FUNC_INHERITE)
1233                 run_test(inheritance_cases, ARRAY_SIZE(inheritance_cases));
1234
1235         if (tests & FUNC_PASID)
1236                 run_test(pasid_cases, ARRAY_SIZE(pasid_cases));
1237
1238         ksft_set_plan(tests_cnt);
1239
1240         return ksft_exit_pass();
1241 }