netfilter: nft_ct: reject direction for ct id
[platform/kernel/linux-starfive.git] / net / netfilter / nft_ct.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Copyright (c) 2008-2009 Patrick McHardy <kaber@trash.net>
4  * Copyright (c) 2016 Pablo Neira Ayuso <pablo@netfilter.org>
5  *
6  * Development of this code funded by Astaro AG (http://www.astaro.com/)
7  */
8
9 #include <linux/kernel.h>
10 #include <linux/init.h>
11 #include <linux/module.h>
12 #include <linux/netlink.h>
13 #include <linux/netfilter.h>
14 #include <linux/netfilter/nf_tables.h>
15 #include <net/netfilter/nf_tables_core.h>
16 #include <net/netfilter/nf_conntrack.h>
17 #include <net/netfilter/nf_conntrack_acct.h>
18 #include <net/netfilter/nf_conntrack_tuple.h>
19 #include <net/netfilter/nf_conntrack_helper.h>
20 #include <net/netfilter/nf_conntrack_ecache.h>
21 #include <net/netfilter/nf_conntrack_labels.h>
22 #include <net/netfilter/nf_conntrack_timeout.h>
23 #include <net/netfilter/nf_conntrack_l4proto.h>
24 #include <net/netfilter/nf_conntrack_expect.h>
25
26 struct nft_ct_helper_obj  {
27         struct nf_conntrack_helper *helper4;
28         struct nf_conntrack_helper *helper6;
29         u8 l4proto;
30 };
31
32 #ifdef CONFIG_NF_CONNTRACK_ZONES
33 static DEFINE_PER_CPU(struct nf_conn *, nft_ct_pcpu_template);
34 static unsigned int nft_ct_pcpu_template_refcnt __read_mostly;
35 static DEFINE_MUTEX(nft_ct_pcpu_mutex);
36 #endif
37
38 static u64 nft_ct_get_eval_counter(const struct nf_conn_counter *c,
39                                    enum nft_ct_keys k,
40                                    enum ip_conntrack_dir d)
41 {
42         if (d < IP_CT_DIR_MAX)
43                 return k == NFT_CT_BYTES ? atomic64_read(&c[d].bytes) :
44                                            atomic64_read(&c[d].packets);
45
46         return nft_ct_get_eval_counter(c, k, IP_CT_DIR_ORIGINAL) +
47                nft_ct_get_eval_counter(c, k, IP_CT_DIR_REPLY);
48 }
49
50 static void nft_ct_get_eval(const struct nft_expr *expr,
51                             struct nft_regs *regs,
52                             const struct nft_pktinfo *pkt)
53 {
54         const struct nft_ct *priv = nft_expr_priv(expr);
55         u32 *dest = &regs->data[priv->dreg];
56         enum ip_conntrack_info ctinfo;
57         const struct nf_conn *ct;
58         const struct nf_conn_help *help;
59         const struct nf_conntrack_tuple *tuple;
60         const struct nf_conntrack_helper *helper;
61         unsigned int state;
62
63         ct = nf_ct_get(pkt->skb, &ctinfo);
64
65         switch (priv->key) {
66         case NFT_CT_STATE:
67                 if (ct)
68                         state = NF_CT_STATE_BIT(ctinfo);
69                 else if (ctinfo == IP_CT_UNTRACKED)
70                         state = NF_CT_STATE_UNTRACKED_BIT;
71                 else
72                         state = NF_CT_STATE_INVALID_BIT;
73                 *dest = state;
74                 return;
75         default:
76                 break;
77         }
78
79         if (ct == NULL)
80                 goto err;
81
82         switch (priv->key) {
83         case NFT_CT_DIRECTION:
84                 nft_reg_store8(dest, CTINFO2DIR(ctinfo));
85                 return;
86         case NFT_CT_STATUS:
87                 *dest = ct->status;
88                 return;
89 #ifdef CONFIG_NF_CONNTRACK_MARK
90         case NFT_CT_MARK:
91                 *dest = READ_ONCE(ct->mark);
92                 return;
93 #endif
94 #ifdef CONFIG_NF_CONNTRACK_SECMARK
95         case NFT_CT_SECMARK:
96                 *dest = ct->secmark;
97                 return;
98 #endif
99         case NFT_CT_EXPIRATION:
100                 *dest = jiffies_to_msecs(nf_ct_expires(ct));
101                 return;
102         case NFT_CT_HELPER:
103                 if (ct->master == NULL)
104                         goto err;
105                 help = nfct_help(ct->master);
106                 if (help == NULL)
107                         goto err;
108                 helper = rcu_dereference(help->helper);
109                 if (helper == NULL)
110                         goto err;
111                 strscpy_pad((char *)dest, helper->name, NF_CT_HELPER_NAME_LEN);
112                 return;
113 #ifdef CONFIG_NF_CONNTRACK_LABELS
114         case NFT_CT_LABELS: {
115                 struct nf_conn_labels *labels = nf_ct_labels_find(ct);
116
117                 if (labels)
118                         memcpy(dest, labels->bits, NF_CT_LABELS_MAX_SIZE);
119                 else
120                         memset(dest, 0, NF_CT_LABELS_MAX_SIZE);
121                 return;
122         }
123 #endif
124         case NFT_CT_BYTES:
125         case NFT_CT_PKTS: {
126                 const struct nf_conn_acct *acct = nf_conn_acct_find(ct);
127                 u64 count = 0;
128
129                 if (acct)
130                         count = nft_ct_get_eval_counter(acct->counter,
131                                                         priv->key, priv->dir);
132                 memcpy(dest, &count, sizeof(count));
133                 return;
134         }
135         case NFT_CT_AVGPKT: {
136                 const struct nf_conn_acct *acct = nf_conn_acct_find(ct);
137                 u64 avgcnt = 0, bcnt = 0, pcnt = 0;
138
139                 if (acct) {
140                         pcnt = nft_ct_get_eval_counter(acct->counter,
141                                                        NFT_CT_PKTS, priv->dir);
142                         bcnt = nft_ct_get_eval_counter(acct->counter,
143                                                        NFT_CT_BYTES, priv->dir);
144                         if (pcnt != 0)
145                                 avgcnt = div64_u64(bcnt, pcnt);
146                 }
147
148                 memcpy(dest, &avgcnt, sizeof(avgcnt));
149                 return;
150         }
151         case NFT_CT_L3PROTOCOL:
152                 nft_reg_store8(dest, nf_ct_l3num(ct));
153                 return;
154         case NFT_CT_PROTOCOL:
155                 nft_reg_store8(dest, nf_ct_protonum(ct));
156                 return;
157 #ifdef CONFIG_NF_CONNTRACK_ZONES
158         case NFT_CT_ZONE: {
159                 const struct nf_conntrack_zone *zone = nf_ct_zone(ct);
160                 u16 zoneid;
161
162                 if (priv->dir < IP_CT_DIR_MAX)
163                         zoneid = nf_ct_zone_id(zone, priv->dir);
164                 else
165                         zoneid = zone->id;
166
167                 nft_reg_store16(dest, zoneid);
168                 return;
169         }
170 #endif
171         case NFT_CT_ID:
172                 *dest = nf_ct_get_id(ct);
173                 return;
174         default:
175                 break;
176         }
177
178         tuple = &ct->tuplehash[priv->dir].tuple;
179         switch (priv->key) {
180         case NFT_CT_SRC:
181                 memcpy(dest, tuple->src.u3.all,
182                        nf_ct_l3num(ct) == NFPROTO_IPV4 ? 4 : 16);
183                 return;
184         case NFT_CT_DST:
185                 memcpy(dest, tuple->dst.u3.all,
186                        nf_ct_l3num(ct) == NFPROTO_IPV4 ? 4 : 16);
187                 return;
188         case NFT_CT_PROTO_SRC:
189                 nft_reg_store16(dest, (__force u16)tuple->src.u.all);
190                 return;
191         case NFT_CT_PROTO_DST:
192                 nft_reg_store16(dest, (__force u16)tuple->dst.u.all);
193                 return;
194         case NFT_CT_SRC_IP:
195                 if (nf_ct_l3num(ct) != NFPROTO_IPV4)
196                         goto err;
197                 *dest = (__force __u32)tuple->src.u3.ip;
198                 return;
199         case NFT_CT_DST_IP:
200                 if (nf_ct_l3num(ct) != NFPROTO_IPV4)
201                         goto err;
202                 *dest = (__force __u32)tuple->dst.u3.ip;
203                 return;
204         case NFT_CT_SRC_IP6:
205                 if (nf_ct_l3num(ct) != NFPROTO_IPV6)
206                         goto err;
207                 memcpy(dest, tuple->src.u3.ip6, sizeof(struct in6_addr));
208                 return;
209         case NFT_CT_DST_IP6:
210                 if (nf_ct_l3num(ct) != NFPROTO_IPV6)
211                         goto err;
212                 memcpy(dest, tuple->dst.u3.ip6, sizeof(struct in6_addr));
213                 return;
214         default:
215                 break;
216         }
217         return;
218 err:
219         regs->verdict.code = NFT_BREAK;
220 }
221
222 #ifdef CONFIG_NF_CONNTRACK_ZONES
223 static void nft_ct_set_zone_eval(const struct nft_expr *expr,
224                                  struct nft_regs *regs,
225                                  const struct nft_pktinfo *pkt)
226 {
227         struct nf_conntrack_zone zone = { .dir = NF_CT_DEFAULT_ZONE_DIR };
228         const struct nft_ct *priv = nft_expr_priv(expr);
229         struct sk_buff *skb = pkt->skb;
230         enum ip_conntrack_info ctinfo;
231         u16 value = nft_reg_load16(&regs->data[priv->sreg]);
232         struct nf_conn *ct;
233
234         ct = nf_ct_get(skb, &ctinfo);
235         if (ct) /* already tracked */
236                 return;
237
238         zone.id = value;
239
240         switch (priv->dir) {
241         case IP_CT_DIR_ORIGINAL:
242                 zone.dir = NF_CT_ZONE_DIR_ORIG;
243                 break;
244         case IP_CT_DIR_REPLY:
245                 zone.dir = NF_CT_ZONE_DIR_REPL;
246                 break;
247         default:
248                 break;
249         }
250
251         ct = this_cpu_read(nft_ct_pcpu_template);
252
253         if (likely(refcount_read(&ct->ct_general.use) == 1)) {
254                 refcount_inc(&ct->ct_general.use);
255                 nf_ct_zone_add(ct, &zone);
256         } else {
257                 /* previous skb got queued to userspace, allocate temporary
258                  * one until percpu template can be reused.
259                  */
260                 ct = nf_ct_tmpl_alloc(nft_net(pkt), &zone, GFP_ATOMIC);
261                 if (!ct) {
262                         regs->verdict.code = NF_DROP;
263                         return;
264                 }
265                 __set_bit(IPS_CONFIRMED_BIT, &ct->status);
266         }
267
268         nf_ct_set(skb, ct, IP_CT_NEW);
269 }
270 #endif
271
272 static void nft_ct_set_eval(const struct nft_expr *expr,
273                             struct nft_regs *regs,
274                             const struct nft_pktinfo *pkt)
275 {
276         const struct nft_ct *priv = nft_expr_priv(expr);
277         struct sk_buff *skb = pkt->skb;
278 #if defined(CONFIG_NF_CONNTRACK_MARK) || defined(CONFIG_NF_CONNTRACK_SECMARK)
279         u32 value = regs->data[priv->sreg];
280 #endif
281         enum ip_conntrack_info ctinfo;
282         struct nf_conn *ct;
283
284         ct = nf_ct_get(skb, &ctinfo);
285         if (ct == NULL || nf_ct_is_template(ct))
286                 return;
287
288         switch (priv->key) {
289 #ifdef CONFIG_NF_CONNTRACK_MARK
290         case NFT_CT_MARK:
291                 if (READ_ONCE(ct->mark) != value) {
292                         WRITE_ONCE(ct->mark, value);
293                         nf_conntrack_event_cache(IPCT_MARK, ct);
294                 }
295                 break;
296 #endif
297 #ifdef CONFIG_NF_CONNTRACK_SECMARK
298         case NFT_CT_SECMARK:
299                 if (ct->secmark != value) {
300                         ct->secmark = value;
301                         nf_conntrack_event_cache(IPCT_SECMARK, ct);
302                 }
303                 break;
304 #endif
305 #ifdef CONFIG_NF_CONNTRACK_LABELS
306         case NFT_CT_LABELS:
307                 nf_connlabels_replace(ct,
308                                       &regs->data[priv->sreg],
309                                       &regs->data[priv->sreg],
310                                       NF_CT_LABELS_MAX_SIZE / sizeof(u32));
311                 break;
312 #endif
313 #ifdef CONFIG_NF_CONNTRACK_EVENTS
314         case NFT_CT_EVENTMASK: {
315                 struct nf_conntrack_ecache *e = nf_ct_ecache_find(ct);
316                 u32 ctmask = regs->data[priv->sreg];
317
318                 if (e) {
319                         if (e->ctmask != ctmask)
320                                 e->ctmask = ctmask;
321                         break;
322                 }
323
324                 if (ctmask && !nf_ct_is_confirmed(ct))
325                         nf_ct_ecache_ext_add(ct, ctmask, 0, GFP_ATOMIC);
326                 break;
327         }
328 #endif
329         default:
330                 break;
331         }
332 }
333
334 static const struct nla_policy nft_ct_policy[NFTA_CT_MAX + 1] = {
335         [NFTA_CT_DREG]          = { .type = NLA_U32 },
336         [NFTA_CT_KEY]           = NLA_POLICY_MAX(NLA_BE32, 255),
337         [NFTA_CT_DIRECTION]     = { .type = NLA_U8 },
338         [NFTA_CT_SREG]          = { .type = NLA_U32 },
339 };
340
341 #ifdef CONFIG_NF_CONNTRACK_ZONES
342 static void nft_ct_tmpl_put_pcpu(void)
343 {
344         struct nf_conn *ct;
345         int cpu;
346
347         for_each_possible_cpu(cpu) {
348                 ct = per_cpu(nft_ct_pcpu_template, cpu);
349                 if (!ct)
350                         break;
351                 nf_ct_put(ct);
352                 per_cpu(nft_ct_pcpu_template, cpu) = NULL;
353         }
354 }
355
356 static bool nft_ct_tmpl_alloc_pcpu(void)
357 {
358         struct nf_conntrack_zone zone = { .id = 0 };
359         struct nf_conn *tmp;
360         int cpu;
361
362         if (nft_ct_pcpu_template_refcnt)
363                 return true;
364
365         for_each_possible_cpu(cpu) {
366                 tmp = nf_ct_tmpl_alloc(&init_net, &zone, GFP_KERNEL);
367                 if (!tmp) {
368                         nft_ct_tmpl_put_pcpu();
369                         return false;
370                 }
371
372                 __set_bit(IPS_CONFIRMED_BIT, &tmp->status);
373                 per_cpu(nft_ct_pcpu_template, cpu) = tmp;
374         }
375
376         return true;
377 }
378 #endif
379
380 static int nft_ct_get_init(const struct nft_ctx *ctx,
381                            const struct nft_expr *expr,
382                            const struct nlattr * const tb[])
383 {
384         struct nft_ct *priv = nft_expr_priv(expr);
385         unsigned int len;
386         int err;
387
388         priv->key = ntohl(nla_get_be32(tb[NFTA_CT_KEY]));
389         priv->dir = IP_CT_DIR_MAX;
390         switch (priv->key) {
391         case NFT_CT_DIRECTION:
392                 if (tb[NFTA_CT_DIRECTION] != NULL)
393                         return -EINVAL;
394                 len = sizeof(u8);
395                 break;
396         case NFT_CT_STATE:
397         case NFT_CT_STATUS:
398 #ifdef CONFIG_NF_CONNTRACK_MARK
399         case NFT_CT_MARK:
400 #endif
401 #ifdef CONFIG_NF_CONNTRACK_SECMARK
402         case NFT_CT_SECMARK:
403 #endif
404         case NFT_CT_EXPIRATION:
405                 if (tb[NFTA_CT_DIRECTION] != NULL)
406                         return -EINVAL;
407                 len = sizeof(u32);
408                 break;
409 #ifdef CONFIG_NF_CONNTRACK_LABELS
410         case NFT_CT_LABELS:
411                 if (tb[NFTA_CT_DIRECTION] != NULL)
412                         return -EINVAL;
413                 len = NF_CT_LABELS_MAX_SIZE;
414                 break;
415 #endif
416         case NFT_CT_HELPER:
417                 if (tb[NFTA_CT_DIRECTION] != NULL)
418                         return -EINVAL;
419                 len = NF_CT_HELPER_NAME_LEN;
420                 break;
421
422         case NFT_CT_L3PROTOCOL:
423         case NFT_CT_PROTOCOL:
424                 /* For compatibility, do not report error if NFTA_CT_DIRECTION
425                  * attribute is specified.
426                  */
427                 len = sizeof(u8);
428                 break;
429         case NFT_CT_SRC:
430         case NFT_CT_DST:
431                 if (tb[NFTA_CT_DIRECTION] == NULL)
432                         return -EINVAL;
433
434                 switch (ctx->family) {
435                 case NFPROTO_IPV4:
436                         len = sizeof_field(struct nf_conntrack_tuple,
437                                            src.u3.ip);
438                         break;
439                 case NFPROTO_IPV6:
440                 case NFPROTO_INET:
441                         len = sizeof_field(struct nf_conntrack_tuple,
442                                            src.u3.ip6);
443                         break;
444                 default:
445                         return -EAFNOSUPPORT;
446                 }
447                 break;
448         case NFT_CT_SRC_IP:
449         case NFT_CT_DST_IP:
450                 if (tb[NFTA_CT_DIRECTION] == NULL)
451                         return -EINVAL;
452
453                 len = sizeof_field(struct nf_conntrack_tuple, src.u3.ip);
454                 break;
455         case NFT_CT_SRC_IP6:
456         case NFT_CT_DST_IP6:
457                 if (tb[NFTA_CT_DIRECTION] == NULL)
458                         return -EINVAL;
459
460                 len = sizeof_field(struct nf_conntrack_tuple, src.u3.ip6);
461                 break;
462         case NFT_CT_PROTO_SRC:
463         case NFT_CT_PROTO_DST:
464                 if (tb[NFTA_CT_DIRECTION] == NULL)
465                         return -EINVAL;
466                 len = sizeof_field(struct nf_conntrack_tuple, src.u.all);
467                 break;
468         case NFT_CT_BYTES:
469         case NFT_CT_PKTS:
470         case NFT_CT_AVGPKT:
471                 len = sizeof(u64);
472                 break;
473 #ifdef CONFIG_NF_CONNTRACK_ZONES
474         case NFT_CT_ZONE:
475                 len = sizeof(u16);
476                 break;
477 #endif
478         case NFT_CT_ID:
479                 if (tb[NFTA_CT_DIRECTION])
480                         return -EINVAL;
481
482                 len = sizeof(u32);
483                 break;
484         default:
485                 return -EOPNOTSUPP;
486         }
487
488         if (tb[NFTA_CT_DIRECTION] != NULL) {
489                 priv->dir = nla_get_u8(tb[NFTA_CT_DIRECTION]);
490                 switch (priv->dir) {
491                 case IP_CT_DIR_ORIGINAL:
492                 case IP_CT_DIR_REPLY:
493                         break;
494                 default:
495                         return -EINVAL;
496                 }
497         }
498
499         priv->len = len;
500         err = nft_parse_register_store(ctx, tb[NFTA_CT_DREG], &priv->dreg, NULL,
501                                        NFT_DATA_VALUE, len);
502         if (err < 0)
503                 return err;
504
505         err = nf_ct_netns_get(ctx->net, ctx->family);
506         if (err < 0)
507                 return err;
508
509         if (priv->key == NFT_CT_BYTES ||
510             priv->key == NFT_CT_PKTS  ||
511             priv->key == NFT_CT_AVGPKT)
512                 nf_ct_set_acct(ctx->net, true);
513
514         return 0;
515 }
516
517 static void __nft_ct_set_destroy(const struct nft_ctx *ctx, struct nft_ct *priv)
518 {
519         switch (priv->key) {
520 #ifdef CONFIG_NF_CONNTRACK_LABELS
521         case NFT_CT_LABELS:
522                 nf_connlabels_put(ctx->net);
523                 break;
524 #endif
525 #ifdef CONFIG_NF_CONNTRACK_ZONES
526         case NFT_CT_ZONE:
527                 mutex_lock(&nft_ct_pcpu_mutex);
528                 if (--nft_ct_pcpu_template_refcnt == 0)
529                         nft_ct_tmpl_put_pcpu();
530                 mutex_unlock(&nft_ct_pcpu_mutex);
531                 break;
532 #endif
533         default:
534                 break;
535         }
536 }
537
538 static int nft_ct_set_init(const struct nft_ctx *ctx,
539                            const struct nft_expr *expr,
540                            const struct nlattr * const tb[])
541 {
542         struct nft_ct *priv = nft_expr_priv(expr);
543         unsigned int len;
544         int err;
545
546         priv->dir = IP_CT_DIR_MAX;
547         priv->key = ntohl(nla_get_be32(tb[NFTA_CT_KEY]));
548         switch (priv->key) {
549 #ifdef CONFIG_NF_CONNTRACK_MARK
550         case NFT_CT_MARK:
551                 if (tb[NFTA_CT_DIRECTION])
552                         return -EINVAL;
553                 len = sizeof_field(struct nf_conn, mark);
554                 break;
555 #endif
556 #ifdef CONFIG_NF_CONNTRACK_LABELS
557         case NFT_CT_LABELS:
558                 if (tb[NFTA_CT_DIRECTION])
559                         return -EINVAL;
560                 len = NF_CT_LABELS_MAX_SIZE;
561                 err = nf_connlabels_get(ctx->net, (len * BITS_PER_BYTE) - 1);
562                 if (err)
563                         return err;
564                 break;
565 #endif
566 #ifdef CONFIG_NF_CONNTRACK_ZONES
567         case NFT_CT_ZONE:
568                 mutex_lock(&nft_ct_pcpu_mutex);
569                 if (!nft_ct_tmpl_alloc_pcpu()) {
570                         mutex_unlock(&nft_ct_pcpu_mutex);
571                         return -ENOMEM;
572                 }
573                 nft_ct_pcpu_template_refcnt++;
574                 mutex_unlock(&nft_ct_pcpu_mutex);
575                 len = sizeof(u16);
576                 break;
577 #endif
578 #ifdef CONFIG_NF_CONNTRACK_EVENTS
579         case NFT_CT_EVENTMASK:
580                 if (tb[NFTA_CT_DIRECTION])
581                         return -EINVAL;
582                 len = sizeof(u32);
583                 break;
584 #endif
585 #ifdef CONFIG_NF_CONNTRACK_SECMARK
586         case NFT_CT_SECMARK:
587                 if (tb[NFTA_CT_DIRECTION])
588                         return -EINVAL;
589                 len = sizeof(u32);
590                 break;
591 #endif
592         default:
593                 return -EOPNOTSUPP;
594         }
595
596         if (tb[NFTA_CT_DIRECTION]) {
597                 priv->dir = nla_get_u8(tb[NFTA_CT_DIRECTION]);
598                 switch (priv->dir) {
599                 case IP_CT_DIR_ORIGINAL:
600                 case IP_CT_DIR_REPLY:
601                         break;
602                 default:
603                         err = -EINVAL;
604                         goto err1;
605                 }
606         }
607
608         priv->len = len;
609         err = nft_parse_register_load(tb[NFTA_CT_SREG], &priv->sreg, len);
610         if (err < 0)
611                 goto err1;
612
613         err = nf_ct_netns_get(ctx->net, ctx->family);
614         if (err < 0)
615                 goto err1;
616
617         return 0;
618
619 err1:
620         __nft_ct_set_destroy(ctx, priv);
621         return err;
622 }
623
624 static void nft_ct_get_destroy(const struct nft_ctx *ctx,
625                                const struct nft_expr *expr)
626 {
627         nf_ct_netns_put(ctx->net, ctx->family);
628 }
629
630 static void nft_ct_set_destroy(const struct nft_ctx *ctx,
631                                const struct nft_expr *expr)
632 {
633         struct nft_ct *priv = nft_expr_priv(expr);
634
635         __nft_ct_set_destroy(ctx, priv);
636         nf_ct_netns_put(ctx->net, ctx->family);
637 }
638
639 static int nft_ct_get_dump(struct sk_buff *skb,
640                            const struct nft_expr *expr, bool reset)
641 {
642         const struct nft_ct *priv = nft_expr_priv(expr);
643
644         if (nft_dump_register(skb, NFTA_CT_DREG, priv->dreg))
645                 goto nla_put_failure;
646         if (nla_put_be32(skb, NFTA_CT_KEY, htonl(priv->key)))
647                 goto nla_put_failure;
648
649         switch (priv->key) {
650         case NFT_CT_SRC:
651         case NFT_CT_DST:
652         case NFT_CT_SRC_IP:
653         case NFT_CT_DST_IP:
654         case NFT_CT_SRC_IP6:
655         case NFT_CT_DST_IP6:
656         case NFT_CT_PROTO_SRC:
657         case NFT_CT_PROTO_DST:
658                 if (nla_put_u8(skb, NFTA_CT_DIRECTION, priv->dir))
659                         goto nla_put_failure;
660                 break;
661         case NFT_CT_BYTES:
662         case NFT_CT_PKTS:
663         case NFT_CT_AVGPKT:
664         case NFT_CT_ZONE:
665                 if (priv->dir < IP_CT_DIR_MAX &&
666                     nla_put_u8(skb, NFTA_CT_DIRECTION, priv->dir))
667                         goto nla_put_failure;
668                 break;
669         default:
670                 break;
671         }
672
673         return 0;
674
675 nla_put_failure:
676         return -1;
677 }
678
679 static bool nft_ct_get_reduce(struct nft_regs_track *track,
680                               const struct nft_expr *expr)
681 {
682         const struct nft_ct *priv = nft_expr_priv(expr);
683         const struct nft_ct *ct;
684
685         if (!nft_reg_track_cmp(track, expr, priv->dreg)) {
686                 nft_reg_track_update(track, expr, priv->dreg, priv->len);
687                 return false;
688         }
689
690         ct = nft_expr_priv(track->regs[priv->dreg].selector);
691         if (priv->key != ct->key) {
692                 nft_reg_track_update(track, expr, priv->dreg, priv->len);
693                 return false;
694         }
695
696         if (!track->regs[priv->dreg].bitwise)
697                 return true;
698
699         return nft_expr_reduce_bitwise(track, expr);
700 }
701
702 static int nft_ct_set_dump(struct sk_buff *skb,
703                            const struct nft_expr *expr, bool reset)
704 {
705         const struct nft_ct *priv = nft_expr_priv(expr);
706
707         if (nft_dump_register(skb, NFTA_CT_SREG, priv->sreg))
708                 goto nla_put_failure;
709         if (nla_put_be32(skb, NFTA_CT_KEY, htonl(priv->key)))
710                 goto nla_put_failure;
711
712         switch (priv->key) {
713         case NFT_CT_ZONE:
714                 if (priv->dir < IP_CT_DIR_MAX &&
715                     nla_put_u8(skb, NFTA_CT_DIRECTION, priv->dir))
716                         goto nla_put_failure;
717                 break;
718         default:
719                 break;
720         }
721
722         return 0;
723
724 nla_put_failure:
725         return -1;
726 }
727
728 static struct nft_expr_type nft_ct_type;
729 static const struct nft_expr_ops nft_ct_get_ops = {
730         .type           = &nft_ct_type,
731         .size           = NFT_EXPR_SIZE(sizeof(struct nft_ct)),
732         .eval           = nft_ct_get_eval,
733         .init           = nft_ct_get_init,
734         .destroy        = nft_ct_get_destroy,
735         .dump           = nft_ct_get_dump,
736         .reduce         = nft_ct_get_reduce,
737 };
738
739 static bool nft_ct_set_reduce(struct nft_regs_track *track,
740                               const struct nft_expr *expr)
741 {
742         int i;
743
744         for (i = 0; i < NFT_REG32_NUM; i++) {
745                 if (!track->regs[i].selector)
746                         continue;
747
748                 if (track->regs[i].selector->ops != &nft_ct_get_ops)
749                         continue;
750
751                 __nft_reg_track_cancel(track, i);
752         }
753
754         return false;
755 }
756
757 #ifdef CONFIG_RETPOLINE
758 static const struct nft_expr_ops nft_ct_get_fast_ops = {
759         .type           = &nft_ct_type,
760         .size           = NFT_EXPR_SIZE(sizeof(struct nft_ct)),
761         .eval           = nft_ct_get_fast_eval,
762         .init           = nft_ct_get_init,
763         .destroy        = nft_ct_get_destroy,
764         .dump           = nft_ct_get_dump,
765         .reduce         = nft_ct_set_reduce,
766 };
767 #endif
768
769 static const struct nft_expr_ops nft_ct_set_ops = {
770         .type           = &nft_ct_type,
771         .size           = NFT_EXPR_SIZE(sizeof(struct nft_ct)),
772         .eval           = nft_ct_set_eval,
773         .init           = nft_ct_set_init,
774         .destroy        = nft_ct_set_destroy,
775         .dump           = nft_ct_set_dump,
776         .reduce         = nft_ct_set_reduce,
777 };
778
779 #ifdef CONFIG_NF_CONNTRACK_ZONES
780 static const struct nft_expr_ops nft_ct_set_zone_ops = {
781         .type           = &nft_ct_type,
782         .size           = NFT_EXPR_SIZE(sizeof(struct nft_ct)),
783         .eval           = nft_ct_set_zone_eval,
784         .init           = nft_ct_set_init,
785         .destroy        = nft_ct_set_destroy,
786         .dump           = nft_ct_set_dump,
787         .reduce         = nft_ct_set_reduce,
788 };
789 #endif
790
791 static const struct nft_expr_ops *
792 nft_ct_select_ops(const struct nft_ctx *ctx,
793                     const struct nlattr * const tb[])
794 {
795         if (tb[NFTA_CT_KEY] == NULL)
796                 return ERR_PTR(-EINVAL);
797
798         if (tb[NFTA_CT_DREG] && tb[NFTA_CT_SREG])
799                 return ERR_PTR(-EINVAL);
800
801         if (tb[NFTA_CT_DREG]) {
802 #ifdef CONFIG_RETPOLINE
803                 u32 k = ntohl(nla_get_be32(tb[NFTA_CT_KEY]));
804
805                 switch (k) {
806                 case NFT_CT_STATE:
807                 case NFT_CT_DIRECTION:
808                 case NFT_CT_STATUS:
809                 case NFT_CT_MARK:
810                 case NFT_CT_SECMARK:
811                         return &nft_ct_get_fast_ops;
812                 }
813 #endif
814                 return &nft_ct_get_ops;
815         }
816
817         if (tb[NFTA_CT_SREG]) {
818 #ifdef CONFIG_NF_CONNTRACK_ZONES
819                 if (nla_get_be32(tb[NFTA_CT_KEY]) == htonl(NFT_CT_ZONE))
820                         return &nft_ct_set_zone_ops;
821 #endif
822                 return &nft_ct_set_ops;
823         }
824
825         return ERR_PTR(-EINVAL);
826 }
827
828 static struct nft_expr_type nft_ct_type __read_mostly = {
829         .name           = "ct",
830         .select_ops     = nft_ct_select_ops,
831         .policy         = nft_ct_policy,
832         .maxattr        = NFTA_CT_MAX,
833         .owner          = THIS_MODULE,
834 };
835
836 static void nft_notrack_eval(const struct nft_expr *expr,
837                              struct nft_regs *regs,
838                              const struct nft_pktinfo *pkt)
839 {
840         struct sk_buff *skb = pkt->skb;
841         enum ip_conntrack_info ctinfo;
842         struct nf_conn *ct;
843
844         ct = nf_ct_get(pkt->skb, &ctinfo);
845         /* Previously seen (loopback or untracked)?  Ignore. */
846         if (ct || ctinfo == IP_CT_UNTRACKED)
847                 return;
848
849         nf_ct_set(skb, ct, IP_CT_UNTRACKED);
850 }
851
852 static struct nft_expr_type nft_notrack_type;
853 static const struct nft_expr_ops nft_notrack_ops = {
854         .type           = &nft_notrack_type,
855         .size           = NFT_EXPR_SIZE(0),
856         .eval           = nft_notrack_eval,
857         .reduce         = NFT_REDUCE_READONLY,
858 };
859
860 static struct nft_expr_type nft_notrack_type __read_mostly = {
861         .name           = "notrack",
862         .ops            = &nft_notrack_ops,
863         .owner          = THIS_MODULE,
864 };
865
866 #ifdef CONFIG_NF_CONNTRACK_TIMEOUT
867 static int
868 nft_ct_timeout_parse_policy(void *timeouts,
869                             const struct nf_conntrack_l4proto *l4proto,
870                             struct net *net, const struct nlattr *attr)
871 {
872         struct nlattr **tb;
873         int ret = 0;
874
875         tb = kcalloc(l4proto->ctnl_timeout.nlattr_max + 1, sizeof(*tb),
876                      GFP_KERNEL);
877
878         if (!tb)
879                 return -ENOMEM;
880
881         ret = nla_parse_nested_deprecated(tb,
882                                           l4proto->ctnl_timeout.nlattr_max,
883                                           attr,
884                                           l4proto->ctnl_timeout.nla_policy,
885                                           NULL);
886         if (ret < 0)
887                 goto err;
888
889         ret = l4proto->ctnl_timeout.nlattr_to_obj(tb, net, timeouts);
890
891 err:
892         kfree(tb);
893         return ret;
894 }
895
896 struct nft_ct_timeout_obj {
897         struct nf_ct_timeout    *timeout;
898         u8                      l4proto;
899 };
900
901 static void nft_ct_timeout_obj_eval(struct nft_object *obj,
902                                     struct nft_regs *regs,
903                                     const struct nft_pktinfo *pkt)
904 {
905         const struct nft_ct_timeout_obj *priv = nft_obj_data(obj);
906         struct nf_conn *ct = (struct nf_conn *)skb_nfct(pkt->skb);
907         struct nf_conn_timeout *timeout;
908         const unsigned int *values;
909
910         if (priv->l4proto != pkt->tprot)
911                 return;
912
913         if (!ct || nf_ct_is_template(ct) || nf_ct_is_confirmed(ct))
914                 return;
915
916         timeout = nf_ct_timeout_find(ct);
917         if (!timeout) {
918                 timeout = nf_ct_timeout_ext_add(ct, priv->timeout, GFP_ATOMIC);
919                 if (!timeout) {
920                         regs->verdict.code = NF_DROP;
921                         return;
922                 }
923         }
924
925         rcu_assign_pointer(timeout->timeout, priv->timeout);
926
927         /* adjust the timeout as per 'new' state. ct is unconfirmed,
928          * so the current timestamp must not be added.
929          */
930         values = nf_ct_timeout_data(timeout);
931         if (values)
932                 nf_ct_refresh(ct, pkt->skb, values[0]);
933 }
934
935 static int nft_ct_timeout_obj_init(const struct nft_ctx *ctx,
936                                    const struct nlattr * const tb[],
937                                    struct nft_object *obj)
938 {
939         struct nft_ct_timeout_obj *priv = nft_obj_data(obj);
940         const struct nf_conntrack_l4proto *l4proto;
941         struct nf_ct_timeout *timeout;
942         int l3num = ctx->family;
943         __u8 l4num;
944         int ret;
945
946         if (!tb[NFTA_CT_TIMEOUT_L4PROTO] ||
947             !tb[NFTA_CT_TIMEOUT_DATA])
948                 return -EINVAL;
949
950         if (tb[NFTA_CT_TIMEOUT_L3PROTO])
951                 l3num = ntohs(nla_get_be16(tb[NFTA_CT_TIMEOUT_L3PROTO]));
952
953         l4num = nla_get_u8(tb[NFTA_CT_TIMEOUT_L4PROTO]);
954         priv->l4proto = l4num;
955
956         l4proto = nf_ct_l4proto_find(l4num);
957
958         if (l4proto->l4proto != l4num) {
959                 ret = -EOPNOTSUPP;
960                 goto err_proto_put;
961         }
962
963         timeout = kzalloc(sizeof(struct nf_ct_timeout) +
964                           l4proto->ctnl_timeout.obj_size, GFP_KERNEL);
965         if (timeout == NULL) {
966                 ret = -ENOMEM;
967                 goto err_proto_put;
968         }
969
970         ret = nft_ct_timeout_parse_policy(&timeout->data, l4proto, ctx->net,
971                                           tb[NFTA_CT_TIMEOUT_DATA]);
972         if (ret < 0)
973                 goto err_free_timeout;
974
975         timeout->l3num = l3num;
976         timeout->l4proto = l4proto;
977
978         ret = nf_ct_netns_get(ctx->net, ctx->family);
979         if (ret < 0)
980                 goto err_free_timeout;
981
982         priv->timeout = timeout;
983         return 0;
984
985 err_free_timeout:
986         kfree(timeout);
987 err_proto_put:
988         return ret;
989 }
990
991 static void nft_ct_timeout_obj_destroy(const struct nft_ctx *ctx,
992                                        struct nft_object *obj)
993 {
994         struct nft_ct_timeout_obj *priv = nft_obj_data(obj);
995         struct nf_ct_timeout *timeout = priv->timeout;
996
997         nf_ct_untimeout(ctx->net, timeout);
998         nf_ct_netns_put(ctx->net, ctx->family);
999         kfree(priv->timeout);
1000 }
1001
1002 static int nft_ct_timeout_obj_dump(struct sk_buff *skb,
1003                                    struct nft_object *obj, bool reset)
1004 {
1005         const struct nft_ct_timeout_obj *priv = nft_obj_data(obj);
1006         const struct nf_ct_timeout *timeout = priv->timeout;
1007         struct nlattr *nest_params;
1008         int ret;
1009
1010         if (nla_put_u8(skb, NFTA_CT_TIMEOUT_L4PROTO, timeout->l4proto->l4proto) ||
1011             nla_put_be16(skb, NFTA_CT_TIMEOUT_L3PROTO, htons(timeout->l3num)))
1012                 return -1;
1013
1014         nest_params = nla_nest_start(skb, NFTA_CT_TIMEOUT_DATA);
1015         if (!nest_params)
1016                 return -1;
1017
1018         ret = timeout->l4proto->ctnl_timeout.obj_to_nlattr(skb, &timeout->data);
1019         if (ret < 0)
1020                 return -1;
1021         nla_nest_end(skb, nest_params);
1022         return 0;
1023 }
1024
1025 static const struct nla_policy nft_ct_timeout_policy[NFTA_CT_TIMEOUT_MAX + 1] = {
1026         [NFTA_CT_TIMEOUT_L3PROTO] = {.type = NLA_U16 },
1027         [NFTA_CT_TIMEOUT_L4PROTO] = {.type = NLA_U8 },
1028         [NFTA_CT_TIMEOUT_DATA]    = {.type = NLA_NESTED },
1029 };
1030
1031 static struct nft_object_type nft_ct_timeout_obj_type;
1032
1033 static const struct nft_object_ops nft_ct_timeout_obj_ops = {
1034         .type           = &nft_ct_timeout_obj_type,
1035         .size           = sizeof(struct nft_ct_timeout_obj),
1036         .eval           = nft_ct_timeout_obj_eval,
1037         .init           = nft_ct_timeout_obj_init,
1038         .destroy        = nft_ct_timeout_obj_destroy,
1039         .dump           = nft_ct_timeout_obj_dump,
1040 };
1041
1042 static struct nft_object_type nft_ct_timeout_obj_type __read_mostly = {
1043         .type           = NFT_OBJECT_CT_TIMEOUT,
1044         .ops            = &nft_ct_timeout_obj_ops,
1045         .maxattr        = NFTA_CT_TIMEOUT_MAX,
1046         .policy         = nft_ct_timeout_policy,
1047         .owner          = THIS_MODULE,
1048 };
1049 #endif /* CONFIG_NF_CONNTRACK_TIMEOUT */
1050
1051 static int nft_ct_helper_obj_init(const struct nft_ctx *ctx,
1052                                   const struct nlattr * const tb[],
1053                                   struct nft_object *obj)
1054 {
1055         struct nft_ct_helper_obj *priv = nft_obj_data(obj);
1056         struct nf_conntrack_helper *help4, *help6;
1057         char name[NF_CT_HELPER_NAME_LEN];
1058         int family = ctx->family;
1059         int err;
1060
1061         if (!tb[NFTA_CT_HELPER_NAME] || !tb[NFTA_CT_HELPER_L4PROTO])
1062                 return -EINVAL;
1063
1064         priv->l4proto = nla_get_u8(tb[NFTA_CT_HELPER_L4PROTO]);
1065         if (!priv->l4proto)
1066                 return -ENOENT;
1067
1068         nla_strscpy(name, tb[NFTA_CT_HELPER_NAME], sizeof(name));
1069
1070         if (tb[NFTA_CT_HELPER_L3PROTO])
1071                 family = ntohs(nla_get_be16(tb[NFTA_CT_HELPER_L3PROTO]));
1072
1073         help4 = NULL;
1074         help6 = NULL;
1075
1076         switch (family) {
1077         case NFPROTO_IPV4:
1078                 if (ctx->family == NFPROTO_IPV6)
1079                         return -EINVAL;
1080
1081                 help4 = nf_conntrack_helper_try_module_get(name, family,
1082                                                            priv->l4proto);
1083                 break;
1084         case NFPROTO_IPV6:
1085                 if (ctx->family == NFPROTO_IPV4)
1086                         return -EINVAL;
1087
1088                 help6 = nf_conntrack_helper_try_module_get(name, family,
1089                                                            priv->l4proto);
1090                 break;
1091         case NFPROTO_NETDEV:
1092         case NFPROTO_BRIDGE:
1093         case NFPROTO_INET:
1094                 help4 = nf_conntrack_helper_try_module_get(name, NFPROTO_IPV4,
1095                                                            priv->l4proto);
1096                 help6 = nf_conntrack_helper_try_module_get(name, NFPROTO_IPV6,
1097                                                            priv->l4proto);
1098                 break;
1099         default:
1100                 return -EAFNOSUPPORT;
1101         }
1102
1103         /* && is intentional; only error if INET found neither ipv4 or ipv6 */
1104         if (!help4 && !help6)
1105                 return -ENOENT;
1106
1107         priv->helper4 = help4;
1108         priv->helper6 = help6;
1109
1110         err = nf_ct_netns_get(ctx->net, ctx->family);
1111         if (err < 0)
1112                 goto err_put_helper;
1113
1114         return 0;
1115
1116 err_put_helper:
1117         if (priv->helper4)
1118                 nf_conntrack_helper_put(priv->helper4);
1119         if (priv->helper6)
1120                 nf_conntrack_helper_put(priv->helper6);
1121         return err;
1122 }
1123
1124 static void nft_ct_helper_obj_destroy(const struct nft_ctx *ctx,
1125                                       struct nft_object *obj)
1126 {
1127         struct nft_ct_helper_obj *priv = nft_obj_data(obj);
1128
1129         if (priv->helper4)
1130                 nf_conntrack_helper_put(priv->helper4);
1131         if (priv->helper6)
1132                 nf_conntrack_helper_put(priv->helper6);
1133
1134         nf_ct_netns_put(ctx->net, ctx->family);
1135 }
1136
1137 static void nft_ct_helper_obj_eval(struct nft_object *obj,
1138                                    struct nft_regs *regs,
1139                                    const struct nft_pktinfo *pkt)
1140 {
1141         const struct nft_ct_helper_obj *priv = nft_obj_data(obj);
1142         struct nf_conn *ct = (struct nf_conn *)skb_nfct(pkt->skb);
1143         struct nf_conntrack_helper *to_assign = NULL;
1144         struct nf_conn_help *help;
1145
1146         if (!ct ||
1147             nf_ct_is_confirmed(ct) ||
1148             nf_ct_is_template(ct) ||
1149             priv->l4proto != nf_ct_protonum(ct))
1150                 return;
1151
1152         switch (nf_ct_l3num(ct)) {
1153         case NFPROTO_IPV4:
1154                 to_assign = priv->helper4;
1155                 break;
1156         case NFPROTO_IPV6:
1157                 to_assign = priv->helper6;
1158                 break;
1159         default:
1160                 WARN_ON_ONCE(1);
1161                 return;
1162         }
1163
1164         if (!to_assign)
1165                 return;
1166
1167         if (test_bit(IPS_HELPER_BIT, &ct->status))
1168                 return;
1169
1170         help = nf_ct_helper_ext_add(ct, GFP_ATOMIC);
1171         if (help) {
1172                 rcu_assign_pointer(help->helper, to_assign);
1173                 set_bit(IPS_HELPER_BIT, &ct->status);
1174         }
1175 }
1176
1177 static int nft_ct_helper_obj_dump(struct sk_buff *skb,
1178                                   struct nft_object *obj, bool reset)
1179 {
1180         const struct nft_ct_helper_obj *priv = nft_obj_data(obj);
1181         const struct nf_conntrack_helper *helper;
1182         u16 family;
1183
1184         if (priv->helper4 && priv->helper6) {
1185                 family = NFPROTO_INET;
1186                 helper = priv->helper4;
1187         } else if (priv->helper6) {
1188                 family = NFPROTO_IPV6;
1189                 helper = priv->helper6;
1190         } else {
1191                 family = NFPROTO_IPV4;
1192                 helper = priv->helper4;
1193         }
1194
1195         if (nla_put_string(skb, NFTA_CT_HELPER_NAME, helper->name))
1196                 return -1;
1197
1198         if (nla_put_u8(skb, NFTA_CT_HELPER_L4PROTO, priv->l4proto))
1199                 return -1;
1200
1201         if (nla_put_be16(skb, NFTA_CT_HELPER_L3PROTO, htons(family)))
1202                 return -1;
1203
1204         return 0;
1205 }
1206
1207 static const struct nla_policy nft_ct_helper_policy[NFTA_CT_HELPER_MAX + 1] = {
1208         [NFTA_CT_HELPER_NAME] = { .type = NLA_STRING,
1209                                   .len = NF_CT_HELPER_NAME_LEN - 1 },
1210         [NFTA_CT_HELPER_L3PROTO] = { .type = NLA_U16 },
1211         [NFTA_CT_HELPER_L4PROTO] = { .type = NLA_U8 },
1212 };
1213
1214 static struct nft_object_type nft_ct_helper_obj_type;
1215 static const struct nft_object_ops nft_ct_helper_obj_ops = {
1216         .type           = &nft_ct_helper_obj_type,
1217         .size           = sizeof(struct nft_ct_helper_obj),
1218         .eval           = nft_ct_helper_obj_eval,
1219         .init           = nft_ct_helper_obj_init,
1220         .destroy        = nft_ct_helper_obj_destroy,
1221         .dump           = nft_ct_helper_obj_dump,
1222 };
1223
1224 static struct nft_object_type nft_ct_helper_obj_type __read_mostly = {
1225         .type           = NFT_OBJECT_CT_HELPER,
1226         .ops            = &nft_ct_helper_obj_ops,
1227         .maxattr        = NFTA_CT_HELPER_MAX,
1228         .policy         = nft_ct_helper_policy,
1229         .owner          = THIS_MODULE,
1230 };
1231
1232 struct nft_ct_expect_obj {
1233         u16             l3num;
1234         __be16          dport;
1235         u8              l4proto;
1236         u8              size;
1237         u32             timeout;
1238 };
1239
1240 static int nft_ct_expect_obj_init(const struct nft_ctx *ctx,
1241                                   const struct nlattr * const tb[],
1242                                   struct nft_object *obj)
1243 {
1244         struct nft_ct_expect_obj *priv = nft_obj_data(obj);
1245
1246         if (!tb[NFTA_CT_EXPECT_L4PROTO] ||
1247             !tb[NFTA_CT_EXPECT_DPORT] ||
1248             !tb[NFTA_CT_EXPECT_TIMEOUT] ||
1249             !tb[NFTA_CT_EXPECT_SIZE])
1250                 return -EINVAL;
1251
1252         priv->l3num = ctx->family;
1253         if (tb[NFTA_CT_EXPECT_L3PROTO])
1254                 priv->l3num = ntohs(nla_get_be16(tb[NFTA_CT_EXPECT_L3PROTO]));
1255
1256         switch (priv->l3num) {
1257         case NFPROTO_IPV4:
1258         case NFPROTO_IPV6:
1259                 if (priv->l3num != ctx->family)
1260                         return -EINVAL;
1261
1262                 fallthrough;
1263         case NFPROTO_INET:
1264                 break;
1265         default:
1266                 return -EOPNOTSUPP;
1267         }
1268
1269         priv->l4proto = nla_get_u8(tb[NFTA_CT_EXPECT_L4PROTO]);
1270         switch (priv->l4proto) {
1271         case IPPROTO_TCP:
1272         case IPPROTO_UDP:
1273         case IPPROTO_UDPLITE:
1274         case IPPROTO_DCCP:
1275         case IPPROTO_SCTP:
1276                 break;
1277         default:
1278                 return -EOPNOTSUPP;
1279         }
1280
1281         priv->dport = nla_get_be16(tb[NFTA_CT_EXPECT_DPORT]);
1282         priv->timeout = nla_get_u32(tb[NFTA_CT_EXPECT_TIMEOUT]);
1283         priv->size = nla_get_u8(tb[NFTA_CT_EXPECT_SIZE]);
1284
1285         return nf_ct_netns_get(ctx->net, ctx->family);
1286 }
1287
1288 static void nft_ct_expect_obj_destroy(const struct nft_ctx *ctx,
1289                                        struct nft_object *obj)
1290 {
1291         nf_ct_netns_put(ctx->net, ctx->family);
1292 }
1293
1294 static int nft_ct_expect_obj_dump(struct sk_buff *skb,
1295                                   struct nft_object *obj, bool reset)
1296 {
1297         const struct nft_ct_expect_obj *priv = nft_obj_data(obj);
1298
1299         if (nla_put_be16(skb, NFTA_CT_EXPECT_L3PROTO, htons(priv->l3num)) ||
1300             nla_put_u8(skb, NFTA_CT_EXPECT_L4PROTO, priv->l4proto) ||
1301             nla_put_be16(skb, NFTA_CT_EXPECT_DPORT, priv->dport) ||
1302             nla_put_u32(skb, NFTA_CT_EXPECT_TIMEOUT, priv->timeout) ||
1303             nla_put_u8(skb, NFTA_CT_EXPECT_SIZE, priv->size))
1304                 return -1;
1305
1306         return 0;
1307 }
1308
1309 static void nft_ct_expect_obj_eval(struct nft_object *obj,
1310                                    struct nft_regs *regs,
1311                                    const struct nft_pktinfo *pkt)
1312 {
1313         const struct nft_ct_expect_obj *priv = nft_obj_data(obj);
1314         struct nf_conntrack_expect *exp;
1315         enum ip_conntrack_info ctinfo;
1316         struct nf_conn_help *help;
1317         enum ip_conntrack_dir dir;
1318         u16 l3num = priv->l3num;
1319         struct nf_conn *ct;
1320
1321         ct = nf_ct_get(pkt->skb, &ctinfo);
1322         if (!ct || nf_ct_is_confirmed(ct) || nf_ct_is_template(ct)) {
1323                 regs->verdict.code = NFT_BREAK;
1324                 return;
1325         }
1326         dir = CTINFO2DIR(ctinfo);
1327
1328         help = nfct_help(ct);
1329         if (!help)
1330                 help = nf_ct_helper_ext_add(ct, GFP_ATOMIC);
1331         if (!help) {
1332                 regs->verdict.code = NF_DROP;
1333                 return;
1334         }
1335
1336         if (help->expecting[NF_CT_EXPECT_CLASS_DEFAULT] >= priv->size) {
1337                 regs->verdict.code = NFT_BREAK;
1338                 return;
1339         }
1340         if (l3num == NFPROTO_INET)
1341                 l3num = nf_ct_l3num(ct);
1342
1343         exp = nf_ct_expect_alloc(ct);
1344         if (exp == NULL) {
1345                 regs->verdict.code = NF_DROP;
1346                 return;
1347         }
1348         nf_ct_expect_init(exp, NF_CT_EXPECT_CLASS_DEFAULT, l3num,
1349                           &ct->tuplehash[!dir].tuple.src.u3,
1350                           &ct->tuplehash[!dir].tuple.dst.u3,
1351                           priv->l4proto, NULL, &priv->dport);
1352         exp->timeout.expires = jiffies + priv->timeout * HZ;
1353
1354         if (nf_ct_expect_related(exp, 0) != 0)
1355                 regs->verdict.code = NF_DROP;
1356 }
1357
1358 static const struct nla_policy nft_ct_expect_policy[NFTA_CT_EXPECT_MAX + 1] = {
1359         [NFTA_CT_EXPECT_L3PROTO]        = { .type = NLA_U16 },
1360         [NFTA_CT_EXPECT_L4PROTO]        = { .type = NLA_U8 },
1361         [NFTA_CT_EXPECT_DPORT]          = { .type = NLA_U16 },
1362         [NFTA_CT_EXPECT_TIMEOUT]        = { .type = NLA_U32 },
1363         [NFTA_CT_EXPECT_SIZE]           = { .type = NLA_U8 },
1364 };
1365
1366 static struct nft_object_type nft_ct_expect_obj_type;
1367
1368 static const struct nft_object_ops nft_ct_expect_obj_ops = {
1369         .type           = &nft_ct_expect_obj_type,
1370         .size           = sizeof(struct nft_ct_expect_obj),
1371         .eval           = nft_ct_expect_obj_eval,
1372         .init           = nft_ct_expect_obj_init,
1373         .destroy        = nft_ct_expect_obj_destroy,
1374         .dump           = nft_ct_expect_obj_dump,
1375 };
1376
1377 static struct nft_object_type nft_ct_expect_obj_type __read_mostly = {
1378         .type           = NFT_OBJECT_CT_EXPECT,
1379         .ops            = &nft_ct_expect_obj_ops,
1380         .maxattr        = NFTA_CT_EXPECT_MAX,
1381         .policy         = nft_ct_expect_policy,
1382         .owner          = THIS_MODULE,
1383 };
1384
1385 static int __init nft_ct_module_init(void)
1386 {
1387         int err;
1388
1389         BUILD_BUG_ON(NF_CT_LABELS_MAX_SIZE > NFT_REG_SIZE);
1390
1391         err = nft_register_expr(&nft_ct_type);
1392         if (err < 0)
1393                 return err;
1394
1395         err = nft_register_expr(&nft_notrack_type);
1396         if (err < 0)
1397                 goto err1;
1398
1399         err = nft_register_obj(&nft_ct_helper_obj_type);
1400         if (err < 0)
1401                 goto err2;
1402
1403         err = nft_register_obj(&nft_ct_expect_obj_type);
1404         if (err < 0)
1405                 goto err3;
1406 #ifdef CONFIG_NF_CONNTRACK_TIMEOUT
1407         err = nft_register_obj(&nft_ct_timeout_obj_type);
1408         if (err < 0)
1409                 goto err4;
1410 #endif
1411         return 0;
1412
1413 #ifdef CONFIG_NF_CONNTRACK_TIMEOUT
1414 err4:
1415         nft_unregister_obj(&nft_ct_expect_obj_type);
1416 #endif
1417 err3:
1418         nft_unregister_obj(&nft_ct_helper_obj_type);
1419 err2:
1420         nft_unregister_expr(&nft_notrack_type);
1421 err1:
1422         nft_unregister_expr(&nft_ct_type);
1423         return err;
1424 }
1425
1426 static void __exit nft_ct_module_exit(void)
1427 {
1428 #ifdef CONFIG_NF_CONNTRACK_TIMEOUT
1429         nft_unregister_obj(&nft_ct_timeout_obj_type);
1430 #endif
1431         nft_unregister_obj(&nft_ct_expect_obj_type);
1432         nft_unregister_obj(&nft_ct_helper_obj_type);
1433         nft_unregister_expr(&nft_notrack_type);
1434         nft_unregister_expr(&nft_ct_type);
1435 }
1436
1437 module_init(nft_ct_module_init);
1438 module_exit(nft_ct_module_exit);
1439
1440 MODULE_LICENSE("GPL");
1441 MODULE_AUTHOR("Patrick McHardy <kaber@trash.net>");
1442 MODULE_ALIAS_NFT_EXPR("ct");
1443 MODULE_ALIAS_NFT_EXPR("notrack");
1444 MODULE_ALIAS_NFT_OBJ(NFT_OBJECT_CT_HELPER);
1445 MODULE_ALIAS_NFT_OBJ(NFT_OBJECT_CT_TIMEOUT);
1446 MODULE_ALIAS_NFT_OBJ(NFT_OBJECT_CT_EXPECT);
1447 MODULE_DESCRIPTION("Netfilter nf_tables conntrack module");