Merge tag 'at91-fixes-6.1' of https://git.kernel.org/pub/scm/linux/kernel/git/at91...
[platform/kernel/linux-starfive.git] / fs / ksmbd / transport_ipc.c
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  *   Copyright (C) 2018 Samsung Electronics Co., Ltd.
4  */
5
6 #include <linux/jhash.h>
7 #include <linux/slab.h>
8 #include <linux/rwsem.h>
9 #include <linux/mutex.h>
10 #include <linux/wait.h>
11 #include <linux/hashtable.h>
12 #include <net/net_namespace.h>
13 #include <net/genetlink.h>
14 #include <linux/socket.h>
15 #include <linux/workqueue.h>
16
17 #include "vfs_cache.h"
18 #include "transport_ipc.h"
19 #include "server.h"
20 #include "smb_common.h"
21
22 #include "mgmt/user_config.h"
23 #include "mgmt/share_config.h"
24 #include "mgmt/user_session.h"
25 #include "mgmt/tree_connect.h"
26 #include "mgmt/ksmbd_ida.h"
27 #include "connection.h"
28 #include "transport_tcp.h"
29 #include "transport_rdma.h"
30
31 #define IPC_WAIT_TIMEOUT        (2 * HZ)
32
33 #define IPC_MSG_HASH_BITS       3
34 static DEFINE_HASHTABLE(ipc_msg_table, IPC_MSG_HASH_BITS);
35 static DECLARE_RWSEM(ipc_msg_table_lock);
36 static DEFINE_MUTEX(startup_lock);
37
38 static DEFINE_IDA(ipc_ida);
39
40 static unsigned int ksmbd_tools_pid;
41
42 static bool ksmbd_ipc_validate_version(struct genl_info *m)
43 {
44         if (m->genlhdr->version != KSMBD_GENL_VERSION) {
45                 pr_err("%s. ksmbd: %d, kernel module: %d. %s.\n",
46                        "Daemon and kernel module version mismatch",
47                        m->genlhdr->version,
48                        KSMBD_GENL_VERSION,
49                        "User-space ksmbd should terminate");
50                 return false;
51         }
52         return true;
53 }
54
55 struct ksmbd_ipc_msg {
56         unsigned int            type;
57         unsigned int            sz;
58         unsigned char           payload[];
59 };
60
61 struct ipc_msg_table_entry {
62         unsigned int            handle;
63         unsigned int            type;
64         wait_queue_head_t       wait;
65         struct hlist_node       ipc_table_hlist;
66
67         void                    *response;
68 };
69
70 static struct delayed_work ipc_timer_work;
71
72 static int handle_startup_event(struct sk_buff *skb, struct genl_info *info);
73 static int handle_unsupported_event(struct sk_buff *skb, struct genl_info *info);
74 static int handle_generic_event(struct sk_buff *skb, struct genl_info *info);
75 static int ksmbd_ipc_heartbeat_request(void);
76
77 static const struct nla_policy ksmbd_nl_policy[KSMBD_EVENT_MAX] = {
78         [KSMBD_EVENT_UNSPEC] = {
79                 .len = 0,
80         },
81         [KSMBD_EVENT_HEARTBEAT_REQUEST] = {
82                 .len = sizeof(struct ksmbd_heartbeat),
83         },
84         [KSMBD_EVENT_STARTING_UP] = {
85                 .len = sizeof(struct ksmbd_startup_request),
86         },
87         [KSMBD_EVENT_SHUTTING_DOWN] = {
88                 .len = sizeof(struct ksmbd_shutdown_request),
89         },
90         [KSMBD_EVENT_LOGIN_REQUEST] = {
91                 .len = sizeof(struct ksmbd_login_request),
92         },
93         [KSMBD_EVENT_LOGIN_RESPONSE] = {
94                 .len = sizeof(struct ksmbd_login_response),
95         },
96         [KSMBD_EVENT_SHARE_CONFIG_REQUEST] = {
97                 .len = sizeof(struct ksmbd_share_config_request),
98         },
99         [KSMBD_EVENT_SHARE_CONFIG_RESPONSE] = {
100                 .len = sizeof(struct ksmbd_share_config_response),
101         },
102         [KSMBD_EVENT_TREE_CONNECT_REQUEST] = {
103                 .len = sizeof(struct ksmbd_tree_connect_request),
104         },
105         [KSMBD_EVENT_TREE_CONNECT_RESPONSE] = {
106                 .len = sizeof(struct ksmbd_tree_connect_response),
107         },
108         [KSMBD_EVENT_TREE_DISCONNECT_REQUEST] = {
109                 .len = sizeof(struct ksmbd_tree_disconnect_request),
110         },
111         [KSMBD_EVENT_LOGOUT_REQUEST] = {
112                 .len = sizeof(struct ksmbd_logout_request),
113         },
114         [KSMBD_EVENT_RPC_REQUEST] = {
115         },
116         [KSMBD_EVENT_RPC_RESPONSE] = {
117         },
118         [KSMBD_EVENT_SPNEGO_AUTHEN_REQUEST] = {
119         },
120         [KSMBD_EVENT_SPNEGO_AUTHEN_RESPONSE] = {
121         },
122 };
123
124 static struct genl_ops ksmbd_genl_ops[] = {
125         {
126                 .cmd    = KSMBD_EVENT_UNSPEC,
127                 .doit   = handle_unsupported_event,
128         },
129         {
130                 .cmd    = KSMBD_EVENT_HEARTBEAT_REQUEST,
131                 .doit   = handle_unsupported_event,
132         },
133         {
134                 .cmd    = KSMBD_EVENT_STARTING_UP,
135                 .doit   = handle_startup_event,
136         },
137         {
138                 .cmd    = KSMBD_EVENT_SHUTTING_DOWN,
139                 .doit   = handle_unsupported_event,
140         },
141         {
142                 .cmd    = KSMBD_EVENT_LOGIN_REQUEST,
143                 .doit   = handle_unsupported_event,
144         },
145         {
146                 .cmd    = KSMBD_EVENT_LOGIN_RESPONSE,
147                 .doit   = handle_generic_event,
148         },
149         {
150                 .cmd    = KSMBD_EVENT_SHARE_CONFIG_REQUEST,
151                 .doit   = handle_unsupported_event,
152         },
153         {
154                 .cmd    = KSMBD_EVENT_SHARE_CONFIG_RESPONSE,
155                 .doit   = handle_generic_event,
156         },
157         {
158                 .cmd    = KSMBD_EVENT_TREE_CONNECT_REQUEST,
159                 .doit   = handle_unsupported_event,
160         },
161         {
162                 .cmd    = KSMBD_EVENT_TREE_CONNECT_RESPONSE,
163                 .doit   = handle_generic_event,
164         },
165         {
166                 .cmd    = KSMBD_EVENT_TREE_DISCONNECT_REQUEST,
167                 .doit   = handle_unsupported_event,
168         },
169         {
170                 .cmd    = KSMBD_EVENT_LOGOUT_REQUEST,
171                 .doit   = handle_unsupported_event,
172         },
173         {
174                 .cmd    = KSMBD_EVENT_RPC_REQUEST,
175                 .doit   = handle_unsupported_event,
176         },
177         {
178                 .cmd    = KSMBD_EVENT_RPC_RESPONSE,
179                 .doit   = handle_generic_event,
180         },
181         {
182                 .cmd    = KSMBD_EVENT_SPNEGO_AUTHEN_REQUEST,
183                 .doit   = handle_unsupported_event,
184         },
185         {
186                 .cmd    = KSMBD_EVENT_SPNEGO_AUTHEN_RESPONSE,
187                 .doit   = handle_generic_event,
188         },
189 };
190
191 static struct genl_family ksmbd_genl_family = {
192         .name           = KSMBD_GENL_NAME,
193         .version        = KSMBD_GENL_VERSION,
194         .hdrsize        = 0,
195         .maxattr        = KSMBD_EVENT_MAX,
196         .netnsok        = true,
197         .module         = THIS_MODULE,
198         .ops            = ksmbd_genl_ops,
199         .n_ops          = ARRAY_SIZE(ksmbd_genl_ops),
200         .resv_start_op  = KSMBD_EVENT_SPNEGO_AUTHEN_RESPONSE + 1,
201 };
202
203 static void ksmbd_nl_init_fixup(void)
204 {
205         int i;
206
207         for (i = 0; i < ARRAY_SIZE(ksmbd_genl_ops); i++)
208                 ksmbd_genl_ops[i].validate = GENL_DONT_VALIDATE_STRICT |
209                                                 GENL_DONT_VALIDATE_DUMP;
210
211         ksmbd_genl_family.policy = ksmbd_nl_policy;
212 }
213
214 static int rpc_context_flags(struct ksmbd_session *sess)
215 {
216         if (user_guest(sess->user))
217                 return KSMBD_RPC_RESTRICTED_CONTEXT;
218         return 0;
219 }
220
221 static void ipc_update_last_active(void)
222 {
223         if (server_conf.ipc_timeout)
224                 server_conf.ipc_last_active = jiffies;
225 }
226
227 static struct ksmbd_ipc_msg *ipc_msg_alloc(size_t sz)
228 {
229         struct ksmbd_ipc_msg *msg;
230         size_t msg_sz = sz + sizeof(struct ksmbd_ipc_msg);
231
232         msg = kvmalloc(msg_sz, GFP_KERNEL | __GFP_ZERO);
233         if (msg)
234                 msg->sz = sz;
235         return msg;
236 }
237
238 static void ipc_msg_free(struct ksmbd_ipc_msg *msg)
239 {
240         kvfree(msg);
241 }
242
243 static void ipc_msg_handle_free(int handle)
244 {
245         if (handle >= 0)
246                 ksmbd_release_id(&ipc_ida, handle);
247 }
248
249 static int handle_response(int type, void *payload, size_t sz)
250 {
251         unsigned int handle = *(unsigned int *)payload;
252         struct ipc_msg_table_entry *entry;
253         int ret = 0;
254
255         ipc_update_last_active();
256         down_read(&ipc_msg_table_lock);
257         hash_for_each_possible(ipc_msg_table, entry, ipc_table_hlist, handle) {
258                 if (handle != entry->handle)
259                         continue;
260
261                 entry->response = NULL;
262                 /*
263                  * Response message type value should be equal to
264                  * request message type + 1.
265                  */
266                 if (entry->type + 1 != type) {
267                         pr_err("Waiting for IPC type %d, got %d. Ignore.\n",
268                                entry->type + 1, type);
269                 }
270
271                 entry->response = kvmalloc(sz, GFP_KERNEL | __GFP_ZERO);
272                 if (!entry->response) {
273                         ret = -ENOMEM;
274                         break;
275                 }
276
277                 memcpy(entry->response, payload, sz);
278                 wake_up_interruptible(&entry->wait);
279                 ret = 0;
280                 break;
281         }
282         up_read(&ipc_msg_table_lock);
283
284         return ret;
285 }
286
287 static int ipc_server_config_on_startup(struct ksmbd_startup_request *req)
288 {
289         int ret;
290
291         ksmbd_set_fd_limit(req->file_max);
292         server_conf.flags = req->flags;
293         server_conf.signing = req->signing;
294         server_conf.tcp_port = req->tcp_port;
295         server_conf.ipc_timeout = req->ipc_timeout * HZ;
296         server_conf.deadtime = req->deadtime * SMB_ECHO_INTERVAL;
297         server_conf.share_fake_fscaps = req->share_fake_fscaps;
298         ksmbd_init_domain(req->sub_auth);
299
300         if (req->smb2_max_read)
301                 init_smb2_max_read_size(req->smb2_max_read);
302         if (req->smb2_max_write)
303                 init_smb2_max_write_size(req->smb2_max_write);
304         if (req->smb2_max_trans)
305                 init_smb2_max_trans_size(req->smb2_max_trans);
306         if (req->smb2_max_credits)
307                 init_smb2_max_credits(req->smb2_max_credits);
308         if (req->smbd_max_io_size)
309                 init_smbd_max_io_size(req->smbd_max_io_size);
310
311         ret = ksmbd_set_netbios_name(req->netbios_name);
312         ret |= ksmbd_set_server_string(req->server_string);
313         ret |= ksmbd_set_work_group(req->work_group);
314         ret |= ksmbd_tcp_set_interfaces(KSMBD_STARTUP_CONFIG_INTERFACES(req),
315                                         req->ifc_list_sz);
316         if (ret) {
317                 pr_err("Server configuration error: %s %s %s\n",
318                        req->netbios_name, req->server_string,
319                        req->work_group);
320                 return ret;
321         }
322
323         if (req->min_prot[0]) {
324                 ret = ksmbd_lookup_protocol_idx(req->min_prot);
325                 if (ret >= 0)
326                         server_conf.min_protocol = ret;
327         }
328         if (req->max_prot[0]) {
329                 ret = ksmbd_lookup_protocol_idx(req->max_prot);
330                 if (ret >= 0)
331                         server_conf.max_protocol = ret;
332         }
333
334         if (server_conf.ipc_timeout)
335                 schedule_delayed_work(&ipc_timer_work, server_conf.ipc_timeout);
336         return 0;
337 }
338
339 static int handle_startup_event(struct sk_buff *skb, struct genl_info *info)
340 {
341         int ret = 0;
342
343 #ifdef CONFIG_SMB_SERVER_CHECK_CAP_NET_ADMIN
344         if (!netlink_capable(skb, CAP_NET_ADMIN))
345                 return -EPERM;
346 #endif
347
348         if (!ksmbd_ipc_validate_version(info))
349                 return -EINVAL;
350
351         if (!info->attrs[KSMBD_EVENT_STARTING_UP])
352                 return -EINVAL;
353
354         mutex_lock(&startup_lock);
355         if (!ksmbd_server_configurable()) {
356                 mutex_unlock(&startup_lock);
357                 pr_err("Server reset is in progress, can't start daemon\n");
358                 return -EINVAL;
359         }
360
361         if (ksmbd_tools_pid) {
362                 if (ksmbd_ipc_heartbeat_request() == 0) {
363                         ret = -EINVAL;
364                         goto out;
365                 }
366
367                 pr_err("Reconnect to a new user space daemon\n");
368         } else {
369                 struct ksmbd_startup_request *req;
370
371                 req = nla_data(info->attrs[info->genlhdr->cmd]);
372                 ret = ipc_server_config_on_startup(req);
373                 if (ret)
374                         goto out;
375                 server_queue_ctrl_init_work();
376         }
377
378         ksmbd_tools_pid = info->snd_portid;
379         ipc_update_last_active();
380
381 out:
382         mutex_unlock(&startup_lock);
383         return ret;
384 }
385
386 static int handle_unsupported_event(struct sk_buff *skb, struct genl_info *info)
387 {
388         pr_err("Unknown IPC event: %d, ignore.\n", info->genlhdr->cmd);
389         return -EINVAL;
390 }
391
392 static int handle_generic_event(struct sk_buff *skb, struct genl_info *info)
393 {
394         void *payload;
395         int sz;
396         int type = info->genlhdr->cmd;
397
398 #ifdef CONFIG_SMB_SERVER_CHECK_CAP_NET_ADMIN
399         if (!netlink_capable(skb, CAP_NET_ADMIN))
400                 return -EPERM;
401 #endif
402
403         if (type >= KSMBD_EVENT_MAX) {
404                 WARN_ON(1);
405                 return -EINVAL;
406         }
407
408         if (!ksmbd_ipc_validate_version(info))
409                 return -EINVAL;
410
411         if (!info->attrs[type])
412                 return -EINVAL;
413
414         payload = nla_data(info->attrs[info->genlhdr->cmd]);
415         sz = nla_len(info->attrs[info->genlhdr->cmd]);
416         return handle_response(type, payload, sz);
417 }
418
419 static int ipc_msg_send(struct ksmbd_ipc_msg *msg)
420 {
421         struct genlmsghdr *nlh;
422         struct sk_buff *skb;
423         int ret = -EINVAL;
424
425         if (!ksmbd_tools_pid)
426                 return ret;
427
428         skb = genlmsg_new(msg->sz, GFP_KERNEL);
429         if (!skb)
430                 return -ENOMEM;
431
432         nlh = genlmsg_put(skb, 0, 0, &ksmbd_genl_family, 0, msg->type);
433         if (!nlh)
434                 goto out;
435
436         ret = nla_put(skb, msg->type, msg->sz, msg->payload);
437         if (ret) {
438                 genlmsg_cancel(skb, nlh);
439                 goto out;
440         }
441
442         genlmsg_end(skb, nlh);
443         ret = genlmsg_unicast(&init_net, skb, ksmbd_tools_pid);
444         if (!ret)
445                 ipc_update_last_active();
446         return ret;
447
448 out:
449         nlmsg_free(skb);
450         return ret;
451 }
452
453 static void *ipc_msg_send_request(struct ksmbd_ipc_msg *msg, unsigned int handle)
454 {
455         struct ipc_msg_table_entry entry;
456         int ret;
457
458         if ((int)handle < 0)
459                 return NULL;
460
461         entry.type = msg->type;
462         entry.response = NULL;
463         init_waitqueue_head(&entry.wait);
464
465         down_write(&ipc_msg_table_lock);
466         entry.handle = handle;
467         hash_add(ipc_msg_table, &entry.ipc_table_hlist, entry.handle);
468         up_write(&ipc_msg_table_lock);
469
470         ret = ipc_msg_send(msg);
471         if (ret)
472                 goto out;
473
474         ret = wait_event_interruptible_timeout(entry.wait,
475                                                entry.response != NULL,
476                                                IPC_WAIT_TIMEOUT);
477 out:
478         down_write(&ipc_msg_table_lock);
479         hash_del(&entry.ipc_table_hlist);
480         up_write(&ipc_msg_table_lock);
481         return entry.response;
482 }
483
484 static int ksmbd_ipc_heartbeat_request(void)
485 {
486         struct ksmbd_ipc_msg *msg;
487         int ret;
488
489         msg = ipc_msg_alloc(sizeof(struct ksmbd_heartbeat));
490         if (!msg)
491                 return -EINVAL;
492
493         msg->type = KSMBD_EVENT_HEARTBEAT_REQUEST;
494         ret = ipc_msg_send(msg);
495         ipc_msg_free(msg);
496         return ret;
497 }
498
499 struct ksmbd_login_response *ksmbd_ipc_login_request(const char *account)
500 {
501         struct ksmbd_ipc_msg *msg;
502         struct ksmbd_login_request *req;
503         struct ksmbd_login_response *resp;
504
505         if (strlen(account) >= KSMBD_REQ_MAX_ACCOUNT_NAME_SZ)
506                 return NULL;
507
508         msg = ipc_msg_alloc(sizeof(struct ksmbd_login_request));
509         if (!msg)
510                 return NULL;
511
512         msg->type = KSMBD_EVENT_LOGIN_REQUEST;
513         req = (struct ksmbd_login_request *)msg->payload;
514         req->handle = ksmbd_acquire_id(&ipc_ida);
515         strscpy(req->account, account, KSMBD_REQ_MAX_ACCOUNT_NAME_SZ);
516
517         resp = ipc_msg_send_request(msg, req->handle);
518         ipc_msg_handle_free(req->handle);
519         ipc_msg_free(msg);
520         return resp;
521 }
522
523 struct ksmbd_spnego_authen_response *
524 ksmbd_ipc_spnego_authen_request(const char *spnego_blob, int blob_len)
525 {
526         struct ksmbd_ipc_msg *msg;
527         struct ksmbd_spnego_authen_request *req;
528         struct ksmbd_spnego_authen_response *resp;
529
530         msg = ipc_msg_alloc(sizeof(struct ksmbd_spnego_authen_request) +
531                         blob_len + 1);
532         if (!msg)
533                 return NULL;
534
535         msg->type = KSMBD_EVENT_SPNEGO_AUTHEN_REQUEST;
536         req = (struct ksmbd_spnego_authen_request *)msg->payload;
537         req->handle = ksmbd_acquire_id(&ipc_ida);
538         req->spnego_blob_len = blob_len;
539         memcpy(req->spnego_blob, spnego_blob, blob_len);
540
541         resp = ipc_msg_send_request(msg, req->handle);
542         ipc_msg_handle_free(req->handle);
543         ipc_msg_free(msg);
544         return resp;
545 }
546
547 struct ksmbd_tree_connect_response *
548 ksmbd_ipc_tree_connect_request(struct ksmbd_session *sess,
549                                struct ksmbd_share_config *share,
550                                struct ksmbd_tree_connect *tree_conn,
551                                struct sockaddr *peer_addr)
552 {
553         struct ksmbd_ipc_msg *msg;
554         struct ksmbd_tree_connect_request *req;
555         struct ksmbd_tree_connect_response *resp;
556
557         if (strlen(user_name(sess->user)) >= KSMBD_REQ_MAX_ACCOUNT_NAME_SZ)
558                 return NULL;
559
560         if (strlen(share->name) >= KSMBD_REQ_MAX_SHARE_NAME)
561                 return NULL;
562
563         msg = ipc_msg_alloc(sizeof(struct ksmbd_tree_connect_request));
564         if (!msg)
565                 return NULL;
566
567         msg->type = KSMBD_EVENT_TREE_CONNECT_REQUEST;
568         req = (struct ksmbd_tree_connect_request *)msg->payload;
569
570         req->handle = ksmbd_acquire_id(&ipc_ida);
571         req->account_flags = sess->user->flags;
572         req->session_id = sess->id;
573         req->connect_id = tree_conn->id;
574         strscpy(req->account, user_name(sess->user), KSMBD_REQ_MAX_ACCOUNT_NAME_SZ);
575         strscpy(req->share, share->name, KSMBD_REQ_MAX_SHARE_NAME);
576         snprintf(req->peer_addr, sizeof(req->peer_addr), "%pIS", peer_addr);
577
578         if (peer_addr->sa_family == AF_INET6)
579                 req->flags |= KSMBD_TREE_CONN_FLAG_REQUEST_IPV6;
580         if (test_session_flag(sess, CIFDS_SESSION_FLAG_SMB2))
581                 req->flags |= KSMBD_TREE_CONN_FLAG_REQUEST_SMB2;
582
583         resp = ipc_msg_send_request(msg, req->handle);
584         ipc_msg_handle_free(req->handle);
585         ipc_msg_free(msg);
586         return resp;
587 }
588
589 int ksmbd_ipc_tree_disconnect_request(unsigned long long session_id,
590                                       unsigned long long connect_id)
591 {
592         struct ksmbd_ipc_msg *msg;
593         struct ksmbd_tree_disconnect_request *req;
594         int ret;
595
596         msg = ipc_msg_alloc(sizeof(struct ksmbd_tree_disconnect_request));
597         if (!msg)
598                 return -ENOMEM;
599
600         msg->type = KSMBD_EVENT_TREE_DISCONNECT_REQUEST;
601         req = (struct ksmbd_tree_disconnect_request *)msg->payload;
602         req->session_id = session_id;
603         req->connect_id = connect_id;
604
605         ret = ipc_msg_send(msg);
606         ipc_msg_free(msg);
607         return ret;
608 }
609
610 int ksmbd_ipc_logout_request(const char *account, int flags)
611 {
612         struct ksmbd_ipc_msg *msg;
613         struct ksmbd_logout_request *req;
614         int ret;
615
616         if (strlen(account) >= KSMBD_REQ_MAX_ACCOUNT_NAME_SZ)
617                 return -EINVAL;
618
619         msg = ipc_msg_alloc(sizeof(struct ksmbd_logout_request));
620         if (!msg)
621                 return -ENOMEM;
622
623         msg->type = KSMBD_EVENT_LOGOUT_REQUEST;
624         req = (struct ksmbd_logout_request *)msg->payload;
625         req->account_flags = flags;
626         strscpy(req->account, account, KSMBD_REQ_MAX_ACCOUNT_NAME_SZ);
627
628         ret = ipc_msg_send(msg);
629         ipc_msg_free(msg);
630         return ret;
631 }
632
633 struct ksmbd_share_config_response *
634 ksmbd_ipc_share_config_request(const char *name)
635 {
636         struct ksmbd_ipc_msg *msg;
637         struct ksmbd_share_config_request *req;
638         struct ksmbd_share_config_response *resp;
639
640         if (strlen(name) >= KSMBD_REQ_MAX_SHARE_NAME)
641                 return NULL;
642
643         msg = ipc_msg_alloc(sizeof(struct ksmbd_share_config_request));
644         if (!msg)
645                 return NULL;
646
647         msg->type = KSMBD_EVENT_SHARE_CONFIG_REQUEST;
648         req = (struct ksmbd_share_config_request *)msg->payload;
649         req->handle = ksmbd_acquire_id(&ipc_ida);
650         strscpy(req->share_name, name, KSMBD_REQ_MAX_SHARE_NAME);
651
652         resp = ipc_msg_send_request(msg, req->handle);
653         ipc_msg_handle_free(req->handle);
654         ipc_msg_free(msg);
655         return resp;
656 }
657
658 struct ksmbd_rpc_command *ksmbd_rpc_open(struct ksmbd_session *sess, int handle)
659 {
660         struct ksmbd_ipc_msg *msg;
661         struct ksmbd_rpc_command *req;
662         struct ksmbd_rpc_command *resp;
663
664         msg = ipc_msg_alloc(sizeof(struct ksmbd_rpc_command));
665         if (!msg)
666                 return NULL;
667
668         msg->type = KSMBD_EVENT_RPC_REQUEST;
669         req = (struct ksmbd_rpc_command *)msg->payload;
670         req->handle = handle;
671         req->flags = ksmbd_session_rpc_method(sess, handle);
672         req->flags |= KSMBD_RPC_OPEN_METHOD;
673         req->payload_sz = 0;
674
675         resp = ipc_msg_send_request(msg, req->handle);
676         ipc_msg_free(msg);
677         return resp;
678 }
679
680 struct ksmbd_rpc_command *ksmbd_rpc_close(struct ksmbd_session *sess, int handle)
681 {
682         struct ksmbd_ipc_msg *msg;
683         struct ksmbd_rpc_command *req;
684         struct ksmbd_rpc_command *resp;
685
686         msg = ipc_msg_alloc(sizeof(struct ksmbd_rpc_command));
687         if (!msg)
688                 return NULL;
689
690         msg->type = KSMBD_EVENT_RPC_REQUEST;
691         req = (struct ksmbd_rpc_command *)msg->payload;
692         req->handle = handle;
693         req->flags = ksmbd_session_rpc_method(sess, handle);
694         req->flags |= KSMBD_RPC_CLOSE_METHOD;
695         req->payload_sz = 0;
696
697         resp = ipc_msg_send_request(msg, req->handle);
698         ipc_msg_free(msg);
699         return resp;
700 }
701
702 struct ksmbd_rpc_command *ksmbd_rpc_write(struct ksmbd_session *sess, int handle,
703                                           void *payload, size_t payload_sz)
704 {
705         struct ksmbd_ipc_msg *msg;
706         struct ksmbd_rpc_command *req;
707         struct ksmbd_rpc_command *resp;
708
709         msg = ipc_msg_alloc(sizeof(struct ksmbd_rpc_command) + payload_sz + 1);
710         if (!msg)
711                 return NULL;
712
713         msg->type = KSMBD_EVENT_RPC_REQUEST;
714         req = (struct ksmbd_rpc_command *)msg->payload;
715         req->handle = handle;
716         req->flags = ksmbd_session_rpc_method(sess, handle);
717         req->flags |= rpc_context_flags(sess);
718         req->flags |= KSMBD_RPC_WRITE_METHOD;
719         req->payload_sz = payload_sz;
720         memcpy(req->payload, payload, payload_sz);
721
722         resp = ipc_msg_send_request(msg, req->handle);
723         ipc_msg_free(msg);
724         return resp;
725 }
726
727 struct ksmbd_rpc_command *ksmbd_rpc_read(struct ksmbd_session *sess, int handle)
728 {
729         struct ksmbd_ipc_msg *msg;
730         struct ksmbd_rpc_command *req;
731         struct ksmbd_rpc_command *resp;
732
733         msg = ipc_msg_alloc(sizeof(struct ksmbd_rpc_command));
734         if (!msg)
735                 return NULL;
736
737         msg->type = KSMBD_EVENT_RPC_REQUEST;
738         req = (struct ksmbd_rpc_command *)msg->payload;
739         req->handle = handle;
740         req->flags = ksmbd_session_rpc_method(sess, handle);
741         req->flags |= rpc_context_flags(sess);
742         req->flags |= KSMBD_RPC_READ_METHOD;
743         req->payload_sz = 0;
744
745         resp = ipc_msg_send_request(msg, req->handle);
746         ipc_msg_free(msg);
747         return resp;
748 }
749
750 struct ksmbd_rpc_command *ksmbd_rpc_ioctl(struct ksmbd_session *sess, int handle,
751                                           void *payload, size_t payload_sz)
752 {
753         struct ksmbd_ipc_msg *msg;
754         struct ksmbd_rpc_command *req;
755         struct ksmbd_rpc_command *resp;
756
757         msg = ipc_msg_alloc(sizeof(struct ksmbd_rpc_command) + payload_sz + 1);
758         if (!msg)
759                 return NULL;
760
761         msg->type = KSMBD_EVENT_RPC_REQUEST;
762         req = (struct ksmbd_rpc_command *)msg->payload;
763         req->handle = handle;
764         req->flags = ksmbd_session_rpc_method(sess, handle);
765         req->flags |= rpc_context_flags(sess);
766         req->flags |= KSMBD_RPC_IOCTL_METHOD;
767         req->payload_sz = payload_sz;
768         memcpy(req->payload, payload, payload_sz);
769
770         resp = ipc_msg_send_request(msg, req->handle);
771         ipc_msg_free(msg);
772         return resp;
773 }
774
775 struct ksmbd_rpc_command *ksmbd_rpc_rap(struct ksmbd_session *sess, void *payload,
776                                         size_t payload_sz)
777 {
778         struct ksmbd_ipc_msg *msg;
779         struct ksmbd_rpc_command *req;
780         struct ksmbd_rpc_command *resp;
781
782         msg = ipc_msg_alloc(sizeof(struct ksmbd_rpc_command) + payload_sz + 1);
783         if (!msg)
784                 return NULL;
785
786         msg->type = KSMBD_EVENT_RPC_REQUEST;
787         req = (struct ksmbd_rpc_command *)msg->payload;
788         req->handle = ksmbd_acquire_id(&ipc_ida);
789         req->flags = rpc_context_flags(sess);
790         req->flags |= KSMBD_RPC_RAP_METHOD;
791         req->payload_sz = payload_sz;
792         memcpy(req->payload, payload, payload_sz);
793
794         resp = ipc_msg_send_request(msg, req->handle);
795         ipc_msg_handle_free(req->handle);
796         ipc_msg_free(msg);
797         return resp;
798 }
799
800 static int __ipc_heartbeat(void)
801 {
802         unsigned long delta;
803
804         if (!ksmbd_server_running())
805                 return 0;
806
807         if (time_after(jiffies, server_conf.ipc_last_active)) {
808                 delta = (jiffies - server_conf.ipc_last_active);
809         } else {
810                 ipc_update_last_active();
811                 schedule_delayed_work(&ipc_timer_work,
812                                       server_conf.ipc_timeout);
813                 return 0;
814         }
815
816         if (delta < server_conf.ipc_timeout) {
817                 schedule_delayed_work(&ipc_timer_work,
818                                       server_conf.ipc_timeout - delta);
819                 return 0;
820         }
821
822         if (ksmbd_ipc_heartbeat_request() == 0) {
823                 schedule_delayed_work(&ipc_timer_work,
824                                       server_conf.ipc_timeout);
825                 return 0;
826         }
827
828         mutex_lock(&startup_lock);
829         WRITE_ONCE(server_conf.state, SERVER_STATE_RESETTING);
830         server_conf.ipc_last_active = 0;
831         ksmbd_tools_pid = 0;
832         pr_err("No IPC daemon response for %lus\n", delta / HZ);
833         mutex_unlock(&startup_lock);
834         return -EINVAL;
835 }
836
837 static void ipc_timer_heartbeat(struct work_struct *w)
838 {
839         if (__ipc_heartbeat())
840                 server_queue_ctrl_reset_work();
841 }
842
843 int ksmbd_ipc_id_alloc(void)
844 {
845         return ksmbd_acquire_id(&ipc_ida);
846 }
847
848 void ksmbd_rpc_id_free(int handle)
849 {
850         ksmbd_release_id(&ipc_ida, handle);
851 }
852
853 void ksmbd_ipc_release(void)
854 {
855         cancel_delayed_work_sync(&ipc_timer_work);
856         genl_unregister_family(&ksmbd_genl_family);
857 }
858
859 void ksmbd_ipc_soft_reset(void)
860 {
861         mutex_lock(&startup_lock);
862         ksmbd_tools_pid = 0;
863         cancel_delayed_work_sync(&ipc_timer_work);
864         mutex_unlock(&startup_lock);
865 }
866
867 int ksmbd_ipc_init(void)
868 {
869         int ret = 0;
870
871         ksmbd_nl_init_fixup();
872         INIT_DELAYED_WORK(&ipc_timer_work, ipc_timer_heartbeat);
873
874         ret = genl_register_family(&ksmbd_genl_family);
875         if (ret) {
876                 pr_err("Failed to register KSMBD netlink interface %d\n", ret);
877                 cancel_delayed_work_sync(&ipc_timer_work);
878         }
879
880         return ret;
881 }