Merge branch 'for-linus' of git://git.kernel.org/pub/scm/linux/kernel/git/dtor/input
[platform/kernel/linux-starfive.git] / net / mctp / af_mctp.c
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Management Component Transport Protocol (MCTP)
4  *
5  * Copyright (c) 2021 Code Construct
6  * Copyright (c) 2021 Google
7  */
8
9 #include <linux/if_arp.h>
10 #include <linux/net.h>
11 #include <linux/mctp.h>
12 #include <linux/module.h>
13 #include <linux/socket.h>
14
15 #include <net/mctp.h>
16 #include <net/mctpdevice.h>
17 #include <net/sock.h>
18
19 /* socket implementation */
20
21 static int mctp_release(struct socket *sock)
22 {
23         struct sock *sk = sock->sk;
24
25         if (sk) {
26                 sock->sk = NULL;
27                 sk->sk_prot->close(sk, 0);
28         }
29
30         return 0;
31 }
32
33 static int mctp_bind(struct socket *sock, struct sockaddr *addr, int addrlen)
34 {
35         struct sock *sk = sock->sk;
36         struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
37         struct sockaddr_mctp *smctp;
38         int rc;
39
40         if (addrlen < sizeof(*smctp))
41                 return -EINVAL;
42
43         if (addr->sa_family != AF_MCTP)
44                 return -EAFNOSUPPORT;
45
46         if (!capable(CAP_NET_BIND_SERVICE))
47                 return -EACCES;
48
49         /* it's a valid sockaddr for MCTP, cast and do protocol checks */
50         smctp = (struct sockaddr_mctp *)addr;
51
52         lock_sock(sk);
53
54         /* TODO: allow rebind */
55         if (sk_hashed(sk)) {
56                 rc = -EADDRINUSE;
57                 goto out_release;
58         }
59         msk->bind_net = smctp->smctp_network;
60         msk->bind_addr = smctp->smctp_addr.s_addr;
61         msk->bind_type = smctp->smctp_type & 0x7f; /* ignore the IC bit */
62
63         rc = sk->sk_prot->hash(sk);
64
65 out_release:
66         release_sock(sk);
67
68         return rc;
69 }
70
71 static int mctp_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
72 {
73         DECLARE_SOCKADDR(struct sockaddr_mctp *, addr, msg->msg_name);
74         const int hlen = MCTP_HEADER_MAXLEN + sizeof(struct mctp_hdr);
75         int rc, addrlen = msg->msg_namelen;
76         struct sock *sk = sock->sk;
77         struct mctp_skb_cb *cb;
78         struct mctp_route *rt;
79         struct sk_buff *skb;
80
81         if (addr) {
82                 if (addrlen < sizeof(struct sockaddr_mctp))
83                         return -EINVAL;
84                 if (addr->smctp_family != AF_MCTP)
85                         return -EINVAL;
86                 if (addr->smctp_tag & ~(MCTP_TAG_MASK | MCTP_TAG_OWNER))
87                         return -EINVAL;
88
89         } else {
90                 /* TODO: connect()ed sockets */
91                 return -EDESTADDRREQ;
92         }
93
94         if (!capable(CAP_NET_RAW))
95                 return -EACCES;
96
97         if (addr->smctp_network == MCTP_NET_ANY)
98                 addr->smctp_network = mctp_default_net(sock_net(sk));
99
100         rt = mctp_route_lookup(sock_net(sk), addr->smctp_network,
101                                addr->smctp_addr.s_addr);
102         if (!rt)
103                 return -EHOSTUNREACH;
104
105         skb = sock_alloc_send_skb(sk, hlen + 1 + len,
106                                   msg->msg_flags & MSG_DONTWAIT, &rc);
107         if (!skb)
108                 return rc;
109
110         skb_reserve(skb, hlen);
111
112         /* set type as fist byte in payload */
113         *(u8 *)skb_put(skb, 1) = addr->smctp_type;
114
115         rc = memcpy_from_msg((void *)skb_put(skb, len), msg, len);
116         if (rc < 0) {
117                 kfree_skb(skb);
118                 return rc;
119         }
120
121         /* set up cb */
122         cb = __mctp_cb(skb);
123         cb->net = addr->smctp_network;
124
125         rc = mctp_local_output(sk, rt, skb, addr->smctp_addr.s_addr,
126                                addr->smctp_tag);
127
128         return rc ? : len;
129 }
130
131 static int mctp_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
132                         int flags)
133 {
134         DECLARE_SOCKADDR(struct sockaddr_mctp *, addr, msg->msg_name);
135         struct sock *sk = sock->sk;
136         struct sk_buff *skb;
137         size_t msglen;
138         u8 type;
139         int rc;
140
141         if (flags & ~(MSG_DONTWAIT | MSG_TRUNC | MSG_PEEK))
142                 return -EOPNOTSUPP;
143
144         skb = skb_recv_datagram(sk, flags, flags & MSG_DONTWAIT, &rc);
145         if (!skb)
146                 return rc;
147
148         if (!skb->len) {
149                 rc = 0;
150                 goto out_free;
151         }
152
153         /* extract message type, remove from data */
154         type = *((u8 *)skb->data);
155         msglen = skb->len - 1;
156
157         if (len < msglen)
158                 msg->msg_flags |= MSG_TRUNC;
159         else
160                 len = msglen;
161
162         rc = skb_copy_datagram_msg(skb, 1, msg, len);
163         if (rc < 0)
164                 goto out_free;
165
166         sock_recv_ts_and_drops(msg, sk, skb);
167
168         if (addr) {
169                 struct mctp_skb_cb *cb = mctp_cb(skb);
170                 /* TODO: expand mctp_skb_cb for header fields? */
171                 struct mctp_hdr *hdr = mctp_hdr(skb);
172
173                 addr = msg->msg_name;
174                 addr->smctp_family = AF_MCTP;
175                 addr->smctp_network = cb->net;
176                 addr->smctp_addr.s_addr = hdr->src;
177                 addr->smctp_type = type;
178                 addr->smctp_tag = hdr->flags_seq_tag &
179                                         (MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO);
180                 msg->msg_namelen = sizeof(*addr);
181         }
182
183         rc = len;
184
185         if (flags & MSG_TRUNC)
186                 rc = msglen;
187
188 out_free:
189         skb_free_datagram(sk, skb);
190         return rc;
191 }
192
193 static int mctp_setsockopt(struct socket *sock, int level, int optname,
194                            sockptr_t optval, unsigned int optlen)
195 {
196         return -EINVAL;
197 }
198
199 static int mctp_getsockopt(struct socket *sock, int level, int optname,
200                            char __user *optval, int __user *optlen)
201 {
202         return -EINVAL;
203 }
204
205 static const struct proto_ops mctp_dgram_ops = {
206         .family         = PF_MCTP,
207         .release        = mctp_release,
208         .bind           = mctp_bind,
209         .connect        = sock_no_connect,
210         .socketpair     = sock_no_socketpair,
211         .accept         = sock_no_accept,
212         .getname        = sock_no_getname,
213         .poll           = datagram_poll,
214         .ioctl          = sock_no_ioctl,
215         .gettstamp      = sock_gettstamp,
216         .listen         = sock_no_listen,
217         .shutdown       = sock_no_shutdown,
218         .setsockopt     = mctp_setsockopt,
219         .getsockopt     = mctp_getsockopt,
220         .sendmsg        = mctp_sendmsg,
221         .recvmsg        = mctp_recvmsg,
222         .mmap           = sock_no_mmap,
223         .sendpage       = sock_no_sendpage,
224 };
225
226 static int mctp_sk_init(struct sock *sk)
227 {
228         struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
229
230         INIT_HLIST_HEAD(&msk->keys);
231         return 0;
232 }
233
234 static void mctp_sk_close(struct sock *sk, long timeout)
235 {
236         sk_common_release(sk);
237 }
238
239 static int mctp_sk_hash(struct sock *sk)
240 {
241         struct net *net = sock_net(sk);
242
243         mutex_lock(&net->mctp.bind_lock);
244         sk_add_node_rcu(sk, &net->mctp.binds);
245         mutex_unlock(&net->mctp.bind_lock);
246
247         return 0;
248 }
249
250 static void mctp_sk_unhash(struct sock *sk)
251 {
252         struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
253         struct net *net = sock_net(sk);
254         struct mctp_sk_key *key;
255         struct hlist_node *tmp;
256         unsigned long flags;
257
258         /* remove from any type-based binds */
259         mutex_lock(&net->mctp.bind_lock);
260         sk_del_node_init_rcu(sk);
261         mutex_unlock(&net->mctp.bind_lock);
262
263         /* remove tag allocations */
264         spin_lock_irqsave(&net->mctp.keys_lock, flags);
265         hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) {
266                 hlist_del_rcu(&key->sklist);
267                 hlist_del_rcu(&key->hlist);
268
269                 spin_lock(&key->reasm_lock);
270                 if (key->reasm_head)
271                         kfree_skb(key->reasm_head);
272                 key->reasm_head = NULL;
273                 key->reasm_dead = true;
274                 spin_unlock(&key->reasm_lock);
275
276                 kfree_rcu(key, rcu);
277         }
278         spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
279
280         synchronize_rcu();
281 }
282
283 static struct proto mctp_proto = {
284         .name           = "MCTP",
285         .owner          = THIS_MODULE,
286         .obj_size       = sizeof(struct mctp_sock),
287         .init           = mctp_sk_init,
288         .close          = mctp_sk_close,
289         .hash           = mctp_sk_hash,
290         .unhash         = mctp_sk_unhash,
291 };
292
293 static int mctp_pf_create(struct net *net, struct socket *sock,
294                           int protocol, int kern)
295 {
296         const struct proto_ops *ops;
297         struct proto *proto;
298         struct sock *sk;
299         int rc;
300
301         if (protocol)
302                 return -EPROTONOSUPPORT;
303
304         /* only datagram sockets are supported */
305         if (sock->type != SOCK_DGRAM)
306                 return -ESOCKTNOSUPPORT;
307
308         proto = &mctp_proto;
309         ops = &mctp_dgram_ops;
310
311         sock->state = SS_UNCONNECTED;
312         sock->ops = ops;
313
314         sk = sk_alloc(net, PF_MCTP, GFP_KERNEL, proto, kern);
315         if (!sk)
316                 return -ENOMEM;
317
318         sock_init_data(sock, sk);
319
320         rc = 0;
321         if (sk->sk_prot->init)
322                 rc = sk->sk_prot->init(sk);
323
324         if (rc)
325                 goto err_sk_put;
326
327         return 0;
328
329 err_sk_put:
330         sock_orphan(sk);
331         sock_put(sk);
332         return rc;
333 }
334
335 static struct net_proto_family mctp_pf = {
336         .family = PF_MCTP,
337         .create = mctp_pf_create,
338         .owner = THIS_MODULE,
339 };
340
341 static __init int mctp_init(void)
342 {
343         int rc;
344
345         /* ensure our uapi tag definitions match the header format */
346         BUILD_BUG_ON(MCTP_TAG_OWNER != MCTP_HDR_FLAG_TO);
347         BUILD_BUG_ON(MCTP_TAG_MASK != MCTP_HDR_TAG_MASK);
348
349         pr_info("mctp: management component transport protocol core\n");
350
351         rc = sock_register(&mctp_pf);
352         if (rc)
353                 return rc;
354
355         rc = proto_register(&mctp_proto, 0);
356         if (rc)
357                 goto err_unreg_sock;
358
359         rc = mctp_routes_init();
360         if (rc)
361                 goto err_unreg_proto;
362
363         rc = mctp_neigh_init();
364         if (rc)
365                 goto err_unreg_proto;
366
367         mctp_device_init();
368
369         return 0;
370
371 err_unreg_proto:
372         proto_unregister(&mctp_proto);
373 err_unreg_sock:
374         sock_unregister(PF_MCTP);
375
376         return rc;
377 }
378
379 static __exit void mctp_exit(void)
380 {
381         mctp_device_exit();
382         mctp_neigh_exit();
383         mctp_routes_exit();
384         proto_unregister(&mctp_proto);
385         sock_unregister(PF_MCTP);
386 }
387
388 module_init(mctp_init);
389 module_exit(mctp_exit);
390
391 MODULE_DESCRIPTION("MCTP core");
392 MODULE_LICENSE("GPL v2");
393 MODULE_AUTHOR("Jeremy Kerr <jk@codeconstruct.com.au>");
394
395 MODULE_ALIAS_NETPROTO(PF_MCTP);