net/handshake: Remove unneeded check from handshake_dup()
[platform/kernel/linux-starfive.git] / net / handshake / netlink.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Generic netlink handshake service
4  *
5  * Author: Chuck Lever <chuck.lever@oracle.com>
6  *
7  * Copyright (c) 2023, Oracle and/or its affiliates.
8  */
9
10 #include <linux/types.h>
11 #include <linux/socket.h>
12 #include <linux/kernel.h>
13 #include <linux/module.h>
14 #include <linux/skbuff.h>
15 #include <linux/mm.h>
16
17 #include <net/sock.h>
18 #include <net/genetlink.h>
19 #include <net/netns/generic.h>
20
21 #include <kunit/visibility.h>
22
23 #include <uapi/linux/handshake.h>
24 #include "handshake.h"
25 #include "genl.h"
26
27 #include <trace/events/handshake.h>
28
29 /**
30  * handshake_genl_notify - Notify handlers that a request is waiting
31  * @net: target network namespace
32  * @proto: handshake protocol
33  * @flags: memory allocation control flags
34  *
35  * Returns zero on success or a negative errno if notification failed.
36  */
37 int handshake_genl_notify(struct net *net, const struct handshake_proto *proto,
38                           gfp_t flags)
39 {
40         struct sk_buff *msg;
41         void *hdr;
42
43         /* Disable notifications during unit testing */
44         if (!test_bit(HANDSHAKE_F_PROTO_NOTIFY, &proto->hp_flags))
45                 return 0;
46
47         if (!genl_has_listeners(&handshake_nl_family, net,
48                                 proto->hp_handler_class))
49                 return -ESRCH;
50
51         msg = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL);
52         if (!msg)
53                 return -ENOMEM;
54
55         hdr = genlmsg_put(msg, 0, 0, &handshake_nl_family, 0,
56                           HANDSHAKE_CMD_READY);
57         if (!hdr)
58                 goto out_free;
59
60         if (nla_put_u32(msg, HANDSHAKE_A_ACCEPT_HANDLER_CLASS,
61                         proto->hp_handler_class) < 0) {
62                 genlmsg_cancel(msg, hdr);
63                 goto out_free;
64         }
65
66         genlmsg_end(msg, hdr);
67         return genlmsg_multicast_netns(&handshake_nl_family, net, msg,
68                                        0, proto->hp_handler_class, flags);
69
70 out_free:
71         nlmsg_free(msg);
72         return -EMSGSIZE;
73 }
74
75 /**
76  * handshake_genl_put - Create a generic netlink message header
77  * @msg: buffer in which to create the header
78  * @info: generic netlink message context
79  *
80  * Returns a ready-to-use header, or NULL.
81  */
82 struct nlmsghdr *handshake_genl_put(struct sk_buff *msg,
83                                     struct genl_info *info)
84 {
85         return genlmsg_put(msg, info->snd_portid, info->snd_seq,
86                            &handshake_nl_family, 0, info->genlhdr->cmd);
87 }
88 EXPORT_SYMBOL(handshake_genl_put);
89
90 /*
91  * dup() a kernel socket for use as a user space file descriptor
92  * in the current process. The kernel socket must have an
93  * instatiated struct file.
94  *
95  * Implicit argument: "current()"
96  */
97 static int handshake_dup(struct socket *sock)
98 {
99         struct file *file;
100         int newfd;
101
102         file = get_file(sock->file);
103         newfd = get_unused_fd_flags(O_CLOEXEC);
104         if (newfd < 0) {
105                 fput(file);
106                 return newfd;
107         }
108
109         fd_install(newfd, file);
110         return newfd;
111 }
112
113 int handshake_nl_accept_doit(struct sk_buff *skb, struct genl_info *info)
114 {
115         struct net *net = sock_net(skb->sk);
116         struct handshake_net *hn = handshake_pernet(net);
117         struct handshake_req *req = NULL;
118         struct socket *sock;
119         int class, fd, err;
120
121         err = -EOPNOTSUPP;
122         if (!hn)
123                 goto out_status;
124
125         err = -EINVAL;
126         if (GENL_REQ_ATTR_CHECK(info, HANDSHAKE_A_ACCEPT_HANDLER_CLASS))
127                 goto out_status;
128         class = nla_get_u32(info->attrs[HANDSHAKE_A_ACCEPT_HANDLER_CLASS]);
129
130         err = -EAGAIN;
131         req = handshake_req_next(hn, class);
132         if (!req)
133                 goto out_status;
134
135         sock = req->hr_sk->sk_socket;
136         fd = handshake_dup(sock);
137         if (fd < 0) {
138                 err = fd;
139                 goto out_complete;
140         }
141         err = req->hr_proto->hp_accept(req, info, fd);
142         if (err)
143                 goto out_complete;
144
145         trace_handshake_cmd_accept(net, req, req->hr_sk, fd);
146         return 0;
147
148 out_complete:
149         handshake_complete(req, -EIO, NULL);
150         fput(sock->file);
151 out_status:
152         trace_handshake_cmd_accept_err(net, req, NULL, err);
153         return err;
154 }
155
156 int handshake_nl_done_doit(struct sk_buff *skb, struct genl_info *info)
157 {
158         struct net *net = sock_net(skb->sk);
159         struct socket *sock = NULL;
160         struct handshake_req *req;
161         int fd, status, err;
162
163         if (GENL_REQ_ATTR_CHECK(info, HANDSHAKE_A_DONE_SOCKFD))
164                 return -EINVAL;
165         fd = nla_get_u32(info->attrs[HANDSHAKE_A_DONE_SOCKFD]);
166
167         err = 0;
168         sock = sockfd_lookup(fd, &err);
169         if (err) {
170                 err = -EBADF;
171                 goto out_status;
172         }
173
174         req = handshake_req_hash_lookup(sock->sk);
175         if (!req) {
176                 err = -EBUSY;
177                 fput(sock->file);
178                 goto out_status;
179         }
180
181         trace_handshake_cmd_done(net, req, sock->sk, fd);
182
183         status = -EIO;
184         if (info->attrs[HANDSHAKE_A_DONE_STATUS])
185                 status = nla_get_u32(info->attrs[HANDSHAKE_A_DONE_STATUS]);
186
187         handshake_complete(req, status, info);
188         fput(sock->file);
189         return 0;
190
191 out_status:
192         trace_handshake_cmd_done_err(net, req, sock->sk, err);
193         return err;
194 }
195
196 static unsigned int handshake_net_id;
197
198 static int __net_init handshake_net_init(struct net *net)
199 {
200         struct handshake_net *hn = net_generic(net, handshake_net_id);
201         unsigned long tmp;
202         struct sysinfo si;
203
204         /*
205          * Arbitrary limit to prevent handshakes that do not make
206          * progress from clogging up the system. The cap scales up
207          * with the amount of physical memory on the system.
208          */
209         si_meminfo(&si);
210         tmp = si.totalram / (25 * si.mem_unit);
211         hn->hn_pending_max = clamp(tmp, 3UL, 50UL);
212
213         spin_lock_init(&hn->hn_lock);
214         hn->hn_pending = 0;
215         hn->hn_flags = 0;
216         INIT_LIST_HEAD(&hn->hn_requests);
217         return 0;
218 }
219
220 static void __net_exit handshake_net_exit(struct net *net)
221 {
222         struct handshake_net *hn = net_generic(net, handshake_net_id);
223         struct handshake_req *req;
224         LIST_HEAD(requests);
225
226         /*
227          * Drain the net's pending list. Requests that have been
228          * accepted and are in progress will be destroyed when
229          * the socket is closed.
230          */
231         spin_lock(&hn->hn_lock);
232         set_bit(HANDSHAKE_F_NET_DRAINING, &hn->hn_flags);
233         list_splice_init(&requests, &hn->hn_requests);
234         spin_unlock(&hn->hn_lock);
235
236         while (!list_empty(&requests)) {
237                 req = list_first_entry(&requests, struct handshake_req, hr_list);
238                 list_del(&req->hr_list);
239
240                 /*
241                  * Requests on this list have not yet been
242                  * accepted, so they do not have an fd to put.
243                  */
244
245                 handshake_complete(req, -ETIMEDOUT, NULL);
246         }
247 }
248
249 static struct pernet_operations handshake_genl_net_ops = {
250         .init           = handshake_net_init,
251         .exit           = handshake_net_exit,
252         .id             = &handshake_net_id,
253         .size           = sizeof(struct handshake_net),
254 };
255
256 /**
257  * handshake_pernet - Get the handshake private per-net structure
258  * @net: network namespace
259  *
260  * Returns a pointer to the net's private per-net structure for the
261  * handshake module, or NULL if handshake_init() failed.
262  */
263 struct handshake_net *handshake_pernet(struct net *net)
264 {
265         return handshake_net_id ?
266                 net_generic(net, handshake_net_id) : NULL;
267 }
268 EXPORT_SYMBOL_IF_KUNIT(handshake_pernet);
269
270 static int __init handshake_init(void)
271 {
272         int ret;
273
274         ret = handshake_req_hash_init();
275         if (ret) {
276                 pr_warn("handshake: hash initialization failed (%d)\n", ret);
277                 return ret;
278         }
279
280         ret = genl_register_family(&handshake_nl_family);
281         if (ret) {
282                 pr_warn("handshake: netlink registration failed (%d)\n", ret);
283                 handshake_req_hash_destroy();
284                 return ret;
285         }
286
287         /*
288          * ORDER: register_pernet_subsys must be done last.
289          *
290          *      If initialization does not make it past pernet_subsys
291          *      registration, then handshake_net_id will remain 0. That
292          *      shunts the handshake consumer API to return ENOTSUPP
293          *      to prevent it from dereferencing something that hasn't
294          *      been allocated.
295          */
296         ret = register_pernet_subsys(&handshake_genl_net_ops);
297         if (ret) {
298                 pr_warn("handshake: pernet registration failed (%d)\n", ret);
299                 genl_unregister_family(&handshake_nl_family);
300                 handshake_req_hash_destroy();
301         }
302
303         return ret;
304 }
305
306 static void __exit handshake_exit(void)
307 {
308         unregister_pernet_subsys(&handshake_genl_net_ops);
309         handshake_net_id = 0;
310
311         handshake_req_hash_destroy();
312         genl_unregister_family(&handshake_nl_family);
313 }
314
315 module_init(handshake_init);
316 module_exit(handshake_exit);