Merge tag 'vfio-v6.1-rc6' of https://github.com/awilliam/linux-vfio
[platform/kernel/linux-starfive.git] / lib / test_objagg.c
1 // SPDX-License-Identifier: BSD-3-Clause OR GPL-2.0
2 /* Copyright (c) 2018 Mellanox Technologies. All rights reserved */
3
4 #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
5
6 #include <linux/kernel.h>
7 #include <linux/module.h>
8 #include <linux/slab.h>
9 #include <linux/random.h>
10 #include <linux/objagg.h>
11
12 struct tokey {
13         unsigned int id;
14 };
15
16 #define NUM_KEYS 32
17
18 static int key_id_index(unsigned int key_id)
19 {
20         if (key_id >= NUM_KEYS) {
21                 WARN_ON(1);
22                 return 0;
23         }
24         return key_id;
25 }
26
27 #define BUF_LEN 128
28
29 struct world {
30         unsigned int root_count;
31         unsigned int delta_count;
32         char next_root_buf[BUF_LEN];
33         struct objagg_obj *objagg_objs[NUM_KEYS];
34         unsigned int key_refs[NUM_KEYS];
35 };
36
37 struct root {
38         struct tokey key;
39         char buf[BUF_LEN];
40 };
41
42 struct delta {
43         unsigned int key_id_diff;
44 };
45
46 static struct objagg_obj *world_obj_get(struct world *world,
47                                         struct objagg *objagg,
48                                         unsigned int key_id)
49 {
50         struct objagg_obj *objagg_obj;
51         struct tokey key;
52         int err;
53
54         key.id = key_id;
55         objagg_obj = objagg_obj_get(objagg, &key);
56         if (IS_ERR(objagg_obj)) {
57                 pr_err("Key %u: Failed to get object.\n", key_id);
58                 return objagg_obj;
59         }
60         if (!world->key_refs[key_id_index(key_id)]) {
61                 world->objagg_objs[key_id_index(key_id)] = objagg_obj;
62         } else if (world->objagg_objs[key_id_index(key_id)] != objagg_obj) {
63                 pr_err("Key %u: God another object for the same key.\n",
64                        key_id);
65                 err = -EINVAL;
66                 goto err_key_id_check;
67         }
68         world->key_refs[key_id_index(key_id)]++;
69         return objagg_obj;
70
71 err_key_id_check:
72         objagg_obj_put(objagg, objagg_obj);
73         return ERR_PTR(err);
74 }
75
76 static void world_obj_put(struct world *world, struct objagg *objagg,
77                           unsigned int key_id)
78 {
79         struct objagg_obj *objagg_obj;
80
81         if (!world->key_refs[key_id_index(key_id)])
82                 return;
83         objagg_obj = world->objagg_objs[key_id_index(key_id)];
84         objagg_obj_put(objagg, objagg_obj);
85         world->key_refs[key_id_index(key_id)]--;
86 }
87
88 #define MAX_KEY_ID_DIFF 5
89
90 static bool delta_check(void *priv, const void *parent_obj, const void *obj)
91 {
92         const struct tokey *parent_key = parent_obj;
93         const struct tokey *key = obj;
94         int diff = key->id - parent_key->id;
95
96         return diff >= 0 && diff <= MAX_KEY_ID_DIFF;
97 }
98
99 static void *delta_create(void *priv, void *parent_obj, void *obj)
100 {
101         struct tokey *parent_key = parent_obj;
102         struct world *world = priv;
103         struct tokey *key = obj;
104         int diff = key->id - parent_key->id;
105         struct delta *delta;
106
107         if (!delta_check(priv, parent_obj, obj))
108                 return ERR_PTR(-EINVAL);
109
110         delta = kzalloc(sizeof(*delta), GFP_KERNEL);
111         if (!delta)
112                 return ERR_PTR(-ENOMEM);
113         delta->key_id_diff = diff;
114         world->delta_count++;
115         return delta;
116 }
117
118 static void delta_destroy(void *priv, void *delta_priv)
119 {
120         struct delta *delta = delta_priv;
121         struct world *world = priv;
122
123         world->delta_count--;
124         kfree(delta);
125 }
126
127 static void *root_create(void *priv, void *obj, unsigned int id)
128 {
129         struct world *world = priv;
130         struct tokey *key = obj;
131         struct root *root;
132
133         root = kzalloc(sizeof(*root), GFP_KERNEL);
134         if (!root)
135                 return ERR_PTR(-ENOMEM);
136         memcpy(&root->key, key, sizeof(root->key));
137         memcpy(root->buf, world->next_root_buf, sizeof(root->buf));
138         world->root_count++;
139         return root;
140 }
141
142 static void root_destroy(void *priv, void *root_priv)
143 {
144         struct root *root = root_priv;
145         struct world *world = priv;
146
147         world->root_count--;
148         kfree(root);
149 }
150
151 static int test_nodelta_obj_get(struct world *world, struct objagg *objagg,
152                                 unsigned int key_id, bool should_create_root)
153 {
154         unsigned int orig_root_count = world->root_count;
155         struct objagg_obj *objagg_obj;
156         const struct root *root;
157         int err;
158
159         if (should_create_root)
160                 get_random_bytes(world->next_root_buf,
161                               sizeof(world->next_root_buf));
162
163         objagg_obj = world_obj_get(world, objagg, key_id);
164         if (IS_ERR(objagg_obj)) {
165                 pr_err("Key %u: Failed to get object.\n", key_id);
166                 return PTR_ERR(objagg_obj);
167         }
168         if (should_create_root) {
169                 if (world->root_count != orig_root_count + 1) {
170                         pr_err("Key %u: Root was not created\n", key_id);
171                         err = -EINVAL;
172                         goto err_check_root_count;
173                 }
174         } else {
175                 if (world->root_count != orig_root_count) {
176                         pr_err("Key %u: Root was incorrectly created\n",
177                                key_id);
178                         err = -EINVAL;
179                         goto err_check_root_count;
180                 }
181         }
182         root = objagg_obj_root_priv(objagg_obj);
183         if (root->key.id != key_id) {
184                 pr_err("Key %u: Root has unexpected key id\n", key_id);
185                 err = -EINVAL;
186                 goto err_check_key_id;
187         }
188         if (should_create_root &&
189             memcmp(world->next_root_buf, root->buf, sizeof(root->buf))) {
190                 pr_err("Key %u: Buffer does not match the expected content\n",
191                        key_id);
192                 err = -EINVAL;
193                 goto err_check_buf;
194         }
195         return 0;
196
197 err_check_buf:
198 err_check_key_id:
199 err_check_root_count:
200         objagg_obj_put(objagg, objagg_obj);
201         return err;
202 }
203
204 static int test_nodelta_obj_put(struct world *world, struct objagg *objagg,
205                                 unsigned int key_id, bool should_destroy_root)
206 {
207         unsigned int orig_root_count = world->root_count;
208
209         world_obj_put(world, objagg, key_id);
210
211         if (should_destroy_root) {
212                 if (world->root_count != orig_root_count - 1) {
213                         pr_err("Key %u: Root was not destroyed\n", key_id);
214                         return -EINVAL;
215                 }
216         } else {
217                 if (world->root_count != orig_root_count) {
218                         pr_err("Key %u: Root was incorrectly destroyed\n",
219                                key_id);
220                         return -EINVAL;
221                 }
222         }
223         return 0;
224 }
225
226 static int check_stats_zero(struct objagg *objagg)
227 {
228         const struct objagg_stats *stats;
229         int err = 0;
230
231         stats = objagg_stats_get(objagg);
232         if (IS_ERR(stats))
233                 return PTR_ERR(stats);
234
235         if (stats->stats_info_count != 0) {
236                 pr_err("Stats: Object count is not zero while it should be\n");
237                 err = -EINVAL;
238         }
239
240         objagg_stats_put(stats);
241         return err;
242 }
243
244 static int check_stats_nodelta(struct objagg *objagg)
245 {
246         const struct objagg_stats *stats;
247         int i;
248         int err;
249
250         stats = objagg_stats_get(objagg);
251         if (IS_ERR(stats))
252                 return PTR_ERR(stats);
253
254         if (stats->stats_info_count != NUM_KEYS) {
255                 pr_err("Stats: Unexpected object count (%u expected, %u returned)\n",
256                        NUM_KEYS, stats->stats_info_count);
257                 err = -EINVAL;
258                 goto stats_put;
259         }
260
261         for (i = 0; i < stats->stats_info_count; i++) {
262                 if (stats->stats_info[i].stats.user_count != 2) {
263                         pr_err("Stats: incorrect user count\n");
264                         err = -EINVAL;
265                         goto stats_put;
266                 }
267                 if (stats->stats_info[i].stats.delta_user_count != 2) {
268                         pr_err("Stats: incorrect delta user count\n");
269                         err = -EINVAL;
270                         goto stats_put;
271                 }
272         }
273         err = 0;
274
275 stats_put:
276         objagg_stats_put(stats);
277         return err;
278 }
279
280 static bool delta_check_dummy(void *priv, const void *parent_obj,
281                               const void *obj)
282 {
283         return false;
284 }
285
286 static void *delta_create_dummy(void *priv, void *parent_obj, void *obj)
287 {
288         return ERR_PTR(-EOPNOTSUPP);
289 }
290
291 static void delta_destroy_dummy(void *priv, void *delta_priv)
292 {
293 }
294
295 static const struct objagg_ops nodelta_ops = {
296         .obj_size = sizeof(struct tokey),
297         .delta_check = delta_check_dummy,
298         .delta_create = delta_create_dummy,
299         .delta_destroy = delta_destroy_dummy,
300         .root_create = root_create,
301         .root_destroy = root_destroy,
302 };
303
304 static int test_nodelta(void)
305 {
306         struct world world = {};
307         struct objagg *objagg;
308         int i;
309         int err;
310
311         objagg = objagg_create(&nodelta_ops, NULL, &world);
312         if (IS_ERR(objagg))
313                 return PTR_ERR(objagg);
314
315         err = check_stats_zero(objagg);
316         if (err)
317                 goto err_stats_first_zero;
318
319         /* First round of gets, the root objects should be created */
320         for (i = 0; i < NUM_KEYS; i++) {
321                 err = test_nodelta_obj_get(&world, objagg, i, true);
322                 if (err)
323                         goto err_obj_first_get;
324         }
325
326         /* Do the second round of gets, all roots are already created,
327          * make sure that no new root is created
328          */
329         for (i = 0; i < NUM_KEYS; i++) {
330                 err = test_nodelta_obj_get(&world, objagg, i, false);
331                 if (err)
332                         goto err_obj_second_get;
333         }
334
335         err = check_stats_nodelta(objagg);
336         if (err)
337                 goto err_stats_nodelta;
338
339         for (i = NUM_KEYS - 1; i >= 0; i--) {
340                 err = test_nodelta_obj_put(&world, objagg, i, false);
341                 if (err)
342                         goto err_obj_first_put;
343         }
344         for (i = NUM_KEYS - 1; i >= 0; i--) {
345                 err = test_nodelta_obj_put(&world, objagg, i, true);
346                 if (err)
347                         goto err_obj_second_put;
348         }
349
350         err = check_stats_zero(objagg);
351         if (err)
352                 goto err_stats_second_zero;
353
354         objagg_destroy(objagg);
355         return 0;
356
357 err_stats_nodelta:
358 err_obj_first_put:
359 err_obj_second_get:
360         for (i--; i >= 0; i--)
361                 world_obj_put(&world, objagg, i);
362
363         i = NUM_KEYS;
364 err_obj_first_get:
365 err_obj_second_put:
366         for (i--; i >= 0; i--)
367                 world_obj_put(&world, objagg, i);
368 err_stats_first_zero:
369 err_stats_second_zero:
370         objagg_destroy(objagg);
371         return err;
372 }
373
374 static const struct objagg_ops delta_ops = {
375         .obj_size = sizeof(struct tokey),
376         .delta_check = delta_check,
377         .delta_create = delta_create,
378         .delta_destroy = delta_destroy,
379         .root_create = root_create,
380         .root_destroy = root_destroy,
381 };
382
383 enum action {
384         ACTION_GET,
385         ACTION_PUT,
386 };
387
388 enum expect_delta {
389         EXPECT_DELTA_SAME,
390         EXPECT_DELTA_INC,
391         EXPECT_DELTA_DEC,
392 };
393
394 enum expect_root {
395         EXPECT_ROOT_SAME,
396         EXPECT_ROOT_INC,
397         EXPECT_ROOT_DEC,
398 };
399
400 struct expect_stats_info {
401         struct objagg_obj_stats stats;
402         bool is_root;
403         unsigned int key_id;
404 };
405
406 struct expect_stats {
407         unsigned int info_count;
408         struct expect_stats_info info[NUM_KEYS];
409 };
410
411 struct action_item {
412         unsigned int key_id;
413         enum action action;
414         enum expect_delta expect_delta;
415         enum expect_root expect_root;
416         struct expect_stats expect_stats;
417 };
418
419 #define EXPECT_STATS(count, ...)                \
420 {                                               \
421         .info_count = count,                    \
422         .info = { __VA_ARGS__ }                 \
423 }
424
425 #define ROOT(key_id, user_count, delta_user_count)      \
426         {{user_count, delta_user_count}, true, key_id}
427
428 #define DELTA(key_id, user_count)                       \
429         {{user_count, user_count}, false, key_id}
430
431 static const struct action_item action_items[] = {
432         {
433                 1, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_INC,
434                 EXPECT_STATS(1, ROOT(1, 1, 1)),
435         },      /* r: 1                 d: */
436         {
437                 7, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_INC,
438                 EXPECT_STATS(2, ROOT(1, 1, 1), ROOT(7, 1, 1)),
439         },      /* r: 1, 7              d: */
440         {
441                 3, ACTION_GET, EXPECT_DELTA_INC, EXPECT_ROOT_SAME,
442                 EXPECT_STATS(3, ROOT(1, 1, 2), ROOT(7, 1, 1),
443                                 DELTA(3, 1)),
444         },      /* r: 1, 7              d: 3^1 */
445         {
446                 5, ACTION_GET, EXPECT_DELTA_INC, EXPECT_ROOT_SAME,
447                 EXPECT_STATS(4, ROOT(1, 1, 3), ROOT(7, 1, 1),
448                                 DELTA(3, 1), DELTA(5, 1)),
449         },      /* r: 1, 7              d: 3^1, 5^1 */
450         {
451                 3, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
452                 EXPECT_STATS(4, ROOT(1, 1, 4), ROOT(7, 1, 1),
453                                 DELTA(3, 2), DELTA(5, 1)),
454         },      /* r: 1, 7              d: 3^1, 3^1, 5^1 */
455         {
456                 1, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
457                 EXPECT_STATS(4, ROOT(1, 2, 5), ROOT(7, 1, 1),
458                                 DELTA(3, 2), DELTA(5, 1)),
459         },      /* r: 1, 1, 7           d: 3^1, 3^1, 5^1 */
460         {
461                 30, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_INC,
462                 EXPECT_STATS(5, ROOT(1, 2, 5), ROOT(7, 1, 1), ROOT(30, 1, 1),
463                                 DELTA(3, 2), DELTA(5, 1)),
464         },      /* r: 1, 1, 7, 30       d: 3^1, 3^1, 5^1 */
465         {
466                 8, ACTION_GET, EXPECT_DELTA_INC, EXPECT_ROOT_SAME,
467                 EXPECT_STATS(6, ROOT(1, 2, 5), ROOT(7, 1, 2), ROOT(30, 1, 1),
468                                 DELTA(3, 2), DELTA(5, 1), DELTA(8, 1)),
469         },      /* r: 1, 1, 7, 30       d: 3^1, 3^1, 5^1, 8^7 */
470         {
471                 8, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
472                 EXPECT_STATS(6, ROOT(1, 2, 5), ROOT(7, 1, 3), ROOT(30, 1, 1),
473                                 DELTA(3, 2), DELTA(8, 2), DELTA(5, 1)),
474         },      /* r: 1, 1, 7, 30       d: 3^1, 3^1, 5^1, 8^7, 8^7 */
475         {
476                 3, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
477                 EXPECT_STATS(6, ROOT(1, 2, 4), ROOT(7, 1, 3), ROOT(30, 1, 1),
478                                 DELTA(8, 2), DELTA(3, 1), DELTA(5, 1)),
479         },      /* r: 1, 1, 7, 30       d: 3^1, 5^1, 8^7, 8^7 */
480         {
481                 3, ACTION_PUT, EXPECT_DELTA_DEC, EXPECT_ROOT_SAME,
482                 EXPECT_STATS(5, ROOT(1, 2, 3), ROOT(7, 1, 3), ROOT(30, 1, 1),
483                                 DELTA(8, 2), DELTA(5, 1)),
484         },      /* r: 1, 1, 7, 30       d: 5^1, 8^7, 8^7 */
485         {
486                 1, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
487                 EXPECT_STATS(5, ROOT(7, 1, 3), ROOT(1, 1, 2), ROOT(30, 1, 1),
488                                 DELTA(8, 2), DELTA(5, 1)),
489         },      /* r: 1, 7, 30          d: 5^1, 8^7, 8^7 */
490         {
491                 1, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
492                 EXPECT_STATS(5, ROOT(7, 1, 3), ROOT(30, 1, 1), ROOT(1, 0, 1),
493                                 DELTA(8, 2), DELTA(5, 1)),
494         },      /* r: 7, 30             d: 5^1, 8^7, 8^7 */
495         {
496                 5, ACTION_PUT, EXPECT_DELTA_DEC, EXPECT_ROOT_DEC,
497                 EXPECT_STATS(3, ROOT(7, 1, 3), ROOT(30, 1, 1),
498                                 DELTA(8, 2)),
499         },      /* r: 7, 30             d: 8^7, 8^7 */
500         {
501                 5, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_INC,
502                 EXPECT_STATS(4, ROOT(7, 1, 3), ROOT(30, 1, 1), ROOT(5, 1, 1),
503                                 DELTA(8, 2)),
504         },      /* r: 7, 30, 5          d: 8^7, 8^7 */
505         {
506                 6, ACTION_GET, EXPECT_DELTA_INC, EXPECT_ROOT_SAME,
507                 EXPECT_STATS(5, ROOT(7, 1, 3), ROOT(5, 1, 2), ROOT(30, 1, 1),
508                                 DELTA(8, 2), DELTA(6, 1)),
509         },      /* r: 7, 30, 5          d: 8^7, 8^7, 6^5 */
510         {
511                 8, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
512                 EXPECT_STATS(5, ROOT(7, 1, 4), ROOT(5, 1, 2), ROOT(30, 1, 1),
513                                 DELTA(8, 3), DELTA(6, 1)),
514         },      /* r: 7, 30, 5          d: 8^7, 8^7, 8^7, 6^5 */
515         {
516                 8, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
517                 EXPECT_STATS(5, ROOT(7, 1, 3), ROOT(5, 1, 2), ROOT(30, 1, 1),
518                                 DELTA(8, 2), DELTA(6, 1)),
519         },      /* r: 7, 30, 5          d: 8^7, 8^7, 6^5 */
520         {
521                 8, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
522                 EXPECT_STATS(5, ROOT(7, 1, 2), ROOT(5, 1, 2), ROOT(30, 1, 1),
523                                 DELTA(8, 1), DELTA(6, 1)),
524         },      /* r: 7, 30, 5          d: 8^7, 6^5 */
525         {
526                 8, ACTION_PUT, EXPECT_DELTA_DEC, EXPECT_ROOT_SAME,
527                 EXPECT_STATS(4, ROOT(5, 1, 2), ROOT(7, 1, 1), ROOT(30, 1, 1),
528                                 DELTA(6, 1)),
529         },      /* r: 7, 30, 5          d: 6^5 */
530         {
531                 8, ACTION_GET, EXPECT_DELTA_INC, EXPECT_ROOT_SAME,
532                 EXPECT_STATS(5, ROOT(5, 1, 3), ROOT(7, 1, 1), ROOT(30, 1, 1),
533                                 DELTA(6, 1), DELTA(8, 1)),
534         },      /* r: 7, 30, 5          d: 6^5, 8^5 */
535         {
536                 7, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_DEC,
537                 EXPECT_STATS(4, ROOT(5, 1, 3), ROOT(30, 1, 1),
538                                 DELTA(6, 1), DELTA(8, 1)),
539         },      /* r: 30, 5             d: 6^5, 8^5 */
540         {
541                 30, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_DEC,
542                 EXPECT_STATS(3, ROOT(5, 1, 3),
543                                 DELTA(6, 1), DELTA(8, 1)),
544         },      /* r: 5                 d: 6^5, 8^5 */
545         {
546                 5, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
547                 EXPECT_STATS(3, ROOT(5, 0, 2),
548                                 DELTA(6, 1), DELTA(8, 1)),
549         },      /* r:                   d: 6^5, 8^5 */
550         {
551                 6, ACTION_PUT, EXPECT_DELTA_DEC, EXPECT_ROOT_SAME,
552                 EXPECT_STATS(2, ROOT(5, 0, 1),
553                                 DELTA(8, 1)),
554         },      /* r:                   d: 6^5 */
555         {
556                 8, ACTION_PUT, EXPECT_DELTA_DEC, EXPECT_ROOT_DEC,
557                 EXPECT_STATS(0, ),
558         },      /* r:                   d: */
559 };
560
561 static int check_expect(struct world *world,
562                         const struct action_item *action_item,
563                         unsigned int orig_delta_count,
564                         unsigned int orig_root_count)
565 {
566         unsigned int key_id = action_item->key_id;
567
568         switch (action_item->expect_delta) {
569         case EXPECT_DELTA_SAME:
570                 if (orig_delta_count != world->delta_count) {
571                         pr_err("Key %u: Delta count changed while expected to remain the same.\n",
572                                key_id);
573                         return -EINVAL;
574                 }
575                 break;
576         case EXPECT_DELTA_INC:
577                 if (WARN_ON(action_item->action == ACTION_PUT))
578                         return -EINVAL;
579                 if (orig_delta_count + 1 != world->delta_count) {
580                         pr_err("Key %u: Delta count was not incremented.\n",
581                                key_id);
582                         return -EINVAL;
583                 }
584                 break;
585         case EXPECT_DELTA_DEC:
586                 if (WARN_ON(action_item->action == ACTION_GET))
587                         return -EINVAL;
588                 if (orig_delta_count - 1 != world->delta_count) {
589                         pr_err("Key %u: Delta count was not decremented.\n",
590                                key_id);
591                         return -EINVAL;
592                 }
593                 break;
594         }
595
596         switch (action_item->expect_root) {
597         case EXPECT_ROOT_SAME:
598                 if (orig_root_count != world->root_count) {
599                         pr_err("Key %u: Root count changed while expected to remain the same.\n",
600                                key_id);
601                         return -EINVAL;
602                 }
603                 break;
604         case EXPECT_ROOT_INC:
605                 if (WARN_ON(action_item->action == ACTION_PUT))
606                         return -EINVAL;
607                 if (orig_root_count + 1 != world->root_count) {
608                         pr_err("Key %u: Root count was not incremented.\n",
609                                key_id);
610                         return -EINVAL;
611                 }
612                 break;
613         case EXPECT_ROOT_DEC:
614                 if (WARN_ON(action_item->action == ACTION_GET))
615                         return -EINVAL;
616                 if (orig_root_count - 1 != world->root_count) {
617                         pr_err("Key %u: Root count was not decremented.\n",
618                                key_id);
619                         return -EINVAL;
620                 }
621         }
622
623         return 0;
624 }
625
626 static unsigned int obj_to_key_id(struct objagg_obj *objagg_obj)
627 {
628         const struct tokey *root_key;
629         const struct delta *delta;
630         unsigned int key_id;
631
632         root_key = objagg_obj_root_priv(objagg_obj);
633         key_id = root_key->id;
634         delta = objagg_obj_delta_priv(objagg_obj);
635         if (delta)
636                 key_id += delta->key_id_diff;
637         return key_id;
638 }
639
640 static int
641 check_expect_stats_nums(const struct objagg_obj_stats_info *stats_info,
642                         const struct expect_stats_info *expect_stats_info,
643                         const char **errmsg)
644 {
645         if (stats_info->is_root != expect_stats_info->is_root) {
646                 if (errmsg)
647                         *errmsg = "Incorrect root/delta indication";
648                 return -EINVAL;
649         }
650         if (stats_info->stats.user_count !=
651             expect_stats_info->stats.user_count) {
652                 if (errmsg)
653                         *errmsg = "Incorrect user count";
654                 return -EINVAL;
655         }
656         if (stats_info->stats.delta_user_count !=
657             expect_stats_info->stats.delta_user_count) {
658                 if (errmsg)
659                         *errmsg = "Incorrect delta user count";
660                 return -EINVAL;
661         }
662         return 0;
663 }
664
665 static int
666 check_expect_stats_key_id(const struct objagg_obj_stats_info *stats_info,
667                           const struct expect_stats_info *expect_stats_info,
668                           const char **errmsg)
669 {
670         if (obj_to_key_id(stats_info->objagg_obj) !=
671             expect_stats_info->key_id) {
672                 if (errmsg)
673                         *errmsg = "incorrect key id";
674                 return -EINVAL;
675         }
676         return 0;
677 }
678
679 static int check_expect_stats_neigh(const struct objagg_stats *stats,
680                                     const struct expect_stats *expect_stats,
681                                     int pos)
682 {
683         int i;
684         int err;
685
686         for (i = pos - 1; i >= 0; i--) {
687                 err = check_expect_stats_nums(&stats->stats_info[i],
688                                               &expect_stats->info[pos], NULL);
689                 if (err)
690                         break;
691                 err = check_expect_stats_key_id(&stats->stats_info[i],
692                                                 &expect_stats->info[pos], NULL);
693                 if (!err)
694                         return 0;
695         }
696         for (i = pos + 1; i < stats->stats_info_count; i++) {
697                 err = check_expect_stats_nums(&stats->stats_info[i],
698                                               &expect_stats->info[pos], NULL);
699                 if (err)
700                         break;
701                 err = check_expect_stats_key_id(&stats->stats_info[i],
702                                                 &expect_stats->info[pos], NULL);
703                 if (!err)
704                         return 0;
705         }
706         return -EINVAL;
707 }
708
709 static int __check_expect_stats(const struct objagg_stats *stats,
710                                 const struct expect_stats *expect_stats,
711                                 const char **errmsg)
712 {
713         int i;
714         int err;
715
716         if (stats->stats_info_count != expect_stats->info_count) {
717                 *errmsg = "Unexpected object count";
718                 return -EINVAL;
719         }
720
721         for (i = 0; i < stats->stats_info_count; i++) {
722                 err = check_expect_stats_nums(&stats->stats_info[i],
723                                               &expect_stats->info[i], errmsg);
724                 if (err)
725                         return err;
726                 err = check_expect_stats_key_id(&stats->stats_info[i],
727                                                 &expect_stats->info[i], errmsg);
728                 if (err) {
729                         /* It is possible that one of the neighbor stats with
730                          * same numbers have the correct key id, so check it
731                          */
732                         err = check_expect_stats_neigh(stats, expect_stats, i);
733                         if (err)
734                                 return err;
735                 }
736         }
737         return 0;
738 }
739
740 static int check_expect_stats(struct objagg *objagg,
741                               const struct expect_stats *expect_stats,
742                               const char **errmsg)
743 {
744         const struct objagg_stats *stats;
745         int err;
746
747         stats = objagg_stats_get(objagg);
748         if (IS_ERR(stats)) {
749                 *errmsg = "objagg_stats_get() failed.";
750                 return PTR_ERR(stats);
751         }
752         err = __check_expect_stats(stats, expect_stats, errmsg);
753         objagg_stats_put(stats);
754         return err;
755 }
756
757 static int test_delta_action_item(struct world *world,
758                                   struct objagg *objagg,
759                                   const struct action_item *action_item,
760                                   bool inverse)
761 {
762         unsigned int orig_delta_count = world->delta_count;
763         unsigned int orig_root_count = world->root_count;
764         unsigned int key_id = action_item->key_id;
765         enum action action = action_item->action;
766         struct objagg_obj *objagg_obj;
767         const char *errmsg;
768         int err;
769
770         if (inverse)
771                 action = action == ACTION_GET ? ACTION_PUT : ACTION_GET;
772
773         switch (action) {
774         case ACTION_GET:
775                 objagg_obj = world_obj_get(world, objagg, key_id);
776                 if (IS_ERR(objagg_obj))
777                         return PTR_ERR(objagg_obj);
778                 break;
779         case ACTION_PUT:
780                 world_obj_put(world, objagg, key_id);
781                 break;
782         }
783
784         if (inverse)
785                 return 0;
786         err = check_expect(world, action_item,
787                            orig_delta_count, orig_root_count);
788         if (err)
789                 goto errout;
790
791         err = check_expect_stats(objagg, &action_item->expect_stats, &errmsg);
792         if (err) {
793                 pr_err("Key %u: Stats: %s\n", action_item->key_id, errmsg);
794                 goto errout;
795         }
796
797         return 0;
798
799 errout:
800         /* This can only happen when action is not inversed.
801          * So in case of an error, cleanup by doing inverse action.
802          */
803         test_delta_action_item(world, objagg, action_item, true);
804         return err;
805 }
806
807 static int test_delta(void)
808 {
809         struct world world = {};
810         struct objagg *objagg;
811         int i;
812         int err;
813
814         objagg = objagg_create(&delta_ops, NULL, &world);
815         if (IS_ERR(objagg))
816                 return PTR_ERR(objagg);
817
818         for (i = 0; i < ARRAY_SIZE(action_items); i++) {
819                 err = test_delta_action_item(&world, objagg,
820                                              &action_items[i], false);
821                 if (err)
822                         goto err_do_action_item;
823         }
824
825         objagg_destroy(objagg);
826         return 0;
827
828 err_do_action_item:
829         for (i--; i >= 0; i--)
830                 test_delta_action_item(&world, objagg, &action_items[i], true);
831
832         objagg_destroy(objagg);
833         return err;
834 }
835
836 struct hints_case {
837         const unsigned int *key_ids;
838         size_t key_ids_count;
839         struct expect_stats expect_stats;
840         struct expect_stats expect_stats_hints;
841 };
842
843 static const unsigned int hints_case_key_ids[] = {
844         1, 7, 3, 5, 3, 1, 30, 8, 8, 5, 6, 8,
845 };
846
847 static const struct hints_case hints_case = {
848         .key_ids = hints_case_key_ids,
849         .key_ids_count = ARRAY_SIZE(hints_case_key_ids),
850         .expect_stats =
851                 EXPECT_STATS(7, ROOT(1, 2, 7), ROOT(7, 1, 4), ROOT(30, 1, 1),
852                                 DELTA(8, 3), DELTA(3, 2),
853                                 DELTA(5, 2), DELTA(6, 1)),
854         .expect_stats_hints =
855                 EXPECT_STATS(7, ROOT(3, 2, 9), ROOT(1, 2, 2), ROOT(30, 1, 1),
856                                 DELTA(8, 3), DELTA(5, 2),
857                                 DELTA(6, 1), DELTA(7, 1)),
858 };
859
860 static void __pr_debug_stats(const struct objagg_stats *stats)
861 {
862         int i;
863
864         for (i = 0; i < stats->stats_info_count; i++)
865                 pr_debug("Stat index %d key %u: u %d, d %d, %s\n", i,
866                          obj_to_key_id(stats->stats_info[i].objagg_obj),
867                          stats->stats_info[i].stats.user_count,
868                          stats->stats_info[i].stats.delta_user_count,
869                          stats->stats_info[i].is_root ? "root" : "noroot");
870 }
871
872 static void pr_debug_stats(struct objagg *objagg)
873 {
874         const struct objagg_stats *stats;
875
876         stats = objagg_stats_get(objagg);
877         if (IS_ERR(stats))
878                 return;
879         __pr_debug_stats(stats);
880         objagg_stats_put(stats);
881 }
882
883 static void pr_debug_hints_stats(struct objagg_hints *objagg_hints)
884 {
885         const struct objagg_stats *stats;
886
887         stats = objagg_hints_stats_get(objagg_hints);
888         if (IS_ERR(stats))
889                 return;
890         __pr_debug_stats(stats);
891         objagg_stats_put(stats);
892 }
893
894 static int check_expect_hints_stats(struct objagg_hints *objagg_hints,
895                                     const struct expect_stats *expect_stats,
896                                     const char **errmsg)
897 {
898         const struct objagg_stats *stats;
899         int err;
900
901         stats = objagg_hints_stats_get(objagg_hints);
902         if (IS_ERR(stats))
903                 return PTR_ERR(stats);
904         err = __check_expect_stats(stats, expect_stats, errmsg);
905         objagg_stats_put(stats);
906         return err;
907 }
908
909 static int test_hints_case(const struct hints_case *hints_case)
910 {
911         struct objagg_obj *objagg_obj;
912         struct objagg_hints *hints;
913         struct world world2 = {};
914         struct world world = {};
915         struct objagg *objagg2;
916         struct objagg *objagg;
917         const char *errmsg;
918         int i;
919         int err;
920
921         objagg = objagg_create(&delta_ops, NULL, &world);
922         if (IS_ERR(objagg))
923                 return PTR_ERR(objagg);
924
925         for (i = 0; i < hints_case->key_ids_count; i++) {
926                 objagg_obj = world_obj_get(&world, objagg,
927                                            hints_case->key_ids[i]);
928                 if (IS_ERR(objagg_obj)) {
929                         err = PTR_ERR(objagg_obj);
930                         goto err_world_obj_get;
931                 }
932         }
933
934         pr_debug_stats(objagg);
935         err = check_expect_stats(objagg, &hints_case->expect_stats, &errmsg);
936         if (err) {
937                 pr_err("Stats: %s\n", errmsg);
938                 goto err_check_expect_stats;
939         }
940
941         hints = objagg_hints_get(objagg, OBJAGG_OPT_ALGO_SIMPLE_GREEDY);
942         if (IS_ERR(hints)) {
943                 err = PTR_ERR(hints);
944                 goto err_hints_get;
945         }
946
947         pr_debug_hints_stats(hints);
948         err = check_expect_hints_stats(hints, &hints_case->expect_stats_hints,
949                                        &errmsg);
950         if (err) {
951                 pr_err("Hints stats: %s\n", errmsg);
952                 goto err_check_expect_hints_stats;
953         }
954
955         objagg2 = objagg_create(&delta_ops, hints, &world2);
956         if (IS_ERR(objagg2))
957                 return PTR_ERR(objagg2);
958
959         for (i = 0; i < hints_case->key_ids_count; i++) {
960                 objagg_obj = world_obj_get(&world2, objagg2,
961                                            hints_case->key_ids[i]);
962                 if (IS_ERR(objagg_obj)) {
963                         err = PTR_ERR(objagg_obj);
964                         goto err_world2_obj_get;
965                 }
966         }
967
968         pr_debug_stats(objagg2);
969         err = check_expect_stats(objagg2, &hints_case->expect_stats_hints,
970                                  &errmsg);
971         if (err) {
972                 pr_err("Stats2: %s\n", errmsg);
973                 goto err_check_expect_stats2;
974         }
975
976         err = 0;
977
978 err_check_expect_stats2:
979 err_world2_obj_get:
980         for (i--; i >= 0; i--)
981                 world_obj_put(&world2, objagg, hints_case->key_ids[i]);
982         i = hints_case->key_ids_count;
983         objagg_destroy(objagg2);
984 err_check_expect_hints_stats:
985         objagg_hints_put(hints);
986 err_hints_get:
987 err_check_expect_stats:
988 err_world_obj_get:
989         for (i--; i >= 0; i--)
990                 world_obj_put(&world, objagg, hints_case->key_ids[i]);
991
992         objagg_destroy(objagg);
993         return err;
994 }
995 static int test_hints(void)
996 {
997         return test_hints_case(&hints_case);
998 }
999
1000 static int __init test_objagg_init(void)
1001 {
1002         int err;
1003
1004         err = test_nodelta();
1005         if (err)
1006                 return err;
1007         err = test_delta();
1008         if (err)
1009                 return err;
1010         return test_hints();
1011 }
1012
1013 static void __exit test_objagg_exit(void)
1014 {
1015 }
1016
1017 module_init(test_objagg_init);
1018 module_exit(test_objagg_exit);
1019 MODULE_LICENSE("Dual BSD/GPL");
1020 MODULE_AUTHOR("Jiri Pirko <jiri@mellanox.com>");
1021 MODULE_DESCRIPTION("Test module for objagg");