1 // SPDX-License-Identifier: GPL-2.0-only
3 * Generic netlink handshake service
5 * Author: Chuck Lever <chuck.lever@oracle.com>
7 * Copyright (c) 2023, Oracle and/or its affiliates.
10 #include <linux/types.h>
11 #include <linux/socket.h>
12 #include <linux/kernel.h>
13 #include <linux/module.h>
14 #include <linux/skbuff.h>
18 #include <net/genetlink.h>
19 #include <net/netns/generic.h>
21 #include <kunit/visibility.h>
23 #include <uapi/linux/handshake.h>
24 #include "handshake.h"
27 #include <trace/events/handshake.h>
30 * handshake_genl_notify - Notify handlers that a request is waiting
31 * @net: target network namespace
32 * @proto: handshake protocol
33 * @flags: memory allocation control flags
35 * Returns zero on success or a negative errno if notification failed.
37 int handshake_genl_notify(struct net *net, const struct handshake_proto *proto,
43 /* Disable notifications during unit testing */
44 if (!test_bit(HANDSHAKE_F_PROTO_NOTIFY, &proto->hp_flags))
47 if (!genl_has_listeners(&handshake_nl_family, net,
48 proto->hp_handler_class))
51 msg = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL);
55 hdr = genlmsg_put(msg, 0, 0, &handshake_nl_family, 0,
60 if (nla_put_u32(msg, HANDSHAKE_A_ACCEPT_HANDLER_CLASS,
61 proto->hp_handler_class) < 0) {
62 genlmsg_cancel(msg, hdr);
66 genlmsg_end(msg, hdr);
67 return genlmsg_multicast_netns(&handshake_nl_family, net, msg,
68 0, proto->hp_handler_class, flags);
76 * handshake_genl_put - Create a generic netlink message header
77 * @msg: buffer in which to create the header
78 * @info: generic netlink message context
80 * Returns a ready-to-use header, or NULL.
82 struct nlmsghdr *handshake_genl_put(struct sk_buff *msg,
83 struct genl_info *info)
85 return genlmsg_put(msg, info->snd_portid, info->snd_seq,
86 &handshake_nl_family, 0, info->genlhdr->cmd);
88 EXPORT_SYMBOL(handshake_genl_put);
91 * dup() a kernel socket for use as a user space file descriptor
92 * in the current process. The kernel socket must have an
93 * instatiated struct file.
95 * Implicit argument: "current()"
97 static int handshake_dup(struct socket *sock)
102 file = get_file(sock->file);
103 newfd = get_unused_fd_flags(O_CLOEXEC);
109 fd_install(newfd, file);
113 int handshake_nl_accept_doit(struct sk_buff *skb, struct genl_info *info)
115 struct net *net = sock_net(skb->sk);
116 struct handshake_net *hn = handshake_pernet(net);
117 struct handshake_req *req = NULL;
126 if (GENL_REQ_ATTR_CHECK(info, HANDSHAKE_A_ACCEPT_HANDLER_CLASS))
128 class = nla_get_u32(info->attrs[HANDSHAKE_A_ACCEPT_HANDLER_CLASS]);
131 req = handshake_req_next(hn, class);
135 sock = req->hr_sk->sk_socket;
136 fd = handshake_dup(sock);
141 err = req->hr_proto->hp_accept(req, info, fd);
145 trace_handshake_cmd_accept(net, req, req->hr_sk, fd);
149 handshake_complete(req, -EIO, NULL);
152 trace_handshake_cmd_accept_err(net, req, NULL, err);
156 int handshake_nl_done_doit(struct sk_buff *skb, struct genl_info *info)
158 struct net *net = sock_net(skb->sk);
159 struct socket *sock = NULL;
160 struct handshake_req *req;
163 if (GENL_REQ_ATTR_CHECK(info, HANDSHAKE_A_DONE_SOCKFD))
165 fd = nla_get_u32(info->attrs[HANDSHAKE_A_DONE_SOCKFD]);
168 sock = sockfd_lookup(fd, &err);
174 req = handshake_req_hash_lookup(sock->sk);
181 trace_handshake_cmd_done(net, req, sock->sk, fd);
184 if (info->attrs[HANDSHAKE_A_DONE_STATUS])
185 status = nla_get_u32(info->attrs[HANDSHAKE_A_DONE_STATUS]);
187 handshake_complete(req, status, info);
192 trace_handshake_cmd_done_err(net, req, sock->sk, err);
196 static unsigned int handshake_net_id;
198 static int __net_init handshake_net_init(struct net *net)
200 struct handshake_net *hn = net_generic(net, handshake_net_id);
205 * Arbitrary limit to prevent handshakes that do not make
206 * progress from clogging up the system. The cap scales up
207 * with the amount of physical memory on the system.
210 tmp = si.totalram / (25 * si.mem_unit);
211 hn->hn_pending_max = clamp(tmp, 3UL, 50UL);
213 spin_lock_init(&hn->hn_lock);
216 INIT_LIST_HEAD(&hn->hn_requests);
220 static void __net_exit handshake_net_exit(struct net *net)
222 struct handshake_net *hn = net_generic(net, handshake_net_id);
223 struct handshake_req *req;
227 * Drain the net's pending list. Requests that have been
228 * accepted and are in progress will be destroyed when
229 * the socket is closed.
231 spin_lock(&hn->hn_lock);
232 set_bit(HANDSHAKE_F_NET_DRAINING, &hn->hn_flags);
233 list_splice_init(&requests, &hn->hn_requests);
234 spin_unlock(&hn->hn_lock);
236 while (!list_empty(&requests)) {
237 req = list_first_entry(&requests, struct handshake_req, hr_list);
238 list_del(&req->hr_list);
241 * Requests on this list have not yet been
242 * accepted, so they do not have an fd to put.
245 handshake_complete(req, -ETIMEDOUT, NULL);
249 static struct pernet_operations handshake_genl_net_ops = {
250 .init = handshake_net_init,
251 .exit = handshake_net_exit,
252 .id = &handshake_net_id,
253 .size = sizeof(struct handshake_net),
257 * handshake_pernet - Get the handshake private per-net structure
258 * @net: network namespace
260 * Returns a pointer to the net's private per-net structure for the
261 * handshake module, or NULL if handshake_init() failed.
263 struct handshake_net *handshake_pernet(struct net *net)
265 return handshake_net_id ?
266 net_generic(net, handshake_net_id) : NULL;
268 EXPORT_SYMBOL_IF_KUNIT(handshake_pernet);
270 static int __init handshake_init(void)
274 ret = handshake_req_hash_init();
276 pr_warn("handshake: hash initialization failed (%d)\n", ret);
280 ret = genl_register_family(&handshake_nl_family);
282 pr_warn("handshake: netlink registration failed (%d)\n", ret);
283 handshake_req_hash_destroy();
288 * ORDER: register_pernet_subsys must be done last.
290 * If initialization does not make it past pernet_subsys
291 * registration, then handshake_net_id will remain 0. That
292 * shunts the handshake consumer API to return ENOTSUPP
293 * to prevent it from dereferencing something that hasn't
296 ret = register_pernet_subsys(&handshake_genl_net_ops);
298 pr_warn("handshake: pernet registration failed (%d)\n", ret);
299 genl_unregister_family(&handshake_nl_family);
300 handshake_req_hash_destroy();
306 static void __exit handshake_exit(void)
308 unregister_pernet_subsys(&handshake_genl_net_ops);
309 handshake_net_id = 0;
311 handshake_req_hash_destroy();
312 genl_unregister_family(&handshake_nl_family);
315 module_init(handshake_init);
316 module_exit(handshake_exit);