maple_tree: should get pivots boundary by type
[platform/kernel/linux-starfive.git] / lib / overflow_kunit.c
1 // SPDX-License-Identifier: GPL-2.0 OR MIT
2 /*
3  * Test cases for arithmetic overflow checks. See:
4  * "Running tests with kunit_tool" at Documentation/dev-tools/kunit/start.rst
5  *      ./tools/testing/kunit/kunit.py run overflow [--raw_output]
6  */
7 #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
8
9 #include <kunit/test.h>
10 #include <linux/device.h>
11 #include <linux/kernel.h>
12 #include <linux/mm.h>
13 #include <linux/module.h>
14 #include <linux/overflow.h>
15 #include <linux/slab.h>
16 #include <linux/types.h>
17 #include <linux/vmalloc.h>
18
19 #define SKIP(cond, reason)              do {                    \
20         if (cond) {                                             \
21                 kunit_skip(test, reason);                       \
22                 return;                                         \
23         }                                                       \
24 } while (0)
25
26 /*
27  * Clang 11 and earlier generate unwanted libcalls for signed output
28  * on unsigned input.
29  */
30 #if defined(CONFIG_CC_IS_CLANG) && __clang_major__ <= 11
31 # define SKIP_SIGN_MISMATCH(t)  SKIP(t, "Clang 11 unwanted libcalls")
32 #else
33 # define SKIP_SIGN_MISMATCH(t)  do { } while (0)
34 #endif
35
36 /*
37  * Clang 13 and earlier generate unwanted libcalls for 64-bit tests on
38  * 32-bit hosts.
39  */
40 #if defined(CONFIG_CC_IS_CLANG) && __clang_major__ <= 13 &&     \
41     BITS_PER_LONG != 64
42 # define SKIP_64_ON_32(t)       SKIP(t, "Clang 13 unwanted libcalls")
43 #else
44 # define SKIP_64_ON_32(t)       do { } while (0)
45 #endif
46
47 #define DEFINE_TEST_ARRAY_TYPED(t1, t2, t)                      \
48         static const struct test_ ## t1 ## _ ## t2 ## __ ## t { \
49                 t1 a;                                           \
50                 t2 b;                                           \
51                 t sum, diff, prod;                              \
52                 bool s_of, d_of, p_of;                          \
53         } t1 ## _ ## t2 ## __ ## t ## _tests[]
54
55 #define DEFINE_TEST_ARRAY(t)    DEFINE_TEST_ARRAY_TYPED(t, t, t)
56
57 DEFINE_TEST_ARRAY(u8) = {
58         {0, 0, 0, 0, 0, false, false, false},
59         {1, 1, 2, 0, 1, false, false, false},
60         {0, 1, 1, U8_MAX, 0, false, true, false},
61         {1, 0, 1, 1, 0, false, false, false},
62         {0, U8_MAX, U8_MAX, 1, 0, false, true, false},
63         {U8_MAX, 0, U8_MAX, U8_MAX, 0, false, false, false},
64         {1, U8_MAX, 0, 2, U8_MAX, true, true, false},
65         {U8_MAX, 1, 0, U8_MAX-1, U8_MAX, true, false, false},
66         {U8_MAX, U8_MAX, U8_MAX-1, 0, 1, true, false, true},
67
68         {U8_MAX, U8_MAX-1, U8_MAX-2, 1, 2, true, false, true},
69         {U8_MAX-1, U8_MAX, U8_MAX-2, U8_MAX, 2, true, true, true},
70
71         {1U << 3, 1U << 3, 1U << 4, 0, 1U << 6, false, false, false},
72         {1U << 4, 1U << 4, 1U << 5, 0, 0, false, false, true},
73         {1U << 4, 1U << 3, 3*(1U << 3), 1U << 3, 1U << 7, false, false, false},
74         {1U << 7, 1U << 7, 0, 0, 0, true, false, true},
75
76         {48, 32, 80, 16, 0, false, false, true},
77         {128, 128, 0, 0, 0, true, false, true},
78         {123, 234, 101, 145, 110, true, true, true},
79 };
80 DEFINE_TEST_ARRAY(u16) = {
81         {0, 0, 0, 0, 0, false, false, false},
82         {1, 1, 2, 0, 1, false, false, false},
83         {0, 1, 1, U16_MAX, 0, false, true, false},
84         {1, 0, 1, 1, 0, false, false, false},
85         {0, U16_MAX, U16_MAX, 1, 0, false, true, false},
86         {U16_MAX, 0, U16_MAX, U16_MAX, 0, false, false, false},
87         {1, U16_MAX, 0, 2, U16_MAX, true, true, false},
88         {U16_MAX, 1, 0, U16_MAX-1, U16_MAX, true, false, false},
89         {U16_MAX, U16_MAX, U16_MAX-1, 0, 1, true, false, true},
90
91         {U16_MAX, U16_MAX-1, U16_MAX-2, 1, 2, true, false, true},
92         {U16_MAX-1, U16_MAX, U16_MAX-2, U16_MAX, 2, true, true, true},
93
94         {1U << 7, 1U << 7, 1U << 8, 0, 1U << 14, false, false, false},
95         {1U << 8, 1U << 8, 1U << 9, 0, 0, false, false, true},
96         {1U << 8, 1U << 7, 3*(1U << 7), 1U << 7, 1U << 15, false, false, false},
97         {1U << 15, 1U << 15, 0, 0, 0, true, false, true},
98
99         {123, 234, 357, 65425, 28782, false, true, false},
100         {1234, 2345, 3579, 64425, 10146, false, true, true},
101 };
102 DEFINE_TEST_ARRAY(u32) = {
103         {0, 0, 0, 0, 0, false, false, false},
104         {1, 1, 2, 0, 1, false, false, false},
105         {0, 1, 1, U32_MAX, 0, false, true, false},
106         {1, 0, 1, 1, 0, false, false, false},
107         {0, U32_MAX, U32_MAX, 1, 0, false, true, false},
108         {U32_MAX, 0, U32_MAX, U32_MAX, 0, false, false, false},
109         {1, U32_MAX, 0, 2, U32_MAX, true, true, false},
110         {U32_MAX, 1, 0, U32_MAX-1, U32_MAX, true, false, false},
111         {U32_MAX, U32_MAX, U32_MAX-1, 0, 1, true, false, true},
112
113         {U32_MAX, U32_MAX-1, U32_MAX-2, 1, 2, true, false, true},
114         {U32_MAX-1, U32_MAX, U32_MAX-2, U32_MAX, 2, true, true, true},
115
116         {1U << 15, 1U << 15, 1U << 16, 0, 1U << 30, false, false, false},
117         {1U << 16, 1U << 16, 1U << 17, 0, 0, false, false, true},
118         {1U << 16, 1U << 15, 3*(1U << 15), 1U << 15, 1U << 31, false, false, false},
119         {1U << 31, 1U << 31, 0, 0, 0, true, false, true},
120
121         {-2U, 1U, -1U, -3U, -2U, false, false, false},
122         {-4U, 5U, 1U, -9U, -20U, true, false, true},
123 };
124
125 DEFINE_TEST_ARRAY(u64) = {
126         {0, 0, 0, 0, 0, false, false, false},
127         {1, 1, 2, 0, 1, false, false, false},
128         {0, 1, 1, U64_MAX, 0, false, true, false},
129         {1, 0, 1, 1, 0, false, false, false},
130         {0, U64_MAX, U64_MAX, 1, 0, false, true, false},
131         {U64_MAX, 0, U64_MAX, U64_MAX, 0, false, false, false},
132         {1, U64_MAX, 0, 2, U64_MAX, true, true, false},
133         {U64_MAX, 1, 0, U64_MAX-1, U64_MAX, true, false, false},
134         {U64_MAX, U64_MAX, U64_MAX-1, 0, 1, true, false, true},
135
136         {U64_MAX, U64_MAX-1, U64_MAX-2, 1, 2, true, false, true},
137         {U64_MAX-1, U64_MAX, U64_MAX-2, U64_MAX, 2, true, true, true},
138
139         {1ULL << 31, 1ULL << 31, 1ULL << 32, 0, 1ULL << 62, false, false, false},
140         {1ULL << 32, 1ULL << 32, 1ULL << 33, 0, 0, false, false, true},
141         {1ULL << 32, 1ULL << 31, 3*(1ULL << 31), 1ULL << 31, 1ULL << 63, false, false, false},
142         {1ULL << 63, 1ULL << 63, 0, 0, 0, true, false, true},
143         {1000000000ULL /* 10^9 */, 10000000000ULL /* 10^10 */,
144          11000000000ULL, 18446744064709551616ULL, 10000000000000000000ULL,
145          false, true, false},
146         {-15ULL, 10ULL, -5ULL, -25ULL, -150ULL, false, false, true},
147 };
148
149 DEFINE_TEST_ARRAY(s8) = {
150         {0, 0, 0, 0, 0, false, false, false},
151
152         {0, S8_MAX, S8_MAX, -S8_MAX, 0, false, false, false},
153         {S8_MAX, 0, S8_MAX, S8_MAX, 0, false, false, false},
154         {0, S8_MIN, S8_MIN, S8_MIN, 0, false, true, false},
155         {S8_MIN, 0, S8_MIN, S8_MIN, 0, false, false, false},
156
157         {-1, S8_MIN, S8_MAX, S8_MAX, S8_MIN, true, false, true},
158         {S8_MIN, -1, S8_MAX, -S8_MAX, S8_MIN, true, false, true},
159         {-1, S8_MAX, S8_MAX-1, S8_MIN, -S8_MAX, false, false, false},
160         {S8_MAX, -1, S8_MAX-1, S8_MIN, -S8_MAX, false, true, false},
161         {-1, -S8_MAX, S8_MIN, S8_MAX-1, S8_MAX, false, false, false},
162         {-S8_MAX, -1, S8_MIN, S8_MIN+2, S8_MAX, false, false, false},
163
164         {1, S8_MIN, -S8_MAX, -S8_MAX, S8_MIN, false, true, false},
165         {S8_MIN, 1, -S8_MAX, S8_MAX, S8_MIN, false, true, false},
166         {1, S8_MAX, S8_MIN, S8_MIN+2, S8_MAX, true, false, false},
167         {S8_MAX, 1, S8_MIN, S8_MAX-1, S8_MAX, true, false, false},
168
169         {S8_MIN, S8_MIN, 0, 0, 0, true, false, true},
170         {S8_MAX, S8_MAX, -2, 0, 1, true, false, true},
171
172         {-4, -32, -36, 28, -128, false, false, true},
173         {-4, 32, 28, -36, -128, false, false, false},
174 };
175
176 DEFINE_TEST_ARRAY(s16) = {
177         {0, 0, 0, 0, 0, false, false, false},
178
179         {0, S16_MAX, S16_MAX, -S16_MAX, 0, false, false, false},
180         {S16_MAX, 0, S16_MAX, S16_MAX, 0, false, false, false},
181         {0, S16_MIN, S16_MIN, S16_MIN, 0, false, true, false},
182         {S16_MIN, 0, S16_MIN, S16_MIN, 0, false, false, false},
183
184         {-1, S16_MIN, S16_MAX, S16_MAX, S16_MIN, true, false, true},
185         {S16_MIN, -1, S16_MAX, -S16_MAX, S16_MIN, true, false, true},
186         {-1, S16_MAX, S16_MAX-1, S16_MIN, -S16_MAX, false, false, false},
187         {S16_MAX, -1, S16_MAX-1, S16_MIN, -S16_MAX, false, true, false},
188         {-1, -S16_MAX, S16_MIN, S16_MAX-1, S16_MAX, false, false, false},
189         {-S16_MAX, -1, S16_MIN, S16_MIN+2, S16_MAX, false, false, false},
190
191         {1, S16_MIN, -S16_MAX, -S16_MAX, S16_MIN, false, true, false},
192         {S16_MIN, 1, -S16_MAX, S16_MAX, S16_MIN, false, true, false},
193         {1, S16_MAX, S16_MIN, S16_MIN+2, S16_MAX, true, false, false},
194         {S16_MAX, 1, S16_MIN, S16_MAX-1, S16_MAX, true, false, false},
195
196         {S16_MIN, S16_MIN, 0, 0, 0, true, false, true},
197         {S16_MAX, S16_MAX, -2, 0, 1, true, false, true},
198 };
199 DEFINE_TEST_ARRAY(s32) = {
200         {0, 0, 0, 0, 0, false, false, false},
201
202         {0, S32_MAX, S32_MAX, -S32_MAX, 0, false, false, false},
203         {S32_MAX, 0, S32_MAX, S32_MAX, 0, false, false, false},
204         {0, S32_MIN, S32_MIN, S32_MIN, 0, false, true, false},
205         {S32_MIN, 0, S32_MIN, S32_MIN, 0, false, false, false},
206
207         {-1, S32_MIN, S32_MAX, S32_MAX, S32_MIN, true, false, true},
208         {S32_MIN, -1, S32_MAX, -S32_MAX, S32_MIN, true, false, true},
209         {-1, S32_MAX, S32_MAX-1, S32_MIN, -S32_MAX, false, false, false},
210         {S32_MAX, -1, S32_MAX-1, S32_MIN, -S32_MAX, false, true, false},
211         {-1, -S32_MAX, S32_MIN, S32_MAX-1, S32_MAX, false, false, false},
212         {-S32_MAX, -1, S32_MIN, S32_MIN+2, S32_MAX, false, false, false},
213
214         {1, S32_MIN, -S32_MAX, -S32_MAX, S32_MIN, false, true, false},
215         {S32_MIN, 1, -S32_MAX, S32_MAX, S32_MIN, false, true, false},
216         {1, S32_MAX, S32_MIN, S32_MIN+2, S32_MAX, true, false, false},
217         {S32_MAX, 1, S32_MIN, S32_MAX-1, S32_MAX, true, false, false},
218
219         {S32_MIN, S32_MIN, 0, 0, 0, true, false, true},
220         {S32_MAX, S32_MAX, -2, 0, 1, true, false, true},
221 };
222
223 DEFINE_TEST_ARRAY(s64) = {
224         {0, 0, 0, 0, 0, false, false, false},
225
226         {0, S64_MAX, S64_MAX, -S64_MAX, 0, false, false, false},
227         {S64_MAX, 0, S64_MAX, S64_MAX, 0, false, false, false},
228         {0, S64_MIN, S64_MIN, S64_MIN, 0, false, true, false},
229         {S64_MIN, 0, S64_MIN, S64_MIN, 0, false, false, false},
230
231         {-1, S64_MIN, S64_MAX, S64_MAX, S64_MIN, true, false, true},
232         {S64_MIN, -1, S64_MAX, -S64_MAX, S64_MIN, true, false, true},
233         {-1, S64_MAX, S64_MAX-1, S64_MIN, -S64_MAX, false, false, false},
234         {S64_MAX, -1, S64_MAX-1, S64_MIN, -S64_MAX, false, true, false},
235         {-1, -S64_MAX, S64_MIN, S64_MAX-1, S64_MAX, false, false, false},
236         {-S64_MAX, -1, S64_MIN, S64_MIN+2, S64_MAX, false, false, false},
237
238         {1, S64_MIN, -S64_MAX, -S64_MAX, S64_MIN, false, true, false},
239         {S64_MIN, 1, -S64_MAX, S64_MAX, S64_MIN, false, true, false},
240         {1, S64_MAX, S64_MIN, S64_MIN+2, S64_MAX, true, false, false},
241         {S64_MAX, 1, S64_MIN, S64_MAX-1, S64_MAX, true, false, false},
242
243         {S64_MIN, S64_MIN, 0, 0, 0, true, false, true},
244         {S64_MAX, S64_MAX, -2, 0, 1, true, false, true},
245
246         {-1, -1, -2, 0, 1, false, false, false},
247         {-1, -128, -129, 127, 128, false, false, false},
248         {-128, -1, -129, -127, 128, false, false, false},
249         {0, -S64_MAX, -S64_MAX, S64_MAX, 0, false, false, false},
250 };
251
252 #define check_one_op(t, fmt, op, sym, a, b, r, of) do {                 \
253         int _a_orig = a, _a_bump = a + 1;                               \
254         int _b_orig = b, _b_bump = b + 1;                               \
255         bool _of;                                                       \
256         t _r;                                                           \
257                                                                         \
258         _of = check_ ## op ## _overflow(a, b, &_r);                     \
259         KUNIT_EXPECT_EQ_MSG(test, _of, of,                              \
260                 "expected "fmt" "sym" "fmt" to%s overflow (type %s)\n", \
261                 a, b, of ? "" : " not", #t);                            \
262         KUNIT_EXPECT_EQ_MSG(test, _r, r,                                \
263                 "expected "fmt" "sym" "fmt" == "fmt", got "fmt" (type %s)\n", \
264                 a, b, r, _r, #t);                                       \
265         /* Check for internal macro side-effects. */                    \
266         _of = check_ ## op ## _overflow(_a_orig++, _b_orig++, &_r);     \
267         KUNIT_EXPECT_EQ_MSG(test, _a_orig, _a_bump, "Unexpected " #op " macro side-effect!\n"); \
268         KUNIT_EXPECT_EQ_MSG(test, _b_orig, _b_bump, "Unexpected " #op " macro side-effect!\n"); \
269 } while (0)
270
271 #define DEFINE_TEST_FUNC_TYPED(n, t, fmt)                               \
272 static void do_test_ ## n(struct kunit *test, const struct test_ ## n *p) \
273 {                                                                       \
274         check_one_op(t, fmt, add, "+", p->a, p->b, p->sum, p->s_of);    \
275         check_one_op(t, fmt, add, "+", p->b, p->a, p->sum, p->s_of);    \
276         check_one_op(t, fmt, sub, "-", p->a, p->b, p->diff, p->d_of);   \
277         check_one_op(t, fmt, mul, "*", p->a, p->b, p->prod, p->p_of);   \
278         check_one_op(t, fmt, mul, "*", p->b, p->a, p->prod, p->p_of);   \
279 }                                                                       \
280                                                                         \
281 static void n ## _overflow_test(struct kunit *test) {                   \
282         unsigned i;                                                     \
283                                                                         \
284         SKIP_64_ON_32(__same_type(t, u64));                             \
285         SKIP_64_ON_32(__same_type(t, s64));                             \
286         SKIP_SIGN_MISMATCH(__same_type(n ## _tests[0].a, u32) &&        \
287                            __same_type(n ## _tests[0].b, u32) &&        \
288                            __same_type(n ## _tests[0].sum, int));       \
289                                                                         \
290         for (i = 0; i < ARRAY_SIZE(n ## _tests); ++i)                   \
291                 do_test_ ## n(test, &n ## _tests[i]);                   \
292         kunit_info(test, "%zu %s arithmetic tests finished\n",          \
293                 ARRAY_SIZE(n ## _tests), #n);                           \
294 }
295
296 #define DEFINE_TEST_FUNC(t, fmt)                                        \
297         DEFINE_TEST_FUNC_TYPED(t ## _ ## t ## __ ## t, t, fmt)
298
299 DEFINE_TEST_FUNC(u8, "%d");
300 DEFINE_TEST_FUNC(s8, "%d");
301 DEFINE_TEST_FUNC(u16, "%d");
302 DEFINE_TEST_FUNC(s16, "%d");
303 DEFINE_TEST_FUNC(u32, "%u");
304 DEFINE_TEST_FUNC(s32, "%d");
305 DEFINE_TEST_FUNC(u64, "%llu");
306 DEFINE_TEST_FUNC(s64, "%lld");
307
308 DEFINE_TEST_ARRAY_TYPED(u32, u32, u8) = {
309         {0, 0, 0, 0, 0, false, false, false},
310         {U8_MAX, 2, 1, U8_MAX - 2, U8_MAX - 1, true, false, true},
311         {U8_MAX + 1, 0, 0, 0, 0, true, true, false},
312 };
313 DEFINE_TEST_FUNC_TYPED(u32_u32__u8, u8, "%d");
314
315 DEFINE_TEST_ARRAY_TYPED(u32, u32, int) = {
316         {0, 0, 0, 0, 0, false, false, false},
317         {U32_MAX, 0, -1, -1, 0, true, true, false},
318 };
319 DEFINE_TEST_FUNC_TYPED(u32_u32__int, int, "%d");
320
321 DEFINE_TEST_ARRAY_TYPED(u8, u8, int) = {
322         {0, 0, 0, 0, 0, false, false, false},
323         {U8_MAX, U8_MAX, 2 * U8_MAX, 0, U8_MAX * U8_MAX, false, false, false},
324         {1, 2, 3, -1, 2, false, false, false},
325 };
326 DEFINE_TEST_FUNC_TYPED(u8_u8__int, int, "%d");
327
328 DEFINE_TEST_ARRAY_TYPED(int, int, u8) = {
329         {0, 0, 0, 0, 0, false, false, false},
330         {1, 2, 3, U8_MAX, 2, false, true, false},
331         {-1, 0, U8_MAX, U8_MAX, 0, true, true, false},
332 };
333 DEFINE_TEST_FUNC_TYPED(int_int__u8, u8, "%d");
334
335 /* Args are: value, shift, type, expected result, overflow expected */
336 #define TEST_ONE_SHIFT(a, s, t, expect, of)     do {                    \
337         typeof(a) __a = (a);                                            \
338         typeof(s) __s = (s);                                            \
339         t __e = (expect);                                               \
340         t __d;                                                          \
341         bool __of = check_shl_overflow(__a, __s, &__d);                 \
342         if (__of != of) {                                               \
343                 KUNIT_EXPECT_EQ_MSG(test, __of, of,                     \
344                         "expected (%s)(%s << %s) to%s overflow\n",      \
345                         #t, #a, #s, of ? "" : " not");                  \
346         } else if (!__of && __d != __e) {                               \
347                 KUNIT_EXPECT_EQ_MSG(test, __d, __e,                     \
348                         "expected (%s)(%s << %s) == %s\n",              \
349                         #t, #a, #s, #expect);                           \
350                 if ((t)-1 < 0)                                          \
351                         kunit_info(test, "got %lld\n", (s64)__d);       \
352                 else                                                    \
353                         kunit_info(test, "got %llu\n", (u64)__d);       \
354         }                                                               \
355         count++;                                                        \
356 } while (0)
357
358 static void shift_sane_test(struct kunit *test)
359 {
360         int count = 0;
361
362         /* Sane shifts. */
363         TEST_ONE_SHIFT(1, 0, u8, 1 << 0, false);
364         TEST_ONE_SHIFT(1, 4, u8, 1 << 4, false);
365         TEST_ONE_SHIFT(1, 7, u8, 1 << 7, false);
366         TEST_ONE_SHIFT(0xF, 4, u8, 0xF << 4, false);
367         TEST_ONE_SHIFT(1, 0, u16, 1 << 0, false);
368         TEST_ONE_SHIFT(1, 10, u16, 1 << 10, false);
369         TEST_ONE_SHIFT(1, 15, u16, 1 << 15, false);
370         TEST_ONE_SHIFT(0xFF, 8, u16, 0xFF << 8, false);
371         TEST_ONE_SHIFT(1, 0, int, 1 << 0, false);
372         TEST_ONE_SHIFT(1, 16, int, 1 << 16, false);
373         TEST_ONE_SHIFT(1, 30, int, 1 << 30, false);
374         TEST_ONE_SHIFT(1, 0, s32, 1 << 0, false);
375         TEST_ONE_SHIFT(1, 16, s32, 1 << 16, false);
376         TEST_ONE_SHIFT(1, 30, s32, 1 << 30, false);
377         TEST_ONE_SHIFT(1, 0, unsigned int, 1U << 0, false);
378         TEST_ONE_SHIFT(1, 20, unsigned int, 1U << 20, false);
379         TEST_ONE_SHIFT(1, 31, unsigned int, 1U << 31, false);
380         TEST_ONE_SHIFT(0xFFFFU, 16, unsigned int, 0xFFFFU << 16, false);
381         TEST_ONE_SHIFT(1, 0, u32, 1U << 0, false);
382         TEST_ONE_SHIFT(1, 20, u32, 1U << 20, false);
383         TEST_ONE_SHIFT(1, 31, u32, 1U << 31, false);
384         TEST_ONE_SHIFT(0xFFFFU, 16, u32, 0xFFFFU << 16, false);
385         TEST_ONE_SHIFT(1, 0, u64, 1ULL << 0, false);
386         TEST_ONE_SHIFT(1, 40, u64, 1ULL << 40, false);
387         TEST_ONE_SHIFT(1, 63, u64, 1ULL << 63, false);
388         TEST_ONE_SHIFT(0xFFFFFFFFULL, 32, u64, 0xFFFFFFFFULL << 32, false);
389
390         /* Sane shift: start and end with 0, without a too-wide shift. */
391         TEST_ONE_SHIFT(0, 7, u8, 0, false);
392         TEST_ONE_SHIFT(0, 15, u16, 0, false);
393         TEST_ONE_SHIFT(0, 31, unsigned int, 0, false);
394         TEST_ONE_SHIFT(0, 31, u32, 0, false);
395         TEST_ONE_SHIFT(0, 63, u64, 0, false);
396
397         /* Sane shift: start and end with 0, without reaching signed bit. */
398         TEST_ONE_SHIFT(0, 6, s8, 0, false);
399         TEST_ONE_SHIFT(0, 14, s16, 0, false);
400         TEST_ONE_SHIFT(0, 30, int, 0, false);
401         TEST_ONE_SHIFT(0, 30, s32, 0, false);
402         TEST_ONE_SHIFT(0, 62, s64, 0, false);
403
404         kunit_info(test, "%d sane shift tests finished\n", count);
405 }
406
407 static void shift_overflow_test(struct kunit *test)
408 {
409         int count = 0;
410
411         /* Overflow: shifted the bit off the end. */
412         TEST_ONE_SHIFT(1, 8, u8, 0, true);
413         TEST_ONE_SHIFT(1, 16, u16, 0, true);
414         TEST_ONE_SHIFT(1, 32, unsigned int, 0, true);
415         TEST_ONE_SHIFT(1, 32, u32, 0, true);
416         TEST_ONE_SHIFT(1, 64, u64, 0, true);
417
418         /* Overflow: shifted into the signed bit. */
419         TEST_ONE_SHIFT(1, 7, s8, 0, true);
420         TEST_ONE_SHIFT(1, 15, s16, 0, true);
421         TEST_ONE_SHIFT(1, 31, int, 0, true);
422         TEST_ONE_SHIFT(1, 31, s32, 0, true);
423         TEST_ONE_SHIFT(1, 63, s64, 0, true);
424
425         /* Overflow: high bit falls off unsigned types. */
426         /* 10010110 */
427         TEST_ONE_SHIFT(150, 1, u8, 0, true);
428         /* 1000100010010110 */
429         TEST_ONE_SHIFT(34966, 1, u16, 0, true);
430         /* 10000100000010001000100010010110 */
431         TEST_ONE_SHIFT(2215151766U, 1, u32, 0, true);
432         TEST_ONE_SHIFT(2215151766U, 1, unsigned int, 0, true);
433         /* 1000001000010000010000000100000010000100000010001000100010010110 */
434         TEST_ONE_SHIFT(9372061470395238550ULL, 1, u64, 0, true);
435
436         /* Overflow: bit shifted into signed bit on signed types. */
437         /* 01001011 */
438         TEST_ONE_SHIFT(75, 1, s8, 0, true);
439         /* 0100010001001011 */
440         TEST_ONE_SHIFT(17483, 1, s16, 0, true);
441         /* 01000010000001000100010001001011 */
442         TEST_ONE_SHIFT(1107575883, 1, s32, 0, true);
443         TEST_ONE_SHIFT(1107575883, 1, int, 0, true);
444         /* 0100000100001000001000000010000001000010000001000100010001001011 */
445         TEST_ONE_SHIFT(4686030735197619275LL, 1, s64, 0, true);
446
447         /* Overflow: bit shifted past signed bit on signed types. */
448         /* 01001011 */
449         TEST_ONE_SHIFT(75, 2, s8, 0, true);
450         /* 0100010001001011 */
451         TEST_ONE_SHIFT(17483, 2, s16, 0, true);
452         /* 01000010000001000100010001001011 */
453         TEST_ONE_SHIFT(1107575883, 2, s32, 0, true);
454         TEST_ONE_SHIFT(1107575883, 2, int, 0, true);
455         /* 0100000100001000001000000010000001000010000001000100010001001011 */
456         TEST_ONE_SHIFT(4686030735197619275LL, 2, s64, 0, true);
457
458         kunit_info(test, "%d overflow shift tests finished\n", count);
459 }
460
461 static void shift_truncate_test(struct kunit *test)
462 {
463         int count = 0;
464
465         /* Overflow: values larger than destination type. */
466         TEST_ONE_SHIFT(0x100, 0, u8, 0, true);
467         TEST_ONE_SHIFT(0xFF, 0, s8, 0, true);
468         TEST_ONE_SHIFT(0x10000U, 0, u16, 0, true);
469         TEST_ONE_SHIFT(0xFFFFU, 0, s16, 0, true);
470         TEST_ONE_SHIFT(0x100000000ULL, 0, u32, 0, true);
471         TEST_ONE_SHIFT(0x100000000ULL, 0, unsigned int, 0, true);
472         TEST_ONE_SHIFT(0xFFFFFFFFUL, 0, s32, 0, true);
473         TEST_ONE_SHIFT(0xFFFFFFFFUL, 0, int, 0, true);
474         TEST_ONE_SHIFT(0xFFFFFFFFFFFFFFFFULL, 0, s64, 0, true);
475
476         /* Overflow: shifted at or beyond entire type's bit width. */
477         TEST_ONE_SHIFT(0, 8, u8, 0, true);
478         TEST_ONE_SHIFT(0, 9, u8, 0, true);
479         TEST_ONE_SHIFT(0, 8, s8, 0, true);
480         TEST_ONE_SHIFT(0, 9, s8, 0, true);
481         TEST_ONE_SHIFT(0, 16, u16, 0, true);
482         TEST_ONE_SHIFT(0, 17, u16, 0, true);
483         TEST_ONE_SHIFT(0, 16, s16, 0, true);
484         TEST_ONE_SHIFT(0, 17, s16, 0, true);
485         TEST_ONE_SHIFT(0, 32, u32, 0, true);
486         TEST_ONE_SHIFT(0, 33, u32, 0, true);
487         TEST_ONE_SHIFT(0, 32, int, 0, true);
488         TEST_ONE_SHIFT(0, 33, int, 0, true);
489         TEST_ONE_SHIFT(0, 32, s32, 0, true);
490         TEST_ONE_SHIFT(0, 33, s32, 0, true);
491         TEST_ONE_SHIFT(0, 64, u64, 0, true);
492         TEST_ONE_SHIFT(0, 65, u64, 0, true);
493         TEST_ONE_SHIFT(0, 64, s64, 0, true);
494         TEST_ONE_SHIFT(0, 65, s64, 0, true);
495
496         kunit_info(test, "%d truncate shift tests finished\n", count);
497 }
498
499 static void shift_nonsense_test(struct kunit *test)
500 {
501         int count = 0;
502
503         /* Nonsense: negative initial value. */
504         TEST_ONE_SHIFT(-1, 0, s8, 0, true);
505         TEST_ONE_SHIFT(-1, 0, u8, 0, true);
506         TEST_ONE_SHIFT(-5, 0, s16, 0, true);
507         TEST_ONE_SHIFT(-5, 0, u16, 0, true);
508         TEST_ONE_SHIFT(-10, 0, int, 0, true);
509         TEST_ONE_SHIFT(-10, 0, unsigned int, 0, true);
510         TEST_ONE_SHIFT(-100, 0, s32, 0, true);
511         TEST_ONE_SHIFT(-100, 0, u32, 0, true);
512         TEST_ONE_SHIFT(-10000, 0, s64, 0, true);
513         TEST_ONE_SHIFT(-10000, 0, u64, 0, true);
514
515         /* Nonsense: negative shift values. */
516         TEST_ONE_SHIFT(0, -5, s8, 0, true);
517         TEST_ONE_SHIFT(0, -5, u8, 0, true);
518         TEST_ONE_SHIFT(0, -10, s16, 0, true);
519         TEST_ONE_SHIFT(0, -10, u16, 0, true);
520         TEST_ONE_SHIFT(0, -15, int, 0, true);
521         TEST_ONE_SHIFT(0, -15, unsigned int, 0, true);
522         TEST_ONE_SHIFT(0, -20, s32, 0, true);
523         TEST_ONE_SHIFT(0, -20, u32, 0, true);
524         TEST_ONE_SHIFT(0, -30, s64, 0, true);
525         TEST_ONE_SHIFT(0, -30, u64, 0, true);
526
527         /*
528          * Corner case: for unsigned types, we fail when we've shifted
529          * through the entire width of bits. For signed types, we might
530          * want to match this behavior, but that would mean noticing if
531          * we shift through all but the signed bit, and this is not
532          * currently detected (but we'll notice an overflow into the
533          * signed bit). So, for now, we will test this condition but
534          * mark it as not expected to overflow.
535          */
536         TEST_ONE_SHIFT(0, 7, s8, 0, false);
537         TEST_ONE_SHIFT(0, 15, s16, 0, false);
538         TEST_ONE_SHIFT(0, 31, int, 0, false);
539         TEST_ONE_SHIFT(0, 31, s32, 0, false);
540         TEST_ONE_SHIFT(0, 63, s64, 0, false);
541
542         kunit_info(test, "%d nonsense shift tests finished\n", count);
543 }
544 #undef TEST_ONE_SHIFT
545
546 /*
547  * Deal with the various forms of allocator arguments. See comments above
548  * the DEFINE_TEST_ALLOC() instances for mapping of the "bits".
549  */
550 #define alloc_GFP                (GFP_KERNEL | __GFP_NOWARN)
551 #define alloc010(alloc, arg, sz) alloc(sz, alloc_GFP)
552 #define alloc011(alloc, arg, sz) alloc(sz, alloc_GFP, NUMA_NO_NODE)
553 #define alloc000(alloc, arg, sz) alloc(sz)
554 #define alloc001(alloc, arg, sz) alloc(sz, NUMA_NO_NODE)
555 #define alloc110(alloc, arg, sz) alloc(arg, sz, alloc_GFP)
556 #define free0(free, arg, ptr)    free(ptr)
557 #define free1(free, arg, ptr)    free(arg, ptr)
558
559 /* Wrap around to 16K */
560 #define TEST_SIZE               (5 * 4096)
561
562 #define DEFINE_TEST_ALLOC(func, free_func, want_arg, want_gfp, want_node)\
563 static void test_ ## func (struct kunit *test, void *arg)               \
564 {                                                                       \
565         volatile size_t a = TEST_SIZE;                                  \
566         volatile size_t b = (SIZE_MAX / TEST_SIZE) + 1;                 \
567         void *ptr;                                                      \
568                                                                         \
569         /* Tiny allocation test. */                                     \
570         ptr = alloc ## want_arg ## want_gfp ## want_node (func, arg, 1);\
571         KUNIT_ASSERT_NOT_ERR_OR_NULL_MSG(test, ptr,                     \
572                             #func " failed regular allocation?!\n");    \
573         free ## want_arg (free_func, arg, ptr);                         \
574                                                                         \
575         /* Wrapped allocation test. */                                  \
576         ptr = alloc ## want_arg ## want_gfp ## want_node (func, arg,    \
577                                                           a * b);       \
578         KUNIT_ASSERT_NOT_ERR_OR_NULL_MSG(test, ptr,                     \
579                             #func " unexpectedly failed bad wrapping?!\n"); \
580         free ## want_arg (free_func, arg, ptr);                         \
581                                                                         \
582         /* Saturated allocation test. */                                \
583         ptr = alloc ## want_arg ## want_gfp ## want_node (func, arg,    \
584                                                    array_size(a, b));   \
585         if (ptr) {                                                      \
586                 KUNIT_FAIL(test, #func " missed saturation!\n");        \
587                 free ## want_arg (free_func, arg, ptr);                 \
588         }                                                               \
589 }
590
591 /*
592  * Allocator uses a trailing node argument --------+  (e.g. kmalloc_node())
593  * Allocator uses the gfp_t argument -----------+  |  (e.g. kmalloc())
594  * Allocator uses a special leading argument +  |  |  (e.g. devm_kmalloc())
595  *                                           |  |  |
596  */
597 DEFINE_TEST_ALLOC(kmalloc,       kfree,      0, 1, 0);
598 DEFINE_TEST_ALLOC(kmalloc_node,  kfree,      0, 1, 1);
599 DEFINE_TEST_ALLOC(kzalloc,       kfree,      0, 1, 0);
600 DEFINE_TEST_ALLOC(kzalloc_node,  kfree,      0, 1, 1);
601 DEFINE_TEST_ALLOC(__vmalloc,     vfree,      0, 1, 0);
602 DEFINE_TEST_ALLOC(kvmalloc,      kvfree,     0, 1, 0);
603 DEFINE_TEST_ALLOC(kvmalloc_node, kvfree,     0, 1, 1);
604 DEFINE_TEST_ALLOC(kvzalloc,      kvfree,     0, 1, 0);
605 DEFINE_TEST_ALLOC(kvzalloc_node, kvfree,     0, 1, 1);
606 DEFINE_TEST_ALLOC(devm_kmalloc,  devm_kfree, 1, 1, 0);
607 DEFINE_TEST_ALLOC(devm_kzalloc,  devm_kfree, 1, 1, 0);
608
609 static void overflow_allocation_test(struct kunit *test)
610 {
611         const char device_name[] = "overflow-test";
612         struct device *dev;
613         int count = 0;
614
615 #define check_allocation_overflow(alloc)        do {    \
616         count++;                                        \
617         test_ ## alloc(test, dev);                      \
618 } while (0)
619
620         /* Create dummy device for devm_kmalloc()-family tests. */
621         dev = root_device_register(device_name);
622         KUNIT_ASSERT_FALSE_MSG(test, IS_ERR(dev),
623                                "Cannot register test device\n");
624
625         check_allocation_overflow(kmalloc);
626         check_allocation_overflow(kmalloc_node);
627         check_allocation_overflow(kzalloc);
628         check_allocation_overflow(kzalloc_node);
629         check_allocation_overflow(__vmalloc);
630         check_allocation_overflow(kvmalloc);
631         check_allocation_overflow(kvmalloc_node);
632         check_allocation_overflow(kvzalloc);
633         check_allocation_overflow(kvzalloc_node);
634         check_allocation_overflow(devm_kmalloc);
635         check_allocation_overflow(devm_kzalloc);
636
637         device_unregister(dev);
638
639         kunit_info(test, "%d allocation overflow tests finished\n", count);
640 #undef check_allocation_overflow
641 }
642
643 struct __test_flex_array {
644         unsigned long flags;
645         size_t count;
646         unsigned long data[];
647 };
648
649 static void overflow_size_helpers_test(struct kunit *test)
650 {
651         /* Make sure struct_size() can be used in a constant expression. */
652         u8 ce_array[struct_size((struct __test_flex_array *)0, data, 55)];
653         struct __test_flex_array *obj;
654         int count = 0;
655         int var;
656         volatile int unconst = 0;
657
658         /* Verify constant expression against runtime version. */
659         var = 55;
660         OPTIMIZER_HIDE_VAR(var);
661         KUNIT_EXPECT_EQ(test, sizeof(ce_array), struct_size(obj, data, var));
662
663 #define check_one_size_helper(expected, func, args...)  do {    \
664         size_t _r = func(args);                                 \
665         KUNIT_EXPECT_EQ_MSG(test, _r, expected,                 \
666                 "expected " #func "(" #args ") to return %zu but got %zu instead\n", \
667                 (size_t)(expected), _r);                        \
668         count++;                                                \
669 } while (0)
670
671         var = 4;
672         check_one_size_helper(20,       size_mul, var++, 5);
673         check_one_size_helper(20,       size_mul, 4, var++);
674         check_one_size_helper(0,        size_mul, 0, 3);
675         check_one_size_helper(0,        size_mul, 3, 0);
676         check_one_size_helper(6,        size_mul, 2, 3);
677         check_one_size_helper(SIZE_MAX, size_mul, SIZE_MAX,  1);
678         check_one_size_helper(SIZE_MAX, size_mul, SIZE_MAX,  3);
679         check_one_size_helper(SIZE_MAX, size_mul, SIZE_MAX, -3);
680
681         var = 4;
682         check_one_size_helper(9,        size_add, var++, 5);
683         check_one_size_helper(9,        size_add, 4, var++);
684         check_one_size_helper(9,        size_add, 9, 0);
685         check_one_size_helper(9,        size_add, 0, 9);
686         check_one_size_helper(5,        size_add, 2, 3);
687         check_one_size_helper(SIZE_MAX, size_add, SIZE_MAX,  1);
688         check_one_size_helper(SIZE_MAX, size_add, SIZE_MAX,  3);
689         check_one_size_helper(SIZE_MAX, size_add, SIZE_MAX, -3);
690
691         var = 4;
692         check_one_size_helper(1,        size_sub, var--, 3);
693         check_one_size_helper(1,        size_sub, 4, var--);
694         check_one_size_helper(1,        size_sub, 3, 2);
695         check_one_size_helper(9,        size_sub, 9, 0);
696         check_one_size_helper(SIZE_MAX, size_sub, 9, -3);
697         check_one_size_helper(SIZE_MAX, size_sub, 0, 9);
698         check_one_size_helper(SIZE_MAX, size_sub, 2, 3);
699         check_one_size_helper(SIZE_MAX, size_sub, SIZE_MAX,  0);
700         check_one_size_helper(SIZE_MAX, size_sub, SIZE_MAX, 10);
701         check_one_size_helper(SIZE_MAX, size_sub, 0,  SIZE_MAX);
702         check_one_size_helper(SIZE_MAX, size_sub, 14, SIZE_MAX);
703         check_one_size_helper(SIZE_MAX - 2, size_sub, SIZE_MAX - 1,  1);
704         check_one_size_helper(SIZE_MAX - 4, size_sub, SIZE_MAX - 1,  3);
705         check_one_size_helper(1,                size_sub, SIZE_MAX - 1, -3);
706
707         var = 4;
708         check_one_size_helper(4 * sizeof(*obj->data),
709                               flex_array_size, obj, data, var++);
710         check_one_size_helper(5 * sizeof(*obj->data),
711                               flex_array_size, obj, data, var++);
712         check_one_size_helper(0, flex_array_size, obj, data, 0 + unconst);
713         check_one_size_helper(sizeof(*obj->data),
714                               flex_array_size, obj, data, 1 + unconst);
715         check_one_size_helper(7 * sizeof(*obj->data),
716                               flex_array_size, obj, data, 7 + unconst);
717         check_one_size_helper(SIZE_MAX,
718                               flex_array_size, obj, data, -1 + unconst);
719         check_one_size_helper(SIZE_MAX,
720                               flex_array_size, obj, data, SIZE_MAX - 4 + unconst);
721
722         var = 4;
723         check_one_size_helper(sizeof(*obj) + (4 * sizeof(*obj->data)),
724                               struct_size, obj, data, var++);
725         check_one_size_helper(sizeof(*obj) + (5 * sizeof(*obj->data)),
726                               struct_size, obj, data, var++);
727         check_one_size_helper(sizeof(*obj), struct_size, obj, data, 0 + unconst);
728         check_one_size_helper(sizeof(*obj) + sizeof(*obj->data),
729                               struct_size, obj, data, 1 + unconst);
730         check_one_size_helper(SIZE_MAX,
731                               struct_size, obj, data, -3 + unconst);
732         check_one_size_helper(SIZE_MAX,
733                               struct_size, obj, data, SIZE_MAX - 3 + unconst);
734
735         kunit_info(test, "%d overflow size helper tests finished\n", count);
736 #undef check_one_size_helper
737 }
738
739 static struct kunit_case overflow_test_cases[] = {
740         KUNIT_CASE(u8_u8__u8_overflow_test),
741         KUNIT_CASE(s8_s8__s8_overflow_test),
742         KUNIT_CASE(u16_u16__u16_overflow_test),
743         KUNIT_CASE(s16_s16__s16_overflow_test),
744         KUNIT_CASE(u32_u32__u32_overflow_test),
745         KUNIT_CASE(s32_s32__s32_overflow_test),
746         KUNIT_CASE(u64_u64__u64_overflow_test),
747         KUNIT_CASE(s64_s64__s64_overflow_test),
748         KUNIT_CASE(u32_u32__int_overflow_test),
749         KUNIT_CASE(u32_u32__u8_overflow_test),
750         KUNIT_CASE(u8_u8__int_overflow_test),
751         KUNIT_CASE(int_int__u8_overflow_test),
752         KUNIT_CASE(shift_sane_test),
753         KUNIT_CASE(shift_overflow_test),
754         KUNIT_CASE(shift_truncate_test),
755         KUNIT_CASE(shift_nonsense_test),
756         KUNIT_CASE(overflow_allocation_test),
757         KUNIT_CASE(overflow_size_helpers_test),
758         {}
759 };
760
761 static struct kunit_suite overflow_test_suite = {
762         .name = "overflow",
763         .test_cases = overflow_test_cases,
764 };
765
766 kunit_test_suite(overflow_test_suite);
767
768 MODULE_LICENSE("Dual MIT/GPL");