Merge tag 'audit-pr-20211019' of git://git.kernel.org/pub/scm/linux/kernel/git/pcmoor...
[platform/kernel/linux-starfive.git] / net / 9p / protocol.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * net/9p/protocol.c
4  *
5  * 9P Protocol Support Code
6  *
7  *  Copyright (C) 2008 by Eric Van Hensbergen <ericvh@gmail.com>
8  *
9  *  Base on code from Anthony Liguori <aliguori@us.ibm.com>
10  *  Copyright (C) 2008 by IBM, Corp.
11  */
12
13 #include <linux/module.h>
14 #include <linux/errno.h>
15 #include <linux/kernel.h>
16 #include <linux/uaccess.h>
17 #include <linux/slab.h>
18 #include <linux/sched.h>
19 #include <linux/stddef.h>
20 #include <linux/types.h>
21 #include <linux/uio.h>
22 #include <net/9p/9p.h>
23 #include <net/9p/client.h>
24 #include "protocol.h"
25
26 #include <trace/events/9p.h>
27
28 static int
29 p9pdu_writef(struct p9_fcall *pdu, int proto_version, const char *fmt, ...);
30
31 void p9stat_free(struct p9_wstat *stbuf)
32 {
33         kfree(stbuf->name);
34         stbuf->name = NULL;
35         kfree(stbuf->uid);
36         stbuf->uid = NULL;
37         kfree(stbuf->gid);
38         stbuf->gid = NULL;
39         kfree(stbuf->muid);
40         stbuf->muid = NULL;
41         kfree(stbuf->extension);
42         stbuf->extension = NULL;
43 }
44 EXPORT_SYMBOL(p9stat_free);
45
46 size_t pdu_read(struct p9_fcall *pdu, void *data, size_t size)
47 {
48         size_t len = min(pdu->size - pdu->offset, size);
49         memcpy(data, &pdu->sdata[pdu->offset], len);
50         pdu->offset += len;
51         return size - len;
52 }
53
54 static size_t pdu_write(struct p9_fcall *pdu, const void *data, size_t size)
55 {
56         size_t len = min(pdu->capacity - pdu->size, size);
57         memcpy(&pdu->sdata[pdu->size], data, len);
58         pdu->size += len;
59         return size - len;
60 }
61
62 static size_t
63 pdu_write_u(struct p9_fcall *pdu, struct iov_iter *from, size_t size)
64 {
65         size_t len = min(pdu->capacity - pdu->size, size);
66         struct iov_iter i = *from;
67         if (!copy_from_iter_full(&pdu->sdata[pdu->size], len, &i))
68                 len = 0;
69
70         pdu->size += len;
71         return size - len;
72 }
73
74 /*
75         b - int8_t
76         w - int16_t
77         d - int32_t
78         q - int64_t
79         s - string
80         u - numeric uid
81         g - numeric gid
82         S - stat
83         Q - qid
84         D - data blob (int32_t size followed by void *, results are not freed)
85         T - array of strings (int16_t count, followed by strings)
86         R - array of qids (int16_t count, followed by qids)
87         A - stat for 9p2000.L (p9_stat_dotl)
88         ? - if optional = 1, continue parsing
89 */
90
91 static int
92 p9pdu_vreadf(struct p9_fcall *pdu, int proto_version, const char *fmt,
93         va_list ap)
94 {
95         const char *ptr;
96         int errcode = 0;
97
98         for (ptr = fmt; *ptr; ptr++) {
99                 switch (*ptr) {
100                 case 'b':{
101                                 int8_t *val = va_arg(ap, int8_t *);
102                                 if (pdu_read(pdu, val, sizeof(*val))) {
103                                         errcode = -EFAULT;
104                                         break;
105                                 }
106                         }
107                         break;
108                 case 'w':{
109                                 int16_t *val = va_arg(ap, int16_t *);
110                                 __le16 le_val;
111                                 if (pdu_read(pdu, &le_val, sizeof(le_val))) {
112                                         errcode = -EFAULT;
113                                         break;
114                                 }
115                                 *val = le16_to_cpu(le_val);
116                         }
117                         break;
118                 case 'd':{
119                                 int32_t *val = va_arg(ap, int32_t *);
120                                 __le32 le_val;
121                                 if (pdu_read(pdu, &le_val, sizeof(le_val))) {
122                                         errcode = -EFAULT;
123                                         break;
124                                 }
125                                 *val = le32_to_cpu(le_val);
126                         }
127                         break;
128                 case 'q':{
129                                 int64_t *val = va_arg(ap, int64_t *);
130                                 __le64 le_val;
131                                 if (pdu_read(pdu, &le_val, sizeof(le_val))) {
132                                         errcode = -EFAULT;
133                                         break;
134                                 }
135                                 *val = le64_to_cpu(le_val);
136                         }
137                         break;
138                 case 's':{
139                                 char **sptr = va_arg(ap, char **);
140                                 uint16_t len;
141
142                                 errcode = p9pdu_readf(pdu, proto_version,
143                                                                 "w", &len);
144                                 if (errcode)
145                                         break;
146
147                                 *sptr = kmalloc(len + 1, GFP_NOFS);
148                                 if (*sptr == NULL) {
149                                         errcode = -ENOMEM;
150                                         break;
151                                 }
152                                 if (pdu_read(pdu, *sptr, len)) {
153                                         errcode = -EFAULT;
154                                         kfree(*sptr);
155                                         *sptr = NULL;
156                                 } else
157                                         (*sptr)[len] = 0;
158                         }
159                         break;
160                 case 'u': {
161                                 kuid_t *uid = va_arg(ap, kuid_t *);
162                                 __le32 le_val;
163                                 if (pdu_read(pdu, &le_val, sizeof(le_val))) {
164                                         errcode = -EFAULT;
165                                         break;
166                                 }
167                                 *uid = make_kuid(&init_user_ns,
168                                                  le32_to_cpu(le_val));
169                         } break;
170                 case 'g': {
171                                 kgid_t *gid = va_arg(ap, kgid_t *);
172                                 __le32 le_val;
173                                 if (pdu_read(pdu, &le_val, sizeof(le_val))) {
174                                         errcode = -EFAULT;
175                                         break;
176                                 }
177                                 *gid = make_kgid(&init_user_ns,
178                                                  le32_to_cpu(le_val));
179                         } break;
180                 case 'Q':{
181                                 struct p9_qid *qid =
182                                     va_arg(ap, struct p9_qid *);
183
184                                 errcode = p9pdu_readf(pdu, proto_version, "bdq",
185                                                       &qid->type, &qid->version,
186                                                       &qid->path);
187                         }
188                         break;
189                 case 'S':{
190                                 struct p9_wstat *stbuf =
191                                     va_arg(ap, struct p9_wstat *);
192
193                                 memset(stbuf, 0, sizeof(struct p9_wstat));
194                                 stbuf->n_uid = stbuf->n_muid = INVALID_UID;
195                                 stbuf->n_gid = INVALID_GID;
196
197                                 errcode =
198                                     p9pdu_readf(pdu, proto_version,
199                                                 "wwdQdddqssss?sugu",
200                                                 &stbuf->size, &stbuf->type,
201                                                 &stbuf->dev, &stbuf->qid,
202                                                 &stbuf->mode, &stbuf->atime,
203                                                 &stbuf->mtime, &stbuf->length,
204                                                 &stbuf->name, &stbuf->uid,
205                                                 &stbuf->gid, &stbuf->muid,
206                                                 &stbuf->extension,
207                                                 &stbuf->n_uid, &stbuf->n_gid,
208                                                 &stbuf->n_muid);
209                                 if (errcode)
210                                         p9stat_free(stbuf);
211                         }
212                         break;
213                 case 'D':{
214                                 uint32_t *count = va_arg(ap, uint32_t *);
215                                 void **data = va_arg(ap, void **);
216
217                                 errcode =
218                                     p9pdu_readf(pdu, proto_version, "d", count);
219                                 if (!errcode) {
220                                         *count =
221                                             min_t(uint32_t, *count,
222                                                   pdu->size - pdu->offset);
223                                         *data = &pdu->sdata[pdu->offset];
224                                 }
225                         }
226                         break;
227                 case 'T':{
228                                 uint16_t *nwname = va_arg(ap, uint16_t *);
229                                 char ***wnames = va_arg(ap, char ***);
230
231                                 errcode = p9pdu_readf(pdu, proto_version,
232                                                                 "w", nwname);
233                                 if (!errcode) {
234                                         *wnames =
235                                             kmalloc_array(*nwname,
236                                                           sizeof(char *),
237                                                           GFP_NOFS);
238                                         if (!*wnames)
239                                                 errcode = -ENOMEM;
240                                 }
241
242                                 if (!errcode) {
243                                         int i;
244
245                                         for (i = 0; i < *nwname; i++) {
246                                                 errcode =
247                                                     p9pdu_readf(pdu,
248                                                                 proto_version,
249                                                                 "s",
250                                                                 &(*wnames)[i]);
251                                                 if (errcode)
252                                                         break;
253                                         }
254                                 }
255
256                                 if (errcode) {
257                                         if (*wnames) {
258                                                 int i;
259
260                                                 for (i = 0; i < *nwname; i++)
261                                                         kfree((*wnames)[i]);
262                                         }
263                                         kfree(*wnames);
264                                         *wnames = NULL;
265                                 }
266                         }
267                         break;
268                 case 'R':{
269                                 uint16_t *nwqid = va_arg(ap, uint16_t *);
270                                 struct p9_qid **wqids =
271                                     va_arg(ap, struct p9_qid **);
272
273                                 *wqids = NULL;
274
275                                 errcode =
276                                     p9pdu_readf(pdu, proto_version, "w", nwqid);
277                                 if (!errcode) {
278                                         *wqids =
279                                             kmalloc_array(*nwqid,
280                                                           sizeof(struct p9_qid),
281                                                           GFP_NOFS);
282                                         if (*wqids == NULL)
283                                                 errcode = -ENOMEM;
284                                 }
285
286                                 if (!errcode) {
287                                         int i;
288
289                                         for (i = 0; i < *nwqid; i++) {
290                                                 errcode =
291                                                     p9pdu_readf(pdu,
292                                                                 proto_version,
293                                                                 "Q",
294                                                                 &(*wqids)[i]);
295                                                 if (errcode)
296                                                         break;
297                                         }
298                                 }
299
300                                 if (errcode) {
301                                         kfree(*wqids);
302                                         *wqids = NULL;
303                                 }
304                         }
305                         break;
306                 case 'A': {
307                                 struct p9_stat_dotl *stbuf =
308                                     va_arg(ap, struct p9_stat_dotl *);
309
310                                 memset(stbuf, 0, sizeof(struct p9_stat_dotl));
311                                 errcode =
312                                     p9pdu_readf(pdu, proto_version,
313                                         "qQdugqqqqqqqqqqqqqqq",
314                                         &stbuf->st_result_mask,
315                                         &stbuf->qid,
316                                         &stbuf->st_mode,
317                                         &stbuf->st_uid, &stbuf->st_gid,
318                                         &stbuf->st_nlink,
319                                         &stbuf->st_rdev, &stbuf->st_size,
320                                         &stbuf->st_blksize, &stbuf->st_blocks,
321                                         &stbuf->st_atime_sec,
322                                         &stbuf->st_atime_nsec,
323                                         &stbuf->st_mtime_sec,
324                                         &stbuf->st_mtime_nsec,
325                                         &stbuf->st_ctime_sec,
326                                         &stbuf->st_ctime_nsec,
327                                         &stbuf->st_btime_sec,
328                                         &stbuf->st_btime_nsec,
329                                         &stbuf->st_gen,
330                                         &stbuf->st_data_version);
331                         }
332                         break;
333                 case '?':
334                         if ((proto_version != p9_proto_2000u) &&
335                                 (proto_version != p9_proto_2000L))
336                                 return 0;
337                         break;
338                 default:
339                         BUG();
340                         break;
341                 }
342
343                 if (errcode)
344                         break;
345         }
346
347         return errcode;
348 }
349
350 int
351 p9pdu_vwritef(struct p9_fcall *pdu, int proto_version, const char *fmt,
352         va_list ap)
353 {
354         const char *ptr;
355         int errcode = 0;
356
357         for (ptr = fmt; *ptr; ptr++) {
358                 switch (*ptr) {
359                 case 'b':{
360                                 int8_t val = va_arg(ap, int);
361                                 if (pdu_write(pdu, &val, sizeof(val)))
362                                         errcode = -EFAULT;
363                         }
364                         break;
365                 case 'w':{
366                                 __le16 val = cpu_to_le16(va_arg(ap, int));
367                                 if (pdu_write(pdu, &val, sizeof(val)))
368                                         errcode = -EFAULT;
369                         }
370                         break;
371                 case 'd':{
372                                 __le32 val = cpu_to_le32(va_arg(ap, int32_t));
373                                 if (pdu_write(pdu, &val, sizeof(val)))
374                                         errcode = -EFAULT;
375                         }
376                         break;
377                 case 'q':{
378                                 __le64 val = cpu_to_le64(va_arg(ap, int64_t));
379                                 if (pdu_write(pdu, &val, sizeof(val)))
380                                         errcode = -EFAULT;
381                         }
382                         break;
383                 case 's':{
384                                 const char *sptr = va_arg(ap, const char *);
385                                 uint16_t len = 0;
386                                 if (sptr)
387                                         len = min_t(size_t, strlen(sptr),
388                                                                 USHRT_MAX);
389
390                                 errcode = p9pdu_writef(pdu, proto_version,
391                                                                 "w", len);
392                                 if (!errcode && pdu_write(pdu, sptr, len))
393                                         errcode = -EFAULT;
394                         }
395                         break;
396                 case 'u': {
397                                 kuid_t uid = va_arg(ap, kuid_t);
398                                 __le32 val = cpu_to_le32(
399                                                 from_kuid(&init_user_ns, uid));
400                                 if (pdu_write(pdu, &val, sizeof(val)))
401                                         errcode = -EFAULT;
402                         } break;
403                 case 'g': {
404                                 kgid_t gid = va_arg(ap, kgid_t);
405                                 __le32 val = cpu_to_le32(
406                                                 from_kgid(&init_user_ns, gid));
407                                 if (pdu_write(pdu, &val, sizeof(val)))
408                                         errcode = -EFAULT;
409                         } break;
410                 case 'Q':{
411                                 const struct p9_qid *qid =
412                                     va_arg(ap, const struct p9_qid *);
413                                 errcode =
414                                     p9pdu_writef(pdu, proto_version, "bdq",
415                                                  qid->type, qid->version,
416                                                  qid->path);
417                         } break;
418                 case 'S':{
419                                 const struct p9_wstat *stbuf =
420                                     va_arg(ap, const struct p9_wstat *);
421                                 errcode =
422                                     p9pdu_writef(pdu, proto_version,
423                                                  "wwdQdddqssss?sugu",
424                                                  stbuf->size, stbuf->type,
425                                                  stbuf->dev, &stbuf->qid,
426                                                  stbuf->mode, stbuf->atime,
427                                                  stbuf->mtime, stbuf->length,
428                                                  stbuf->name, stbuf->uid,
429                                                  stbuf->gid, stbuf->muid,
430                                                  stbuf->extension, stbuf->n_uid,
431                                                  stbuf->n_gid, stbuf->n_muid);
432                         } break;
433                 case 'V':{
434                                 uint32_t count = va_arg(ap, uint32_t);
435                                 struct iov_iter *from =
436                                                 va_arg(ap, struct iov_iter *);
437                                 errcode = p9pdu_writef(pdu, proto_version, "d",
438                                                                         count);
439                                 if (!errcode && pdu_write_u(pdu, from, count))
440                                         errcode = -EFAULT;
441                         }
442                         break;
443                 case 'T':{
444                                 uint16_t nwname = va_arg(ap, int);
445                                 const char **wnames = va_arg(ap, const char **);
446
447                                 errcode = p9pdu_writef(pdu, proto_version, "w",
448                                                                         nwname);
449                                 if (!errcode) {
450                                         int i;
451
452                                         for (i = 0; i < nwname; i++) {
453                                                 errcode =
454                                                     p9pdu_writef(pdu,
455                                                                 proto_version,
456                                                                  "s",
457                                                                  wnames[i]);
458                                                 if (errcode)
459                                                         break;
460                                         }
461                                 }
462                         }
463                         break;
464                 case 'R':{
465                                 uint16_t nwqid = va_arg(ap, int);
466                                 struct p9_qid *wqids =
467                                     va_arg(ap, struct p9_qid *);
468
469                                 errcode = p9pdu_writef(pdu, proto_version, "w",
470                                                                         nwqid);
471                                 if (!errcode) {
472                                         int i;
473
474                                         for (i = 0; i < nwqid; i++) {
475                                                 errcode =
476                                                     p9pdu_writef(pdu,
477                                                                 proto_version,
478                                                                  "Q",
479                                                                  &wqids[i]);
480                                                 if (errcode)
481                                                         break;
482                                         }
483                                 }
484                         }
485                         break;
486                 case 'I':{
487                                 struct p9_iattr_dotl *p9attr = va_arg(ap,
488                                                         struct p9_iattr_dotl *);
489
490                                 errcode = p9pdu_writef(pdu, proto_version,
491                                                         "ddugqqqqq",
492                                                         p9attr->valid,
493                                                         p9attr->mode,
494                                                         p9attr->uid,
495                                                         p9attr->gid,
496                                                         p9attr->size,
497                                                         p9attr->atime_sec,
498                                                         p9attr->atime_nsec,
499                                                         p9attr->mtime_sec,
500                                                         p9attr->mtime_nsec);
501                         }
502                         break;
503                 case '?':
504                         if ((proto_version != p9_proto_2000u) &&
505                                 (proto_version != p9_proto_2000L))
506                                 return 0;
507                         break;
508                 default:
509                         BUG();
510                         break;
511                 }
512
513                 if (errcode)
514                         break;
515         }
516
517         return errcode;
518 }
519
520 int p9pdu_readf(struct p9_fcall *pdu, int proto_version, const char *fmt, ...)
521 {
522         va_list ap;
523         int ret;
524
525         va_start(ap, fmt);
526         ret = p9pdu_vreadf(pdu, proto_version, fmt, ap);
527         va_end(ap);
528
529         return ret;
530 }
531
532 static int
533 p9pdu_writef(struct p9_fcall *pdu, int proto_version, const char *fmt, ...)
534 {
535         va_list ap;
536         int ret;
537
538         va_start(ap, fmt);
539         ret = p9pdu_vwritef(pdu, proto_version, fmt, ap);
540         va_end(ap);
541
542         return ret;
543 }
544
545 int p9stat_read(struct p9_client *clnt, char *buf, int len, struct p9_wstat *st)
546 {
547         struct p9_fcall fake_pdu;
548         int ret;
549
550         fake_pdu.size = len;
551         fake_pdu.capacity = len;
552         fake_pdu.sdata = buf;
553         fake_pdu.offset = 0;
554
555         ret = p9pdu_readf(&fake_pdu, clnt->proto_version, "S", st);
556         if (ret) {
557                 p9_debug(P9_DEBUG_9P, "<<< p9stat_read failed: %d\n", ret);
558                 trace_9p_protocol_dump(clnt, &fake_pdu);
559                 return ret;
560         }
561
562         return fake_pdu.offset;
563 }
564 EXPORT_SYMBOL(p9stat_read);
565
566 int p9pdu_prepare(struct p9_fcall *pdu, int16_t tag, int8_t type)
567 {
568         pdu->id = type;
569         return p9pdu_writef(pdu, 0, "dbw", 0, type, tag);
570 }
571
572 int p9pdu_finalize(struct p9_client *clnt, struct p9_fcall *pdu)
573 {
574         int size = pdu->size;
575         int err;
576
577         pdu->size = 0;
578         err = p9pdu_writef(pdu, 0, "d", size);
579         pdu->size = size;
580
581         trace_9p_protocol_dump(clnt, pdu);
582         p9_debug(P9_DEBUG_9P, ">>> size=%d type: %d tag: %d\n",
583                  pdu->size, pdu->id, pdu->tag);
584
585         return err;
586 }
587
588 void p9pdu_reset(struct p9_fcall *pdu)
589 {
590         pdu->offset = 0;
591         pdu->size = 0;
592 }
593
594 int p9dirent_read(struct p9_client *clnt, char *buf, int len,
595                   struct p9_dirent *dirent)
596 {
597         struct p9_fcall fake_pdu;
598         int ret;
599         char *nameptr;
600
601         fake_pdu.size = len;
602         fake_pdu.capacity = len;
603         fake_pdu.sdata = buf;
604         fake_pdu.offset = 0;
605
606         ret = p9pdu_readf(&fake_pdu, clnt->proto_version, "Qqbs", &dirent->qid,
607                           &dirent->d_off, &dirent->d_type, &nameptr);
608         if (ret) {
609                 p9_debug(P9_DEBUG_9P, "<<< p9dirent_read failed: %d\n", ret);
610                 trace_9p_protocol_dump(clnt, &fake_pdu);
611                 return ret;
612         }
613
614         ret = strscpy(dirent->d_name, nameptr, sizeof(dirent->d_name));
615         if (ret < 0) {
616                 p9_debug(P9_DEBUG_ERROR,
617                          "On the wire dirent name too long: %s\n",
618                          nameptr);
619                 kfree(nameptr);
620                 return ret;
621         }
622         kfree(nameptr);
623
624         return fake_pdu.offset;
625 }
626 EXPORT_SYMBOL(p9dirent_read);