aac98a3c966e90e76824dfe5dfa362a857128efa
[platform/kernel/linux-starfive.git] / net / netfilter / nft_ct.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Copyright (c) 2008-2009 Patrick McHardy <kaber@trash.net>
4  * Copyright (c) 2016 Pablo Neira Ayuso <pablo@netfilter.org>
5  *
6  * Development of this code funded by Astaro AG (http://www.astaro.com/)
7  */
8
9 #include <linux/kernel.h>
10 #include <linux/init.h>
11 #include <linux/module.h>
12 #include <linux/netlink.h>
13 #include <linux/netfilter.h>
14 #include <linux/netfilter/nf_tables.h>
15 #include <net/netfilter/nf_tables_core.h>
16 #include <net/netfilter/nf_conntrack.h>
17 #include <net/netfilter/nf_conntrack_acct.h>
18 #include <net/netfilter/nf_conntrack_tuple.h>
19 #include <net/netfilter/nf_conntrack_helper.h>
20 #include <net/netfilter/nf_conntrack_ecache.h>
21 #include <net/netfilter/nf_conntrack_labels.h>
22 #include <net/netfilter/nf_conntrack_timeout.h>
23 #include <net/netfilter/nf_conntrack_l4proto.h>
24 #include <net/netfilter/nf_conntrack_expect.h>
25
26 struct nft_ct_helper_obj  {
27         struct nf_conntrack_helper *helper4;
28         struct nf_conntrack_helper *helper6;
29         u8 l4proto;
30 };
31
32 #ifdef CONFIG_NF_CONNTRACK_ZONES
33 static DEFINE_PER_CPU(struct nf_conn *, nft_ct_pcpu_template);
34 static unsigned int nft_ct_pcpu_template_refcnt __read_mostly;
35 static DEFINE_MUTEX(nft_ct_pcpu_mutex);
36 #endif
37
38 static u64 nft_ct_get_eval_counter(const struct nf_conn_counter *c,
39                                    enum nft_ct_keys k,
40                                    enum ip_conntrack_dir d)
41 {
42         if (d < IP_CT_DIR_MAX)
43                 return k == NFT_CT_BYTES ? atomic64_read(&c[d].bytes) :
44                                            atomic64_read(&c[d].packets);
45
46         return nft_ct_get_eval_counter(c, k, IP_CT_DIR_ORIGINAL) +
47                nft_ct_get_eval_counter(c, k, IP_CT_DIR_REPLY);
48 }
49
50 static void nft_ct_get_eval(const struct nft_expr *expr,
51                             struct nft_regs *regs,
52                             const struct nft_pktinfo *pkt)
53 {
54         const struct nft_ct *priv = nft_expr_priv(expr);
55         u32 *dest = &regs->data[priv->dreg];
56         enum ip_conntrack_info ctinfo;
57         const struct nf_conn *ct;
58         const struct nf_conn_help *help;
59         const struct nf_conntrack_tuple *tuple;
60         const struct nf_conntrack_helper *helper;
61         unsigned int state;
62
63         ct = nf_ct_get(pkt->skb, &ctinfo);
64
65         switch (priv->key) {
66         case NFT_CT_STATE:
67                 if (ct)
68                         state = NF_CT_STATE_BIT(ctinfo);
69                 else if (ctinfo == IP_CT_UNTRACKED)
70                         state = NF_CT_STATE_UNTRACKED_BIT;
71                 else
72                         state = NF_CT_STATE_INVALID_BIT;
73                 *dest = state;
74                 return;
75         default:
76                 break;
77         }
78
79         if (ct == NULL)
80                 goto err;
81
82         switch (priv->key) {
83         case NFT_CT_DIRECTION:
84                 nft_reg_store8(dest, CTINFO2DIR(ctinfo));
85                 return;
86         case NFT_CT_STATUS:
87                 *dest = ct->status;
88                 return;
89 #ifdef CONFIG_NF_CONNTRACK_MARK
90         case NFT_CT_MARK:
91                 *dest = READ_ONCE(ct->mark);
92                 return;
93 #endif
94 #ifdef CONFIG_NF_CONNTRACK_SECMARK
95         case NFT_CT_SECMARK:
96                 *dest = ct->secmark;
97                 return;
98 #endif
99         case NFT_CT_EXPIRATION:
100                 *dest = jiffies_to_msecs(nf_ct_expires(ct));
101                 return;
102         case NFT_CT_HELPER:
103                 if (ct->master == NULL)
104                         goto err;
105                 help = nfct_help(ct->master);
106                 if (help == NULL)
107                         goto err;
108                 helper = rcu_dereference(help->helper);
109                 if (helper == NULL)
110                         goto err;
111                 strscpy_pad((char *)dest, helper->name, NF_CT_HELPER_NAME_LEN);
112                 return;
113 #ifdef CONFIG_NF_CONNTRACK_LABELS
114         case NFT_CT_LABELS: {
115                 struct nf_conn_labels *labels = nf_ct_labels_find(ct);
116
117                 if (labels)
118                         memcpy(dest, labels->bits, NF_CT_LABELS_MAX_SIZE);
119                 else
120                         memset(dest, 0, NF_CT_LABELS_MAX_SIZE);
121                 return;
122         }
123 #endif
124         case NFT_CT_BYTES:
125         case NFT_CT_PKTS: {
126                 const struct nf_conn_acct *acct = nf_conn_acct_find(ct);
127                 u64 count = 0;
128
129                 if (acct)
130                         count = nft_ct_get_eval_counter(acct->counter,
131                                                         priv->key, priv->dir);
132                 memcpy(dest, &count, sizeof(count));
133                 return;
134         }
135         case NFT_CT_AVGPKT: {
136                 const struct nf_conn_acct *acct = nf_conn_acct_find(ct);
137                 u64 avgcnt = 0, bcnt = 0, pcnt = 0;
138
139                 if (acct) {
140                         pcnt = nft_ct_get_eval_counter(acct->counter,
141                                                        NFT_CT_PKTS, priv->dir);
142                         bcnt = nft_ct_get_eval_counter(acct->counter,
143                                                        NFT_CT_BYTES, priv->dir);
144                         if (pcnt != 0)
145                                 avgcnt = div64_u64(bcnt, pcnt);
146                 }
147
148                 memcpy(dest, &avgcnt, sizeof(avgcnt));
149                 return;
150         }
151         case NFT_CT_L3PROTOCOL:
152                 nft_reg_store8(dest, nf_ct_l3num(ct));
153                 return;
154         case NFT_CT_PROTOCOL:
155                 nft_reg_store8(dest, nf_ct_protonum(ct));
156                 return;
157 #ifdef CONFIG_NF_CONNTRACK_ZONES
158         case NFT_CT_ZONE: {
159                 const struct nf_conntrack_zone *zone = nf_ct_zone(ct);
160                 u16 zoneid;
161
162                 if (priv->dir < IP_CT_DIR_MAX)
163                         zoneid = nf_ct_zone_id(zone, priv->dir);
164                 else
165                         zoneid = zone->id;
166
167                 nft_reg_store16(dest, zoneid);
168                 return;
169         }
170 #endif
171         case NFT_CT_ID:
172                 *dest = nf_ct_get_id(ct);
173                 return;
174         default:
175                 break;
176         }
177
178         tuple = &ct->tuplehash[priv->dir].tuple;
179         switch (priv->key) {
180         case NFT_CT_SRC:
181                 memcpy(dest, tuple->src.u3.all,
182                        nf_ct_l3num(ct) == NFPROTO_IPV4 ? 4 : 16);
183                 return;
184         case NFT_CT_DST:
185                 memcpy(dest, tuple->dst.u3.all,
186                        nf_ct_l3num(ct) == NFPROTO_IPV4 ? 4 : 16);
187                 return;
188         case NFT_CT_PROTO_SRC:
189                 nft_reg_store16(dest, (__force u16)tuple->src.u.all);
190                 return;
191         case NFT_CT_PROTO_DST:
192                 nft_reg_store16(dest, (__force u16)tuple->dst.u.all);
193                 return;
194         case NFT_CT_SRC_IP:
195                 if (nf_ct_l3num(ct) != NFPROTO_IPV4)
196                         goto err;
197                 *dest = (__force __u32)tuple->src.u3.ip;
198                 return;
199         case NFT_CT_DST_IP:
200                 if (nf_ct_l3num(ct) != NFPROTO_IPV4)
201                         goto err;
202                 *dest = (__force __u32)tuple->dst.u3.ip;
203                 return;
204         case NFT_CT_SRC_IP6:
205                 if (nf_ct_l3num(ct) != NFPROTO_IPV6)
206                         goto err;
207                 memcpy(dest, tuple->src.u3.ip6, sizeof(struct in6_addr));
208                 return;
209         case NFT_CT_DST_IP6:
210                 if (nf_ct_l3num(ct) != NFPROTO_IPV6)
211                         goto err;
212                 memcpy(dest, tuple->dst.u3.ip6, sizeof(struct in6_addr));
213                 return;
214         default:
215                 break;
216         }
217         return;
218 err:
219         regs->verdict.code = NFT_BREAK;
220 }
221
222 #ifdef CONFIG_NF_CONNTRACK_ZONES
223 static void nft_ct_set_zone_eval(const struct nft_expr *expr,
224                                  struct nft_regs *regs,
225                                  const struct nft_pktinfo *pkt)
226 {
227         struct nf_conntrack_zone zone = { .dir = NF_CT_DEFAULT_ZONE_DIR };
228         const struct nft_ct *priv = nft_expr_priv(expr);
229         struct sk_buff *skb = pkt->skb;
230         enum ip_conntrack_info ctinfo;
231         u16 value = nft_reg_load16(&regs->data[priv->sreg]);
232         struct nf_conn *ct;
233
234         ct = nf_ct_get(skb, &ctinfo);
235         if (ct) /* already tracked */
236                 return;
237
238         zone.id = value;
239
240         switch (priv->dir) {
241         case IP_CT_DIR_ORIGINAL:
242                 zone.dir = NF_CT_ZONE_DIR_ORIG;
243                 break;
244         case IP_CT_DIR_REPLY:
245                 zone.dir = NF_CT_ZONE_DIR_REPL;
246                 break;
247         default:
248                 break;
249         }
250
251         ct = this_cpu_read(nft_ct_pcpu_template);
252
253         if (likely(refcount_read(&ct->ct_general.use) == 1)) {
254                 refcount_inc(&ct->ct_general.use);
255                 nf_ct_zone_add(ct, &zone);
256         } else {
257                 /* previous skb got queued to userspace, allocate temporary
258                  * one until percpu template can be reused.
259                  */
260                 ct = nf_ct_tmpl_alloc(nft_net(pkt), &zone, GFP_ATOMIC);
261                 if (!ct) {
262                         regs->verdict.code = NF_DROP;
263                         return;
264                 }
265                 __set_bit(IPS_CONFIRMED_BIT, &ct->status);
266         }
267
268         nf_ct_set(skb, ct, IP_CT_NEW);
269 }
270 #endif
271
272 static void nft_ct_set_eval(const struct nft_expr *expr,
273                             struct nft_regs *regs,
274                             const struct nft_pktinfo *pkt)
275 {
276         const struct nft_ct *priv = nft_expr_priv(expr);
277         struct sk_buff *skb = pkt->skb;
278 #if defined(CONFIG_NF_CONNTRACK_MARK) || defined(CONFIG_NF_CONNTRACK_SECMARK)
279         u32 value = regs->data[priv->sreg];
280 #endif
281         enum ip_conntrack_info ctinfo;
282         struct nf_conn *ct;
283
284         ct = nf_ct_get(skb, &ctinfo);
285         if (ct == NULL || nf_ct_is_template(ct))
286                 return;
287
288         switch (priv->key) {
289 #ifdef CONFIG_NF_CONNTRACK_MARK
290         case NFT_CT_MARK:
291                 if (READ_ONCE(ct->mark) != value) {
292                         WRITE_ONCE(ct->mark, value);
293                         nf_conntrack_event_cache(IPCT_MARK, ct);
294                 }
295                 break;
296 #endif
297 #ifdef CONFIG_NF_CONNTRACK_SECMARK
298         case NFT_CT_SECMARK:
299                 if (ct->secmark != value) {
300                         ct->secmark = value;
301                         nf_conntrack_event_cache(IPCT_SECMARK, ct);
302                 }
303                 break;
304 #endif
305 #ifdef CONFIG_NF_CONNTRACK_LABELS
306         case NFT_CT_LABELS:
307                 nf_connlabels_replace(ct,
308                                       &regs->data[priv->sreg],
309                                       &regs->data[priv->sreg],
310                                       NF_CT_LABELS_MAX_SIZE / sizeof(u32));
311                 break;
312 #endif
313 #ifdef CONFIG_NF_CONNTRACK_EVENTS
314         case NFT_CT_EVENTMASK: {
315                 struct nf_conntrack_ecache *e = nf_ct_ecache_find(ct);
316                 u32 ctmask = regs->data[priv->sreg];
317
318                 if (e) {
319                         if (e->ctmask != ctmask)
320                                 e->ctmask = ctmask;
321                         break;
322                 }
323
324                 if (ctmask && !nf_ct_is_confirmed(ct))
325                         nf_ct_ecache_ext_add(ct, ctmask, 0, GFP_ATOMIC);
326                 break;
327         }
328 #endif
329         default:
330                 break;
331         }
332 }
333
334 static const struct nla_policy nft_ct_policy[NFTA_CT_MAX + 1] = {
335         [NFTA_CT_DREG]          = { .type = NLA_U32 },
336         [NFTA_CT_KEY]           = NLA_POLICY_MAX(NLA_BE32, 255),
337         [NFTA_CT_DIRECTION]     = { .type = NLA_U8 },
338         [NFTA_CT_SREG]          = { .type = NLA_U32 },
339 };
340
341 #ifdef CONFIG_NF_CONNTRACK_ZONES
342 static void nft_ct_tmpl_put_pcpu(void)
343 {
344         struct nf_conn *ct;
345         int cpu;
346
347         for_each_possible_cpu(cpu) {
348                 ct = per_cpu(nft_ct_pcpu_template, cpu);
349                 if (!ct)
350                         break;
351                 nf_ct_put(ct);
352                 per_cpu(nft_ct_pcpu_template, cpu) = NULL;
353         }
354 }
355
356 static bool nft_ct_tmpl_alloc_pcpu(void)
357 {
358         struct nf_conntrack_zone zone = { .id = 0 };
359         struct nf_conn *tmp;
360         int cpu;
361
362         if (nft_ct_pcpu_template_refcnt)
363                 return true;
364
365         for_each_possible_cpu(cpu) {
366                 tmp = nf_ct_tmpl_alloc(&init_net, &zone, GFP_KERNEL);
367                 if (!tmp) {
368                         nft_ct_tmpl_put_pcpu();
369                         return false;
370                 }
371
372                 __set_bit(IPS_CONFIRMED_BIT, &tmp->status);
373                 per_cpu(nft_ct_pcpu_template, cpu) = tmp;
374         }
375
376         return true;
377 }
378 #endif
379
380 static int nft_ct_get_init(const struct nft_ctx *ctx,
381                            const struct nft_expr *expr,
382                            const struct nlattr * const tb[])
383 {
384         struct nft_ct *priv = nft_expr_priv(expr);
385         unsigned int len;
386         int err;
387
388         priv->key = ntohl(nla_get_be32(tb[NFTA_CT_KEY]));
389         priv->dir = IP_CT_DIR_MAX;
390         switch (priv->key) {
391         case NFT_CT_DIRECTION:
392                 if (tb[NFTA_CT_DIRECTION] != NULL)
393                         return -EINVAL;
394                 len = sizeof(u8);
395                 break;
396         case NFT_CT_STATE:
397         case NFT_CT_STATUS:
398 #ifdef CONFIG_NF_CONNTRACK_MARK
399         case NFT_CT_MARK:
400 #endif
401 #ifdef CONFIG_NF_CONNTRACK_SECMARK
402         case NFT_CT_SECMARK:
403 #endif
404         case NFT_CT_EXPIRATION:
405                 if (tb[NFTA_CT_DIRECTION] != NULL)
406                         return -EINVAL;
407                 len = sizeof(u32);
408                 break;
409 #ifdef CONFIG_NF_CONNTRACK_LABELS
410         case NFT_CT_LABELS:
411                 if (tb[NFTA_CT_DIRECTION] != NULL)
412                         return -EINVAL;
413                 len = NF_CT_LABELS_MAX_SIZE;
414                 break;
415 #endif
416         case NFT_CT_HELPER:
417                 if (tb[NFTA_CT_DIRECTION] != NULL)
418                         return -EINVAL;
419                 len = NF_CT_HELPER_NAME_LEN;
420                 break;
421
422         case NFT_CT_L3PROTOCOL:
423         case NFT_CT_PROTOCOL:
424                 /* For compatibility, do not report error if NFTA_CT_DIRECTION
425                  * attribute is specified.
426                  */
427                 len = sizeof(u8);
428                 break;
429         case NFT_CT_SRC:
430         case NFT_CT_DST:
431                 if (tb[NFTA_CT_DIRECTION] == NULL)
432                         return -EINVAL;
433
434                 switch (ctx->family) {
435                 case NFPROTO_IPV4:
436                         len = sizeof_field(struct nf_conntrack_tuple,
437                                            src.u3.ip);
438                         break;
439                 case NFPROTO_IPV6:
440                 case NFPROTO_INET:
441                         len = sizeof_field(struct nf_conntrack_tuple,
442                                            src.u3.ip6);
443                         break;
444                 default:
445                         return -EAFNOSUPPORT;
446                 }
447                 break;
448         case NFT_CT_SRC_IP:
449         case NFT_CT_DST_IP:
450                 if (tb[NFTA_CT_DIRECTION] == NULL)
451                         return -EINVAL;
452
453                 len = sizeof_field(struct nf_conntrack_tuple, src.u3.ip);
454                 break;
455         case NFT_CT_SRC_IP6:
456         case NFT_CT_DST_IP6:
457                 if (tb[NFTA_CT_DIRECTION] == NULL)
458                         return -EINVAL;
459
460                 len = sizeof_field(struct nf_conntrack_tuple, src.u3.ip6);
461                 break;
462         case NFT_CT_PROTO_SRC:
463         case NFT_CT_PROTO_DST:
464                 if (tb[NFTA_CT_DIRECTION] == NULL)
465                         return -EINVAL;
466                 len = sizeof_field(struct nf_conntrack_tuple, src.u.all);
467                 break;
468         case NFT_CT_BYTES:
469         case NFT_CT_PKTS:
470         case NFT_CT_AVGPKT:
471                 len = sizeof(u64);
472                 break;
473 #ifdef CONFIG_NF_CONNTRACK_ZONES
474         case NFT_CT_ZONE:
475                 len = sizeof(u16);
476                 break;
477 #endif
478         case NFT_CT_ID:
479                 len = sizeof(u32);
480                 break;
481         default:
482                 return -EOPNOTSUPP;
483         }
484
485         if (tb[NFTA_CT_DIRECTION] != NULL) {
486                 priv->dir = nla_get_u8(tb[NFTA_CT_DIRECTION]);
487                 switch (priv->dir) {
488                 case IP_CT_DIR_ORIGINAL:
489                 case IP_CT_DIR_REPLY:
490                         break;
491                 default:
492                         return -EINVAL;
493                 }
494         }
495
496         priv->len = len;
497         err = nft_parse_register_store(ctx, tb[NFTA_CT_DREG], &priv->dreg, NULL,
498                                        NFT_DATA_VALUE, len);
499         if (err < 0)
500                 return err;
501
502         err = nf_ct_netns_get(ctx->net, ctx->family);
503         if (err < 0)
504                 return err;
505
506         if (priv->key == NFT_CT_BYTES ||
507             priv->key == NFT_CT_PKTS  ||
508             priv->key == NFT_CT_AVGPKT)
509                 nf_ct_set_acct(ctx->net, true);
510
511         return 0;
512 }
513
514 static void __nft_ct_set_destroy(const struct nft_ctx *ctx, struct nft_ct *priv)
515 {
516         switch (priv->key) {
517 #ifdef CONFIG_NF_CONNTRACK_LABELS
518         case NFT_CT_LABELS:
519                 nf_connlabels_put(ctx->net);
520                 break;
521 #endif
522 #ifdef CONFIG_NF_CONNTRACK_ZONES
523         case NFT_CT_ZONE:
524                 mutex_lock(&nft_ct_pcpu_mutex);
525                 if (--nft_ct_pcpu_template_refcnt == 0)
526                         nft_ct_tmpl_put_pcpu();
527                 mutex_unlock(&nft_ct_pcpu_mutex);
528                 break;
529 #endif
530         default:
531                 break;
532         }
533 }
534
535 static int nft_ct_set_init(const struct nft_ctx *ctx,
536                            const struct nft_expr *expr,
537                            const struct nlattr * const tb[])
538 {
539         struct nft_ct *priv = nft_expr_priv(expr);
540         unsigned int len;
541         int err;
542
543         priv->dir = IP_CT_DIR_MAX;
544         priv->key = ntohl(nla_get_be32(tb[NFTA_CT_KEY]));
545         switch (priv->key) {
546 #ifdef CONFIG_NF_CONNTRACK_MARK
547         case NFT_CT_MARK:
548                 if (tb[NFTA_CT_DIRECTION])
549                         return -EINVAL;
550                 len = sizeof_field(struct nf_conn, mark);
551                 break;
552 #endif
553 #ifdef CONFIG_NF_CONNTRACK_LABELS
554         case NFT_CT_LABELS:
555                 if (tb[NFTA_CT_DIRECTION])
556                         return -EINVAL;
557                 len = NF_CT_LABELS_MAX_SIZE;
558                 err = nf_connlabels_get(ctx->net, (len * BITS_PER_BYTE) - 1);
559                 if (err)
560                         return err;
561                 break;
562 #endif
563 #ifdef CONFIG_NF_CONNTRACK_ZONES
564         case NFT_CT_ZONE:
565                 mutex_lock(&nft_ct_pcpu_mutex);
566                 if (!nft_ct_tmpl_alloc_pcpu()) {
567                         mutex_unlock(&nft_ct_pcpu_mutex);
568                         return -ENOMEM;
569                 }
570                 nft_ct_pcpu_template_refcnt++;
571                 mutex_unlock(&nft_ct_pcpu_mutex);
572                 len = sizeof(u16);
573                 break;
574 #endif
575 #ifdef CONFIG_NF_CONNTRACK_EVENTS
576         case NFT_CT_EVENTMASK:
577                 if (tb[NFTA_CT_DIRECTION])
578                         return -EINVAL;
579                 len = sizeof(u32);
580                 break;
581 #endif
582 #ifdef CONFIG_NF_CONNTRACK_SECMARK
583         case NFT_CT_SECMARK:
584                 if (tb[NFTA_CT_DIRECTION])
585                         return -EINVAL;
586                 len = sizeof(u32);
587                 break;
588 #endif
589         default:
590                 return -EOPNOTSUPP;
591         }
592
593         if (tb[NFTA_CT_DIRECTION]) {
594                 priv->dir = nla_get_u8(tb[NFTA_CT_DIRECTION]);
595                 switch (priv->dir) {
596                 case IP_CT_DIR_ORIGINAL:
597                 case IP_CT_DIR_REPLY:
598                         break;
599                 default:
600                         err = -EINVAL;
601                         goto err1;
602                 }
603         }
604
605         priv->len = len;
606         err = nft_parse_register_load(tb[NFTA_CT_SREG], &priv->sreg, len);
607         if (err < 0)
608                 goto err1;
609
610         err = nf_ct_netns_get(ctx->net, ctx->family);
611         if (err < 0)
612                 goto err1;
613
614         return 0;
615
616 err1:
617         __nft_ct_set_destroy(ctx, priv);
618         return err;
619 }
620
621 static void nft_ct_get_destroy(const struct nft_ctx *ctx,
622                                const struct nft_expr *expr)
623 {
624         nf_ct_netns_put(ctx->net, ctx->family);
625 }
626
627 static void nft_ct_set_destroy(const struct nft_ctx *ctx,
628                                const struct nft_expr *expr)
629 {
630         struct nft_ct *priv = nft_expr_priv(expr);
631
632         __nft_ct_set_destroy(ctx, priv);
633         nf_ct_netns_put(ctx->net, ctx->family);
634 }
635
636 static int nft_ct_get_dump(struct sk_buff *skb,
637                            const struct nft_expr *expr, bool reset)
638 {
639         const struct nft_ct *priv = nft_expr_priv(expr);
640
641         if (nft_dump_register(skb, NFTA_CT_DREG, priv->dreg))
642                 goto nla_put_failure;
643         if (nla_put_be32(skb, NFTA_CT_KEY, htonl(priv->key)))
644                 goto nla_put_failure;
645
646         switch (priv->key) {
647         case NFT_CT_SRC:
648         case NFT_CT_DST:
649         case NFT_CT_SRC_IP:
650         case NFT_CT_DST_IP:
651         case NFT_CT_SRC_IP6:
652         case NFT_CT_DST_IP6:
653         case NFT_CT_PROTO_SRC:
654         case NFT_CT_PROTO_DST:
655                 if (nla_put_u8(skb, NFTA_CT_DIRECTION, priv->dir))
656                         goto nla_put_failure;
657                 break;
658         case NFT_CT_BYTES:
659         case NFT_CT_PKTS:
660         case NFT_CT_AVGPKT:
661         case NFT_CT_ZONE:
662                 if (priv->dir < IP_CT_DIR_MAX &&
663                     nla_put_u8(skb, NFTA_CT_DIRECTION, priv->dir))
664                         goto nla_put_failure;
665                 break;
666         default:
667                 break;
668         }
669
670         return 0;
671
672 nla_put_failure:
673         return -1;
674 }
675
676 static bool nft_ct_get_reduce(struct nft_regs_track *track,
677                               const struct nft_expr *expr)
678 {
679         const struct nft_ct *priv = nft_expr_priv(expr);
680         const struct nft_ct *ct;
681
682         if (!nft_reg_track_cmp(track, expr, priv->dreg)) {
683                 nft_reg_track_update(track, expr, priv->dreg, priv->len);
684                 return false;
685         }
686
687         ct = nft_expr_priv(track->regs[priv->dreg].selector);
688         if (priv->key != ct->key) {
689                 nft_reg_track_update(track, expr, priv->dreg, priv->len);
690                 return false;
691         }
692
693         if (!track->regs[priv->dreg].bitwise)
694                 return true;
695
696         return nft_expr_reduce_bitwise(track, expr);
697 }
698
699 static int nft_ct_set_dump(struct sk_buff *skb,
700                            const struct nft_expr *expr, bool reset)
701 {
702         const struct nft_ct *priv = nft_expr_priv(expr);
703
704         if (nft_dump_register(skb, NFTA_CT_SREG, priv->sreg))
705                 goto nla_put_failure;
706         if (nla_put_be32(skb, NFTA_CT_KEY, htonl(priv->key)))
707                 goto nla_put_failure;
708
709         switch (priv->key) {
710         case NFT_CT_ZONE:
711                 if (priv->dir < IP_CT_DIR_MAX &&
712                     nla_put_u8(skb, NFTA_CT_DIRECTION, priv->dir))
713                         goto nla_put_failure;
714                 break;
715         default:
716                 break;
717         }
718
719         return 0;
720
721 nla_put_failure:
722         return -1;
723 }
724
725 static struct nft_expr_type nft_ct_type;
726 static const struct nft_expr_ops nft_ct_get_ops = {
727         .type           = &nft_ct_type,
728         .size           = NFT_EXPR_SIZE(sizeof(struct nft_ct)),
729         .eval           = nft_ct_get_eval,
730         .init           = nft_ct_get_init,
731         .destroy        = nft_ct_get_destroy,
732         .dump           = nft_ct_get_dump,
733         .reduce         = nft_ct_get_reduce,
734 };
735
736 static bool nft_ct_set_reduce(struct nft_regs_track *track,
737                               const struct nft_expr *expr)
738 {
739         int i;
740
741         for (i = 0; i < NFT_REG32_NUM; i++) {
742                 if (!track->regs[i].selector)
743                         continue;
744
745                 if (track->regs[i].selector->ops != &nft_ct_get_ops)
746                         continue;
747
748                 __nft_reg_track_cancel(track, i);
749         }
750
751         return false;
752 }
753
754 #ifdef CONFIG_RETPOLINE
755 static const struct nft_expr_ops nft_ct_get_fast_ops = {
756         .type           = &nft_ct_type,
757         .size           = NFT_EXPR_SIZE(sizeof(struct nft_ct)),
758         .eval           = nft_ct_get_fast_eval,
759         .init           = nft_ct_get_init,
760         .destroy        = nft_ct_get_destroy,
761         .dump           = nft_ct_get_dump,
762         .reduce         = nft_ct_set_reduce,
763 };
764 #endif
765
766 static const struct nft_expr_ops nft_ct_set_ops = {
767         .type           = &nft_ct_type,
768         .size           = NFT_EXPR_SIZE(sizeof(struct nft_ct)),
769         .eval           = nft_ct_set_eval,
770         .init           = nft_ct_set_init,
771         .destroy        = nft_ct_set_destroy,
772         .dump           = nft_ct_set_dump,
773         .reduce         = nft_ct_set_reduce,
774 };
775
776 #ifdef CONFIG_NF_CONNTRACK_ZONES
777 static const struct nft_expr_ops nft_ct_set_zone_ops = {
778         .type           = &nft_ct_type,
779         .size           = NFT_EXPR_SIZE(sizeof(struct nft_ct)),
780         .eval           = nft_ct_set_zone_eval,
781         .init           = nft_ct_set_init,
782         .destroy        = nft_ct_set_destroy,
783         .dump           = nft_ct_set_dump,
784         .reduce         = nft_ct_set_reduce,
785 };
786 #endif
787
788 static const struct nft_expr_ops *
789 nft_ct_select_ops(const struct nft_ctx *ctx,
790                     const struct nlattr * const tb[])
791 {
792         if (tb[NFTA_CT_KEY] == NULL)
793                 return ERR_PTR(-EINVAL);
794
795         if (tb[NFTA_CT_DREG] && tb[NFTA_CT_SREG])
796                 return ERR_PTR(-EINVAL);
797
798         if (tb[NFTA_CT_DREG]) {
799 #ifdef CONFIG_RETPOLINE
800                 u32 k = ntohl(nla_get_be32(tb[NFTA_CT_KEY]));
801
802                 switch (k) {
803                 case NFT_CT_STATE:
804                 case NFT_CT_DIRECTION:
805                 case NFT_CT_STATUS:
806                 case NFT_CT_MARK:
807                 case NFT_CT_SECMARK:
808                         return &nft_ct_get_fast_ops;
809                 }
810 #endif
811                 return &nft_ct_get_ops;
812         }
813
814         if (tb[NFTA_CT_SREG]) {
815 #ifdef CONFIG_NF_CONNTRACK_ZONES
816                 if (nla_get_be32(tb[NFTA_CT_KEY]) == htonl(NFT_CT_ZONE))
817                         return &nft_ct_set_zone_ops;
818 #endif
819                 return &nft_ct_set_ops;
820         }
821
822         return ERR_PTR(-EINVAL);
823 }
824
825 static struct nft_expr_type nft_ct_type __read_mostly = {
826         .name           = "ct",
827         .select_ops     = nft_ct_select_ops,
828         .policy         = nft_ct_policy,
829         .maxattr        = NFTA_CT_MAX,
830         .owner          = THIS_MODULE,
831 };
832
833 static void nft_notrack_eval(const struct nft_expr *expr,
834                              struct nft_regs *regs,
835                              const struct nft_pktinfo *pkt)
836 {
837         struct sk_buff *skb = pkt->skb;
838         enum ip_conntrack_info ctinfo;
839         struct nf_conn *ct;
840
841         ct = nf_ct_get(pkt->skb, &ctinfo);
842         /* Previously seen (loopback or untracked)?  Ignore. */
843         if (ct || ctinfo == IP_CT_UNTRACKED)
844                 return;
845
846         nf_ct_set(skb, ct, IP_CT_UNTRACKED);
847 }
848
849 static struct nft_expr_type nft_notrack_type;
850 static const struct nft_expr_ops nft_notrack_ops = {
851         .type           = &nft_notrack_type,
852         .size           = NFT_EXPR_SIZE(0),
853         .eval           = nft_notrack_eval,
854         .reduce         = NFT_REDUCE_READONLY,
855 };
856
857 static struct nft_expr_type nft_notrack_type __read_mostly = {
858         .name           = "notrack",
859         .ops            = &nft_notrack_ops,
860         .owner          = THIS_MODULE,
861 };
862
863 #ifdef CONFIG_NF_CONNTRACK_TIMEOUT
864 static int
865 nft_ct_timeout_parse_policy(void *timeouts,
866                             const struct nf_conntrack_l4proto *l4proto,
867                             struct net *net, const struct nlattr *attr)
868 {
869         struct nlattr **tb;
870         int ret = 0;
871
872         tb = kcalloc(l4proto->ctnl_timeout.nlattr_max + 1, sizeof(*tb),
873                      GFP_KERNEL);
874
875         if (!tb)
876                 return -ENOMEM;
877
878         ret = nla_parse_nested_deprecated(tb,
879                                           l4proto->ctnl_timeout.nlattr_max,
880                                           attr,
881                                           l4proto->ctnl_timeout.nla_policy,
882                                           NULL);
883         if (ret < 0)
884                 goto err;
885
886         ret = l4proto->ctnl_timeout.nlattr_to_obj(tb, net, timeouts);
887
888 err:
889         kfree(tb);
890         return ret;
891 }
892
893 struct nft_ct_timeout_obj {
894         struct nf_ct_timeout    *timeout;
895         u8                      l4proto;
896 };
897
898 static void nft_ct_timeout_obj_eval(struct nft_object *obj,
899                                     struct nft_regs *regs,
900                                     const struct nft_pktinfo *pkt)
901 {
902         const struct nft_ct_timeout_obj *priv = nft_obj_data(obj);
903         struct nf_conn *ct = (struct nf_conn *)skb_nfct(pkt->skb);
904         struct nf_conn_timeout *timeout;
905         const unsigned int *values;
906
907         if (priv->l4proto != pkt->tprot)
908                 return;
909
910         if (!ct || nf_ct_is_template(ct) || nf_ct_is_confirmed(ct))
911                 return;
912
913         timeout = nf_ct_timeout_find(ct);
914         if (!timeout) {
915                 timeout = nf_ct_timeout_ext_add(ct, priv->timeout, GFP_ATOMIC);
916                 if (!timeout) {
917                         regs->verdict.code = NF_DROP;
918                         return;
919                 }
920         }
921
922         rcu_assign_pointer(timeout->timeout, priv->timeout);
923
924         /* adjust the timeout as per 'new' state. ct is unconfirmed,
925          * so the current timestamp must not be added.
926          */
927         values = nf_ct_timeout_data(timeout);
928         if (values)
929                 nf_ct_refresh(ct, pkt->skb, values[0]);
930 }
931
932 static int nft_ct_timeout_obj_init(const struct nft_ctx *ctx,
933                                    const struct nlattr * const tb[],
934                                    struct nft_object *obj)
935 {
936         struct nft_ct_timeout_obj *priv = nft_obj_data(obj);
937         const struct nf_conntrack_l4proto *l4proto;
938         struct nf_ct_timeout *timeout;
939         int l3num = ctx->family;
940         __u8 l4num;
941         int ret;
942
943         if (!tb[NFTA_CT_TIMEOUT_L4PROTO] ||
944             !tb[NFTA_CT_TIMEOUT_DATA])
945                 return -EINVAL;
946
947         if (tb[NFTA_CT_TIMEOUT_L3PROTO])
948                 l3num = ntohs(nla_get_be16(tb[NFTA_CT_TIMEOUT_L3PROTO]));
949
950         l4num = nla_get_u8(tb[NFTA_CT_TIMEOUT_L4PROTO]);
951         priv->l4proto = l4num;
952
953         l4proto = nf_ct_l4proto_find(l4num);
954
955         if (l4proto->l4proto != l4num) {
956                 ret = -EOPNOTSUPP;
957                 goto err_proto_put;
958         }
959
960         timeout = kzalloc(sizeof(struct nf_ct_timeout) +
961                           l4proto->ctnl_timeout.obj_size, GFP_KERNEL);
962         if (timeout == NULL) {
963                 ret = -ENOMEM;
964                 goto err_proto_put;
965         }
966
967         ret = nft_ct_timeout_parse_policy(&timeout->data, l4proto, ctx->net,
968                                           tb[NFTA_CT_TIMEOUT_DATA]);
969         if (ret < 0)
970                 goto err_free_timeout;
971
972         timeout->l3num = l3num;
973         timeout->l4proto = l4proto;
974
975         ret = nf_ct_netns_get(ctx->net, ctx->family);
976         if (ret < 0)
977                 goto err_free_timeout;
978
979         priv->timeout = timeout;
980         return 0;
981
982 err_free_timeout:
983         kfree(timeout);
984 err_proto_put:
985         return ret;
986 }
987
988 static void nft_ct_timeout_obj_destroy(const struct nft_ctx *ctx,
989                                        struct nft_object *obj)
990 {
991         struct nft_ct_timeout_obj *priv = nft_obj_data(obj);
992         struct nf_ct_timeout *timeout = priv->timeout;
993
994         nf_ct_untimeout(ctx->net, timeout);
995         nf_ct_netns_put(ctx->net, ctx->family);
996         kfree(priv->timeout);
997 }
998
999 static int nft_ct_timeout_obj_dump(struct sk_buff *skb,
1000                                    struct nft_object *obj, bool reset)
1001 {
1002         const struct nft_ct_timeout_obj *priv = nft_obj_data(obj);
1003         const struct nf_ct_timeout *timeout = priv->timeout;
1004         struct nlattr *nest_params;
1005         int ret;
1006
1007         if (nla_put_u8(skb, NFTA_CT_TIMEOUT_L4PROTO, timeout->l4proto->l4proto) ||
1008             nla_put_be16(skb, NFTA_CT_TIMEOUT_L3PROTO, htons(timeout->l3num)))
1009                 return -1;
1010
1011         nest_params = nla_nest_start(skb, NFTA_CT_TIMEOUT_DATA);
1012         if (!nest_params)
1013                 return -1;
1014
1015         ret = timeout->l4proto->ctnl_timeout.obj_to_nlattr(skb, &timeout->data);
1016         if (ret < 0)
1017                 return -1;
1018         nla_nest_end(skb, nest_params);
1019         return 0;
1020 }
1021
1022 static const struct nla_policy nft_ct_timeout_policy[NFTA_CT_TIMEOUT_MAX + 1] = {
1023         [NFTA_CT_TIMEOUT_L3PROTO] = {.type = NLA_U16 },
1024         [NFTA_CT_TIMEOUT_L4PROTO] = {.type = NLA_U8 },
1025         [NFTA_CT_TIMEOUT_DATA]    = {.type = NLA_NESTED },
1026 };
1027
1028 static struct nft_object_type nft_ct_timeout_obj_type;
1029
1030 static const struct nft_object_ops nft_ct_timeout_obj_ops = {
1031         .type           = &nft_ct_timeout_obj_type,
1032         .size           = sizeof(struct nft_ct_timeout_obj),
1033         .eval           = nft_ct_timeout_obj_eval,
1034         .init           = nft_ct_timeout_obj_init,
1035         .destroy        = nft_ct_timeout_obj_destroy,
1036         .dump           = nft_ct_timeout_obj_dump,
1037 };
1038
1039 static struct nft_object_type nft_ct_timeout_obj_type __read_mostly = {
1040         .type           = NFT_OBJECT_CT_TIMEOUT,
1041         .ops            = &nft_ct_timeout_obj_ops,
1042         .maxattr        = NFTA_CT_TIMEOUT_MAX,
1043         .policy         = nft_ct_timeout_policy,
1044         .owner          = THIS_MODULE,
1045 };
1046 #endif /* CONFIG_NF_CONNTRACK_TIMEOUT */
1047
1048 static int nft_ct_helper_obj_init(const struct nft_ctx *ctx,
1049                                   const struct nlattr * const tb[],
1050                                   struct nft_object *obj)
1051 {
1052         struct nft_ct_helper_obj *priv = nft_obj_data(obj);
1053         struct nf_conntrack_helper *help4, *help6;
1054         char name[NF_CT_HELPER_NAME_LEN];
1055         int family = ctx->family;
1056         int err;
1057
1058         if (!tb[NFTA_CT_HELPER_NAME] || !tb[NFTA_CT_HELPER_L4PROTO])
1059                 return -EINVAL;
1060
1061         priv->l4proto = nla_get_u8(tb[NFTA_CT_HELPER_L4PROTO]);
1062         if (!priv->l4proto)
1063                 return -ENOENT;
1064
1065         nla_strscpy(name, tb[NFTA_CT_HELPER_NAME], sizeof(name));
1066
1067         if (tb[NFTA_CT_HELPER_L3PROTO])
1068                 family = ntohs(nla_get_be16(tb[NFTA_CT_HELPER_L3PROTO]));
1069
1070         help4 = NULL;
1071         help6 = NULL;
1072
1073         switch (family) {
1074         case NFPROTO_IPV4:
1075                 if (ctx->family == NFPROTO_IPV6)
1076                         return -EINVAL;
1077
1078                 help4 = nf_conntrack_helper_try_module_get(name, family,
1079                                                            priv->l4proto);
1080                 break;
1081         case NFPROTO_IPV6:
1082                 if (ctx->family == NFPROTO_IPV4)
1083                         return -EINVAL;
1084
1085                 help6 = nf_conntrack_helper_try_module_get(name, family,
1086                                                            priv->l4proto);
1087                 break;
1088         case NFPROTO_NETDEV:
1089         case NFPROTO_BRIDGE:
1090         case NFPROTO_INET:
1091                 help4 = nf_conntrack_helper_try_module_get(name, NFPROTO_IPV4,
1092                                                            priv->l4proto);
1093                 help6 = nf_conntrack_helper_try_module_get(name, NFPROTO_IPV6,
1094                                                            priv->l4proto);
1095                 break;
1096         default:
1097                 return -EAFNOSUPPORT;
1098         }
1099
1100         /* && is intentional; only error if INET found neither ipv4 or ipv6 */
1101         if (!help4 && !help6)
1102                 return -ENOENT;
1103
1104         priv->helper4 = help4;
1105         priv->helper6 = help6;
1106
1107         err = nf_ct_netns_get(ctx->net, ctx->family);
1108         if (err < 0)
1109                 goto err_put_helper;
1110
1111         return 0;
1112
1113 err_put_helper:
1114         if (priv->helper4)
1115                 nf_conntrack_helper_put(priv->helper4);
1116         if (priv->helper6)
1117                 nf_conntrack_helper_put(priv->helper6);
1118         return err;
1119 }
1120
1121 static void nft_ct_helper_obj_destroy(const struct nft_ctx *ctx,
1122                                       struct nft_object *obj)
1123 {
1124         struct nft_ct_helper_obj *priv = nft_obj_data(obj);
1125
1126         if (priv->helper4)
1127                 nf_conntrack_helper_put(priv->helper4);
1128         if (priv->helper6)
1129                 nf_conntrack_helper_put(priv->helper6);
1130
1131         nf_ct_netns_put(ctx->net, ctx->family);
1132 }
1133
1134 static void nft_ct_helper_obj_eval(struct nft_object *obj,
1135                                    struct nft_regs *regs,
1136                                    const struct nft_pktinfo *pkt)
1137 {
1138         const struct nft_ct_helper_obj *priv = nft_obj_data(obj);
1139         struct nf_conn *ct = (struct nf_conn *)skb_nfct(pkt->skb);
1140         struct nf_conntrack_helper *to_assign = NULL;
1141         struct nf_conn_help *help;
1142
1143         if (!ct ||
1144             nf_ct_is_confirmed(ct) ||
1145             nf_ct_is_template(ct) ||
1146             priv->l4proto != nf_ct_protonum(ct))
1147                 return;
1148
1149         switch (nf_ct_l3num(ct)) {
1150         case NFPROTO_IPV4:
1151                 to_assign = priv->helper4;
1152                 break;
1153         case NFPROTO_IPV6:
1154                 to_assign = priv->helper6;
1155                 break;
1156         default:
1157                 WARN_ON_ONCE(1);
1158                 return;
1159         }
1160
1161         if (!to_assign)
1162                 return;
1163
1164         if (test_bit(IPS_HELPER_BIT, &ct->status))
1165                 return;
1166
1167         help = nf_ct_helper_ext_add(ct, GFP_ATOMIC);
1168         if (help) {
1169                 rcu_assign_pointer(help->helper, to_assign);
1170                 set_bit(IPS_HELPER_BIT, &ct->status);
1171         }
1172 }
1173
1174 static int nft_ct_helper_obj_dump(struct sk_buff *skb,
1175                                   struct nft_object *obj, bool reset)
1176 {
1177         const struct nft_ct_helper_obj *priv = nft_obj_data(obj);
1178         const struct nf_conntrack_helper *helper;
1179         u16 family;
1180
1181         if (priv->helper4 && priv->helper6) {
1182                 family = NFPROTO_INET;
1183                 helper = priv->helper4;
1184         } else if (priv->helper6) {
1185                 family = NFPROTO_IPV6;
1186                 helper = priv->helper6;
1187         } else {
1188                 family = NFPROTO_IPV4;
1189                 helper = priv->helper4;
1190         }
1191
1192         if (nla_put_string(skb, NFTA_CT_HELPER_NAME, helper->name))
1193                 return -1;
1194
1195         if (nla_put_u8(skb, NFTA_CT_HELPER_L4PROTO, priv->l4proto))
1196                 return -1;
1197
1198         if (nla_put_be16(skb, NFTA_CT_HELPER_L3PROTO, htons(family)))
1199                 return -1;
1200
1201         return 0;
1202 }
1203
1204 static const struct nla_policy nft_ct_helper_policy[NFTA_CT_HELPER_MAX + 1] = {
1205         [NFTA_CT_HELPER_NAME] = { .type = NLA_STRING,
1206                                   .len = NF_CT_HELPER_NAME_LEN - 1 },
1207         [NFTA_CT_HELPER_L3PROTO] = { .type = NLA_U16 },
1208         [NFTA_CT_HELPER_L4PROTO] = { .type = NLA_U8 },
1209 };
1210
1211 static struct nft_object_type nft_ct_helper_obj_type;
1212 static const struct nft_object_ops nft_ct_helper_obj_ops = {
1213         .type           = &nft_ct_helper_obj_type,
1214         .size           = sizeof(struct nft_ct_helper_obj),
1215         .eval           = nft_ct_helper_obj_eval,
1216         .init           = nft_ct_helper_obj_init,
1217         .destroy        = nft_ct_helper_obj_destroy,
1218         .dump           = nft_ct_helper_obj_dump,
1219 };
1220
1221 static struct nft_object_type nft_ct_helper_obj_type __read_mostly = {
1222         .type           = NFT_OBJECT_CT_HELPER,
1223         .ops            = &nft_ct_helper_obj_ops,
1224         .maxattr        = NFTA_CT_HELPER_MAX,
1225         .policy         = nft_ct_helper_policy,
1226         .owner          = THIS_MODULE,
1227 };
1228
1229 struct nft_ct_expect_obj {
1230         u16             l3num;
1231         __be16          dport;
1232         u8              l4proto;
1233         u8              size;
1234         u32             timeout;
1235 };
1236
1237 static int nft_ct_expect_obj_init(const struct nft_ctx *ctx,
1238                                   const struct nlattr * const tb[],
1239                                   struct nft_object *obj)
1240 {
1241         struct nft_ct_expect_obj *priv = nft_obj_data(obj);
1242
1243         if (!tb[NFTA_CT_EXPECT_L4PROTO] ||
1244             !tb[NFTA_CT_EXPECT_DPORT] ||
1245             !tb[NFTA_CT_EXPECT_TIMEOUT] ||
1246             !tb[NFTA_CT_EXPECT_SIZE])
1247                 return -EINVAL;
1248
1249         priv->l3num = ctx->family;
1250         if (tb[NFTA_CT_EXPECT_L3PROTO])
1251                 priv->l3num = ntohs(nla_get_be16(tb[NFTA_CT_EXPECT_L3PROTO]));
1252
1253         switch (priv->l3num) {
1254         case NFPROTO_IPV4:
1255         case NFPROTO_IPV6:
1256                 if (priv->l3num != ctx->family)
1257                         return -EINVAL;
1258
1259                 fallthrough;
1260         case NFPROTO_INET:
1261                 break;
1262         default:
1263                 return -EOPNOTSUPP;
1264         }
1265
1266         priv->l4proto = nla_get_u8(tb[NFTA_CT_EXPECT_L4PROTO]);
1267         switch (priv->l4proto) {
1268         case IPPROTO_TCP:
1269         case IPPROTO_UDP:
1270         case IPPROTO_UDPLITE:
1271         case IPPROTO_DCCP:
1272         case IPPROTO_SCTP:
1273                 break;
1274         default:
1275                 return -EOPNOTSUPP;
1276         }
1277
1278         priv->dport = nla_get_be16(tb[NFTA_CT_EXPECT_DPORT]);
1279         priv->timeout = nla_get_u32(tb[NFTA_CT_EXPECT_TIMEOUT]);
1280         priv->size = nla_get_u8(tb[NFTA_CT_EXPECT_SIZE]);
1281
1282         return nf_ct_netns_get(ctx->net, ctx->family);
1283 }
1284
1285 static void nft_ct_expect_obj_destroy(const struct nft_ctx *ctx,
1286                                        struct nft_object *obj)
1287 {
1288         nf_ct_netns_put(ctx->net, ctx->family);
1289 }
1290
1291 static int nft_ct_expect_obj_dump(struct sk_buff *skb,
1292                                   struct nft_object *obj, bool reset)
1293 {
1294         const struct nft_ct_expect_obj *priv = nft_obj_data(obj);
1295
1296         if (nla_put_be16(skb, NFTA_CT_EXPECT_L3PROTO, htons(priv->l3num)) ||
1297             nla_put_u8(skb, NFTA_CT_EXPECT_L4PROTO, priv->l4proto) ||
1298             nla_put_be16(skb, NFTA_CT_EXPECT_DPORT, priv->dport) ||
1299             nla_put_u32(skb, NFTA_CT_EXPECT_TIMEOUT, priv->timeout) ||
1300             nla_put_u8(skb, NFTA_CT_EXPECT_SIZE, priv->size))
1301                 return -1;
1302
1303         return 0;
1304 }
1305
1306 static void nft_ct_expect_obj_eval(struct nft_object *obj,
1307                                    struct nft_regs *regs,
1308                                    const struct nft_pktinfo *pkt)
1309 {
1310         const struct nft_ct_expect_obj *priv = nft_obj_data(obj);
1311         struct nf_conntrack_expect *exp;
1312         enum ip_conntrack_info ctinfo;
1313         struct nf_conn_help *help;
1314         enum ip_conntrack_dir dir;
1315         u16 l3num = priv->l3num;
1316         struct nf_conn *ct;
1317
1318         ct = nf_ct_get(pkt->skb, &ctinfo);
1319         if (!ct || nf_ct_is_confirmed(ct) || nf_ct_is_template(ct)) {
1320                 regs->verdict.code = NFT_BREAK;
1321                 return;
1322         }
1323         dir = CTINFO2DIR(ctinfo);
1324
1325         help = nfct_help(ct);
1326         if (!help)
1327                 help = nf_ct_helper_ext_add(ct, GFP_ATOMIC);
1328         if (!help) {
1329                 regs->verdict.code = NF_DROP;
1330                 return;
1331         }
1332
1333         if (help->expecting[NF_CT_EXPECT_CLASS_DEFAULT] >= priv->size) {
1334                 regs->verdict.code = NFT_BREAK;
1335                 return;
1336         }
1337         if (l3num == NFPROTO_INET)
1338                 l3num = nf_ct_l3num(ct);
1339
1340         exp = nf_ct_expect_alloc(ct);
1341         if (exp == NULL) {
1342                 regs->verdict.code = NF_DROP;
1343                 return;
1344         }
1345         nf_ct_expect_init(exp, NF_CT_EXPECT_CLASS_DEFAULT, l3num,
1346                           &ct->tuplehash[!dir].tuple.src.u3,
1347                           &ct->tuplehash[!dir].tuple.dst.u3,
1348                           priv->l4proto, NULL, &priv->dport);
1349         exp->timeout.expires = jiffies + priv->timeout * HZ;
1350
1351         if (nf_ct_expect_related(exp, 0) != 0)
1352                 regs->verdict.code = NF_DROP;
1353 }
1354
1355 static const struct nla_policy nft_ct_expect_policy[NFTA_CT_EXPECT_MAX + 1] = {
1356         [NFTA_CT_EXPECT_L3PROTO]        = { .type = NLA_U16 },
1357         [NFTA_CT_EXPECT_L4PROTO]        = { .type = NLA_U8 },
1358         [NFTA_CT_EXPECT_DPORT]          = { .type = NLA_U16 },
1359         [NFTA_CT_EXPECT_TIMEOUT]        = { .type = NLA_U32 },
1360         [NFTA_CT_EXPECT_SIZE]           = { .type = NLA_U8 },
1361 };
1362
1363 static struct nft_object_type nft_ct_expect_obj_type;
1364
1365 static const struct nft_object_ops nft_ct_expect_obj_ops = {
1366         .type           = &nft_ct_expect_obj_type,
1367         .size           = sizeof(struct nft_ct_expect_obj),
1368         .eval           = nft_ct_expect_obj_eval,
1369         .init           = nft_ct_expect_obj_init,
1370         .destroy        = nft_ct_expect_obj_destroy,
1371         .dump           = nft_ct_expect_obj_dump,
1372 };
1373
1374 static struct nft_object_type nft_ct_expect_obj_type __read_mostly = {
1375         .type           = NFT_OBJECT_CT_EXPECT,
1376         .ops            = &nft_ct_expect_obj_ops,
1377         .maxattr        = NFTA_CT_EXPECT_MAX,
1378         .policy         = nft_ct_expect_policy,
1379         .owner          = THIS_MODULE,
1380 };
1381
1382 static int __init nft_ct_module_init(void)
1383 {
1384         int err;
1385
1386         BUILD_BUG_ON(NF_CT_LABELS_MAX_SIZE > NFT_REG_SIZE);
1387
1388         err = nft_register_expr(&nft_ct_type);
1389         if (err < 0)
1390                 return err;
1391
1392         err = nft_register_expr(&nft_notrack_type);
1393         if (err < 0)
1394                 goto err1;
1395
1396         err = nft_register_obj(&nft_ct_helper_obj_type);
1397         if (err < 0)
1398                 goto err2;
1399
1400         err = nft_register_obj(&nft_ct_expect_obj_type);
1401         if (err < 0)
1402                 goto err3;
1403 #ifdef CONFIG_NF_CONNTRACK_TIMEOUT
1404         err = nft_register_obj(&nft_ct_timeout_obj_type);
1405         if (err < 0)
1406                 goto err4;
1407 #endif
1408         return 0;
1409
1410 #ifdef CONFIG_NF_CONNTRACK_TIMEOUT
1411 err4:
1412         nft_unregister_obj(&nft_ct_expect_obj_type);
1413 #endif
1414 err3:
1415         nft_unregister_obj(&nft_ct_helper_obj_type);
1416 err2:
1417         nft_unregister_expr(&nft_notrack_type);
1418 err1:
1419         nft_unregister_expr(&nft_ct_type);
1420         return err;
1421 }
1422
1423 static void __exit nft_ct_module_exit(void)
1424 {
1425 #ifdef CONFIG_NF_CONNTRACK_TIMEOUT
1426         nft_unregister_obj(&nft_ct_timeout_obj_type);
1427 #endif
1428         nft_unregister_obj(&nft_ct_expect_obj_type);
1429         nft_unregister_obj(&nft_ct_helper_obj_type);
1430         nft_unregister_expr(&nft_notrack_type);
1431         nft_unregister_expr(&nft_ct_type);
1432 }
1433
1434 module_init(nft_ct_module_init);
1435 module_exit(nft_ct_module_exit);
1436
1437 MODULE_LICENSE("GPL");
1438 MODULE_AUTHOR("Patrick McHardy <kaber@trash.net>");
1439 MODULE_ALIAS_NFT_EXPR("ct");
1440 MODULE_ALIAS_NFT_EXPR("notrack");
1441 MODULE_ALIAS_NFT_OBJ(NFT_OBJECT_CT_HELPER);
1442 MODULE_ALIAS_NFT_OBJ(NFT_OBJECT_CT_TIMEOUT);
1443 MODULE_ALIAS_NFT_OBJ(NFT_OBJECT_CT_EXPECT);
1444 MODULE_DESCRIPTION("Netfilter nf_tables conntrack module");