Merge branch 'for-linus' of git://git.kernel.org/pub/scm/linux/kernel/git/dtor/input
[platform/kernel/linux-rpi.git] / lib / test_objagg.c
1 // SPDX-License-Identifier: BSD-3-Clause OR GPL-2.0
2 /* Copyright (c) 2018 Mellanox Technologies. All rights reserved */
3
4 #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
5
6 #include <linux/kernel.h>
7 #include <linux/module.h>
8 #include <linux/slab.h>
9 #include <linux/random.h>
10 #include <linux/objagg.h>
11
12 struct tokey {
13         unsigned int id;
14 };
15
16 #define NUM_KEYS 32
17
18 static int key_id_index(unsigned int key_id)
19 {
20         if (key_id >= NUM_KEYS) {
21                 WARN_ON(1);
22                 return 0;
23         }
24         return key_id;
25 }
26
27 #define BUF_LEN 128
28
29 struct world {
30         unsigned int root_count;
31         unsigned int delta_count;
32         char next_root_buf[BUF_LEN];
33         struct objagg_obj *objagg_objs[NUM_KEYS];
34         unsigned int key_refs[NUM_KEYS];
35 };
36
37 struct root {
38         struct tokey key;
39         char buf[BUF_LEN];
40 };
41
42 struct delta {
43         unsigned int key_id_diff;
44 };
45
46 static struct objagg_obj *world_obj_get(struct world *world,
47                                         struct objagg *objagg,
48                                         unsigned int key_id)
49 {
50         struct objagg_obj *objagg_obj;
51         struct tokey key;
52         int err;
53
54         key.id = key_id;
55         objagg_obj = objagg_obj_get(objagg, &key);
56         if (IS_ERR(objagg_obj)) {
57                 pr_err("Key %u: Failed to get object.\n", key_id);
58                 return objagg_obj;
59         }
60         if (!world->key_refs[key_id_index(key_id)]) {
61                 world->objagg_objs[key_id_index(key_id)] = objagg_obj;
62         } else if (world->objagg_objs[key_id_index(key_id)] != objagg_obj) {
63                 pr_err("Key %u: God another object for the same key.\n",
64                        key_id);
65                 err = -EINVAL;
66                 goto err_key_id_check;
67         }
68         world->key_refs[key_id_index(key_id)]++;
69         return objagg_obj;
70
71 err_key_id_check:
72         objagg_obj_put(objagg, objagg_obj);
73         return ERR_PTR(err);
74 }
75
76 static void world_obj_put(struct world *world, struct objagg *objagg,
77                           unsigned int key_id)
78 {
79         struct objagg_obj *objagg_obj;
80
81         if (!world->key_refs[key_id_index(key_id)])
82                 return;
83         objagg_obj = world->objagg_objs[key_id_index(key_id)];
84         objagg_obj_put(objagg, objagg_obj);
85         world->key_refs[key_id_index(key_id)]--;
86 }
87
88 #define MAX_KEY_ID_DIFF 5
89
90 static void *delta_create(void *priv, void *parent_obj, void *obj)
91 {
92         struct tokey *parent_key = parent_obj;
93         struct world *world = priv;
94         struct tokey *key = obj;
95         int diff = key->id - parent_key->id;
96         struct delta *delta;
97
98         if (diff < 0 || diff > MAX_KEY_ID_DIFF)
99                 return ERR_PTR(-EINVAL);
100
101         delta = kzalloc(sizeof(*delta), GFP_KERNEL);
102         if (!delta)
103                 return ERR_PTR(-ENOMEM);
104         delta->key_id_diff = diff;
105         world->delta_count++;
106         return delta;
107 }
108
109 static void delta_destroy(void *priv, void *delta_priv)
110 {
111         struct delta *delta = delta_priv;
112         struct world *world = priv;
113
114         world->delta_count--;
115         kfree(delta);
116 }
117
118 static void *root_create(void *priv, void *obj)
119 {
120         struct world *world = priv;
121         struct tokey *key = obj;
122         struct root *root;
123
124         root = kzalloc(sizeof(*root), GFP_KERNEL);
125         if (!root)
126                 return ERR_PTR(-ENOMEM);
127         memcpy(&root->key, key, sizeof(root->key));
128         memcpy(root->buf, world->next_root_buf, sizeof(root->buf));
129         world->root_count++;
130         return root;
131 }
132
133 static void root_destroy(void *priv, void *root_priv)
134 {
135         struct root *root = root_priv;
136         struct world *world = priv;
137
138         world->root_count--;
139         kfree(root);
140 }
141
142 static int test_nodelta_obj_get(struct world *world, struct objagg *objagg,
143                                 unsigned int key_id, bool should_create_root)
144 {
145         unsigned int orig_root_count = world->root_count;
146         struct objagg_obj *objagg_obj;
147         const struct root *root;
148         int err;
149
150         if (should_create_root)
151                 prandom_bytes(world->next_root_buf,
152                               sizeof(world->next_root_buf));
153
154         objagg_obj = world_obj_get(world, objagg, key_id);
155         if (IS_ERR(objagg_obj)) {
156                 pr_err("Key %u: Failed to get object.\n", key_id);
157                 return PTR_ERR(objagg_obj);
158         }
159         if (should_create_root) {
160                 if (world->root_count != orig_root_count + 1) {
161                         pr_err("Key %u: Root was not created\n", key_id);
162                         err = -EINVAL;
163                         goto err_check_root_count;
164                 }
165         } else {
166                 if (world->root_count != orig_root_count) {
167                         pr_err("Key %u: Root was incorrectly created\n",
168                                key_id);
169                         err = -EINVAL;
170                         goto err_check_root_count;
171                 }
172         }
173         root = objagg_obj_root_priv(objagg_obj);
174         if (root->key.id != key_id) {
175                 pr_err("Key %u: Root has unexpected key id\n", key_id);
176                 err = -EINVAL;
177                 goto err_check_key_id;
178         }
179         if (should_create_root &&
180             memcmp(world->next_root_buf, root->buf, sizeof(root->buf))) {
181                 pr_err("Key %u: Buffer does not match the expected content\n",
182                        key_id);
183                 err = -EINVAL;
184                 goto err_check_buf;
185         }
186         return 0;
187
188 err_check_buf:
189 err_check_key_id:
190 err_check_root_count:
191         objagg_obj_put(objagg, objagg_obj);
192         return err;
193 }
194
195 static int test_nodelta_obj_put(struct world *world, struct objagg *objagg,
196                                 unsigned int key_id, bool should_destroy_root)
197 {
198         unsigned int orig_root_count = world->root_count;
199
200         world_obj_put(world, objagg, key_id);
201
202         if (should_destroy_root) {
203                 if (world->root_count != orig_root_count - 1) {
204                         pr_err("Key %u: Root was not destroyed\n", key_id);
205                         return -EINVAL;
206                 }
207         } else {
208                 if (world->root_count != orig_root_count) {
209                         pr_err("Key %u: Root was incorrectly destroyed\n",
210                                key_id);
211                         return -EINVAL;
212                 }
213         }
214         return 0;
215 }
216
217 static int check_stats_zero(struct objagg *objagg)
218 {
219         const struct objagg_stats *stats;
220         int err = 0;
221
222         stats = objagg_stats_get(objagg);
223         if (IS_ERR(stats))
224                 return PTR_ERR(stats);
225
226         if (stats->stats_info_count != 0) {
227                 pr_err("Stats: Object count is not zero while it should be\n");
228                 err = -EINVAL;
229         }
230
231         objagg_stats_put(stats);
232         return err;
233 }
234
235 static int check_stats_nodelta(struct objagg *objagg)
236 {
237         const struct objagg_stats *stats;
238         int i;
239         int err;
240
241         stats = objagg_stats_get(objagg);
242         if (IS_ERR(stats))
243                 return PTR_ERR(stats);
244
245         if (stats->stats_info_count != NUM_KEYS) {
246                 pr_err("Stats: Unexpected object count (%u expected, %u returned)\n",
247                        NUM_KEYS, stats->stats_info_count);
248                 err = -EINVAL;
249                 goto stats_put;
250         }
251
252         for (i = 0; i < stats->stats_info_count; i++) {
253                 if (stats->stats_info[i].stats.user_count != 2) {
254                         pr_err("Stats: incorrect user count\n");
255                         err = -EINVAL;
256                         goto stats_put;
257                 }
258                 if (stats->stats_info[i].stats.delta_user_count != 2) {
259                         pr_err("Stats: incorrect delta user count\n");
260                         err = -EINVAL;
261                         goto stats_put;
262                 }
263         }
264         err = 0;
265
266 stats_put:
267         objagg_stats_put(stats);
268         return err;
269 }
270
271 static void *delta_create_dummy(void *priv, void *parent_obj, void *obj)
272 {
273         return ERR_PTR(-EOPNOTSUPP);
274 }
275
276 static void delta_destroy_dummy(void *priv, void *delta_priv)
277 {
278 }
279
280 static const struct objagg_ops nodelta_ops = {
281         .obj_size = sizeof(struct tokey),
282         .delta_create = delta_create_dummy,
283         .delta_destroy = delta_destroy_dummy,
284         .root_create = root_create,
285         .root_destroy = root_destroy,
286 };
287
288 static int test_nodelta(void)
289 {
290         struct world world = {};
291         struct objagg *objagg;
292         int i;
293         int err;
294
295         objagg = objagg_create(&nodelta_ops, &world);
296         if (IS_ERR(objagg))
297                 return PTR_ERR(objagg);
298
299         err = check_stats_zero(objagg);
300         if (err)
301                 goto err_stats_first_zero;
302
303         /* First round of gets, the root objects should be created */
304         for (i = 0; i < NUM_KEYS; i++) {
305                 err = test_nodelta_obj_get(&world, objagg, i, true);
306                 if (err)
307                         goto err_obj_first_get;
308         }
309
310         /* Do the second round of gets, all roots are already created,
311          * make sure that no new root is created
312          */
313         for (i = 0; i < NUM_KEYS; i++) {
314                 err = test_nodelta_obj_get(&world, objagg, i, false);
315                 if (err)
316                         goto err_obj_second_get;
317         }
318
319         err = check_stats_nodelta(objagg);
320         if (err)
321                 goto err_stats_nodelta;
322
323         for (i = NUM_KEYS - 1; i >= 0; i--) {
324                 err = test_nodelta_obj_put(&world, objagg, i, false);
325                 if (err)
326                         goto err_obj_first_put;
327         }
328         for (i = NUM_KEYS - 1; i >= 0; i--) {
329                 err = test_nodelta_obj_put(&world, objagg, i, true);
330                 if (err)
331                         goto err_obj_second_put;
332         }
333
334         err = check_stats_zero(objagg);
335         if (err)
336                 goto err_stats_second_zero;
337
338         objagg_destroy(objagg);
339         return 0;
340
341 err_stats_nodelta:
342 err_obj_first_put:
343 err_obj_second_get:
344         for (i--; i >= 0; i--)
345                 world_obj_put(&world, objagg, i);
346
347         i = NUM_KEYS;
348 err_obj_first_get:
349 err_obj_second_put:
350         for (i--; i >= 0; i--)
351                 world_obj_put(&world, objagg, i);
352 err_stats_first_zero:
353 err_stats_second_zero:
354         objagg_destroy(objagg);
355         return err;
356 }
357
358 static const struct objagg_ops delta_ops = {
359         .obj_size = sizeof(struct tokey),
360         .delta_create = delta_create,
361         .delta_destroy = delta_destroy,
362         .root_create = root_create,
363         .root_destroy = root_destroy,
364 };
365
366 enum action {
367         ACTION_GET,
368         ACTION_PUT,
369 };
370
371 enum expect_delta {
372         EXPECT_DELTA_SAME,
373         EXPECT_DELTA_INC,
374         EXPECT_DELTA_DEC,
375 };
376
377 enum expect_root {
378         EXPECT_ROOT_SAME,
379         EXPECT_ROOT_INC,
380         EXPECT_ROOT_DEC,
381 };
382
383 struct expect_stats_info {
384         struct objagg_obj_stats stats;
385         bool is_root;
386         unsigned int key_id;
387 };
388
389 struct expect_stats {
390         unsigned int info_count;
391         struct expect_stats_info info[NUM_KEYS];
392 };
393
394 struct action_item {
395         unsigned int key_id;
396         enum action action;
397         enum expect_delta expect_delta;
398         enum expect_root expect_root;
399         struct expect_stats expect_stats;
400 };
401
402 #define EXPECT_STATS(count, ...)                \
403 {                                               \
404         .info_count = count,                    \
405         .info = { __VA_ARGS__ }                 \
406 }
407
408 #define ROOT(key_id, user_count, delta_user_count)      \
409         {{user_count, delta_user_count}, true, key_id}
410
411 #define DELTA(key_id, user_count)                       \
412         {{user_count, user_count}, false, key_id}
413
414 static const struct action_item action_items[] = {
415         {
416                 1, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_INC,
417                 EXPECT_STATS(1, ROOT(1, 1, 1)),
418         },      /* r: 1                 d: */
419         {
420                 7, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_INC,
421                 EXPECT_STATS(2, ROOT(1, 1, 1), ROOT(7, 1, 1)),
422         },      /* r: 1, 7              d: */
423         {
424                 3, ACTION_GET, EXPECT_DELTA_INC, EXPECT_ROOT_SAME,
425                 EXPECT_STATS(3, ROOT(1, 1, 2), ROOT(7, 1, 1),
426                                 DELTA(3, 1)),
427         },      /* r: 1, 7              d: 3^1 */
428         {
429                 5, ACTION_GET, EXPECT_DELTA_INC, EXPECT_ROOT_SAME,
430                 EXPECT_STATS(4, ROOT(1, 1, 3), ROOT(7, 1, 1),
431                                 DELTA(3, 1), DELTA(5, 1)),
432         },      /* r: 1, 7              d: 3^1, 5^1 */
433         {
434                 3, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
435                 EXPECT_STATS(4, ROOT(1, 1, 4), ROOT(7, 1, 1),
436                                 DELTA(3, 2), DELTA(5, 1)),
437         },      /* r: 1, 7              d: 3^1, 3^1, 5^1 */
438         {
439                 1, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
440                 EXPECT_STATS(4, ROOT(1, 2, 5), ROOT(7, 1, 1),
441                                 DELTA(3, 2), DELTA(5, 1)),
442         },      /* r: 1, 1, 7           d: 3^1, 3^1, 5^1 */
443         {
444                 30, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_INC,
445                 EXPECT_STATS(5, ROOT(1, 2, 5), ROOT(7, 1, 1), ROOT(30, 1, 1),
446                                 DELTA(3, 2), DELTA(5, 1)),
447         },      /* r: 1, 1, 7, 30       d: 3^1, 3^1, 5^1 */
448         {
449                 8, ACTION_GET, EXPECT_DELTA_INC, EXPECT_ROOT_SAME,
450                 EXPECT_STATS(6, ROOT(1, 2, 5), ROOT(7, 1, 2), ROOT(30, 1, 1),
451                                 DELTA(3, 2), DELTA(5, 1), DELTA(8, 1)),
452         },      /* r: 1, 1, 7, 30       d: 3^1, 3^1, 5^1, 8^7 */
453         {
454                 8, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
455                 EXPECT_STATS(6, ROOT(1, 2, 5), ROOT(7, 1, 3), ROOT(30, 1, 1),
456                                 DELTA(3, 2), DELTA(8, 2), DELTA(5, 1)),
457         },      /* r: 1, 1, 7, 30       d: 3^1, 3^1, 5^1, 8^7, 8^7 */
458         {
459                 3, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
460                 EXPECT_STATS(6, ROOT(1, 2, 4), ROOT(7, 1, 3), ROOT(30, 1, 1),
461                                 DELTA(8, 2), DELTA(3, 1), DELTA(5, 1)),
462         },      /* r: 1, 1, 7, 30       d: 3^1, 5^1, 8^7, 8^7 */
463         {
464                 3, ACTION_PUT, EXPECT_DELTA_DEC, EXPECT_ROOT_SAME,
465                 EXPECT_STATS(5, ROOT(1, 2, 3), ROOT(7, 1, 3), ROOT(30, 1, 1),
466                                 DELTA(8, 2), DELTA(5, 1)),
467         },      /* r: 1, 1, 7, 30       d: 5^1, 8^7, 8^7 */
468         {
469                 1, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
470                 EXPECT_STATS(5, ROOT(7, 1, 3), ROOT(1, 1, 2), ROOT(30, 1, 1),
471                                 DELTA(8, 2), DELTA(5, 1)),
472         },      /* r: 1, 7, 30          d: 5^1, 8^7, 8^7 */
473         {
474                 1, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
475                 EXPECT_STATS(5, ROOT(7, 1, 3), ROOT(30, 1, 1), ROOT(1, 0, 1),
476                                 DELTA(8, 2), DELTA(5, 1)),
477         },      /* r: 7, 30             d: 5^1, 8^7, 8^7 */
478         {
479                 5, ACTION_PUT, EXPECT_DELTA_DEC, EXPECT_ROOT_DEC,
480                 EXPECT_STATS(3, ROOT(7, 1, 3), ROOT(30, 1, 1),
481                                 DELTA(8, 2)),
482         },      /* r: 7, 30             d: 8^7, 8^7 */
483         {
484                 5, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_INC,
485                 EXPECT_STATS(4, ROOT(7, 1, 3), ROOT(30, 1, 1), ROOT(5, 1, 1),
486                                 DELTA(8, 2)),
487         },      /* r: 7, 30, 5          d: 8^7, 8^7 */
488         {
489                 6, ACTION_GET, EXPECT_DELTA_INC, EXPECT_ROOT_SAME,
490                 EXPECT_STATS(5, ROOT(7, 1, 3), ROOT(5, 1, 2), ROOT(30, 1, 1),
491                                 DELTA(8, 2), DELTA(6, 1)),
492         },      /* r: 7, 30, 5          d: 8^7, 8^7, 6^5 */
493         {
494                 8, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
495                 EXPECT_STATS(5, ROOT(7, 1, 4), ROOT(5, 1, 2), ROOT(30, 1, 1),
496                                 DELTA(8, 3), DELTA(6, 1)),
497         },      /* r: 7, 30, 5          d: 8^7, 8^7, 8^7, 6^5 */
498         {
499                 8, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
500                 EXPECT_STATS(5, ROOT(7, 1, 3), ROOT(5, 1, 2), ROOT(30, 1, 1),
501                                 DELTA(8, 2), DELTA(6, 1)),
502         },      /* r: 7, 30, 5          d: 8^7, 8^7, 6^5 */
503         {
504                 8, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
505                 EXPECT_STATS(5, ROOT(7, 1, 2), ROOT(5, 1, 2), ROOT(30, 1, 1),
506                                 DELTA(8, 1), DELTA(6, 1)),
507         },      /* r: 7, 30, 5          d: 8^7, 6^5 */
508         {
509                 8, ACTION_PUT, EXPECT_DELTA_DEC, EXPECT_ROOT_SAME,
510                 EXPECT_STATS(4, ROOT(5, 1, 2), ROOT(7, 1, 1), ROOT(30, 1, 1),
511                                 DELTA(6, 1)),
512         },      /* r: 7, 30, 5          d: 6^5 */
513         {
514                 8, ACTION_GET, EXPECT_DELTA_INC, EXPECT_ROOT_SAME,
515                 EXPECT_STATS(5, ROOT(5, 1, 3), ROOT(7, 1, 1), ROOT(30, 1, 1),
516                                 DELTA(6, 1), DELTA(8, 1)),
517         },      /* r: 7, 30, 5          d: 6^5, 8^5 */
518         {
519                 7, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_DEC,
520                 EXPECT_STATS(4, ROOT(5, 1, 3), ROOT(30, 1, 1),
521                                 DELTA(6, 1), DELTA(8, 1)),
522         },      /* r: 30, 5             d: 6^5, 8^5 */
523         {
524                 30, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_DEC,
525                 EXPECT_STATS(3, ROOT(5, 1, 3),
526                                 DELTA(6, 1), DELTA(8, 1)),
527         },      /* r: 5                 d: 6^5, 8^5 */
528         {
529                 5, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
530                 EXPECT_STATS(3, ROOT(5, 0, 2),
531                                 DELTA(6, 1), DELTA(8, 1)),
532         },      /* r:                   d: 6^5, 8^5 */
533         {
534                 6, ACTION_PUT, EXPECT_DELTA_DEC, EXPECT_ROOT_SAME,
535                 EXPECT_STATS(2, ROOT(5, 0, 1),
536                                 DELTA(8, 1)),
537         },      /* r:                   d: 6^5 */
538         {
539                 8, ACTION_PUT, EXPECT_DELTA_DEC, EXPECT_ROOT_DEC,
540                 EXPECT_STATS(0, ),
541         },      /* r:                   d: */
542 };
543
544 static int check_expect(struct world *world,
545                         const struct action_item *action_item,
546                         unsigned int orig_delta_count,
547                         unsigned int orig_root_count)
548 {
549         unsigned int key_id = action_item->key_id;
550
551         switch (action_item->expect_delta) {
552         case EXPECT_DELTA_SAME:
553                 if (orig_delta_count != world->delta_count) {
554                         pr_err("Key %u: Delta count changed while expected to remain the same.\n",
555                                key_id);
556                         return -EINVAL;
557                 }
558                 break;
559         case EXPECT_DELTA_INC:
560                 if (WARN_ON(action_item->action == ACTION_PUT))
561                         return -EINVAL;
562                 if (orig_delta_count + 1 != world->delta_count) {
563                         pr_err("Key %u: Delta count was not incremented.\n",
564                                key_id);
565                         return -EINVAL;
566                 }
567                 break;
568         case EXPECT_DELTA_DEC:
569                 if (WARN_ON(action_item->action == ACTION_GET))
570                         return -EINVAL;
571                 if (orig_delta_count - 1 != world->delta_count) {
572                         pr_err("Key %u: Delta count was not decremented.\n",
573                                key_id);
574                         return -EINVAL;
575                 }
576                 break;
577         }
578
579         switch (action_item->expect_root) {
580         case EXPECT_ROOT_SAME:
581                 if (orig_root_count != world->root_count) {
582                         pr_err("Key %u: Root count changed while expected to remain the same.\n",
583                                key_id);
584                         return -EINVAL;
585                 }
586                 break;
587         case EXPECT_ROOT_INC:
588                 if (WARN_ON(action_item->action == ACTION_PUT))
589                         return -EINVAL;
590                 if (orig_root_count + 1 != world->root_count) {
591                         pr_err("Key %u: Root count was not incremented.\n",
592                                key_id);
593                         return -EINVAL;
594                 }
595                 break;
596         case EXPECT_ROOT_DEC:
597                 if (WARN_ON(action_item->action == ACTION_GET))
598                         return -EINVAL;
599                 if (orig_root_count - 1 != world->root_count) {
600                         pr_err("Key %u: Root count was not decremented.\n",
601                                key_id);
602                         return -EINVAL;
603                 }
604         }
605
606         return 0;
607 }
608
609 static unsigned int obj_to_key_id(struct objagg_obj *objagg_obj)
610 {
611         const struct tokey *root_key;
612         const struct delta *delta;
613         unsigned int key_id;
614
615         root_key = objagg_obj_root_priv(objagg_obj);
616         key_id = root_key->id;
617         delta = objagg_obj_delta_priv(objagg_obj);
618         if (delta)
619                 key_id += delta->key_id_diff;
620         return key_id;
621 }
622
623 static int
624 check_expect_stats_nums(const struct objagg_obj_stats_info *stats_info,
625                         const struct expect_stats_info *expect_stats_info,
626                         const char **errmsg)
627 {
628         if (stats_info->is_root != expect_stats_info->is_root) {
629                 if (errmsg)
630                         *errmsg = "Incorrect root/delta indication";
631                 return -EINVAL;
632         }
633         if (stats_info->stats.user_count !=
634             expect_stats_info->stats.user_count) {
635                 if (errmsg)
636                         *errmsg = "Incorrect user count";
637                 return -EINVAL;
638         }
639         if (stats_info->stats.delta_user_count !=
640             expect_stats_info->stats.delta_user_count) {
641                 if (errmsg)
642                         *errmsg = "Incorrect delta user count";
643                 return -EINVAL;
644         }
645         return 0;
646 }
647
648 static int
649 check_expect_stats_key_id(const struct objagg_obj_stats_info *stats_info,
650                           const struct expect_stats_info *expect_stats_info,
651                           const char **errmsg)
652 {
653         if (obj_to_key_id(stats_info->objagg_obj) !=
654             expect_stats_info->key_id) {
655                 if (errmsg)
656                         *errmsg = "incorrect key id";
657                 return -EINVAL;
658         }
659         return 0;
660 }
661
662 static int check_expect_stats_neigh(const struct objagg_stats *stats,
663                                     const struct expect_stats *expect_stats,
664                                     int pos)
665 {
666         int i;
667         int err;
668
669         for (i = pos - 1; i >= 0; i--) {
670                 err = check_expect_stats_nums(&stats->stats_info[i],
671                                               &expect_stats->info[pos], NULL);
672                 if (err)
673                         break;
674                 err = check_expect_stats_key_id(&stats->stats_info[i],
675                                                 &expect_stats->info[pos], NULL);
676                 if (!err)
677                         return 0;
678         }
679         for (i = pos + 1; i < stats->stats_info_count; i++) {
680                 err = check_expect_stats_nums(&stats->stats_info[i],
681                                               &expect_stats->info[pos], NULL);
682                 if (err)
683                         break;
684                 err = check_expect_stats_key_id(&stats->stats_info[i],
685                                                 &expect_stats->info[pos], NULL);
686                 if (!err)
687                         return 0;
688         }
689         return -EINVAL;
690 }
691
692 static int __check_expect_stats(const struct objagg_stats *stats,
693                                 const struct expect_stats *expect_stats,
694                                 const char **errmsg)
695 {
696         int i;
697         int err;
698
699         if (stats->stats_info_count != expect_stats->info_count) {
700                 *errmsg = "Unexpected object count";
701                 return -EINVAL;
702         }
703
704         for (i = 0; i < stats->stats_info_count; i++) {
705                 err = check_expect_stats_nums(&stats->stats_info[i],
706                                               &expect_stats->info[i], errmsg);
707                 if (err)
708                         return err;
709                 err = check_expect_stats_key_id(&stats->stats_info[i],
710                                                 &expect_stats->info[i], errmsg);
711                 if (err) {
712                         /* It is possible that one of the neighbor stats with
713                          * same numbers have the correct key id, so check it
714                          */
715                         err = check_expect_stats_neigh(stats, expect_stats, i);
716                         if (err)
717                                 return err;
718                 }
719         }
720         return 0;
721 }
722
723 static int check_expect_stats(struct objagg *objagg,
724                               const struct expect_stats *expect_stats,
725                               const char **errmsg)
726 {
727         const struct objagg_stats *stats;
728         int err;
729
730         stats = objagg_stats_get(objagg);
731         if (IS_ERR(stats))
732                 return PTR_ERR(stats);
733         err = __check_expect_stats(stats, expect_stats, errmsg);
734         objagg_stats_put(stats);
735         return err;
736 }
737
738 static int test_delta_action_item(struct world *world,
739                                   struct objagg *objagg,
740                                   const struct action_item *action_item,
741                                   bool inverse)
742 {
743         unsigned int orig_delta_count = world->delta_count;
744         unsigned int orig_root_count = world->root_count;
745         unsigned int key_id = action_item->key_id;
746         enum action action = action_item->action;
747         struct objagg_obj *objagg_obj;
748         const char *errmsg;
749         int err;
750
751         if (inverse)
752                 action = action == ACTION_GET ? ACTION_PUT : ACTION_GET;
753
754         switch (action) {
755         case ACTION_GET:
756                 objagg_obj = world_obj_get(world, objagg, key_id);
757                 if (IS_ERR(objagg_obj))
758                         return PTR_ERR(objagg_obj);
759                 break;
760         case ACTION_PUT:
761                 world_obj_put(world, objagg, key_id);
762                 break;
763         }
764
765         if (inverse)
766                 return 0;
767         err = check_expect(world, action_item,
768                            orig_delta_count, orig_root_count);
769         if (err)
770                 goto errout;
771
772         errmsg = NULL;
773         err = check_expect_stats(objagg, &action_item->expect_stats, &errmsg);
774         if (err) {
775                 pr_err("Key %u: Stats: %s\n", action_item->key_id, errmsg);
776                 goto errout;
777         }
778
779         return 0;
780
781 errout:
782         /* This can only happen when action is not inversed.
783          * So in case of an error, cleanup by doing inverse action.
784          */
785         test_delta_action_item(world, objagg, action_item, true);
786         return err;
787 }
788
789 static int test_delta(void)
790 {
791         struct world world = {};
792         struct objagg *objagg;
793         int i;
794         int err;
795
796         objagg = objagg_create(&delta_ops, &world);
797         if (IS_ERR(objagg))
798                 return PTR_ERR(objagg);
799
800         for (i = 0; i < ARRAY_SIZE(action_items); i++) {
801                 err = test_delta_action_item(&world, objagg,
802                                              &action_items[i], false);
803                 if (err)
804                         goto err_do_action_item;
805         }
806
807         objagg_destroy(objagg);
808         return 0;
809
810 err_do_action_item:
811         for (i--; i >= 0; i--)
812                 test_delta_action_item(&world, objagg, &action_items[i], true);
813
814         objagg_destroy(objagg);
815         return err;
816 }
817
818 static int __init test_objagg_init(void)
819 {
820         int err;
821
822         err = test_nodelta();
823         if (err)
824                 return err;
825         return test_delta();
826 }
827
828 static void __exit test_objagg_exit(void)
829 {
830 }
831
832 module_init(test_objagg_init);
833 module_exit(test_objagg_exit);
834 MODULE_LICENSE("Dual BSD/GPL");
835 MODULE_AUTHOR("Jiri Pirko <jiri@mellanox.com>");
836 MODULE_DESCRIPTION("Test module for objagg");