RISC-V: Fix a race condition during kernel stack overflow
[platform/kernel/linux-starfive.git] / lib / overflow_kunit.c
1 // SPDX-License-Identifier: GPL-2.0 OR MIT
2 /*
3  * Test cases for arithmetic overflow checks. See:
4  * "Running tests with kunit_tool" at Documentation/dev-tools/kunit/start.rst
5  *      ./tools/testing/kunit/kunit.py run overflow [--raw_output]
6  */
7 #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
8
9 #include <kunit/test.h>
10 #include <linux/device.h>
11 #include <linux/kernel.h>
12 #include <linux/mm.h>
13 #include <linux/module.h>
14 #include <linux/overflow.h>
15 #include <linux/slab.h>
16 #include <linux/types.h>
17 #include <linux/vmalloc.h>
18
19 #define DEFINE_TEST_ARRAY_TYPED(t1, t2, t)                      \
20         static const struct test_ ## t1 ## _ ## t2 ## __ ## t { \
21                 t1 a;                                           \
22                 t2 b;                                           \
23                 t sum, diff, prod;                              \
24                 bool s_of, d_of, p_of;                          \
25         } t1 ## _ ## t2 ## __ ## t ## _tests[]
26
27 #define DEFINE_TEST_ARRAY(t)    DEFINE_TEST_ARRAY_TYPED(t, t, t)
28
29 DEFINE_TEST_ARRAY(u8) = {
30         {0, 0, 0, 0, 0, false, false, false},
31         {1, 1, 2, 0, 1, false, false, false},
32         {0, 1, 1, U8_MAX, 0, false, true, false},
33         {1, 0, 1, 1, 0, false, false, false},
34         {0, U8_MAX, U8_MAX, 1, 0, false, true, false},
35         {U8_MAX, 0, U8_MAX, U8_MAX, 0, false, false, false},
36         {1, U8_MAX, 0, 2, U8_MAX, true, true, false},
37         {U8_MAX, 1, 0, U8_MAX-1, U8_MAX, true, false, false},
38         {U8_MAX, U8_MAX, U8_MAX-1, 0, 1, true, false, true},
39
40         {U8_MAX, U8_MAX-1, U8_MAX-2, 1, 2, true, false, true},
41         {U8_MAX-1, U8_MAX, U8_MAX-2, U8_MAX, 2, true, true, true},
42
43         {1U << 3, 1U << 3, 1U << 4, 0, 1U << 6, false, false, false},
44         {1U << 4, 1U << 4, 1U << 5, 0, 0, false, false, true},
45         {1U << 4, 1U << 3, 3*(1U << 3), 1U << 3, 1U << 7, false, false, false},
46         {1U << 7, 1U << 7, 0, 0, 0, true, false, true},
47
48         {48, 32, 80, 16, 0, false, false, true},
49         {128, 128, 0, 0, 0, true, false, true},
50         {123, 234, 101, 145, 110, true, true, true},
51 };
52 DEFINE_TEST_ARRAY(u16) = {
53         {0, 0, 0, 0, 0, false, false, false},
54         {1, 1, 2, 0, 1, false, false, false},
55         {0, 1, 1, U16_MAX, 0, false, true, false},
56         {1, 0, 1, 1, 0, false, false, false},
57         {0, U16_MAX, U16_MAX, 1, 0, false, true, false},
58         {U16_MAX, 0, U16_MAX, U16_MAX, 0, false, false, false},
59         {1, U16_MAX, 0, 2, U16_MAX, true, true, false},
60         {U16_MAX, 1, 0, U16_MAX-1, U16_MAX, true, false, false},
61         {U16_MAX, U16_MAX, U16_MAX-1, 0, 1, true, false, true},
62
63         {U16_MAX, U16_MAX-1, U16_MAX-2, 1, 2, true, false, true},
64         {U16_MAX-1, U16_MAX, U16_MAX-2, U16_MAX, 2, true, true, true},
65
66         {1U << 7, 1U << 7, 1U << 8, 0, 1U << 14, false, false, false},
67         {1U << 8, 1U << 8, 1U << 9, 0, 0, false, false, true},
68         {1U << 8, 1U << 7, 3*(1U << 7), 1U << 7, 1U << 15, false, false, false},
69         {1U << 15, 1U << 15, 0, 0, 0, true, false, true},
70
71         {123, 234, 357, 65425, 28782, false, true, false},
72         {1234, 2345, 3579, 64425, 10146, false, true, true},
73 };
74 DEFINE_TEST_ARRAY(u32) = {
75         {0, 0, 0, 0, 0, false, false, false},
76         {1, 1, 2, 0, 1, false, false, false},
77         {0, 1, 1, U32_MAX, 0, false, true, false},
78         {1, 0, 1, 1, 0, false, false, false},
79         {0, U32_MAX, U32_MAX, 1, 0, false, true, false},
80         {U32_MAX, 0, U32_MAX, U32_MAX, 0, false, false, false},
81         {1, U32_MAX, 0, 2, U32_MAX, true, true, false},
82         {U32_MAX, 1, 0, U32_MAX-1, U32_MAX, true, false, false},
83         {U32_MAX, U32_MAX, U32_MAX-1, 0, 1, true, false, true},
84
85         {U32_MAX, U32_MAX-1, U32_MAX-2, 1, 2, true, false, true},
86         {U32_MAX-1, U32_MAX, U32_MAX-2, U32_MAX, 2, true, true, true},
87
88         {1U << 15, 1U << 15, 1U << 16, 0, 1U << 30, false, false, false},
89         {1U << 16, 1U << 16, 1U << 17, 0, 0, false, false, true},
90         {1U << 16, 1U << 15, 3*(1U << 15), 1U << 15, 1U << 31, false, false, false},
91         {1U << 31, 1U << 31, 0, 0, 0, true, false, true},
92
93         {-2U, 1U, -1U, -3U, -2U, false, false, false},
94         {-4U, 5U, 1U, -9U, -20U, true, false, true},
95 };
96
97 #if BITS_PER_LONG == 64
98 DEFINE_TEST_ARRAY(u64) = {
99         {0, 0, 0, 0, 0, false, false, false},
100         {1, 1, 2, 0, 1, false, false, false},
101         {0, 1, 1, U64_MAX, 0, false, true, false},
102         {1, 0, 1, 1, 0, false, false, false},
103         {0, U64_MAX, U64_MAX, 1, 0, false, true, false},
104         {U64_MAX, 0, U64_MAX, U64_MAX, 0, false, false, false},
105         {1, U64_MAX, 0, 2, U64_MAX, true, true, false},
106         {U64_MAX, 1, 0, U64_MAX-1, U64_MAX, true, false, false},
107         {U64_MAX, U64_MAX, U64_MAX-1, 0, 1, true, false, true},
108
109         {U64_MAX, U64_MAX-1, U64_MAX-2, 1, 2, true, false, true},
110         {U64_MAX-1, U64_MAX, U64_MAX-2, U64_MAX, 2, true, true, true},
111
112         {1ULL << 31, 1ULL << 31, 1ULL << 32, 0, 1ULL << 62, false, false, false},
113         {1ULL << 32, 1ULL << 32, 1ULL << 33, 0, 0, false, false, true},
114         {1ULL << 32, 1ULL << 31, 3*(1ULL << 31), 1ULL << 31, 1ULL << 63, false, false, false},
115         {1ULL << 63, 1ULL << 63, 0, 0, 0, true, false, true},
116         {1000000000ULL /* 10^9 */, 10000000000ULL /* 10^10 */,
117          11000000000ULL, 18446744064709551616ULL, 10000000000000000000ULL,
118          false, true, false},
119         {-15ULL, 10ULL, -5ULL, -25ULL, -150ULL, false, false, true},
120 };
121 #endif
122
123 DEFINE_TEST_ARRAY(s8) = {
124         {0, 0, 0, 0, 0, false, false, false},
125
126         {0, S8_MAX, S8_MAX, -S8_MAX, 0, false, false, false},
127         {S8_MAX, 0, S8_MAX, S8_MAX, 0, false, false, false},
128         {0, S8_MIN, S8_MIN, S8_MIN, 0, false, true, false},
129         {S8_MIN, 0, S8_MIN, S8_MIN, 0, false, false, false},
130
131         {-1, S8_MIN, S8_MAX, S8_MAX, S8_MIN, true, false, true},
132         {S8_MIN, -1, S8_MAX, -S8_MAX, S8_MIN, true, false, true},
133         {-1, S8_MAX, S8_MAX-1, S8_MIN, -S8_MAX, false, false, false},
134         {S8_MAX, -1, S8_MAX-1, S8_MIN, -S8_MAX, false, true, false},
135         {-1, -S8_MAX, S8_MIN, S8_MAX-1, S8_MAX, false, false, false},
136         {-S8_MAX, -1, S8_MIN, S8_MIN+2, S8_MAX, false, false, false},
137
138         {1, S8_MIN, -S8_MAX, -S8_MAX, S8_MIN, false, true, false},
139         {S8_MIN, 1, -S8_MAX, S8_MAX, S8_MIN, false, true, false},
140         {1, S8_MAX, S8_MIN, S8_MIN+2, S8_MAX, true, false, false},
141         {S8_MAX, 1, S8_MIN, S8_MAX-1, S8_MAX, true, false, false},
142
143         {S8_MIN, S8_MIN, 0, 0, 0, true, false, true},
144         {S8_MAX, S8_MAX, -2, 0, 1, true, false, true},
145
146         {-4, -32, -36, 28, -128, false, false, true},
147         {-4, 32, 28, -36, -128, false, false, false},
148 };
149
150 DEFINE_TEST_ARRAY(s16) = {
151         {0, 0, 0, 0, 0, false, false, false},
152
153         {0, S16_MAX, S16_MAX, -S16_MAX, 0, false, false, false},
154         {S16_MAX, 0, S16_MAX, S16_MAX, 0, false, false, false},
155         {0, S16_MIN, S16_MIN, S16_MIN, 0, false, true, false},
156         {S16_MIN, 0, S16_MIN, S16_MIN, 0, false, false, false},
157
158         {-1, S16_MIN, S16_MAX, S16_MAX, S16_MIN, true, false, true},
159         {S16_MIN, -1, S16_MAX, -S16_MAX, S16_MIN, true, false, true},
160         {-1, S16_MAX, S16_MAX-1, S16_MIN, -S16_MAX, false, false, false},
161         {S16_MAX, -1, S16_MAX-1, S16_MIN, -S16_MAX, false, true, false},
162         {-1, -S16_MAX, S16_MIN, S16_MAX-1, S16_MAX, false, false, false},
163         {-S16_MAX, -1, S16_MIN, S16_MIN+2, S16_MAX, false, false, false},
164
165         {1, S16_MIN, -S16_MAX, -S16_MAX, S16_MIN, false, true, false},
166         {S16_MIN, 1, -S16_MAX, S16_MAX, S16_MIN, false, true, false},
167         {1, S16_MAX, S16_MIN, S16_MIN+2, S16_MAX, true, false, false},
168         {S16_MAX, 1, S16_MIN, S16_MAX-1, S16_MAX, true, false, false},
169
170         {S16_MIN, S16_MIN, 0, 0, 0, true, false, true},
171         {S16_MAX, S16_MAX, -2, 0, 1, true, false, true},
172 };
173 DEFINE_TEST_ARRAY(s32) = {
174         {0, 0, 0, 0, 0, false, false, false},
175
176         {0, S32_MAX, S32_MAX, -S32_MAX, 0, false, false, false},
177         {S32_MAX, 0, S32_MAX, S32_MAX, 0, false, false, false},
178         {0, S32_MIN, S32_MIN, S32_MIN, 0, false, true, false},
179         {S32_MIN, 0, S32_MIN, S32_MIN, 0, false, false, false},
180
181         {-1, S32_MIN, S32_MAX, S32_MAX, S32_MIN, true, false, true},
182         {S32_MIN, -1, S32_MAX, -S32_MAX, S32_MIN, true, false, true},
183         {-1, S32_MAX, S32_MAX-1, S32_MIN, -S32_MAX, false, false, false},
184         {S32_MAX, -1, S32_MAX-1, S32_MIN, -S32_MAX, false, true, false},
185         {-1, -S32_MAX, S32_MIN, S32_MAX-1, S32_MAX, false, false, false},
186         {-S32_MAX, -1, S32_MIN, S32_MIN+2, S32_MAX, false, false, false},
187
188         {1, S32_MIN, -S32_MAX, -S32_MAX, S32_MIN, false, true, false},
189         {S32_MIN, 1, -S32_MAX, S32_MAX, S32_MIN, false, true, false},
190         {1, S32_MAX, S32_MIN, S32_MIN+2, S32_MAX, true, false, false},
191         {S32_MAX, 1, S32_MIN, S32_MAX-1, S32_MAX, true, false, false},
192
193         {S32_MIN, S32_MIN, 0, 0, 0, true, false, true},
194         {S32_MAX, S32_MAX, -2, 0, 1, true, false, true},
195 };
196
197 #if BITS_PER_LONG == 64
198 DEFINE_TEST_ARRAY(s64) = {
199         {0, 0, 0, 0, 0, false, false, false},
200
201         {0, S64_MAX, S64_MAX, -S64_MAX, 0, false, false, false},
202         {S64_MAX, 0, S64_MAX, S64_MAX, 0, false, false, false},
203         {0, S64_MIN, S64_MIN, S64_MIN, 0, false, true, false},
204         {S64_MIN, 0, S64_MIN, S64_MIN, 0, false, false, false},
205
206         {-1, S64_MIN, S64_MAX, S64_MAX, S64_MIN, true, false, true},
207         {S64_MIN, -1, S64_MAX, -S64_MAX, S64_MIN, true, false, true},
208         {-1, S64_MAX, S64_MAX-1, S64_MIN, -S64_MAX, false, false, false},
209         {S64_MAX, -1, S64_MAX-1, S64_MIN, -S64_MAX, false, true, false},
210         {-1, -S64_MAX, S64_MIN, S64_MAX-1, S64_MAX, false, false, false},
211         {-S64_MAX, -1, S64_MIN, S64_MIN+2, S64_MAX, false, false, false},
212
213         {1, S64_MIN, -S64_MAX, -S64_MAX, S64_MIN, false, true, false},
214         {S64_MIN, 1, -S64_MAX, S64_MAX, S64_MIN, false, true, false},
215         {1, S64_MAX, S64_MIN, S64_MIN+2, S64_MAX, true, false, false},
216         {S64_MAX, 1, S64_MIN, S64_MAX-1, S64_MAX, true, false, false},
217
218         {S64_MIN, S64_MIN, 0, 0, 0, true, false, true},
219         {S64_MAX, S64_MAX, -2, 0, 1, true, false, true},
220
221         {-1, -1, -2, 0, 1, false, false, false},
222         {-1, -128, -129, 127, 128, false, false, false},
223         {-128, -1, -129, -127, 128, false, false, false},
224         {0, -S64_MAX, -S64_MAX, S64_MAX, 0, false, false, false},
225 };
226 #endif
227
228 #define check_one_op(t, fmt, op, sym, a, b, r, of) do {                 \
229         int _a_orig = a, _a_bump = a + 1;                               \
230         int _b_orig = b, _b_bump = b + 1;                               \
231         bool _of;                                                       \
232         t _r;                                                           \
233                                                                         \
234         _of = check_ ## op ## _overflow(a, b, &_r);                     \
235         KUNIT_EXPECT_EQ_MSG(test, _of, of,                              \
236                 "expected "fmt" "sym" "fmt" to%s overflow (type %s)\n", \
237                 a, b, of ? "" : " not", #t);                            \
238         KUNIT_EXPECT_EQ_MSG(test, _r, r,                                \
239                 "expected "fmt" "sym" "fmt" == "fmt", got "fmt" (type %s)\n", \
240                 a, b, r, _r, #t);                                       \
241         /* Check for internal macro side-effects. */                    \
242         _of = check_ ## op ## _overflow(_a_orig++, _b_orig++, &_r);     \
243         KUNIT_EXPECT_EQ_MSG(test, _a_orig, _a_bump, "Unexpected " #op " macro side-effect!\n"); \
244         KUNIT_EXPECT_EQ_MSG(test, _b_orig, _b_bump, "Unexpected " #op " macro side-effect!\n"); \
245 } while (0)
246
247 #define DEFINE_TEST_FUNC_TYPED(n, t, fmt)                               \
248 static void do_test_ ## n(struct kunit *test, const struct test_ ## n *p) \
249 {                                                                       \
250         check_one_op(t, fmt, add, "+", p->a, p->b, p->sum, p->s_of);    \
251         check_one_op(t, fmt, add, "+", p->b, p->a, p->sum, p->s_of);    \
252         check_one_op(t, fmt, sub, "-", p->a, p->b, p->diff, p->d_of);   \
253         check_one_op(t, fmt, mul, "*", p->a, p->b, p->prod, p->p_of);   \
254         check_one_op(t, fmt, mul, "*", p->b, p->a, p->prod, p->p_of);   \
255 }                                                                       \
256                                                                         \
257 static void n ## _overflow_test(struct kunit *test) {                   \
258         unsigned i;                                                     \
259                                                                         \
260         for (i = 0; i < ARRAY_SIZE(n ## _tests); ++i)                   \
261                 do_test_ ## n(test, &n ## _tests[i]);                   \
262         kunit_info(test, "%zu %s arithmetic tests finished\n",          \
263                 ARRAY_SIZE(n ## _tests), #n);                           \
264 }
265
266 #define DEFINE_TEST_FUNC(t, fmt)                                        \
267         DEFINE_TEST_FUNC_TYPED(t ## _ ## t ## __ ## t, t, fmt)
268
269 DEFINE_TEST_FUNC(u8, "%d");
270 DEFINE_TEST_FUNC(s8, "%d");
271 DEFINE_TEST_FUNC(u16, "%d");
272 DEFINE_TEST_FUNC(s16, "%d");
273 DEFINE_TEST_FUNC(u32, "%u");
274 DEFINE_TEST_FUNC(s32, "%d");
275 #if BITS_PER_LONG == 64
276 DEFINE_TEST_FUNC(u64, "%llu");
277 DEFINE_TEST_FUNC(s64, "%lld");
278 #endif
279
280 DEFINE_TEST_ARRAY_TYPED(u32, u32, u8) = {
281         {0, 0, 0, 0, 0, false, false, false},
282         {U8_MAX, 2, 1, U8_MAX - 2, U8_MAX - 1, true, false, true},
283         {U8_MAX + 1, 0, 0, 0, 0, true, true, false},
284 };
285 DEFINE_TEST_FUNC_TYPED(u32_u32__u8, u8, "%d");
286
287 DEFINE_TEST_ARRAY_TYPED(u32, u32, int) = {
288         {0, 0, 0, 0, 0, false, false, false},
289         {U32_MAX, 0, -1, -1, 0, true, true, false},
290 };
291 DEFINE_TEST_FUNC_TYPED(u32_u32__int, int, "%d");
292
293 DEFINE_TEST_ARRAY_TYPED(u8, u8, int) = {
294         {0, 0, 0, 0, 0, false, false, false},
295         {U8_MAX, U8_MAX, 2 * U8_MAX, 0, U8_MAX * U8_MAX, false, false, false},
296         {1, 2, 3, -1, 2, false, false, false},
297 };
298 DEFINE_TEST_FUNC_TYPED(u8_u8__int, int, "%d");
299
300 DEFINE_TEST_ARRAY_TYPED(int, int, u8) = {
301         {0, 0, 0, 0, 0, false, false, false},
302         {1, 2, 3, U8_MAX, 2, false, true, false},
303         {-1, 0, U8_MAX, U8_MAX, 0, true, true, false},
304 };
305 DEFINE_TEST_FUNC_TYPED(int_int__u8, u8, "%d");
306
307 /* Args are: value, shift, type, expected result, overflow expected */
308 #define TEST_ONE_SHIFT(a, s, t, expect, of)     do {                    \
309         typeof(a) __a = (a);                                            \
310         typeof(s) __s = (s);                                            \
311         t __e = (expect);                                               \
312         t __d;                                                          \
313         bool __of = check_shl_overflow(__a, __s, &__d);                 \
314         if (__of != of) {                                               \
315                 KUNIT_EXPECT_EQ_MSG(test, __of, of,                     \
316                         "expected (%s)(%s << %s) to%s overflow\n",      \
317                         #t, #a, #s, of ? "" : " not");                  \
318         } else if (!__of && __d != __e) {                               \
319                 KUNIT_EXPECT_EQ_MSG(test, __d, __e,                     \
320                         "expected (%s)(%s << %s) == %s\n",              \
321                         #t, #a, #s, #expect);                           \
322                 if ((t)-1 < 0)                                          \
323                         kunit_info(test, "got %lld\n", (s64)__d);       \
324                 else                                                    \
325                         kunit_info(test, "got %llu\n", (u64)__d);       \
326         }                                                               \
327         count++;                                                        \
328 } while (0)
329
330 static void shift_sane_test(struct kunit *test)
331 {
332         int count = 0;
333
334         /* Sane shifts. */
335         TEST_ONE_SHIFT(1, 0, u8, 1 << 0, false);
336         TEST_ONE_SHIFT(1, 4, u8, 1 << 4, false);
337         TEST_ONE_SHIFT(1, 7, u8, 1 << 7, false);
338         TEST_ONE_SHIFT(0xF, 4, u8, 0xF << 4, false);
339         TEST_ONE_SHIFT(1, 0, u16, 1 << 0, false);
340         TEST_ONE_SHIFT(1, 10, u16, 1 << 10, false);
341         TEST_ONE_SHIFT(1, 15, u16, 1 << 15, false);
342         TEST_ONE_SHIFT(0xFF, 8, u16, 0xFF << 8, false);
343         TEST_ONE_SHIFT(1, 0, int, 1 << 0, false);
344         TEST_ONE_SHIFT(1, 16, int, 1 << 16, false);
345         TEST_ONE_SHIFT(1, 30, int, 1 << 30, false);
346         TEST_ONE_SHIFT(1, 0, s32, 1 << 0, false);
347         TEST_ONE_SHIFT(1, 16, s32, 1 << 16, false);
348         TEST_ONE_SHIFT(1, 30, s32, 1 << 30, false);
349         TEST_ONE_SHIFT(1, 0, unsigned int, 1U << 0, false);
350         TEST_ONE_SHIFT(1, 20, unsigned int, 1U << 20, false);
351         TEST_ONE_SHIFT(1, 31, unsigned int, 1U << 31, false);
352         TEST_ONE_SHIFT(0xFFFFU, 16, unsigned int, 0xFFFFU << 16, false);
353         TEST_ONE_SHIFT(1, 0, u32, 1U << 0, false);
354         TEST_ONE_SHIFT(1, 20, u32, 1U << 20, false);
355         TEST_ONE_SHIFT(1, 31, u32, 1U << 31, false);
356         TEST_ONE_SHIFT(0xFFFFU, 16, u32, 0xFFFFU << 16, false);
357         TEST_ONE_SHIFT(1, 0, u64, 1ULL << 0, false);
358         TEST_ONE_SHIFT(1, 40, u64, 1ULL << 40, false);
359         TEST_ONE_SHIFT(1, 63, u64, 1ULL << 63, false);
360         TEST_ONE_SHIFT(0xFFFFFFFFULL, 32, u64, 0xFFFFFFFFULL << 32, false);
361
362         /* Sane shift: start and end with 0, without a too-wide shift. */
363         TEST_ONE_SHIFT(0, 7, u8, 0, false);
364         TEST_ONE_SHIFT(0, 15, u16, 0, false);
365         TEST_ONE_SHIFT(0, 31, unsigned int, 0, false);
366         TEST_ONE_SHIFT(0, 31, u32, 0, false);
367         TEST_ONE_SHIFT(0, 63, u64, 0, false);
368
369         /* Sane shift: start and end with 0, without reaching signed bit. */
370         TEST_ONE_SHIFT(0, 6, s8, 0, false);
371         TEST_ONE_SHIFT(0, 14, s16, 0, false);
372         TEST_ONE_SHIFT(0, 30, int, 0, false);
373         TEST_ONE_SHIFT(0, 30, s32, 0, false);
374         TEST_ONE_SHIFT(0, 62, s64, 0, false);
375
376         kunit_info(test, "%d sane shift tests finished\n", count);
377 }
378
379 static void shift_overflow_test(struct kunit *test)
380 {
381         int count = 0;
382
383         /* Overflow: shifted the bit off the end. */
384         TEST_ONE_SHIFT(1, 8, u8, 0, true);
385         TEST_ONE_SHIFT(1, 16, u16, 0, true);
386         TEST_ONE_SHIFT(1, 32, unsigned int, 0, true);
387         TEST_ONE_SHIFT(1, 32, u32, 0, true);
388         TEST_ONE_SHIFT(1, 64, u64, 0, true);
389
390         /* Overflow: shifted into the signed bit. */
391         TEST_ONE_SHIFT(1, 7, s8, 0, true);
392         TEST_ONE_SHIFT(1, 15, s16, 0, true);
393         TEST_ONE_SHIFT(1, 31, int, 0, true);
394         TEST_ONE_SHIFT(1, 31, s32, 0, true);
395         TEST_ONE_SHIFT(1, 63, s64, 0, true);
396
397         /* Overflow: high bit falls off unsigned types. */
398         /* 10010110 */
399         TEST_ONE_SHIFT(150, 1, u8, 0, true);
400         /* 1000100010010110 */
401         TEST_ONE_SHIFT(34966, 1, u16, 0, true);
402         /* 10000100000010001000100010010110 */
403         TEST_ONE_SHIFT(2215151766U, 1, u32, 0, true);
404         TEST_ONE_SHIFT(2215151766U, 1, unsigned int, 0, true);
405         /* 1000001000010000010000000100000010000100000010001000100010010110 */
406         TEST_ONE_SHIFT(9372061470395238550ULL, 1, u64, 0, true);
407
408         /* Overflow: bit shifted into signed bit on signed types. */
409         /* 01001011 */
410         TEST_ONE_SHIFT(75, 1, s8, 0, true);
411         /* 0100010001001011 */
412         TEST_ONE_SHIFT(17483, 1, s16, 0, true);
413         /* 01000010000001000100010001001011 */
414         TEST_ONE_SHIFT(1107575883, 1, s32, 0, true);
415         TEST_ONE_SHIFT(1107575883, 1, int, 0, true);
416         /* 0100000100001000001000000010000001000010000001000100010001001011 */
417         TEST_ONE_SHIFT(4686030735197619275LL, 1, s64, 0, true);
418
419         /* Overflow: bit shifted past signed bit on signed types. */
420         /* 01001011 */
421         TEST_ONE_SHIFT(75, 2, s8, 0, true);
422         /* 0100010001001011 */
423         TEST_ONE_SHIFT(17483, 2, s16, 0, true);
424         /* 01000010000001000100010001001011 */
425         TEST_ONE_SHIFT(1107575883, 2, s32, 0, true);
426         TEST_ONE_SHIFT(1107575883, 2, int, 0, true);
427         /* 0100000100001000001000000010000001000010000001000100010001001011 */
428         TEST_ONE_SHIFT(4686030735197619275LL, 2, s64, 0, true);
429
430         kunit_info(test, "%d overflow shift tests finished\n", count);
431 }
432
433 static void shift_truncate_test(struct kunit *test)
434 {
435         int count = 0;
436
437         /* Overflow: values larger than destination type. */
438         TEST_ONE_SHIFT(0x100, 0, u8, 0, true);
439         TEST_ONE_SHIFT(0xFF, 0, s8, 0, true);
440         TEST_ONE_SHIFT(0x10000U, 0, u16, 0, true);
441         TEST_ONE_SHIFT(0xFFFFU, 0, s16, 0, true);
442         TEST_ONE_SHIFT(0x100000000ULL, 0, u32, 0, true);
443         TEST_ONE_SHIFT(0x100000000ULL, 0, unsigned int, 0, true);
444         TEST_ONE_SHIFT(0xFFFFFFFFUL, 0, s32, 0, true);
445         TEST_ONE_SHIFT(0xFFFFFFFFUL, 0, int, 0, true);
446         TEST_ONE_SHIFT(0xFFFFFFFFFFFFFFFFULL, 0, s64, 0, true);
447
448         /* Overflow: shifted at or beyond entire type's bit width. */
449         TEST_ONE_SHIFT(0, 8, u8, 0, true);
450         TEST_ONE_SHIFT(0, 9, u8, 0, true);
451         TEST_ONE_SHIFT(0, 8, s8, 0, true);
452         TEST_ONE_SHIFT(0, 9, s8, 0, true);
453         TEST_ONE_SHIFT(0, 16, u16, 0, true);
454         TEST_ONE_SHIFT(0, 17, u16, 0, true);
455         TEST_ONE_SHIFT(0, 16, s16, 0, true);
456         TEST_ONE_SHIFT(0, 17, s16, 0, true);
457         TEST_ONE_SHIFT(0, 32, u32, 0, true);
458         TEST_ONE_SHIFT(0, 33, u32, 0, true);
459         TEST_ONE_SHIFT(0, 32, int, 0, true);
460         TEST_ONE_SHIFT(0, 33, int, 0, true);
461         TEST_ONE_SHIFT(0, 32, s32, 0, true);
462         TEST_ONE_SHIFT(0, 33, s32, 0, true);
463         TEST_ONE_SHIFT(0, 64, u64, 0, true);
464         TEST_ONE_SHIFT(0, 65, u64, 0, true);
465         TEST_ONE_SHIFT(0, 64, s64, 0, true);
466         TEST_ONE_SHIFT(0, 65, s64, 0, true);
467
468         kunit_info(test, "%d truncate shift tests finished\n", count);
469 }
470
471 static void shift_nonsense_test(struct kunit *test)
472 {
473         int count = 0;
474
475         /* Nonsense: negative initial value. */
476         TEST_ONE_SHIFT(-1, 0, s8, 0, true);
477         TEST_ONE_SHIFT(-1, 0, u8, 0, true);
478         TEST_ONE_SHIFT(-5, 0, s16, 0, true);
479         TEST_ONE_SHIFT(-5, 0, u16, 0, true);
480         TEST_ONE_SHIFT(-10, 0, int, 0, true);
481         TEST_ONE_SHIFT(-10, 0, unsigned int, 0, true);
482         TEST_ONE_SHIFT(-100, 0, s32, 0, true);
483         TEST_ONE_SHIFT(-100, 0, u32, 0, true);
484         TEST_ONE_SHIFT(-10000, 0, s64, 0, true);
485         TEST_ONE_SHIFT(-10000, 0, u64, 0, true);
486
487         /* Nonsense: negative shift values. */
488         TEST_ONE_SHIFT(0, -5, s8, 0, true);
489         TEST_ONE_SHIFT(0, -5, u8, 0, true);
490         TEST_ONE_SHIFT(0, -10, s16, 0, true);
491         TEST_ONE_SHIFT(0, -10, u16, 0, true);
492         TEST_ONE_SHIFT(0, -15, int, 0, true);
493         TEST_ONE_SHIFT(0, -15, unsigned int, 0, true);
494         TEST_ONE_SHIFT(0, -20, s32, 0, true);
495         TEST_ONE_SHIFT(0, -20, u32, 0, true);
496         TEST_ONE_SHIFT(0, -30, s64, 0, true);
497         TEST_ONE_SHIFT(0, -30, u64, 0, true);
498
499         /*
500          * Corner case: for unsigned types, we fail when we've shifted
501          * through the entire width of bits. For signed types, we might
502          * want to match this behavior, but that would mean noticing if
503          * we shift through all but the signed bit, and this is not
504          * currently detected (but we'll notice an overflow into the
505          * signed bit). So, for now, we will test this condition but
506          * mark it as not expected to overflow.
507          */
508         TEST_ONE_SHIFT(0, 7, s8, 0, false);
509         TEST_ONE_SHIFT(0, 15, s16, 0, false);
510         TEST_ONE_SHIFT(0, 31, int, 0, false);
511         TEST_ONE_SHIFT(0, 31, s32, 0, false);
512         TEST_ONE_SHIFT(0, 63, s64, 0, false);
513
514         kunit_info(test, "%d nonsense shift tests finished\n", count);
515 }
516 #undef TEST_ONE_SHIFT
517
518 /*
519  * Deal with the various forms of allocator arguments. See comments above
520  * the DEFINE_TEST_ALLOC() instances for mapping of the "bits".
521  */
522 #define alloc_GFP                (GFP_KERNEL | __GFP_NOWARN)
523 #define alloc010(alloc, arg, sz) alloc(sz, alloc_GFP)
524 #define alloc011(alloc, arg, sz) alloc(sz, alloc_GFP, NUMA_NO_NODE)
525 #define alloc000(alloc, arg, sz) alloc(sz)
526 #define alloc001(alloc, arg, sz) alloc(sz, NUMA_NO_NODE)
527 #define alloc110(alloc, arg, sz) alloc(arg, sz, alloc_GFP)
528 #define free0(free, arg, ptr)    free(ptr)
529 #define free1(free, arg, ptr)    free(arg, ptr)
530
531 /* Wrap around to 16K */
532 #define TEST_SIZE               (5 * 4096)
533
534 #define DEFINE_TEST_ALLOC(func, free_func, want_arg, want_gfp, want_node)\
535 static void test_ ## func (struct kunit *test, void *arg)               \
536 {                                                                       \
537         volatile size_t a = TEST_SIZE;                                  \
538         volatile size_t b = (SIZE_MAX / TEST_SIZE) + 1;                 \
539         void *ptr;                                                      \
540                                                                         \
541         /* Tiny allocation test. */                                     \
542         ptr = alloc ## want_arg ## want_gfp ## want_node (func, arg, 1);\
543         KUNIT_ASSERT_NOT_ERR_OR_NULL_MSG(test, ptr,                     \
544                             #func " failed regular allocation?!\n");    \
545         free ## want_arg (free_func, arg, ptr);                         \
546                                                                         \
547         /* Wrapped allocation test. */                                  \
548         ptr = alloc ## want_arg ## want_gfp ## want_node (func, arg,    \
549                                                           a * b);       \
550         KUNIT_ASSERT_NOT_ERR_OR_NULL_MSG(test, ptr,                     \
551                             #func " unexpectedly failed bad wrapping?!\n"); \
552         free ## want_arg (free_func, arg, ptr);                         \
553                                                                         \
554         /* Saturated allocation test. */                                \
555         ptr = alloc ## want_arg ## want_gfp ## want_node (func, arg,    \
556                                                    array_size(a, b));   \
557         if (ptr) {                                                      \
558                 KUNIT_FAIL(test, #func " missed saturation!\n");        \
559                 free ## want_arg (free_func, arg, ptr);                 \
560         }                                                               \
561 }
562
563 /*
564  * Allocator uses a trailing node argument --------+  (e.g. kmalloc_node())
565  * Allocator uses the gfp_t argument -----------+  |  (e.g. kmalloc())
566  * Allocator uses a special leading argument +  |  |  (e.g. devm_kmalloc())
567  *                                           |  |  |
568  */
569 DEFINE_TEST_ALLOC(kmalloc,       kfree,      0, 1, 0);
570 DEFINE_TEST_ALLOC(kmalloc_node,  kfree,      0, 1, 1);
571 DEFINE_TEST_ALLOC(kzalloc,       kfree,      0, 1, 0);
572 DEFINE_TEST_ALLOC(kzalloc_node,  kfree,      0, 1, 1);
573 DEFINE_TEST_ALLOC(__vmalloc,     vfree,      0, 1, 0);
574 DEFINE_TEST_ALLOC(kvmalloc,      kvfree,     0, 1, 0);
575 DEFINE_TEST_ALLOC(kvmalloc_node, kvfree,     0, 1, 1);
576 DEFINE_TEST_ALLOC(kvzalloc,      kvfree,     0, 1, 0);
577 DEFINE_TEST_ALLOC(kvzalloc_node, kvfree,     0, 1, 1);
578 DEFINE_TEST_ALLOC(devm_kmalloc,  devm_kfree, 1, 1, 0);
579 DEFINE_TEST_ALLOC(devm_kzalloc,  devm_kfree, 1, 1, 0);
580
581 static void overflow_allocation_test(struct kunit *test)
582 {
583         const char device_name[] = "overflow-test";
584         struct device *dev;
585         int count = 0;
586
587 #define check_allocation_overflow(alloc)        do {    \
588         count++;                                        \
589         test_ ## alloc(test, dev);                      \
590 } while (0)
591
592         /* Create dummy device for devm_kmalloc()-family tests. */
593         dev = root_device_register(device_name);
594         KUNIT_ASSERT_FALSE_MSG(test, IS_ERR(dev),
595                                "Cannot register test device\n");
596
597         check_allocation_overflow(kmalloc);
598         check_allocation_overflow(kmalloc_node);
599         check_allocation_overflow(kzalloc);
600         check_allocation_overflow(kzalloc_node);
601         check_allocation_overflow(__vmalloc);
602         check_allocation_overflow(kvmalloc);
603         check_allocation_overflow(kvmalloc_node);
604         check_allocation_overflow(kvzalloc);
605         check_allocation_overflow(kvzalloc_node);
606         check_allocation_overflow(devm_kmalloc);
607         check_allocation_overflow(devm_kzalloc);
608
609         device_unregister(dev);
610
611         kunit_info(test, "%d allocation overflow tests finished\n", count);
612 #undef check_allocation_overflow
613 }
614
615 struct __test_flex_array {
616         unsigned long flags;
617         size_t count;
618         unsigned long data[];
619 };
620
621 static void overflow_size_helpers_test(struct kunit *test)
622 {
623         /* Make sure struct_size() can be used in a constant expression. */
624         u8 ce_array[struct_size((struct __test_flex_array *)0, data, 55)];
625         struct __test_flex_array *obj;
626         int count = 0;
627         int var;
628         volatile int unconst = 0;
629
630         /* Verify constant expression against runtime version. */
631         var = 55;
632         OPTIMIZER_HIDE_VAR(var);
633         KUNIT_EXPECT_EQ(test, sizeof(ce_array), struct_size(obj, data, var));
634
635 #define check_one_size_helper(expected, func, args...)  do {    \
636         size_t _r = func(args);                                 \
637         KUNIT_EXPECT_EQ_MSG(test, _r, expected,                 \
638                 "expected " #func "(" #args ") to return %zu but got %zu instead\n", \
639                 (size_t)(expected), _r);                        \
640         count++;                                                \
641 } while (0)
642
643         var = 4;
644         check_one_size_helper(20,       size_mul, var++, 5);
645         check_one_size_helper(20,       size_mul, 4, var++);
646         check_one_size_helper(0,        size_mul, 0, 3);
647         check_one_size_helper(0,        size_mul, 3, 0);
648         check_one_size_helper(6,        size_mul, 2, 3);
649         check_one_size_helper(SIZE_MAX, size_mul, SIZE_MAX,  1);
650         check_one_size_helper(SIZE_MAX, size_mul, SIZE_MAX,  3);
651         check_one_size_helper(SIZE_MAX, size_mul, SIZE_MAX, -3);
652
653         var = 4;
654         check_one_size_helper(9,        size_add, var++, 5);
655         check_one_size_helper(9,        size_add, 4, var++);
656         check_one_size_helper(9,        size_add, 9, 0);
657         check_one_size_helper(9,        size_add, 0, 9);
658         check_one_size_helper(5,        size_add, 2, 3);
659         check_one_size_helper(SIZE_MAX, size_add, SIZE_MAX,  1);
660         check_one_size_helper(SIZE_MAX, size_add, SIZE_MAX,  3);
661         check_one_size_helper(SIZE_MAX, size_add, SIZE_MAX, -3);
662
663         var = 4;
664         check_one_size_helper(1,        size_sub, var--, 3);
665         check_one_size_helper(1,        size_sub, 4, var--);
666         check_one_size_helper(1,        size_sub, 3, 2);
667         check_one_size_helper(9,        size_sub, 9, 0);
668         check_one_size_helper(SIZE_MAX, size_sub, 9, -3);
669         check_one_size_helper(SIZE_MAX, size_sub, 0, 9);
670         check_one_size_helper(SIZE_MAX, size_sub, 2, 3);
671         check_one_size_helper(SIZE_MAX, size_sub, SIZE_MAX,  0);
672         check_one_size_helper(SIZE_MAX, size_sub, SIZE_MAX, 10);
673         check_one_size_helper(SIZE_MAX, size_sub, 0,  SIZE_MAX);
674         check_one_size_helper(SIZE_MAX, size_sub, 14, SIZE_MAX);
675         check_one_size_helper(SIZE_MAX - 2, size_sub, SIZE_MAX - 1,  1);
676         check_one_size_helper(SIZE_MAX - 4, size_sub, SIZE_MAX - 1,  3);
677         check_one_size_helper(1,                size_sub, SIZE_MAX - 1, -3);
678
679         var = 4;
680         check_one_size_helper(4 * sizeof(*obj->data),
681                               flex_array_size, obj, data, var++);
682         check_one_size_helper(5 * sizeof(*obj->data),
683                               flex_array_size, obj, data, var++);
684         check_one_size_helper(0, flex_array_size, obj, data, 0 + unconst);
685         check_one_size_helper(sizeof(*obj->data),
686                               flex_array_size, obj, data, 1 + unconst);
687         check_one_size_helper(7 * sizeof(*obj->data),
688                               flex_array_size, obj, data, 7 + unconst);
689         check_one_size_helper(SIZE_MAX,
690                               flex_array_size, obj, data, -1 + unconst);
691         check_one_size_helper(SIZE_MAX,
692                               flex_array_size, obj, data, SIZE_MAX - 4 + unconst);
693
694         var = 4;
695         check_one_size_helper(sizeof(*obj) + (4 * sizeof(*obj->data)),
696                               struct_size, obj, data, var++);
697         check_one_size_helper(sizeof(*obj) + (5 * sizeof(*obj->data)),
698                               struct_size, obj, data, var++);
699         check_one_size_helper(sizeof(*obj), struct_size, obj, data, 0 + unconst);
700         check_one_size_helper(sizeof(*obj) + sizeof(*obj->data),
701                               struct_size, obj, data, 1 + unconst);
702         check_one_size_helper(SIZE_MAX,
703                               struct_size, obj, data, -3 + unconst);
704         check_one_size_helper(SIZE_MAX,
705                               struct_size, obj, data, SIZE_MAX - 3 + unconst);
706
707         kunit_info(test, "%d overflow size helper tests finished\n", count);
708 #undef check_one_size_helper
709 }
710
711 static struct kunit_case overflow_test_cases[] = {
712         KUNIT_CASE(u8_u8__u8_overflow_test),
713         KUNIT_CASE(s8_s8__s8_overflow_test),
714         KUNIT_CASE(u16_u16__u16_overflow_test),
715         KUNIT_CASE(s16_s16__s16_overflow_test),
716         KUNIT_CASE(u32_u32__u32_overflow_test),
717         KUNIT_CASE(s32_s32__s32_overflow_test),
718 /* Clang 13 and earlier generate unwanted libcalls on 32-bit. */
719 #if BITS_PER_LONG == 64
720         KUNIT_CASE(u64_u64__u64_overflow_test),
721         KUNIT_CASE(s64_s64__s64_overflow_test),
722 #endif
723         KUNIT_CASE(u32_u32__u8_overflow_test),
724         KUNIT_CASE(u32_u32__int_overflow_test),
725         KUNIT_CASE(u8_u8__int_overflow_test),
726         KUNIT_CASE(int_int__u8_overflow_test),
727         KUNIT_CASE(shift_sane_test),
728         KUNIT_CASE(shift_overflow_test),
729         KUNIT_CASE(shift_truncate_test),
730         KUNIT_CASE(shift_nonsense_test),
731         KUNIT_CASE(overflow_allocation_test),
732         KUNIT_CASE(overflow_size_helpers_test),
733         {}
734 };
735
736 static struct kunit_suite overflow_test_suite = {
737         .name = "overflow",
738         .test_cases = overflow_test_cases,
739 };
740
741 kunit_test_suite(overflow_test_suite);
742
743 MODULE_LICENSE("Dual MIT/GPL");