selftests/vm/pkeys: refill shadow register after implicit kernel write
[platform/kernel/linux-starfive.git] / tools / testing / selftests / vm / protection_keys.c
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Tests Memory Protection Keys (see Documentation/core-api/protection-keys.rst)
4  *
5  * There are examples in here of:
6  *  * how to set protection keys on memory
7  *  * how to set/clear bits in pkey registers (the rights register)
8  *  * how to handle SEGV_PKUERR signals and extract pkey-relevant
9  *    information from the siginfo
10  *
11  * Things to add:
12  *      make sure KSM and KSM COW breaking works
13  *      prefault pages in at malloc, or not
14  *      protect MPX bounds tables with protection keys?
15  *      make sure VMA splitting/merging is working correctly
16  *      OOMs can destroy mm->mmap (see exit_mmap()), so make sure it is immune to pkeys
17  *      look for pkey "leaks" where it is still set on a VMA but "freed" back to the kernel
18  *      do a plain mprotect() to a mprotect_pkey() area and make sure the pkey sticks
19  *
20  * Compile like this:
21  *      gcc      -o protection_keys    -O2 -g -std=gnu99 -pthread -Wall protection_keys.c -lrt -ldl -lm
22  *      gcc -m32 -o protection_keys_32 -O2 -g -std=gnu99 -pthread -Wall protection_keys.c -lrt -ldl -lm
23  */
24 #define _GNU_SOURCE
25 #define __SANE_USERSPACE_TYPES__
26 #include <errno.h>
27 #include <linux/futex.h>
28 #include <time.h>
29 #include <sys/time.h>
30 #include <sys/syscall.h>
31 #include <string.h>
32 #include <stdio.h>
33 #include <stdint.h>
34 #include <stdbool.h>
35 #include <signal.h>
36 #include <assert.h>
37 #include <stdlib.h>
38 #include <ucontext.h>
39 #include <sys/mman.h>
40 #include <sys/types.h>
41 #include <sys/wait.h>
42 #include <sys/stat.h>
43 #include <fcntl.h>
44 #include <unistd.h>
45 #include <sys/ptrace.h>
46 #include <setjmp.h>
47
48 #include "pkey-helpers.h"
49
50 int iteration_nr = 1;
51 int test_nr;
52
53 u64 shadow_pkey_reg;
54 int dprint_in_signal;
55 char dprint_in_signal_buffer[DPRINT_IN_SIGNAL_BUF_SIZE];
56
57 void cat_into_file(char *str, char *file)
58 {
59         int fd = open(file, O_RDWR);
60         int ret;
61
62         dprintf2("%s(): writing '%s' to '%s'\n", __func__, str, file);
63         /*
64          * these need to be raw because they are called under
65          * pkey_assert()
66          */
67         if (fd < 0) {
68                 fprintf(stderr, "error opening '%s'\n", str);
69                 perror("error: ");
70                 exit(__LINE__);
71         }
72
73         ret = write(fd, str, strlen(str));
74         if (ret != strlen(str)) {
75                 perror("write to file failed");
76                 fprintf(stderr, "filename: '%s' str: '%s'\n", file, str);
77                 exit(__LINE__);
78         }
79         close(fd);
80 }
81
82 #if CONTROL_TRACING > 0
83 static int warned_tracing;
84 int tracing_root_ok(void)
85 {
86         if (geteuid() != 0) {
87                 if (!warned_tracing)
88                         fprintf(stderr, "WARNING: not run as root, "
89                                         "can not do tracing control\n");
90                 warned_tracing = 1;
91                 return 0;
92         }
93         return 1;
94 }
95 #endif
96
97 void tracing_on(void)
98 {
99 #if CONTROL_TRACING > 0
100 #define TRACEDIR "/sys/kernel/debug/tracing"
101         char pidstr[32];
102
103         if (!tracing_root_ok())
104                 return;
105
106         sprintf(pidstr, "%d", getpid());
107         cat_into_file("0", TRACEDIR "/tracing_on");
108         cat_into_file("\n", TRACEDIR "/trace");
109         if (1) {
110                 cat_into_file("function_graph", TRACEDIR "/current_tracer");
111                 cat_into_file("1", TRACEDIR "/options/funcgraph-proc");
112         } else {
113                 cat_into_file("nop", TRACEDIR "/current_tracer");
114         }
115         cat_into_file(pidstr, TRACEDIR "/set_ftrace_pid");
116         cat_into_file("1", TRACEDIR "/tracing_on");
117         dprintf1("enabled tracing\n");
118 #endif
119 }
120
121 void tracing_off(void)
122 {
123 #if CONTROL_TRACING > 0
124         if (!tracing_root_ok())
125                 return;
126         cat_into_file("0", "/sys/kernel/debug/tracing/tracing_on");
127 #endif
128 }
129
130 void abort_hooks(void)
131 {
132         fprintf(stderr, "running %s()...\n", __func__);
133         tracing_off();
134 #ifdef SLEEP_ON_ABORT
135         sleep(SLEEP_ON_ABORT);
136 #endif
137 }
138
139 /*
140  * This attempts to have roughly a page of instructions followed by a few
141  * instructions that do a write, and another page of instructions.  That
142  * way, we are pretty sure that the write is in the second page of
143  * instructions and has at least a page of padding behind it.
144  *
145  * *That* lets us be sure to madvise() away the write instruction, which
146  * will then fault, which makes sure that the fault code handles
147  * execute-only memory properly.
148  */
149 #ifdef __powerpc64__
150 /* This way, both 4K and 64K alignment are maintained */
151 __attribute__((__aligned__(65536)))
152 #else
153 __attribute__((__aligned__(PAGE_SIZE)))
154 #endif
155 void lots_o_noops_around_write(int *write_to_me)
156 {
157         dprintf3("running %s()\n", __func__);
158         __page_o_noops();
159         /* Assume this happens in the second page of instructions: */
160         *write_to_me = __LINE__;
161         /* pad out by another page: */
162         __page_o_noops();
163         dprintf3("%s() done\n", __func__);
164 }
165
166 void dump_mem(void *dumpme, int len_bytes)
167 {
168         char *c = (void *)dumpme;
169         int i;
170
171         for (i = 0; i < len_bytes; i += sizeof(u64)) {
172                 u64 *ptr = (u64 *)(c + i);
173                 dprintf1("dump[%03d][@%p]: %016llx\n", i, ptr, *ptr);
174         }
175 }
176
177 static u32 hw_pkey_get(int pkey, unsigned long flags)
178 {
179         u64 pkey_reg = __read_pkey_reg();
180
181         dprintf1("%s(pkey=%d, flags=%lx) = %x / %d\n",
182                         __func__, pkey, flags, 0, 0);
183         dprintf2("%s() raw pkey_reg: %016llx\n", __func__, pkey_reg);
184
185         return (u32) get_pkey_bits(pkey_reg, pkey);
186 }
187
188 static int hw_pkey_set(int pkey, unsigned long rights, unsigned long flags)
189 {
190         u32 mask = (PKEY_DISABLE_ACCESS|PKEY_DISABLE_WRITE);
191         u64 old_pkey_reg = __read_pkey_reg();
192         u64 new_pkey_reg;
193
194         /* make sure that 'rights' only contains the bits we expect: */
195         assert(!(rights & ~mask));
196
197         /* modify bits accordingly in old pkey_reg and assign it */
198         new_pkey_reg = set_pkey_bits(old_pkey_reg, pkey, rights);
199
200         __write_pkey_reg(new_pkey_reg);
201
202         dprintf3("%s(pkey=%d, rights=%lx, flags=%lx) = %x"
203                 " pkey_reg now: %016llx old_pkey_reg: %016llx\n",
204                 __func__, pkey, rights, flags, 0, __read_pkey_reg(),
205                 old_pkey_reg);
206         return 0;
207 }
208
209 void pkey_disable_set(int pkey, int flags)
210 {
211         unsigned long syscall_flags = 0;
212         int ret;
213         int pkey_rights;
214         u64 orig_pkey_reg = read_pkey_reg();
215
216         dprintf1("START->%s(%d, 0x%x)\n", __func__,
217                 pkey, flags);
218         pkey_assert(flags & (PKEY_DISABLE_ACCESS | PKEY_DISABLE_WRITE));
219
220         pkey_rights = hw_pkey_get(pkey, syscall_flags);
221
222         dprintf1("%s(%d) hw_pkey_get(%d): %x\n", __func__,
223                         pkey, pkey, pkey_rights);
224
225         pkey_assert(pkey_rights >= 0);
226
227         pkey_rights |= flags;
228
229         ret = hw_pkey_set(pkey, pkey_rights, syscall_flags);
230         assert(!ret);
231         /* pkey_reg and flags have the same format */
232         shadow_pkey_reg = set_pkey_bits(shadow_pkey_reg, pkey, pkey_rights);
233         dprintf1("%s(%d) shadow: 0x%016llx\n",
234                 __func__, pkey, shadow_pkey_reg);
235
236         pkey_assert(ret >= 0);
237
238         pkey_rights = hw_pkey_get(pkey, syscall_flags);
239         dprintf1("%s(%d) hw_pkey_get(%d): %x\n", __func__,
240                         pkey, pkey, pkey_rights);
241
242         dprintf1("%s(%d) pkey_reg: 0x%016llx\n",
243                 __func__, pkey, read_pkey_reg());
244         if (flags)
245                 pkey_assert(read_pkey_reg() >= orig_pkey_reg);
246         dprintf1("END<---%s(%d, 0x%x)\n", __func__,
247                 pkey, flags);
248 }
249
250 void pkey_disable_clear(int pkey, int flags)
251 {
252         unsigned long syscall_flags = 0;
253         int ret;
254         int pkey_rights = hw_pkey_get(pkey, syscall_flags);
255         u64 orig_pkey_reg = read_pkey_reg();
256
257         pkey_assert(flags & (PKEY_DISABLE_ACCESS | PKEY_DISABLE_WRITE));
258
259         dprintf1("%s(%d) hw_pkey_get(%d): %x\n", __func__,
260                         pkey, pkey, pkey_rights);
261         pkey_assert(pkey_rights >= 0);
262
263         pkey_rights &= ~flags;
264
265         ret = hw_pkey_set(pkey, pkey_rights, 0);
266         shadow_pkey_reg = set_pkey_bits(shadow_pkey_reg, pkey, pkey_rights);
267         pkey_assert(ret >= 0);
268
269         pkey_rights = hw_pkey_get(pkey, syscall_flags);
270         dprintf1("%s(%d) hw_pkey_get(%d): %x\n", __func__,
271                         pkey, pkey, pkey_rights);
272
273         dprintf1("%s(%d) pkey_reg: 0x%016llx\n", __func__,
274                         pkey, read_pkey_reg());
275         if (flags)
276                 assert(read_pkey_reg() <= orig_pkey_reg);
277 }
278
279 void pkey_write_allow(int pkey)
280 {
281         pkey_disable_clear(pkey, PKEY_DISABLE_WRITE);
282 }
283 void pkey_write_deny(int pkey)
284 {
285         pkey_disable_set(pkey, PKEY_DISABLE_WRITE);
286 }
287 void pkey_access_allow(int pkey)
288 {
289         pkey_disable_clear(pkey, PKEY_DISABLE_ACCESS);
290 }
291 void pkey_access_deny(int pkey)
292 {
293         pkey_disable_set(pkey, PKEY_DISABLE_ACCESS);
294 }
295
296 /* Failed address bound checks: */
297 #ifndef SEGV_BNDERR
298 # define SEGV_BNDERR            3
299 #endif
300
301 #ifndef SEGV_PKUERR
302 # define SEGV_PKUERR            4
303 #endif
304
305 static char *si_code_str(int si_code)
306 {
307         if (si_code == SEGV_MAPERR)
308                 return "SEGV_MAPERR";
309         if (si_code == SEGV_ACCERR)
310                 return "SEGV_ACCERR";
311         if (si_code == SEGV_BNDERR)
312                 return "SEGV_BNDERR";
313         if (si_code == SEGV_PKUERR)
314                 return "SEGV_PKUERR";
315         return "UNKNOWN";
316 }
317
318 int pkey_faults;
319 int last_si_pkey = -1;
320 void signal_handler(int signum, siginfo_t *si, void *vucontext)
321 {
322         ucontext_t *uctxt = vucontext;
323         int trapno;
324         unsigned long ip;
325         char *fpregs;
326 #if defined(__i386__) || defined(__x86_64__) /* arch */
327         u32 *pkey_reg_ptr;
328         int pkey_reg_offset;
329 #endif /* arch */
330         u64 siginfo_pkey;
331         u32 *si_pkey_ptr;
332
333         dprint_in_signal = 1;
334         dprintf1(">>>>===============SIGSEGV============================\n");
335         dprintf1("%s()::%d, pkey_reg: 0x%016llx shadow: %016llx\n",
336                         __func__, __LINE__,
337                         __read_pkey_reg(), shadow_pkey_reg);
338
339         trapno = uctxt->uc_mcontext.gregs[REG_TRAPNO];
340         ip = uctxt->uc_mcontext.gregs[REG_IP_IDX];
341         fpregs = (char *) uctxt->uc_mcontext.fpregs;
342
343         dprintf2("%s() trapno: %d ip: 0x%016lx info->si_code: %s/%d\n",
344                         __func__, trapno, ip, si_code_str(si->si_code),
345                         si->si_code);
346
347 #if defined(__i386__) || defined(__x86_64__) /* arch */
348 #ifdef __i386__
349         /*
350          * 32-bit has some extra padding so that userspace can tell whether
351          * the XSTATE header is present in addition to the "legacy" FPU
352          * state.  We just assume that it is here.
353          */
354         fpregs += 0x70;
355 #endif /* i386 */
356         pkey_reg_offset = pkey_reg_xstate_offset();
357         pkey_reg_ptr = (void *)(&fpregs[pkey_reg_offset]);
358
359         /*
360          * If we got a PKEY fault, we *HAVE* to have at least one bit set in
361          * here.
362          */
363         dprintf1("pkey_reg_xstate_offset: %d\n", pkey_reg_xstate_offset());
364         if (DEBUG_LEVEL > 4)
365                 dump_mem(pkey_reg_ptr - 128, 256);
366         pkey_assert(*pkey_reg_ptr);
367 #endif /* arch */
368
369         dprintf1("siginfo: %p\n", si);
370         dprintf1(" fpregs: %p\n", fpregs);
371
372         if ((si->si_code == SEGV_MAPERR) ||
373             (si->si_code == SEGV_ACCERR) ||
374             (si->si_code == SEGV_BNDERR)) {
375                 printf("non-PK si_code, exiting...\n");
376                 exit(4);
377         }
378
379         si_pkey_ptr = siginfo_get_pkey_ptr(si);
380         dprintf1("si_pkey_ptr: %p\n", si_pkey_ptr);
381         dump_mem((u8 *)si_pkey_ptr - 8, 24);
382         siginfo_pkey = *si_pkey_ptr;
383         pkey_assert(siginfo_pkey < NR_PKEYS);
384         last_si_pkey = siginfo_pkey;
385
386         /*
387          * need __read_pkey_reg() version so we do not do shadow_pkey_reg
388          * checking
389          */
390         dprintf1("signal pkey_reg from  pkey_reg: %016llx\n",
391                         __read_pkey_reg());
392         dprintf1("pkey from siginfo: %016llx\n", siginfo_pkey);
393 #if defined(__i386__) || defined(__x86_64__) /* arch */
394         dprintf1("signal pkey_reg from xsave: %08x\n", *pkey_reg_ptr);
395         *(u64 *)pkey_reg_ptr = 0x00000000;
396         dprintf1("WARNING: set PKEY_REG=0 to allow faulting instruction to continue\n");
397 #elif defined(__powerpc64__) /* arch */
398         /* restore access and let the faulting instruction continue */
399         pkey_access_allow(siginfo_pkey);
400 #endif /* arch */
401         pkey_faults++;
402         dprintf1("<<<<==================================================\n");
403         dprint_in_signal = 0;
404 }
405
406 int wait_all_children(void)
407 {
408         int status;
409         return waitpid(-1, &status, 0);
410 }
411
412 void sig_chld(int x)
413 {
414         dprint_in_signal = 1;
415         dprintf2("[%d] SIGCHLD: %d\n", getpid(), x);
416         dprint_in_signal = 0;
417 }
418
419 void setup_sigsegv_handler(void)
420 {
421         int r, rs;
422         struct sigaction newact;
423         struct sigaction oldact;
424
425         /* #PF is mapped to sigsegv */
426         int signum  = SIGSEGV;
427
428         newact.sa_handler = 0;
429         newact.sa_sigaction = signal_handler;
430
431         /*sigset_t - signals to block while in the handler */
432         /* get the old signal mask. */
433         rs = sigprocmask(SIG_SETMASK, 0, &newact.sa_mask);
434         pkey_assert(rs == 0);
435
436         /* call sa_sigaction, not sa_handler*/
437         newact.sa_flags = SA_SIGINFO;
438
439         newact.sa_restorer = 0;  /* void(*)(), obsolete */
440         r = sigaction(signum, &newact, &oldact);
441         r = sigaction(SIGALRM, &newact, &oldact);
442         pkey_assert(r == 0);
443 }
444
445 void setup_handlers(void)
446 {
447         signal(SIGCHLD, &sig_chld);
448         setup_sigsegv_handler();
449 }
450
451 pid_t fork_lazy_child(void)
452 {
453         pid_t forkret;
454
455         forkret = fork();
456         pkey_assert(forkret >= 0);
457         dprintf3("[%d] fork() ret: %d\n", getpid(), forkret);
458
459         if (!forkret) {
460                 /* in the child */
461                 while (1) {
462                         dprintf1("child sleeping...\n");
463                         sleep(30);
464                 }
465         }
466         return forkret;
467 }
468
469 int sys_mprotect_pkey(void *ptr, size_t size, unsigned long orig_prot,
470                 unsigned long pkey)
471 {
472         int sret;
473
474         dprintf2("%s(0x%p, %zx, prot=%lx, pkey=%lx)\n", __func__,
475                         ptr, size, orig_prot, pkey);
476
477         errno = 0;
478         sret = syscall(SYS_mprotect_key, ptr, size, orig_prot, pkey);
479         if (errno) {
480                 dprintf2("SYS_mprotect_key sret: %d\n", sret);
481                 dprintf2("SYS_mprotect_key prot: 0x%lx\n", orig_prot);
482                 dprintf2("SYS_mprotect_key failed, errno: %d\n", errno);
483                 if (DEBUG_LEVEL >= 2)
484                         perror("SYS_mprotect_pkey");
485         }
486         return sret;
487 }
488
489 int sys_pkey_alloc(unsigned long flags, unsigned long init_val)
490 {
491         int ret = syscall(SYS_pkey_alloc, flags, init_val);
492         dprintf1("%s(flags=%lx, init_val=%lx) syscall ret: %d errno: %d\n",
493                         __func__, flags, init_val, ret, errno);
494         return ret;
495 }
496
497 int alloc_pkey(void)
498 {
499         int ret;
500         unsigned long init_val = 0x0;
501
502         dprintf1("%s()::%d, pkey_reg: 0x%016llx shadow: %016llx\n",
503                         __func__, __LINE__, __read_pkey_reg(), shadow_pkey_reg);
504         ret = sys_pkey_alloc(0, init_val);
505         /*
506          * pkey_alloc() sets PKEY register, so we need to reflect it in
507          * shadow_pkey_reg:
508          */
509         dprintf4("%s()::%d, ret: %d pkey_reg: 0x%016llx"
510                         " shadow: 0x%016llx\n",
511                         __func__, __LINE__, ret, __read_pkey_reg(),
512                         shadow_pkey_reg);
513         if (ret > 0) {
514                 /* clear both the bits: */
515                 shadow_pkey_reg = set_pkey_bits(shadow_pkey_reg, ret,
516                                                 ~PKEY_MASK);
517                 dprintf4("%s()::%d, ret: %d pkey_reg: 0x%016llx"
518                                 " shadow: 0x%016llx\n",
519                                 __func__,
520                                 __LINE__, ret, __read_pkey_reg(),
521                                 shadow_pkey_reg);
522                 /*
523                  * move the new state in from init_val
524                  * (remember, we cheated and init_val == pkey_reg format)
525                  */
526                 shadow_pkey_reg = set_pkey_bits(shadow_pkey_reg, ret,
527                                                 init_val);
528         }
529         dprintf4("%s()::%d, ret: %d pkey_reg: 0x%016llx"
530                         " shadow: 0x%016llx\n",
531                         __func__, __LINE__, ret, __read_pkey_reg(),
532                         shadow_pkey_reg);
533         dprintf1("%s()::%d errno: %d\n", __func__, __LINE__, errno);
534         /* for shadow checking: */
535         read_pkey_reg();
536         dprintf4("%s()::%d, ret: %d pkey_reg: 0x%016llx"
537                  " shadow: 0x%016llx\n",
538                 __func__, __LINE__, ret, __read_pkey_reg(),
539                 shadow_pkey_reg);
540         return ret;
541 }
542
543 int sys_pkey_free(unsigned long pkey)
544 {
545         int ret = syscall(SYS_pkey_free, pkey);
546         dprintf1("%s(pkey=%ld) syscall ret: %d\n", __func__, pkey, ret);
547         return ret;
548 }
549
550 /*
551  * I had a bug where pkey bits could be set by mprotect() but
552  * not cleared.  This ensures we get lots of random bit sets
553  * and clears on the vma and pte pkey bits.
554  */
555 int alloc_random_pkey(void)
556 {
557         int max_nr_pkey_allocs;
558         int ret;
559         int i;
560         int alloced_pkeys[NR_PKEYS];
561         int nr_alloced = 0;
562         int random_index;
563         memset(alloced_pkeys, 0, sizeof(alloced_pkeys));
564
565         /* allocate every possible key and make a note of which ones we got */
566         max_nr_pkey_allocs = NR_PKEYS;
567         for (i = 0; i < max_nr_pkey_allocs; i++) {
568                 int new_pkey = alloc_pkey();
569                 if (new_pkey < 0)
570                         break;
571                 alloced_pkeys[nr_alloced++] = new_pkey;
572         }
573
574         pkey_assert(nr_alloced > 0);
575         /* select a random one out of the allocated ones */
576         random_index = rand() % nr_alloced;
577         ret = alloced_pkeys[random_index];
578         /* now zero it out so we don't free it next */
579         alloced_pkeys[random_index] = 0;
580
581         /* go through the allocated ones that we did not want and free them */
582         for (i = 0; i < nr_alloced; i++) {
583                 int free_ret;
584                 if (!alloced_pkeys[i])
585                         continue;
586                 free_ret = sys_pkey_free(alloced_pkeys[i]);
587                 pkey_assert(!free_ret);
588         }
589         dprintf1("%s()::%d, ret: %d pkey_reg: 0x%016llx"
590                          " shadow: 0x%016llx\n", __func__,
591                         __LINE__, ret, __read_pkey_reg(), shadow_pkey_reg);
592         return ret;
593 }
594
595 int mprotect_pkey(void *ptr, size_t size, unsigned long orig_prot,
596                 unsigned long pkey)
597 {
598         int nr_iterations = random() % 100;
599         int ret;
600
601         while (0) {
602                 int rpkey = alloc_random_pkey();
603                 ret = sys_mprotect_pkey(ptr, size, orig_prot, pkey);
604                 dprintf1("sys_mprotect_pkey(%p, %zx, prot=0x%lx, pkey=%ld) ret: %d\n",
605                                 ptr, size, orig_prot, pkey, ret);
606                 if (nr_iterations-- < 0)
607                         break;
608
609                 dprintf1("%s()::%d, ret: %d pkey_reg: 0x%016llx"
610                         " shadow: 0x%016llx\n",
611                         __func__, __LINE__, ret, __read_pkey_reg(),
612                         shadow_pkey_reg);
613                 sys_pkey_free(rpkey);
614                 dprintf1("%s()::%d, ret: %d pkey_reg: 0x%016llx"
615                         " shadow: 0x%016llx\n",
616                         __func__, __LINE__, ret, __read_pkey_reg(),
617                         shadow_pkey_reg);
618         }
619         pkey_assert(pkey < NR_PKEYS);
620
621         ret = sys_mprotect_pkey(ptr, size, orig_prot, pkey);
622         dprintf1("mprotect_pkey(%p, %zx, prot=0x%lx, pkey=%ld) ret: %d\n",
623                         ptr, size, orig_prot, pkey, ret);
624         pkey_assert(!ret);
625         dprintf1("%s()::%d, ret: %d pkey_reg: 0x%016llx"
626                         " shadow: 0x%016llx\n", __func__,
627                         __LINE__, ret, __read_pkey_reg(), shadow_pkey_reg);
628         return ret;
629 }
630
631 struct pkey_malloc_record {
632         void *ptr;
633         long size;
634         int prot;
635 };
636 struct pkey_malloc_record *pkey_malloc_records;
637 struct pkey_malloc_record *pkey_last_malloc_record;
638 long nr_pkey_malloc_records;
639 void record_pkey_malloc(void *ptr, long size, int prot)
640 {
641         long i;
642         struct pkey_malloc_record *rec = NULL;
643
644         for (i = 0; i < nr_pkey_malloc_records; i++) {
645                 rec = &pkey_malloc_records[i];
646                 /* find a free record */
647                 if (rec)
648                         break;
649         }
650         if (!rec) {
651                 /* every record is full */
652                 size_t old_nr_records = nr_pkey_malloc_records;
653                 size_t new_nr_records = (nr_pkey_malloc_records * 2 + 1);
654                 size_t new_size = new_nr_records * sizeof(struct pkey_malloc_record);
655                 dprintf2("new_nr_records: %zd\n", new_nr_records);
656                 dprintf2("new_size: %zd\n", new_size);
657                 pkey_malloc_records = realloc(pkey_malloc_records, new_size);
658                 pkey_assert(pkey_malloc_records != NULL);
659                 rec = &pkey_malloc_records[nr_pkey_malloc_records];
660                 /*
661                  * realloc() does not initialize memory, so zero it from
662                  * the first new record all the way to the end.
663                  */
664                 for (i = 0; i < new_nr_records - old_nr_records; i++)
665                         memset(rec + i, 0, sizeof(*rec));
666         }
667         dprintf3("filling malloc record[%d/%p]: {%p, %ld}\n",
668                 (int)(rec - pkey_malloc_records), rec, ptr, size);
669         rec->ptr = ptr;
670         rec->size = size;
671         rec->prot = prot;
672         pkey_last_malloc_record = rec;
673         nr_pkey_malloc_records++;
674 }
675
676 void free_pkey_malloc(void *ptr)
677 {
678         long i;
679         int ret;
680         dprintf3("%s(%p)\n", __func__, ptr);
681         for (i = 0; i < nr_pkey_malloc_records; i++) {
682                 struct pkey_malloc_record *rec = &pkey_malloc_records[i];
683                 dprintf4("looking for ptr %p at record[%ld/%p]: {%p, %ld}\n",
684                                 ptr, i, rec, rec->ptr, rec->size);
685                 if ((ptr <  rec->ptr) ||
686                     (ptr >= rec->ptr + rec->size))
687                         continue;
688
689                 dprintf3("found ptr %p at record[%ld/%p]: {%p, %ld}\n",
690                                 ptr, i, rec, rec->ptr, rec->size);
691                 nr_pkey_malloc_records--;
692                 ret = munmap(rec->ptr, rec->size);
693                 dprintf3("munmap ret: %d\n", ret);
694                 pkey_assert(!ret);
695                 dprintf3("clearing rec->ptr, rec: %p\n", rec);
696                 rec->ptr = NULL;
697                 dprintf3("done clearing rec->ptr, rec: %p\n", rec);
698                 return;
699         }
700         pkey_assert(false);
701 }
702
703
704 void *malloc_pkey_with_mprotect(long size, int prot, u16 pkey)
705 {
706         void *ptr;
707         int ret;
708
709         read_pkey_reg();
710         dprintf1("doing %s(size=%ld, prot=0x%x, pkey=%d)\n", __func__,
711                         size, prot, pkey);
712         pkey_assert(pkey < NR_PKEYS);
713         ptr = mmap(NULL, size, prot, MAP_ANONYMOUS|MAP_PRIVATE, -1, 0);
714         pkey_assert(ptr != (void *)-1);
715         ret = mprotect_pkey((void *)ptr, PAGE_SIZE, prot, pkey);
716         pkey_assert(!ret);
717         record_pkey_malloc(ptr, size, prot);
718         read_pkey_reg();
719
720         dprintf1("%s() for pkey %d @ %p\n", __func__, pkey, ptr);
721         return ptr;
722 }
723
724 void *malloc_pkey_anon_huge(long size, int prot, u16 pkey)
725 {
726         int ret;
727         void *ptr;
728
729         dprintf1("doing %s(size=%ld, prot=0x%x, pkey=%d)\n", __func__,
730                         size, prot, pkey);
731         /*
732          * Guarantee we can fit at least one huge page in the resulting
733          * allocation by allocating space for 2:
734          */
735         size = ALIGN_UP(size, HPAGE_SIZE * 2);
736         ptr = mmap(NULL, size, PROT_NONE, MAP_ANONYMOUS|MAP_PRIVATE, -1, 0);
737         pkey_assert(ptr != (void *)-1);
738         record_pkey_malloc(ptr, size, prot);
739         mprotect_pkey(ptr, size, prot, pkey);
740
741         dprintf1("unaligned ptr: %p\n", ptr);
742         ptr = ALIGN_PTR_UP(ptr, HPAGE_SIZE);
743         dprintf1("  aligned ptr: %p\n", ptr);
744         ret = madvise(ptr, HPAGE_SIZE, MADV_HUGEPAGE);
745         dprintf1("MADV_HUGEPAGE ret: %d\n", ret);
746         ret = madvise(ptr, HPAGE_SIZE, MADV_WILLNEED);
747         dprintf1("MADV_WILLNEED ret: %d\n", ret);
748         memset(ptr, 0, HPAGE_SIZE);
749
750         dprintf1("mmap()'d thp for pkey %d @ %p\n", pkey, ptr);
751         return ptr;
752 }
753
754 int hugetlb_setup_ok;
755 #define SYSFS_FMT_NR_HUGE_PAGES "/sys/kernel/mm/hugepages/hugepages-%ldkB/nr_hugepages"
756 #define GET_NR_HUGE_PAGES 10
757 void setup_hugetlbfs(void)
758 {
759         int err;
760         int fd;
761         char buf[256];
762         long hpagesz_kb;
763         long hpagesz_mb;
764
765         if (geteuid() != 0) {
766                 fprintf(stderr, "WARNING: not run as root, can not do hugetlb test\n");
767                 return;
768         }
769
770         cat_into_file(__stringify(GET_NR_HUGE_PAGES), "/proc/sys/vm/nr_hugepages");
771
772         /*
773          * Now go make sure that we got the pages and that they
774          * are PMD-level pages. Someone might have made PUD-level
775          * pages the default.
776          */
777         hpagesz_kb = HPAGE_SIZE / 1024;
778         hpagesz_mb = hpagesz_kb / 1024;
779         sprintf(buf, SYSFS_FMT_NR_HUGE_PAGES, hpagesz_kb);
780         fd = open(buf, O_RDONLY);
781         if (fd < 0) {
782                 fprintf(stderr, "opening sysfs %ldM hugetlb config: %s\n",
783                         hpagesz_mb, strerror(errno));
784                 return;
785         }
786
787         /* -1 to guarantee leaving the trailing \0 */
788         err = read(fd, buf, sizeof(buf)-1);
789         close(fd);
790         if (err <= 0) {
791                 fprintf(stderr, "reading sysfs %ldM hugetlb config: %s\n",
792                         hpagesz_mb, strerror(errno));
793                 return;
794         }
795
796         if (atoi(buf) != GET_NR_HUGE_PAGES) {
797                 fprintf(stderr, "could not confirm %ldM pages, got: '%s' expected %d\n",
798                         hpagesz_mb, buf, GET_NR_HUGE_PAGES);
799                 return;
800         }
801
802         hugetlb_setup_ok = 1;
803 }
804
805 void *malloc_pkey_hugetlb(long size, int prot, u16 pkey)
806 {
807         void *ptr;
808         int flags = MAP_ANONYMOUS|MAP_PRIVATE|MAP_HUGETLB;
809
810         if (!hugetlb_setup_ok)
811                 return PTR_ERR_ENOTSUP;
812
813         dprintf1("doing %s(%ld, %x, %x)\n", __func__, size, prot, pkey);
814         size = ALIGN_UP(size, HPAGE_SIZE * 2);
815         pkey_assert(pkey < NR_PKEYS);
816         ptr = mmap(NULL, size, PROT_NONE, flags, -1, 0);
817         pkey_assert(ptr != (void *)-1);
818         mprotect_pkey(ptr, size, prot, pkey);
819
820         record_pkey_malloc(ptr, size, prot);
821
822         dprintf1("mmap()'d hugetlbfs for pkey %d @ %p\n", pkey, ptr);
823         return ptr;
824 }
825
826 void *malloc_pkey_mmap_dax(long size, int prot, u16 pkey)
827 {
828         void *ptr;
829         int fd;
830
831         dprintf1("doing %s(size=%ld, prot=0x%x, pkey=%d)\n", __func__,
832                         size, prot, pkey);
833         pkey_assert(pkey < NR_PKEYS);
834         fd = open("/dax/foo", O_RDWR);
835         pkey_assert(fd >= 0);
836
837         ptr = mmap(0, size, prot, MAP_SHARED, fd, 0);
838         pkey_assert(ptr != (void *)-1);
839
840         mprotect_pkey(ptr, size, prot, pkey);
841
842         record_pkey_malloc(ptr, size, prot);
843
844         dprintf1("mmap()'d for pkey %d @ %p\n", pkey, ptr);
845         close(fd);
846         return ptr;
847 }
848
849 void *(*pkey_malloc[])(long size, int prot, u16 pkey) = {
850
851         malloc_pkey_with_mprotect,
852         malloc_pkey_with_mprotect_subpage,
853         malloc_pkey_anon_huge,
854         malloc_pkey_hugetlb
855 /* can not do direct with the pkey_mprotect() API:
856         malloc_pkey_mmap_direct,
857         malloc_pkey_mmap_dax,
858 */
859 };
860
861 void *malloc_pkey(long size, int prot, u16 pkey)
862 {
863         void *ret;
864         static int malloc_type;
865         int nr_malloc_types = ARRAY_SIZE(pkey_malloc);
866
867         pkey_assert(pkey < NR_PKEYS);
868
869         while (1) {
870                 pkey_assert(malloc_type < nr_malloc_types);
871
872                 ret = pkey_malloc[malloc_type](size, prot, pkey);
873                 pkey_assert(ret != (void *)-1);
874
875                 malloc_type++;
876                 if (malloc_type >= nr_malloc_types)
877                         malloc_type = (random()%nr_malloc_types);
878
879                 /* try again if the malloc_type we tried is unsupported */
880                 if (ret == PTR_ERR_ENOTSUP)
881                         continue;
882
883                 break;
884         }
885
886         dprintf3("%s(%ld, prot=%x, pkey=%x) returning: %p\n", __func__,
887                         size, prot, pkey, ret);
888         return ret;
889 }
890
891 int last_pkey_faults;
892 #define UNKNOWN_PKEY -2
893 void expected_pkey_fault(int pkey)
894 {
895         dprintf2("%s(): last_pkey_faults: %d pkey_faults: %d\n",
896                         __func__, last_pkey_faults, pkey_faults);
897         dprintf2("%s(%d): last_si_pkey: %d\n", __func__, pkey, last_si_pkey);
898         pkey_assert(last_pkey_faults + 1 == pkey_faults);
899
900        /*
901         * For exec-only memory, we do not know the pkey in
902         * advance, so skip this check.
903         */
904         if (pkey != UNKNOWN_PKEY)
905                 pkey_assert(last_si_pkey == pkey);
906
907 #if defined(__i386__) || defined(__x86_64__) /* arch */
908         /*
909          * The signal handler shold have cleared out PKEY register to let the
910          * test program continue.  We now have to restore it.
911          */
912         if (__read_pkey_reg() != 0)
913 #else /* arch */
914         if (__read_pkey_reg() != shadow_pkey_reg)
915 #endif /* arch */
916                 pkey_assert(0);
917
918         __write_pkey_reg(shadow_pkey_reg);
919         dprintf1("%s() set pkey_reg=%016llx to restore state after signal "
920                        "nuked it\n", __func__, shadow_pkey_reg);
921         last_pkey_faults = pkey_faults;
922         last_si_pkey = -1;
923 }
924
925 #define do_not_expect_pkey_fault(msg)   do {                    \
926         if (last_pkey_faults != pkey_faults)                    \
927                 dprintf0("unexpected PKey fault: %s\n", msg);   \
928         pkey_assert(last_pkey_faults == pkey_faults);           \
929 } while (0)
930
931 int test_fds[10] = { -1 };
932 int nr_test_fds;
933 void __save_test_fd(int fd)
934 {
935         pkey_assert(fd >= 0);
936         pkey_assert(nr_test_fds < ARRAY_SIZE(test_fds));
937         test_fds[nr_test_fds] = fd;
938         nr_test_fds++;
939 }
940
941 int get_test_read_fd(void)
942 {
943         int test_fd = open("/etc/passwd", O_RDONLY);
944         __save_test_fd(test_fd);
945         return test_fd;
946 }
947
948 void close_test_fds(void)
949 {
950         int i;
951
952         for (i = 0; i < nr_test_fds; i++) {
953                 if (test_fds[i] < 0)
954                         continue;
955                 close(test_fds[i]);
956                 test_fds[i] = -1;
957         }
958         nr_test_fds = 0;
959 }
960
961 #define barrier() __asm__ __volatile__("": : :"memory")
962 __attribute__((noinline)) int read_ptr(int *ptr)
963 {
964         /*
965          * Keep GCC from optimizing this away somehow
966          */
967         barrier();
968         return *ptr;
969 }
970
971 void test_pkey_alloc_free_attach_pkey0(int *ptr, u16 pkey)
972 {
973         int i, err;
974         int max_nr_pkey_allocs;
975         int alloced_pkeys[NR_PKEYS];
976         int nr_alloced = 0;
977         long size;
978
979         pkey_assert(pkey_last_malloc_record);
980         size = pkey_last_malloc_record->size;
981         /*
982          * This is a bit of a hack.  But mprotect() requires
983          * huge-page-aligned sizes when operating on hugetlbfs.
984          * So, make sure that we use something that's a multiple
985          * of a huge page when we can.
986          */
987         if (size >= HPAGE_SIZE)
988                 size = HPAGE_SIZE;
989
990         /* allocate every possible key and make sure key-0 never got allocated */
991         max_nr_pkey_allocs = NR_PKEYS;
992         for (i = 0; i < max_nr_pkey_allocs; i++) {
993                 int new_pkey = alloc_pkey();
994                 pkey_assert(new_pkey != 0);
995
996                 if (new_pkey < 0)
997                         break;
998                 alloced_pkeys[nr_alloced++] = new_pkey;
999         }
1000         /* free all the allocated keys */
1001         for (i = 0; i < nr_alloced; i++) {
1002                 int free_ret;
1003
1004                 if (!alloced_pkeys[i])
1005                         continue;
1006                 free_ret = sys_pkey_free(alloced_pkeys[i]);
1007                 pkey_assert(!free_ret);
1008         }
1009
1010         /* attach key-0 in various modes */
1011         err = sys_mprotect_pkey(ptr, size, PROT_READ, 0);
1012         pkey_assert(!err);
1013         err = sys_mprotect_pkey(ptr, size, PROT_WRITE, 0);
1014         pkey_assert(!err);
1015         err = sys_mprotect_pkey(ptr, size, PROT_EXEC, 0);
1016         pkey_assert(!err);
1017         err = sys_mprotect_pkey(ptr, size, PROT_READ|PROT_WRITE, 0);
1018         pkey_assert(!err);
1019         err = sys_mprotect_pkey(ptr, size, PROT_READ|PROT_WRITE|PROT_EXEC, 0);
1020         pkey_assert(!err);
1021 }
1022
1023 void test_read_of_write_disabled_region(int *ptr, u16 pkey)
1024 {
1025         int ptr_contents;
1026
1027         dprintf1("disabling write access to PKEY[1], doing read\n");
1028         pkey_write_deny(pkey);
1029         ptr_contents = read_ptr(ptr);
1030         dprintf1("*ptr: %d\n", ptr_contents);
1031         dprintf1("\n");
1032 }
1033 void test_read_of_access_disabled_region(int *ptr, u16 pkey)
1034 {
1035         int ptr_contents;
1036
1037         dprintf1("disabling access to PKEY[%02d], doing read @ %p\n", pkey, ptr);
1038         read_pkey_reg();
1039         pkey_access_deny(pkey);
1040         ptr_contents = read_ptr(ptr);
1041         dprintf1("*ptr: %d\n", ptr_contents);
1042         expected_pkey_fault(pkey);
1043 }
1044
1045 void test_read_of_access_disabled_region_with_page_already_mapped(int *ptr,
1046                 u16 pkey)
1047 {
1048         int ptr_contents;
1049
1050         dprintf1("disabling access to PKEY[%02d], doing read @ %p\n",
1051                                 pkey, ptr);
1052         ptr_contents = read_ptr(ptr);
1053         dprintf1("reading ptr before disabling the read : %d\n",
1054                         ptr_contents);
1055         read_pkey_reg();
1056         pkey_access_deny(pkey);
1057         ptr_contents = read_ptr(ptr);
1058         dprintf1("*ptr: %d\n", ptr_contents);
1059         expected_pkey_fault(pkey);
1060 }
1061
1062 void test_write_of_write_disabled_region_with_page_already_mapped(int *ptr,
1063                 u16 pkey)
1064 {
1065         *ptr = __LINE__;
1066         dprintf1("disabling write access; after accessing the page, "
1067                 "to PKEY[%02d], doing write\n", pkey);
1068         pkey_write_deny(pkey);
1069         *ptr = __LINE__;
1070         expected_pkey_fault(pkey);
1071 }
1072
1073 void test_write_of_write_disabled_region(int *ptr, u16 pkey)
1074 {
1075         dprintf1("disabling write access to PKEY[%02d], doing write\n", pkey);
1076         pkey_write_deny(pkey);
1077         *ptr = __LINE__;
1078         expected_pkey_fault(pkey);
1079 }
1080 void test_write_of_access_disabled_region(int *ptr, u16 pkey)
1081 {
1082         dprintf1("disabling access to PKEY[%02d], doing write\n", pkey);
1083         pkey_access_deny(pkey);
1084         *ptr = __LINE__;
1085         expected_pkey_fault(pkey);
1086 }
1087
1088 void test_write_of_access_disabled_region_with_page_already_mapped(int *ptr,
1089                         u16 pkey)
1090 {
1091         *ptr = __LINE__;
1092         dprintf1("disabling access; after accessing the page, "
1093                 " to PKEY[%02d], doing write\n", pkey);
1094         pkey_access_deny(pkey);
1095         *ptr = __LINE__;
1096         expected_pkey_fault(pkey);
1097 }
1098
1099 void test_kernel_write_of_access_disabled_region(int *ptr, u16 pkey)
1100 {
1101         int ret;
1102         int test_fd = get_test_read_fd();
1103
1104         dprintf1("disabling access to PKEY[%02d], "
1105                  "having kernel read() to buffer\n", pkey);
1106         pkey_access_deny(pkey);
1107         ret = read(test_fd, ptr, 1);
1108         dprintf1("read ret: %d\n", ret);
1109         pkey_assert(ret);
1110 }
1111 void test_kernel_write_of_write_disabled_region(int *ptr, u16 pkey)
1112 {
1113         int ret;
1114         int test_fd = get_test_read_fd();
1115
1116         pkey_write_deny(pkey);
1117         ret = read(test_fd, ptr, 100);
1118         dprintf1("read ret: %d\n", ret);
1119         if (ret < 0 && (DEBUG_LEVEL > 0))
1120                 perror("verbose read result (OK for this to be bad)");
1121         pkey_assert(ret);
1122 }
1123
1124 void test_kernel_gup_of_access_disabled_region(int *ptr, u16 pkey)
1125 {
1126         int pipe_ret, vmsplice_ret;
1127         struct iovec iov;
1128         int pipe_fds[2];
1129
1130         pipe_ret = pipe(pipe_fds);
1131
1132         pkey_assert(pipe_ret == 0);
1133         dprintf1("disabling access to PKEY[%02d], "
1134                  "having kernel vmsplice from buffer\n", pkey);
1135         pkey_access_deny(pkey);
1136         iov.iov_base = ptr;
1137         iov.iov_len = PAGE_SIZE;
1138         vmsplice_ret = vmsplice(pipe_fds[1], &iov, 1, SPLICE_F_GIFT);
1139         dprintf1("vmsplice() ret: %d\n", vmsplice_ret);
1140         pkey_assert(vmsplice_ret == -1);
1141
1142         close(pipe_fds[0]);
1143         close(pipe_fds[1]);
1144 }
1145
1146 void test_kernel_gup_write_to_write_disabled_region(int *ptr, u16 pkey)
1147 {
1148         int ignored = 0xdada;
1149         int futex_ret;
1150         int some_int = __LINE__;
1151
1152         dprintf1("disabling write to PKEY[%02d], "
1153                  "doing futex gunk in buffer\n", pkey);
1154         *ptr = some_int;
1155         pkey_write_deny(pkey);
1156         futex_ret = syscall(SYS_futex, ptr, FUTEX_WAIT, some_int-1, NULL,
1157                         &ignored, ignored);
1158         if (DEBUG_LEVEL > 0)
1159                 perror("futex");
1160         dprintf1("futex() ret: %d\n", futex_ret);
1161 }
1162
1163 /* Assumes that all pkeys other than 'pkey' are unallocated */
1164 void test_pkey_syscalls_on_non_allocated_pkey(int *ptr, u16 pkey)
1165 {
1166         int err;
1167         int i;
1168
1169         /* Note: 0 is the default pkey, so don't mess with it */
1170         for (i = 1; i < NR_PKEYS; i++) {
1171                 if (pkey == i)
1172                         continue;
1173
1174                 dprintf1("trying get/set/free to non-allocated pkey: %2d\n", i);
1175                 err = sys_pkey_free(i);
1176                 pkey_assert(err);
1177
1178                 err = sys_pkey_free(i);
1179                 pkey_assert(err);
1180
1181                 err = sys_mprotect_pkey(ptr, PAGE_SIZE, PROT_READ, i);
1182                 pkey_assert(err);
1183         }
1184 }
1185
1186 /* Assumes that all pkeys other than 'pkey' are unallocated */
1187 void test_pkey_syscalls_bad_args(int *ptr, u16 pkey)
1188 {
1189         int err;
1190         int bad_pkey = NR_PKEYS+99;
1191
1192         /* pass a known-invalid pkey in: */
1193         err = sys_mprotect_pkey(ptr, PAGE_SIZE, PROT_READ, bad_pkey);
1194         pkey_assert(err);
1195 }
1196
1197 void become_child(void)
1198 {
1199         pid_t forkret;
1200
1201         forkret = fork();
1202         pkey_assert(forkret >= 0);
1203         dprintf3("[%d] fork() ret: %d\n", getpid(), forkret);
1204
1205         if (!forkret) {
1206                 /* in the child */
1207                 return;
1208         }
1209         exit(0);
1210 }
1211
1212 /* Assumes that all pkeys other than 'pkey' are unallocated */
1213 void test_pkey_alloc_exhaust(int *ptr, u16 pkey)
1214 {
1215         int err;
1216         int allocated_pkeys[NR_PKEYS] = {0};
1217         int nr_allocated_pkeys = 0;
1218         int i;
1219
1220         for (i = 0; i < NR_PKEYS*3; i++) {
1221                 int new_pkey;
1222                 dprintf1("%s() alloc loop: %d\n", __func__, i);
1223                 new_pkey = alloc_pkey();
1224                 dprintf4("%s()::%d, err: %d pkey_reg: 0x%016llx"
1225                                 " shadow: 0x%016llx\n",
1226                                 __func__, __LINE__, err, __read_pkey_reg(),
1227                                 shadow_pkey_reg);
1228                 read_pkey_reg(); /* for shadow checking */
1229                 dprintf2("%s() errno: %d ENOSPC: %d\n", __func__, errno, ENOSPC);
1230                 if ((new_pkey == -1) && (errno == ENOSPC)) {
1231                         dprintf2("%s() failed to allocate pkey after %d tries\n",
1232                                 __func__, nr_allocated_pkeys);
1233                 } else {
1234                         /*
1235                          * Ensure the number of successes never
1236                          * exceeds the number of keys supported
1237                          * in the hardware.
1238                          */
1239                         pkey_assert(nr_allocated_pkeys < NR_PKEYS);
1240                         allocated_pkeys[nr_allocated_pkeys++] = new_pkey;
1241                 }
1242
1243                 /*
1244                  * Make sure that allocation state is properly
1245                  * preserved across fork().
1246                  */
1247                 if (i == NR_PKEYS*2)
1248                         become_child();
1249         }
1250
1251         dprintf3("%s()::%d\n", __func__, __LINE__);
1252
1253         /*
1254          * On x86:
1255          * There are 16 pkeys supported in hardware.  Three are
1256          * allocated by the time we get here:
1257          *   1. The default key (0)
1258          *   2. One possibly consumed by an execute-only mapping.
1259          *   3. One allocated by the test code and passed in via
1260          *      'pkey' to this function.
1261          * Ensure that we can allocate at least another 13 (16-3).
1262          *
1263          * On powerpc:
1264          * There are either 5, 28, 29 or 32 pkeys supported in
1265          * hardware depending on the page size (4K or 64K) and
1266          * platform (powernv or powervm). Four are allocated by
1267          * the time we get here. These include pkey-0, pkey-1,
1268          * exec-only pkey and the one allocated by the test code.
1269          * Ensure that we can allocate the remaining.
1270          */
1271         pkey_assert(i >= (NR_PKEYS - get_arch_reserved_keys() - 1));
1272
1273         for (i = 0; i < nr_allocated_pkeys; i++) {
1274                 err = sys_pkey_free(allocated_pkeys[i]);
1275                 pkey_assert(!err);
1276                 read_pkey_reg(); /* for shadow checking */
1277         }
1278 }
1279
1280 /*
1281  * pkey 0 is special.  It is allocated by default, so you do not
1282  * have to call pkey_alloc() to use it first.  Make sure that it
1283  * is usable.
1284  */
1285 void test_mprotect_with_pkey_0(int *ptr, u16 pkey)
1286 {
1287         long size;
1288         int prot;
1289
1290         assert(pkey_last_malloc_record);
1291         size = pkey_last_malloc_record->size;
1292         /*
1293          * This is a bit of a hack.  But mprotect() requires
1294          * huge-page-aligned sizes when operating on hugetlbfs.
1295          * So, make sure that we use something that's a multiple
1296          * of a huge page when we can.
1297          */
1298         if (size >= HPAGE_SIZE)
1299                 size = HPAGE_SIZE;
1300         prot = pkey_last_malloc_record->prot;
1301
1302         /* Use pkey 0 */
1303         mprotect_pkey(ptr, size, prot, 0);
1304
1305         /* Make sure that we can set it back to the original pkey. */
1306         mprotect_pkey(ptr, size, prot, pkey);
1307 }
1308
1309 void test_ptrace_of_child(int *ptr, u16 pkey)
1310 {
1311         __attribute__((__unused__)) int peek_result;
1312         pid_t child_pid;
1313         void *ignored = 0;
1314         long ret;
1315         int status;
1316         /*
1317          * This is the "control" for our little expermient.  Make sure
1318          * we can always access it when ptracing.
1319          */
1320         int *plain_ptr_unaligned = malloc(HPAGE_SIZE);
1321         int *plain_ptr = ALIGN_PTR_UP(plain_ptr_unaligned, PAGE_SIZE);
1322
1323         /*
1324          * Fork a child which is an exact copy of this process, of course.
1325          * That means we can do all of our tests via ptrace() and then plain
1326          * memory access and ensure they work differently.
1327          */
1328         child_pid = fork_lazy_child();
1329         dprintf1("[%d] child pid: %d\n", getpid(), child_pid);
1330
1331         ret = ptrace(PTRACE_ATTACH, child_pid, ignored, ignored);
1332         if (ret)
1333                 perror("attach");
1334         dprintf1("[%d] attach ret: %ld %d\n", getpid(), ret, __LINE__);
1335         pkey_assert(ret != -1);
1336         ret = waitpid(child_pid, &status, WUNTRACED);
1337         if ((ret != child_pid) || !(WIFSTOPPED(status))) {
1338                 fprintf(stderr, "weird waitpid result %ld stat %x\n",
1339                                 ret, status);
1340                 pkey_assert(0);
1341         }
1342         dprintf2("waitpid ret: %ld\n", ret);
1343         dprintf2("waitpid status: %d\n", status);
1344
1345         pkey_access_deny(pkey);
1346         pkey_write_deny(pkey);
1347
1348         /* Write access, untested for now:
1349         ret = ptrace(PTRACE_POKEDATA, child_pid, peek_at, data);
1350         pkey_assert(ret != -1);
1351         dprintf1("poke at %p: %ld\n", peek_at, ret);
1352         */
1353
1354         /*
1355          * Try to access the pkey-protected "ptr" via ptrace:
1356          */
1357         ret = ptrace(PTRACE_PEEKDATA, child_pid, ptr, ignored);
1358         /* expect it to work, without an error: */
1359         pkey_assert(ret != -1);
1360         /* Now access from the current task, and expect an exception: */
1361         peek_result = read_ptr(ptr);
1362         expected_pkey_fault(pkey);
1363
1364         /*
1365          * Try to access the NON-pkey-protected "plain_ptr" via ptrace:
1366          */
1367         ret = ptrace(PTRACE_PEEKDATA, child_pid, plain_ptr, ignored);
1368         /* expect it to work, without an error: */
1369         pkey_assert(ret != -1);
1370         /* Now access from the current task, and expect NO exception: */
1371         peek_result = read_ptr(plain_ptr);
1372         do_not_expect_pkey_fault("read plain pointer after ptrace");
1373
1374         ret = ptrace(PTRACE_DETACH, child_pid, ignored, 0);
1375         pkey_assert(ret != -1);
1376
1377         ret = kill(child_pid, SIGKILL);
1378         pkey_assert(ret != -1);
1379
1380         wait(&status);
1381
1382         free(plain_ptr_unaligned);
1383 }
1384
1385 void *get_pointer_to_instructions(void)
1386 {
1387         void *p1;
1388
1389         p1 = ALIGN_PTR_UP(&lots_o_noops_around_write, PAGE_SIZE);
1390         dprintf3("&lots_o_noops: %p\n", &lots_o_noops_around_write);
1391         /* lots_o_noops_around_write should be page-aligned already */
1392         assert(p1 == &lots_o_noops_around_write);
1393
1394         /* Point 'p1' at the *second* page of the function: */
1395         p1 += PAGE_SIZE;
1396
1397         /*
1398          * Try to ensure we fault this in on next touch to ensure
1399          * we get an instruction fault as opposed to a data one
1400          */
1401         madvise(p1, PAGE_SIZE, MADV_DONTNEED);
1402
1403         return p1;
1404 }
1405
1406 void test_executing_on_unreadable_memory(int *ptr, u16 pkey)
1407 {
1408         void *p1;
1409         int scratch;
1410         int ptr_contents;
1411         int ret;
1412
1413         p1 = get_pointer_to_instructions();
1414         lots_o_noops_around_write(&scratch);
1415         ptr_contents = read_ptr(p1);
1416         dprintf2("ptr (%p) contents@%d: %x\n", p1, __LINE__, ptr_contents);
1417
1418         ret = mprotect_pkey(p1, PAGE_SIZE, PROT_EXEC, (u64)pkey);
1419         pkey_assert(!ret);
1420         pkey_access_deny(pkey);
1421
1422         dprintf2("pkey_reg: %016llx\n", read_pkey_reg());
1423
1424         /*
1425          * Make sure this is an *instruction* fault
1426          */
1427         madvise(p1, PAGE_SIZE, MADV_DONTNEED);
1428         lots_o_noops_around_write(&scratch);
1429         do_not_expect_pkey_fault("executing on PROT_EXEC memory");
1430         expect_fault_on_read_execonly_key(p1, pkey);
1431 }
1432
1433 void test_implicit_mprotect_exec_only_memory(int *ptr, u16 pkey)
1434 {
1435         void *p1;
1436         int scratch;
1437         int ptr_contents;
1438         int ret;
1439
1440         dprintf1("%s() start\n", __func__);
1441
1442         p1 = get_pointer_to_instructions();
1443         lots_o_noops_around_write(&scratch);
1444         ptr_contents = read_ptr(p1);
1445         dprintf2("ptr (%p) contents@%d: %x\n", p1, __LINE__, ptr_contents);
1446
1447         /* Use a *normal* mprotect(), not mprotect_pkey(): */
1448         ret = mprotect(p1, PAGE_SIZE, PROT_EXEC);
1449         pkey_assert(!ret);
1450
1451         /*
1452          * Reset the shadow, assuming that the above mprotect()
1453          * correctly changed PKRU, but to an unknown value since
1454          * the actual alllocated pkey is unknown.
1455          */
1456         shadow_pkey_reg = __read_pkey_reg();
1457
1458         dprintf2("pkey_reg: %016llx\n", read_pkey_reg());
1459
1460         /* Make sure this is an *instruction* fault */
1461         madvise(p1, PAGE_SIZE, MADV_DONTNEED);
1462         lots_o_noops_around_write(&scratch);
1463         do_not_expect_pkey_fault("executing on PROT_EXEC memory");
1464         expect_fault_on_read_execonly_key(p1, UNKNOWN_PKEY);
1465
1466         /*
1467          * Put the memory back to non-PROT_EXEC.  Should clear the
1468          * exec-only pkey off the VMA and allow it to be readable
1469          * again.  Go to PROT_NONE first to check for a kernel bug
1470          * that did not clear the pkey when doing PROT_NONE.
1471          */
1472         ret = mprotect(p1, PAGE_SIZE, PROT_NONE);
1473         pkey_assert(!ret);
1474
1475         ret = mprotect(p1, PAGE_SIZE, PROT_READ|PROT_EXEC);
1476         pkey_assert(!ret);
1477         ptr_contents = read_ptr(p1);
1478         do_not_expect_pkey_fault("plain read on recently PROT_EXEC area");
1479 }
1480
1481 void test_mprotect_pkey_on_unsupported_cpu(int *ptr, u16 pkey)
1482 {
1483         int size = PAGE_SIZE;
1484         int sret;
1485
1486         if (cpu_has_pkeys()) {
1487                 dprintf1("SKIP: %s: no CPU support\n", __func__);
1488                 return;
1489         }
1490
1491         sret = syscall(SYS_mprotect_key, ptr, size, PROT_READ, pkey);
1492         pkey_assert(sret < 0);
1493 }
1494
1495 void (*pkey_tests[])(int *ptr, u16 pkey) = {
1496         test_read_of_write_disabled_region,
1497         test_read_of_access_disabled_region,
1498         test_read_of_access_disabled_region_with_page_already_mapped,
1499         test_write_of_write_disabled_region,
1500         test_write_of_write_disabled_region_with_page_already_mapped,
1501         test_write_of_access_disabled_region,
1502         test_write_of_access_disabled_region_with_page_already_mapped,
1503         test_kernel_write_of_access_disabled_region,
1504         test_kernel_write_of_write_disabled_region,
1505         test_kernel_gup_of_access_disabled_region,
1506         test_kernel_gup_write_to_write_disabled_region,
1507         test_executing_on_unreadable_memory,
1508         test_implicit_mprotect_exec_only_memory,
1509         test_mprotect_with_pkey_0,
1510         test_ptrace_of_child,
1511         test_pkey_syscalls_on_non_allocated_pkey,
1512         test_pkey_syscalls_bad_args,
1513         test_pkey_alloc_exhaust,
1514         test_pkey_alloc_free_attach_pkey0,
1515 };
1516
1517 void run_tests_once(void)
1518 {
1519         int *ptr;
1520         int prot = PROT_READ|PROT_WRITE;
1521
1522         for (test_nr = 0; test_nr < ARRAY_SIZE(pkey_tests); test_nr++) {
1523                 int pkey;
1524                 int orig_pkey_faults = pkey_faults;
1525
1526                 dprintf1("======================\n");
1527                 dprintf1("test %d preparing...\n", test_nr);
1528
1529                 tracing_on();
1530                 pkey = alloc_random_pkey();
1531                 dprintf1("test %d starting with pkey: %d\n", test_nr, pkey);
1532                 ptr = malloc_pkey(PAGE_SIZE, prot, pkey);
1533                 dprintf1("test %d starting...\n", test_nr);
1534                 pkey_tests[test_nr](ptr, pkey);
1535                 dprintf1("freeing test memory: %p\n", ptr);
1536                 free_pkey_malloc(ptr);
1537                 sys_pkey_free(pkey);
1538
1539                 dprintf1("pkey_faults: %d\n", pkey_faults);
1540                 dprintf1("orig_pkey_faults: %d\n", orig_pkey_faults);
1541
1542                 tracing_off();
1543                 close_test_fds();
1544
1545                 printf("test %2d PASSED (iteration %d)\n", test_nr, iteration_nr);
1546                 dprintf1("======================\n\n");
1547         }
1548         iteration_nr++;
1549 }
1550
1551 void pkey_setup_shadow(void)
1552 {
1553         shadow_pkey_reg = __read_pkey_reg();
1554 }
1555
1556 int main(void)
1557 {
1558         int nr_iterations = 22;
1559         int pkeys_supported = is_pkeys_supported();
1560
1561         srand((unsigned int)time(NULL));
1562
1563         setup_handlers();
1564
1565         printf("has pkeys: %d\n", pkeys_supported);
1566
1567         if (!pkeys_supported) {
1568                 int size = PAGE_SIZE;
1569                 int *ptr;
1570
1571                 printf("running PKEY tests for unsupported CPU/OS\n");
1572
1573                 ptr  = mmap(NULL, size, PROT_NONE, MAP_ANONYMOUS|MAP_PRIVATE, -1, 0);
1574                 assert(ptr != (void *)-1);
1575                 test_mprotect_pkey_on_unsupported_cpu(ptr, 1);
1576                 exit(0);
1577         }
1578
1579         pkey_setup_shadow();
1580         printf("startup pkey_reg: %016llx\n", read_pkey_reg());
1581         setup_hugetlbfs();
1582
1583         while (nr_iterations-- > 0)
1584                 run_tests_once();
1585
1586         printf("done (all tests OK)\n");
1587         return 0;
1588 }